Skip to content

Instantly share code, notes, and snippets.

@kdplus
Created January 14, 2017 14:51
Show Gist options
  • Save kdplus/ff06e0a295e51b0c7ae9a368fff50df7 to your computer and use it in GitHub Desktop.
Save kdplus/ff06e0a295e51b0c7ae9a368fff50df7 to your computer and use it in GitHub Desktop.
for cs433
#!/usr/bin/env python
import time
import numpy as np
import matplotlib.pyplot as plt
from fc_net import *
from data_utils import get_CIFAR10_data
from gradient_check import eval_numerical_gradient, eval_numerical_gradient_array
from solver import Solver
import sys
import json
# Load the (preprocessed) CIFAR10 data.
first = 0
X_train = []
y_train = []
exeornot = 1
for line in sys.stdin:
print '%s\t%s' % ("gaga", 1)
try:
d = json.loads(line)
print '%s\t%s' % (str(d["label"]), 1)
except ValueError:
print '%s\t%s' % ("yi", 1)
# exeornot = 0
continue
a = np.array(d["data"], dtype=np.float32)
b = d["label"]
if first is 0:
X_train = a.reshape(1, 32, 32, 3)
y_train.append(b)
# f = file("data.txt", "a")
# f.writelines(X_train)
# f.close()
else:
X_train = np.append(X_train, a.reshape(1, 32, 32, 3), axis = 0)
# print '%s\t%s' % (str(d["label"]), 1)
y_train.append(b)
first = first + 1
if exeornot is 1:
X_train = np.array(X_train)
y_train = np.array(y_train)
print '%s\t%s' % ("train", X_train.shape[0])
if X_train.shape[0] == 0:
sys.exit(0)
data = get_CIFAR10_data()
# for k, v in data.iteritems():
# print '%s: ' % k, v.shape
data["X_train"] = X_train
data["y_train"] = y_train
model = FullyConnectedNet([100, 100, 100, 100, 100], weight_scale=5e-2,reg=0.0297468942265)
solver = Solver(model, data,
num_epochs=30, batch_size=3000,
update_rule= 'adam',
optim_config={
'learning_rate': 0.000844664246013
},
verbose=True)
solver._reset()
solver.train()
best_model = model
y_test_pred = np.argmax(best_model.loss(data['X_test']), axis=1)
y_val_pred = np.argmax(best_model.loss(data['X_val']), axis=1)
# print 'Validation set accuracy: ', (y_val_pred == data['y_val']).mean()
# print 'Test set accuracy: ', (y_test_pred == data['y_test']).mean()
print '%s\t%s' % ("hey", 1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment