Created
January 18, 2016 20:41
-
-
Save rpCal/c68f46e283cdc8aee66f to your computer and use it in GitHub Desktop.
NAI - projekt - v1
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.io.File; | |
import java.io.IOException; | |
import weka.classifiers.Evaluation; | |
import weka.classifiers.functions.MultilayerPerceptron; | |
import weka.core.Instances; | |
import weka.core.converters.ArffLoader; | |
import weka.filters.Filter; | |
import weka.filters.supervised.attribute.NominalToBinary; | |
import weka.filters.unsupervised.attribute.Normalize; | |
import weka.filters.unsupervised.attribute.ReplaceMissingValues; | |
import weka.filters.unsupervised.instance.Randomize; | |
public class Main { | |
public static void main(String[] args) { | |
// plan dzialania aplikacji | |
// - wczytuje zbiór testowy oraz zbiór uczący | |
// - filtruje i aktualizuje dane: | |
// ---- zamien brakujace wartosci na średnie/nieznane | |
// ---- zamieniam wartosci typu string na wartosc typu number | |
// ---- normalizuje dane do wartosci z przedzialu 0-1 | |
// - przygotowuje siec oraz sprawdzam jej skutecznosc | |
// - badam zmienność wspolczynnika uczenia | |
// - badam zmienność liczby neuronow | |
// - badam zmienność liczby epok | |
// - badam zmienność wspolczynnika momentum | |
System.out.println("Start aplikacji"); | |
try { | |
Main m = new Main(); | |
m.readData(); | |
// przypadek z zmianą brakujacych wartosci na srednie | |
m.updateData(true); | |
m.verifyNNParams_basic(); | |
// przypadek z zmianą brakujacych wartosci na nieznane | |
m.updateData(false); | |
m.verifyNNParams_basic(); | |
// m.test(); | |
} catch (Exception e) { | |
e.printStackTrace(); | |
} | |
System.out.println("Koniec aplikacji"); | |
} | |
// zmienne podstawowe | |
// | |
Instances initDataTrain; | |
Instances initDataTest; | |
Instances structureTrain; | |
Instances structureTest; | |
// wczytuje zbiór testowy oraz zbiór uczący | |
public void readData() throws IOException{ | |
// lista atrybutow, ktore powinny byc w plikach | |
// @attribute age numeric | |
// @attribute workclass {Private,Local-gov,Self-emp-not-inc,Federal-gov,State-gov,Self-emp-inc,Without-pay,Never-worked} | |
// @attribute fnlwgt numeric | |
// @attribute education {11th,HS-grad,Assoc-acdm,Some-college,10th,Prof-school,7th-8th,Bachelors,Masters,Doctorate,5th-6th,Assoc-voc,9th,12th,1st-4th,Preschool} | |
// @attribute education-num numeric | |
// @attribute marital-status {Never-married,Married-civ-spouse,Widowed,Divorced,Separated,Married-spouse-absent,Married-AF-spouse} | |
// @attribute occupation {Machine-op-inspct,Farming-fishing,Protective-serv,Other-service,Prof-specialty,Craft-repair,Adm-clerical,Exec-managerial,Tech-support,Sales,Priv-house-serv,Transport-moving,Handlers-cleaners,Armed-Forces} | |
// @attribute relationship {Own-child,Husband,Not-in-family,Unmarried,Wife,Other-relative} | |
// @attribute race {Black,White,Asian-Pac-Islander,Other,Amer-Indian-Eskimo} | |
// @attribute sex {Male,Female} | |
// @attribute capital-gain numeric | |
// @attribute capital-loss numeric | |
// @attribute hours-per-week numeric | |
// @attribute native-country {United-States,Cuba,Jamaica,India,Mexico,South,Puerto-Rico,Honduras,England,Canada,Germany,Iran,Philippines,Italy,Poland,Columbia,Cambodia,Thailand,Ecuador,Laos,Taiwan,Haiti,Portugal,Dominican-Republic,El-Salvador,France,Guatemala,China,Japan,Yugoslavia,Peru,Outlying-US(Guam-USVI-etc),Scotland,Trinadad&Tobago,Greece,Nicaragua,Vietnam,Hong,Ireland,Hungary,Holand-Netherlands} | |
// @attribute Outcome {<=50K,>50K} | |
// odczytaj plik z zbiorem uczacym | |
String fileNameTrain = "./census-income-data.arff"; | |
File fTrain = new File(fileNameTrain); | |
// plik istnieje? | |
if(!(fTrain.exists() || fTrain.canRead())){ | |
System.out.println("nie mozna odczytac pliku: " + fileNameTrain); | |
return ; | |
} | |
// odczytaj plik z zbiorem testowym | |
String fileNameTest = "./census-income-test.arff"; | |
File fTest = new File(fileNameTest); | |
// plik istnieje? | |
if(!(fTest.exists() || fTest.canRead())){ | |
System.out.println("nie mozna odczytac pliku: " + fileNameTest); | |
return ; | |
} | |
// wczytaj dane | |
ArffLoader loaderTrain = new ArffLoader(); | |
loaderTrain.setFile(fTrain); | |
initDataTrain = loaderTrain.getDataSet(); | |
initDataTrain.setClassIndex(initDataTrain.numAttributes() - 1); | |
// wczytaj dane testowe | |
ArffLoader loaderTest = new ArffLoader(); | |
loaderTest.setFile(fTest); | |
initDataTest = loaderTest.getDataSet(); | |
initDataTest.setClassIndex(initDataTest.numAttributes() - 1); | |
} | |
// aktualizuje - filtruje dane wejsciowe oraz zmieniam ich postac | |
public void updateData(boolean canReplaceValue) throws Exception{ | |
// zamieniam zbior danych uczacych | |
// | |
// robie kopie danych | |
Instances m_Train = new Instances(initDataTrain); | |
// zamieniam brakujace wartosci na wartosci srednie | |
if(canReplaceValue){ | |
ReplaceMissingValues m_ReplaceMissingValues = new ReplaceMissingValues(); | |
m_ReplaceMissingValues.setInputFormat(m_Train); | |
m_Train = Filter.useFilter(m_Train, m_ReplaceMissingValues); | |
}else{ | |
Randomize m_Randomize = new Randomize(); | |
m_Randomize.setInputFormat(m_Train); | |
m_Train = Filter.useFilter(m_Train, m_Randomize); | |
} | |
// zamieniam typ string na na typ binarny - wartosc liczbową | |
NominalToBinary m_NominalToBinary = new NominalToBinary(); | |
m_NominalToBinary.setInputFormat(m_Train); | |
m_Train = Filter.useFilter(m_Train, m_NominalToBinary); | |
// normalizuje wartosci do przedzialu [0-1] | |
Normalize m_Normalize = new Normalize(); | |
m_Normalize.setInputFormat(m_Train); | |
m_Train = Filter.useFilter(m_Train, m_Normalize); | |
// usuwam wpisy ktore nie posiadaja odpowiednich wartosci | |
// m_Train.deleteWithMissingClass(); | |
// zapisuje nowe wartosci | |
structureTrain = m_Train; | |
// zamieniam zbior danych testowych | |
// | |
// robie kopie danych | |
Instances m_Test = new Instances(initDataTest); | |
// zamieniam brakujace wartosci na wartosci srednie | |
if(canReplaceValue){ | |
ReplaceMissingValues m_ReplaceMissingValues2 = new ReplaceMissingValues(); | |
m_ReplaceMissingValues2.setInputFormat(m_Test); | |
m_Test = Filter.useFilter(m_Test, m_ReplaceMissingValues2); | |
}else{ | |
Randomize m_Randomize = new Randomize(); | |
m_Randomize.setInputFormat(m_Test); | |
m_Test = Filter.useFilter(m_Test, m_Randomize); | |
} | |
// zamieniam typ string na na typ binarny - wartosc liczbową | |
NominalToBinary m_NominalToBinary2 = new NominalToBinary(); | |
m_NominalToBinary2.setInputFormat(m_Test); | |
m_Test = Filter.useFilter(m_Test, m_NominalToBinary2); | |
// normalizuje wartosci do przedzialu [0-1] | |
Normalize m_Normalize2 = new Normalize(); | |
m_Normalize2.setInputFormat(m_Test); | |
m_Test = Filter.useFilter(m_Test, m_Normalize2); | |
// usuwam wpisy ktore nie posiadaja odpowiednich wartosci | |
// m_Test.deleteWithMissingClass(); | |
// zapisuje nowe wartosci | |
structureTest = m_Test; | |
} | |
// sprawdzam skutecznosc sieci dla roznych wartosci wspolczynnikow podstawowych sieci | |
public void verifyNNParams_basic() throws Exception{ | |
// wartosci podstawowe uzywane do sprawdzania skutecznosci sieci | |
double init_value_LR = 0.2; | |
double init_value_M = 0.2; | |
int init_value_TT = 500; | |
String init_value_HL = "2"; | |
// sprawdzam skutecznosc sieci dla roznej wartosci wspolczynnika ucenia | |
double[] LR_values = new double[] {0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9}; | |
for (double currentValue : LR_values) { | |
this.runNN_basic(currentValue, init_value_M, init_value_TT, init_value_HL); | |
} | |
// sprawdzam skutecznosc sieci dla roznej ilosc neuronow w warstwie ukrytej | |
String[] HL_values = new String[] {"4","5","6","7","8", "9"}; | |
for (String currentValue : HL_values) { | |
this.runNN_basic(init_value_LR, init_value_M, init_value_TT, currentValue); | |
} | |
// sprawdzam skutecznosc sieci dla roznej ilosci epok | |
int[] TT_values = new int[] {100,200,300,400,500,600,700,800,900,1000}; | |
for (int currentValue : TT_values) { | |
this.runNN_basic(init_value_LR, init_value_M, currentValue, init_value_HL); | |
} | |
// sprawdzam skutecznosc sieci dla roznej ilosci epok | |
double[] M_values = new double[] {0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9}; | |
for (double currentValue : M_values) { | |
this.runNN_basic(init_value_LR, currentValue, init_value_TT, init_value_HL); | |
} | |
} | |
// sprawdzenie skutecznosci odczytanych zbiorow dla przekazanych wspolczynnikow | |
public void runNN_basic(double LR, double M, int TT, String HL) throws Exception{ | |
// Instancja sieci neuronowej | |
MultilayerPerceptron mlp = new MultilayerPerceptron(); | |
// ustawiam parametry sieci | |
mlp.setLearningRate(LR); | |
mlp.setMomentum(M); | |
mlp.setTrainingTime(TT); | |
mlp.setHiddenLayers(HL); | |
// opis dzialania | |
System.out.print("Skutecznosc sieci z parametrami: " ); | |
System.out.print("wspolczynnik uczenia="+LR+", "); | |
System.out.print("liczba neuronow="+HL+", "); | |
System.out.print("liczba epok="+TT+", "); | |
System.out.print("momentum="+M); | |
// robie kopie danych wejsciowych | |
Instances m_Train = new Instances(structureTrain); | |
Instances m_Test = new Instances(structureTest); | |
// nauka sieci | |
mlp.buildClassifier(m_Train); | |
// obliczanie skutecznosci | |
Evaluation eval = new Evaluation(m_Train); | |
eval.evaluateModel(mlp, m_Test); | |
System.out.println(eval.errorRate()); | |
System.out.println(eval.toSummaryString()); | |
System.out.println(""); | |
} | |
// metoda testowa, uzywana jedynie podczas developmentu | |
public void test() throws Exception{ | |
Instances newTrain = structureTrain; | |
Instances newTest = structureTest; | |
MultilayerPerceptron mlp2 = new MultilayerPerceptron(); | |
mlp2.setLearningRate(0.2); | |
mlp2.setMomentum(0.2); | |
mlp2.setTrainingTime(100); | |
mlp2.setHiddenLayers("2"); | |
mlp2.buildClassifier(newTrain); | |
Evaluation eval_dt = new Evaluation(newTrain); | |
eval_dt.evaluateModel(mlp2, newTest); | |
System.out.println(eval_dt.toSummaryString()); //Summary of Training | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment