Skip to content

Instantly share code, notes, and snippets.

@masaponto
Last active December 6, 2016 15:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save masaponto/a9bd6e9390bd9a2b998ccb3989f10958 to your computer and use it in GitHub Desktop.
Save masaponto/a9bd6e9390bd9a2b998ccb3989f10958 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np
from tabulate import tabulate
from sklearn.preprocessing import normalize
from sklearn.datasets import fetch_mldata
from sklearn.model_selection import cross_val_score, KFold
from elm import ELM
def main():
"""
This script is for sample of tabulate.
ELM is from https://github.com/masaponto/Python-ELM
The output is as follows
australian
| # of N_h | Accuracy(%) |
|-----------:|--------------:|
| 10 | 0.718551 |
| 20 | 0.763623 |
| 30 | 0.775507 |
iris
| # of N_h | Accuracy(%) |
|-----------:|--------------:|
| 10 | 0.893333 |
| 20 | 0.897333 |
| 30 | 0.9 |
"""
headder = ["# of N_h", "Accuracy(%)"]
db_names = ['australian', 'iris']
hid_nums = [10, 20, 30]
for db_name in db_names:
aves = []
print(db_name)
data_set = fetch_mldata(db_name)
data_set.data = normalize(data_set.data)
for hid_num in hid_nums:
e = ELM(hid_num)
ave = 0
for i in range(10):
cv = KFold(n_splits=5, shuffle=True)
scores = cross_val_score(e, data_set.data, data_set.target, cv=cv, scoring='accuracy', n_jobs=-1)
ave += scores.mean()
ave /= 10
aves.append(ave)
table = [[hid_num, ave] for hid_num, ave in zip(hid_nums, aves)]
print(tabulate(table, headder, tablefmt="pipe"))
print()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment