-
Notifications
You must be signed in to change notification settings - Fork 0
/
KNN.java (main)
107 lines (78 loc) · 6.19 KB
/
KNN.java (main)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
public class KNN {
public static void main(String[] args) {
try {
System.out.println("... WINE.DAT ..."); //ΓΙΑ WINE.DAT
System.out.println();
for (int k = 1; k <= 10; k++) { // τρέχει για τις τιμές του k από 1 έως 10.
double sumOfRate = 0; // μεταβλητή που θα κρατάει την ακρίβεια κατηγοριοποίησης ανά επανάληψη. ( όχι ανα k !!! ).
BigDecimal averageRate = null; // αντικείμενο τύπου big decimal που θα αποθηκεύει τη μέση ακρίβεια κατηγοριοποίησης.
for (int loop = 1; loop <= 10; loop++) { // τρέχουμε ένα loop 10 φορές για κάθε k, όπως ζητάει η εκφώνηση.
PointList bigList = PointList.readData("wine.dat"); // διαβάζει το αρχείο wine.dat και αποθηκεύει τα στοιχεία του στη μεγάλη λίστα.
int smallLength = (int)(bigList.length() * 0.25); // φτιάχνουμε το μήκος της μικρής λίστας. ( 25% της μεγάλης ).
PointList smallList = new PointList( bigList.getDim() ); // φτιάχνει τη μικρή λίστα.
int bigLength = bigList.length(); // αποθηκεύουμε σε μια μεταβλητή το μήκος της μεγάλης λίστας.
bigList.shuffle(); // κάνουμε shuffle τη μεγάλη λίστα.
// τρέχουμε ένα for loop και διαγράφοντας ένα - ένα τα στοιχεία της big list και τα αποθηκεύουμε στην small list. Οπότε μετά η big list θα περιέχει
// το υπόλοιπο 75%.
for (int i = 0; i < bigLength * 0.25 -1; i++) {
PointData data = bigList.rmFirst(); // διαγράφουμε το πρώτο node αποθηκεύοντας το σε ένα temp αντικείμενο τύπου point data.
smallList.append(data.x, data.y); // προσθέτουμε στη μικρή λίστα αυτό το temp.
}
double correct = 0; // θα μετράει τον αριθμό των σημείων που κατηγοριοποιήθηκαν σωστά.
smallList.shuffle(); // κάνουμε shuffle τη μικρή λίστα
while ( smallList.length() > 0 ) { // Όσο το μήκος της είναι θετικό
PointData temp = smallList.rmFirst(); // διαγράφει το πρώτο στοιχείο αποθηκεύοντας το πάλι σε ένα temp αντικείμενο.
// καλούμε την findKNearest για τη μεγάλη λίστα ψάχνοντας τους k κοντινότερους με τα στοιχεία της μικρής όμως και στη συνέχεια την classify
// για να τα κατηγοριοποιήσουμε.
double classification = bigList.findKNearest(temp.x, k).classify(temp.x);
// αν το y που επέστρεψε συμφωνεί με το προκαθορισμένο, αύξησε ένα μετρητή που μετράει τα σωστά αποτελέσματα κατά 1.
if ( classification == temp.y )
correct++;
}
sumOfRate = sumOfRate + correct / smallLength;
// φτιάχνει ένα αντικείμενο rate τύπου big decimal και το αρχικοποιεί με την τιμή sumOfRate * 10 , δηλαδή αν το sumOfRate
// είναι 9.521 , το * 10 το κάνει 95.21%.
averageRate = new BigDecimal(sumOfRate * 10);
// το περιορίζει να εμφανίζει μέχρι δύο δεκαδικά, στρογγυλοποιώντας το τελευταίο δεκαδικό.
averageRate = averageRate.setScale(2, RoundingMode.CEILING);
}
System.out.println("The average rate of accurately classified points for k = " + k + " is: " + String.format("%.2f", averageRate) + "%");
// εμφανίζει για κάθε επανάληψη και κάθε k το ποσοστό των σωστών κατηγοριοποιήσεων.
}
System.out.println();
System.out.println("... HOUSING.DAT ..."); // HOUSING.DAT ΑΠΟ'ΔΩ ΚΑΙ ΚΑΤΩ.
System.out.println();
// Ομοίως για τον παρακάτω κώδικα τα ίδια ακριβώς σχόλια, μόνο που διαβάζει το housing.dat αντί για το wine.dat.
for (int k = 1; k <= 10; k++) {
double sumOfRate = 0;
BigDecimal averageRate = null;
for (int loop = 1; loop <= 10; loop++) {
PointList bigList = PointList.readData("housing.dat");
PointList smallList = new PointList( bigList.getDim() );
int bigLength = bigList.length();
int smallLength = (int) (bigList.length() * 0.25);
bigList.shuffle();
for (int i = 0; i < bigLength * 0.25 -1; i++) {
PointData data = bigList.rmFirst();
smallList.append(data.x, data.y);
}
smallList.shuffle();
double correct = 0;
while ( smallList.length() > 0 ) {
PointData temp = smallList.rmFirst();
double classification = bigList.findKNearest(temp.x, k).classify(temp.x);
if ( classification == temp.y )
correct++;
}
sumOfRate = sumOfRate + correct / smallLength;
averageRate = new BigDecimal(sumOfRate * 10);
averageRate = averageRate.setScale(2, RoundingMode.CEILING);
}
System.out.println("The average rate of accurately classified points for k = " + k + " is: " + String.format("%.2f", averageRate) + "%");
}
}
catch (Exception e) {
e.printStackTrace();
}
}
}