Skip to content

Instantly share code, notes, and snippets.

@anoken
Created September 16, 2019 15:47
Show Gist options
  • Save anoken/4ad1a73cf77a7aea0d341576e04ef049 to your computer and use it in GitHub Desktop.
Save anoken/4ad1a73cf77a7aea0d341576e04ef049 to your computer and use it in GitHub Desktop.
190917_cnn_test.py
from keras.models import Sequential
from keras.layers import Activation, Dense, Dropout,Conv2D,MaxPooling2D,Flatten
from keras.utils.np_utils import to_categorical
from keras.optimizers import Adagrad
from keras.optimizers import Adam
import numpy as np
from PIL import Image
import os
from sklearn.model_selection import train_test_split
# 学習用のデータを作る.
image_list = []
label_list = []
for dir in os.listdir("train"):
if dir == ".DS_Store":
continue
dir1 = "./train/" + dir
label = 0
if dir == "001":
label = 0
elif dir == "002":
label = 1
elif dir == "003":
label = 2
elif dir == "004":
label = 3
elif dir == "005":
label = 4
for file in os.listdir(dir1):
if file != ".DS_Store":
label_list.append(label)
filepath = dir1 + "/" + file
image = Image.open(filepath)
# image = Image.open(filepath).resize((14, 14)).convert("RGB")
data = np.asarray(image)
image_list.append(data)
image_list = np.array(image_list)
image_list = image_list.astype('float32')
image_list = image_list / 255.0
Y = to_categorical(label_list)
print(Y)
print(image_list)
X_train, X_test, y_train, y_test = train_test_split(image_list, Y, test_size=0.20)
# モデルを生成してニューラルネットを構築
model = Sequential()
model.add(Conv2D(32, (3, 3), padding='same',input_shape=(14, 14, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(32, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(5))
model.add(Activation('softmax'))
opt = Adam(lr=0.001)
model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"])
model.fit(X_train, y_train, nb_epoch=50)
#print(image_list)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment