@emileferreira/

image-classifier

Java

ML image classifier that uses average pixel colour density with 3D spacial models to identify types of flowers.

fork
loading
Files
  • Main.java
  • images
  • test.jpg
  • test2.jpg
Main.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import java.awt.image.BufferedImage;
import java.awt.image.DataBufferByte;
import javax.imageio.ImageIO;
import java.io.IOException;
import java.io.File;
import java.awt.Color;

class Main {

  //models
  static int[][] modelValues = {{158,68,62},
                        {97,37,40},
                        {51,4,3},
                        {112,120,86},
                        {125,119,47},
                        {162,157,96},
                        {166,149,48},
                        {97,113,75},
                        {75,94,61},
                        {91,122,93}};
  static String[] modelNames = {"rose",
                                "rose",
                                "rose",
                                "sunflower",
                                "sunflower",
                                "sunflower",
                                "sunflower",
                                "lily",
                                "lily",
                                "lily"};

  public static void main(String[] args) throws IOException {
    System.out.println("\nWelcome to Emile's image classifier. \nPlease input the image file path. (e.g. test.jpg)");
    BufferedImage image = ImageIO.read(new File(System.console().readLine()));

    //create model of current image
    int[] currentValues = convertToAvgRGB(image);

    System.out.println(currentValues[0] + "," + currentValues[1] + "," + currentValues[2]);
    System.out.println("Comparing to image models...");

    //compare to learnt models
    int min = 1000000000;
    int distance;
    String result = "";
    for (int i = 0; i < modelNames.length; i++) {
      distance = (int) Math.sqrt(Math.pow(currentValues[0] - modelValues[i][0], 2) + Math.pow(currentValues[1] - modelValues[i][1], 2) + Math.pow(currentValues[2] - modelValues[i][2], 2));
      if (min > distance) {
        result = modelNames[i];
        min = distance;
      }
    }

    //output result
    double confidence = Math.round(1 / ((min*0.001) + 0.1) * 10);
    System.out.println("\nIt's a " + result + " (" + confidence + "% confidence)");
  }

  private static int[] convertToAvgRGB(BufferedImage image) {
    System.out.println("\nLooking at image...");

    int width = image.getWidth();
    int height = image.getHeight();
    int[][] pixels = new int[height][width];
    int[] result = new int[3];

    //2D array of hexidecimal colour values
    for (int row = 0; row < height; row++) {
        for (int col = 0; col < width; col++) {
          pixels[row][col] = image.getRGB(col, row);
        }
    }

    System.out.println("Scratching head...");

    //convert to average RBG values
    //red = 0, green = 1, blue = 2
    Color pixelColour;
    for (int row = 0; row < height; row++) {
        for (int col = 0; col < width; col++) {
          pixelColour = new Color(pixels[row][col]);
          result[0] = result[0] + pixelColour.getRed();
          result[1] = result[1] + pixelColour.getGreen();
          result[2] = result[2] + pixelColour.getBlue();
        }
    }
    result[0] = result[0] / (height * width);
    result[1] = result[1] / (height * width);
    result[2] = result[2] / (height * width);

    return result;
   }
}