Skip to content

Instantly share code, notes, and snippets.

@rpCal
Created January 18, 2016 20:41
Show Gist options
  • Save rpCal/c68f46e283cdc8aee66f to your computer and use it in GitHub Desktop.
Save rpCal/c68f46e283cdc8aee66f to your computer and use it in GitHub Desktop.
NAI - projekt - v1
import java.io.File;
import java.io.IOException;
import weka.classifiers.Evaluation;
import weka.classifiers.functions.MultilayerPerceptron;
import weka.core.Instances;
import weka.core.converters.ArffLoader;
import weka.filters.Filter;
import weka.filters.supervised.attribute.NominalToBinary;
import weka.filters.unsupervised.attribute.Normalize;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;
import weka.filters.unsupervised.instance.Randomize;
public class Main {
public static void main(String[] args) {
// plan dzialania aplikacji
// - wczytuje zbiór testowy oraz zbiór uczący
// - filtruje i aktualizuje dane:
// ---- zamien brakujace wartosci na średnie/nieznane
// ---- zamieniam wartosci typu string na wartosc typu number
// ---- normalizuje dane do wartosci z przedzialu 0-1
// - przygotowuje siec oraz sprawdzam jej skutecznosc
// - badam zmienność wspolczynnika uczenia
// - badam zmienność liczby neuronow
// - badam zmienność liczby epok
// - badam zmienność wspolczynnika momentum
System.out.println("Start aplikacji");
try {
Main m = new Main();
m.readData();
// przypadek z zmianą brakujacych wartosci na srednie
m.updateData(true);
m.verifyNNParams_basic();
// przypadek z zmianą brakujacych wartosci na nieznane
m.updateData(false);
m.verifyNNParams_basic();
// m.test();
} catch (Exception e) {
e.printStackTrace();
}
System.out.println("Koniec aplikacji");
}
// zmienne podstawowe
//
Instances initDataTrain;
Instances initDataTest;
Instances structureTrain;
Instances structureTest;
// wczytuje zbiór testowy oraz zbiór uczący
public void readData() throws IOException{
// lista atrybutow, ktore powinny byc w plikach
// @attribute age numeric
// @attribute workclass {Private,Local-gov,Self-emp-not-inc,Federal-gov,State-gov,Self-emp-inc,Without-pay,Never-worked}
// @attribute fnlwgt numeric
// @attribute education {11th,HS-grad,Assoc-acdm,Some-college,10th,Prof-school,7th-8th,Bachelors,Masters,Doctorate,5th-6th,Assoc-voc,9th,12th,1st-4th,Preschool}
// @attribute education-num numeric
// @attribute marital-status {Never-married,Married-civ-spouse,Widowed,Divorced,Separated,Married-spouse-absent,Married-AF-spouse}
// @attribute occupation {Machine-op-inspct,Farming-fishing,Protective-serv,Other-service,Prof-specialty,Craft-repair,Adm-clerical,Exec-managerial,Tech-support,Sales,Priv-house-serv,Transport-moving,Handlers-cleaners,Armed-Forces}
// @attribute relationship {Own-child,Husband,Not-in-family,Unmarried,Wife,Other-relative}
// @attribute race {Black,White,Asian-Pac-Islander,Other,Amer-Indian-Eskimo}
// @attribute sex {Male,Female}
// @attribute capital-gain numeric
// @attribute capital-loss numeric
// @attribute hours-per-week numeric
// @attribute native-country {United-States,Cuba,Jamaica,India,Mexico,South,Puerto-Rico,Honduras,England,Canada,Germany,Iran,Philippines,Italy,Poland,Columbia,Cambodia,Thailand,Ecuador,Laos,Taiwan,Haiti,Portugal,Dominican-Republic,El-Salvador,France,Guatemala,China,Japan,Yugoslavia,Peru,Outlying-US(Guam-USVI-etc),Scotland,Trinadad&Tobago,Greece,Nicaragua,Vietnam,Hong,Ireland,Hungary,Holand-Netherlands}
// @attribute Outcome {<=50K,>50K}
// odczytaj plik z zbiorem uczacym
String fileNameTrain = "./census-income-data.arff";
File fTrain = new File(fileNameTrain);
// plik istnieje?
if(!(fTrain.exists() || fTrain.canRead())){
System.out.println("nie mozna odczytac pliku: " + fileNameTrain);
return ;
}
// odczytaj plik z zbiorem testowym
String fileNameTest = "./census-income-test.arff";
File fTest = new File(fileNameTest);
// plik istnieje?
if(!(fTest.exists() || fTest.canRead())){
System.out.println("nie mozna odczytac pliku: " + fileNameTest);
return ;
}
// wczytaj dane
ArffLoader loaderTrain = new ArffLoader();
loaderTrain.setFile(fTrain);
initDataTrain = loaderTrain.getDataSet();
initDataTrain.setClassIndex(initDataTrain.numAttributes() - 1);
// wczytaj dane testowe
ArffLoader loaderTest = new ArffLoader();
loaderTest.setFile(fTest);
initDataTest = loaderTest.getDataSet();
initDataTest.setClassIndex(initDataTest.numAttributes() - 1);
}
// aktualizuje - filtruje dane wejsciowe oraz zmieniam ich postac
public void updateData(boolean canReplaceValue) throws Exception{
// zamieniam zbior danych uczacych
//
// robie kopie danych
Instances m_Train = new Instances(initDataTrain);
// zamieniam brakujace wartosci na wartosci srednie
if(canReplaceValue){
ReplaceMissingValues m_ReplaceMissingValues = new ReplaceMissingValues();
m_ReplaceMissingValues.setInputFormat(m_Train);
m_Train = Filter.useFilter(m_Train, m_ReplaceMissingValues);
}else{
Randomize m_Randomize = new Randomize();
m_Randomize.setInputFormat(m_Train);
m_Train = Filter.useFilter(m_Train, m_Randomize);
}
// zamieniam typ string na na typ binarny - wartosc liczbową
NominalToBinary m_NominalToBinary = new NominalToBinary();
m_NominalToBinary.setInputFormat(m_Train);
m_Train = Filter.useFilter(m_Train, m_NominalToBinary);
// normalizuje wartosci do przedzialu [0-1]
Normalize m_Normalize = new Normalize();
m_Normalize.setInputFormat(m_Train);
m_Train = Filter.useFilter(m_Train, m_Normalize);
// usuwam wpisy ktore nie posiadaja odpowiednich wartosci
// m_Train.deleteWithMissingClass();
// zapisuje nowe wartosci
structureTrain = m_Train;
// zamieniam zbior danych testowych
//
// robie kopie danych
Instances m_Test = new Instances(initDataTest);
// zamieniam brakujace wartosci na wartosci srednie
if(canReplaceValue){
ReplaceMissingValues m_ReplaceMissingValues2 = new ReplaceMissingValues();
m_ReplaceMissingValues2.setInputFormat(m_Test);
m_Test = Filter.useFilter(m_Test, m_ReplaceMissingValues2);
}else{
Randomize m_Randomize = new Randomize();
m_Randomize.setInputFormat(m_Test);
m_Test = Filter.useFilter(m_Test, m_Randomize);
}
// zamieniam typ string na na typ binarny - wartosc liczbową
NominalToBinary m_NominalToBinary2 = new NominalToBinary();
m_NominalToBinary2.setInputFormat(m_Test);
m_Test = Filter.useFilter(m_Test, m_NominalToBinary2);
// normalizuje wartosci do przedzialu [0-1]
Normalize m_Normalize2 = new Normalize();
m_Normalize2.setInputFormat(m_Test);
m_Test = Filter.useFilter(m_Test, m_Normalize2);
// usuwam wpisy ktore nie posiadaja odpowiednich wartosci
// m_Test.deleteWithMissingClass();
// zapisuje nowe wartosci
structureTest = m_Test;
}
// sprawdzam skutecznosc sieci dla roznych wartosci wspolczynnikow podstawowych sieci
public void verifyNNParams_basic() throws Exception{
// wartosci podstawowe uzywane do sprawdzania skutecznosci sieci
double init_value_LR = 0.2;
double init_value_M = 0.2;
int init_value_TT = 500;
String init_value_HL = "2";
// sprawdzam skutecznosc sieci dla roznej wartosci wspolczynnika ucenia
double[] LR_values = new double[] {0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9};
for (double currentValue : LR_values) {
this.runNN_basic(currentValue, init_value_M, init_value_TT, init_value_HL);
}
// sprawdzam skutecznosc sieci dla roznej ilosc neuronow w warstwie ukrytej
String[] HL_values = new String[] {"4","5","6","7","8", "9"};
for (String currentValue : HL_values) {
this.runNN_basic(init_value_LR, init_value_M, init_value_TT, currentValue);
}
// sprawdzam skutecznosc sieci dla roznej ilosci epok
int[] TT_values = new int[] {100,200,300,400,500,600,700,800,900,1000};
for (int currentValue : TT_values) {
this.runNN_basic(init_value_LR, init_value_M, currentValue, init_value_HL);
}
// sprawdzam skutecznosc sieci dla roznej ilosci epok
double[] M_values = new double[] {0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9};
for (double currentValue : M_values) {
this.runNN_basic(init_value_LR, currentValue, init_value_TT, init_value_HL);
}
}
// sprawdzenie skutecznosci odczytanych zbiorow dla przekazanych wspolczynnikow
public void runNN_basic(double LR, double M, int TT, String HL) throws Exception{
// Instancja sieci neuronowej
MultilayerPerceptron mlp = new MultilayerPerceptron();
// ustawiam parametry sieci
mlp.setLearningRate(LR);
mlp.setMomentum(M);
mlp.setTrainingTime(TT);
mlp.setHiddenLayers(HL);
// opis dzialania
System.out.print("Skutecznosc sieci z parametrami: " );
System.out.print("wspolczynnik uczenia="+LR+", ");
System.out.print("liczba neuronow="+HL+", ");
System.out.print("liczba epok="+TT+", ");
System.out.print("momentum="+M);
// robie kopie danych wejsciowych
Instances m_Train = new Instances(structureTrain);
Instances m_Test = new Instances(structureTest);
// nauka sieci
mlp.buildClassifier(m_Train);
// obliczanie skutecznosci
Evaluation eval = new Evaluation(m_Train);
eval.evaluateModel(mlp, m_Test);
System.out.println(eval.errorRate());
System.out.println(eval.toSummaryString());
System.out.println("");
}
// metoda testowa, uzywana jedynie podczas developmentu
public void test() throws Exception{
Instances newTrain = structureTrain;
Instances newTest = structureTest;
MultilayerPerceptron mlp2 = new MultilayerPerceptron();
mlp2.setLearningRate(0.2);
mlp2.setMomentum(0.2);
mlp2.setTrainingTime(100);
mlp2.setHiddenLayers("2");
mlp2.buildClassifier(newTrain);
Evaluation eval_dt = new Evaluation(newTrain);
eval_dt.evaluateModel(mlp2, newTest);
System.out.println(eval_dt.toSummaryString()); //Summary of Training
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment