-
-
Save buyoh/02e5cd33b93e4a76e749 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// logistic-forest | |
// memo:書きなぐり | |
#include <iostream> | |
#include <cstdio> | |
#include <cmath> | |
#include <vector> | |
using namespace std; | |
double sigmoid(double x){return 1.0/(1.0+exp(-x));} | |
class ml_logistic{ | |
public: // todo:フィールドがら空き | |
int dim; | |
vector<double> weight; | |
ml_logistic(int d):dim(d){ | |
weight.resize(d+1,0.0); | |
} | |
template<class Iterator> | |
double test(Iterator id){ | |
double a; | |
vector<double>::iterator iw=weight.begin(); | |
a=*iw; | |
for (iw++;iw!=weight.end();iw++,id++) | |
a+=*id**iw; | |
return a; | |
} | |
template<class Iterator> | |
void learn(Iterator id,bool cd,double eta){ | |
vector<double>::iterator iw=weight.begin(); | |
double b=(double)cd; | |
double a=test(id); | |
*iw-=eta*(sigmoid(a)-b); | |
for (iw++;iw!=weight.end();iw++,id++) | |
*iw-=eta*(sigmoid(a)-b)**id; | |
} | |
// for amp--- | |
template<class Iterator> | |
double test(Iterator id,Iterator ia){ | |
double a; | |
vector<double>::iterator iw=weight.begin(); | |
a=*ia**iw; | |
for (iw++,ia++;iw!=weight.end();iw++,id++,ia++) | |
a+=*ia**id**iw; | |
return a; | |
} | |
template<class Iterator> | |
void learn(Iterator id,bool cd,double eta,Iterator ia){ | |
vector<double>::iterator iw=weight.begin(); | |
double b=(double)cd; | |
double a=test(id,ia); | |
*iw-=eta*(sigmoid(a)-b); | |
for (iw++;iw!=weight.end();iw++,id++) | |
*iw-=eta*(sigmoid(a)-b)**id; | |
} | |
}; | |
int main(){ | |
int n,dim,asnum; | |
vector<vector<double>> data; | |
vector<bool> type; | |
vector<vector<double>> weight; | |
vector<double> amp; | |
double eta=0.25; | |
int i,j,t; | |
vector<vector<double>>::iterator ivd; | |
vector<bool>::iterator ib; | |
vector<double>::iterator id; | |
int cnt,lcnt,ac,bac; | |
cin>>n>>dim>>asnum; | |
data.resize(n); | |
for (i=0;i<n;i++){ | |
cin>>t;type.push_back(t); | |
for (j=0;j<dim;j++){ | |
cin>>t;data[i].push_back(t); | |
} | |
} | |
weight.resize(asnum); | |
for (i=0;i<asnum;i++) | |
weight[i].resize(dim+1); | |
vector<ml_logistic> mllg; | |
for (i=0;i<asnum;i++) | |
mllg.push_back(*(new ml_logistic(dim))); | |
bac=0; | |
for (cnt=0;cnt<100;cnt++){ | |
// init amp | |
amp.assign(amp.size(),1.0); | |
for (i=0;i<asnum;i++){ | |
for (lcnt=0;lcnt<10;lcnt++){ | |
for (vector<ml_logistic>::iterator ivml : mllg){ | |
for (ivd=data.begin(),ib=type.begin();ivd!=data.end();ivd++,ib++){ | |
mllg.learn((*ivd).begin(),*ib,eta,amp); | |
} | |
} | |
} | |
} | |
for (it=data.begin(),iu=type.begin();it!=data.end();it++,iu++){ | |
// learn | |
mllg.learn((*it).begin(),*iu,eta); | |
} | |
ac=0; | |
for (it=data.begin(),iu=type.begin();it!=data.end();it++,iu++){ | |
// test | |
if ((mllg.test((*it).begin())>=0.0)==*iu) ac++; | |
} | |
if (ac>bac){ | |
weight.assign(mllg.weight.begin(),mllg.weight.end()); | |
bac=ac; | |
} | |
if (ac==n){cout<<"accept:"<<cnt<<endl;break;} | |
if (1e-3<eta)eta*=0.9; | |
} | |
cout<<"bac:"<<bac<<endl; | |
if (ac!=n){ | |
mllg.weight.assign(weight.begin(),weight.end()); | |
} | |
for (iv=weight.begin();iv!=weight.end();iv++) | |
printf("%8.4f ",*iv); | |
cout<<endl; | |
for (it=data.begin(),iu=type.begin();it!=data.end();it++,iu++){ | |
printf("%d %d\n",mllg.test((*it).begin())>0?1:0,(int)(*iu)); | |
} | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment