This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import tensorflow as tf | |
| from keras.utils import plot_model | |
| import pydot | |
| import graphviz | |
| import numpy as np # linear algebra | |
| from sklearn.model_selection import train_test_split | |
| import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv) | |
| from tqdm import tqdm | |
| from numpy import expand_dims, zeros, ones, vstack |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| d_model.summary() | |
| g_model.summary() | |
| with tf.device('/device:GPU:0'): | |
| train(g_model, d_model, gan_model, dataset, latent_dim) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # train the generator and discriminator | |
| def train(g_model, d_model, gan_model, dataset, latent_dim, n_epochs=320, n_batch=64): | |
| bat_per_epo = int(len(dataset) / n_batch) | |
| #print('Batches per Epoch is %d' %bat_per_epo) | |
| half_batch = int(n_batch / 2) | |
| #print("Half Batch %d" % half_batch) | |
| # manually enumerate epochs | |
| for i in range(n_epochs): | |
| # enumerate batches over the training set | |
| for j in range(bat_per_epo): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # create and save a plot of generated images | |
| def save_plot(examples, epoch, n=7): | |
| # scale from [-1,1] to [0,1] | |
| examples = (examples + 1) / 2.0 | |
| # plot images | |
| for i in range(n * n): | |
| # define subplot | |
| pyplot.subplot(n, n, 1 + i) | |
| # turn off axis | |
| pyplot.axis('off') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # read images from df_4 dataframe into x_train | |
| x_train = [] | |
| def load_real_samples(): | |
| for f, breed in tqdm(df_4.values): | |
| try: | |
| img = image.load_img(('/storage/train/{}'.format(f)), target_size=(128, 128)) | |
| # convert to float32 | |
| arr1 = image.img_to_array(img, dtype = 'float32') | |
| # scale images to [-1,1] from [0,255] | |
| arr = (arr1 - 127.5) / 127.5 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # define the standalone discriminator model using GAN training hacks | |
| def define_discriminator(in_shape=(128,128,3)): | |
| model = Sequential() | |
| # input layer with image size of 128x128, since its a colored image it has 3 channels | |
| model.add(Conv2D(16, (3,3), padding='same', input_shape=in_shape)) | |
| model.add(LeakyReLU(alpha=0.2)) | |
| # downsample to 64x64 using strides of 2,2 and use of LeakyReLU | |
| model.add(Conv2D(8, (3,3), strides=(2,2), padding='same')) | |
| model.add(LeakyReLU(alpha=0.2)) | |
| # downsample to 32x32 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #Read all the labels from the CSV file | |
| df_csv = pd.read_csv('/storage/trainLabels.csv') | |
| df_csv['image'] = df_csv['image'].astype(str) + '.jpeg' | |
| ## Delete all the images of size zero (0 KB), No need to do this step while rerunning the program ## | |
| ## There are multiple reasons for size zero data such as issue while downloading the database, ## | |
| ## Limited space on the VM or currupted data from the source. ## | |
| cd /storage/train/ | |
| !find /storage/train/ -size 0 -print | |
| !find /storage/train/ -size 0 -delete |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #save data before one hot encoding | |
| train.to_csv(r'train.csv', header=None, index=None, sep=' ', mode='a') | |
| #Use full training data | |
| LoadData = ImageDataGenerator( | |
| horizontal_flip=True) | |
| hard_data_gen = LoadData.flow_from_dataframe( | |
| dataframe= df_train, directory="/storage/train/", | |
| x_col="image", y_col= None, shuffle=False, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import matplotlib.pyplot as plt | |
| def plot_history(history): | |
| loss_list = [s for s in history.history.keys() if 'loss' in s and 'val' not in s] | |
| val_loss_list = [s for s in history.history.keys() if 'loss' in s and 'val' in s] | |
| acc_list = [s for s in history.history.keys() if 'acc' in s and 'val' not in s] | |
| val_acc_list = [s for s in history.history.keys() if 'acc' in s and 'val' in s] | |
| if len(loss_list) == 0: | |
| print('Loss is missing in history') | |
| return |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| with tf.device('/device:GPU:0'): | |
| m1= model.fit_generator( | |
| train_generator, | |
| epochs=5, | |
| validation_data=valid_generator, | |
| class_weight = class_weights, | |
| callbacks= Callbacks, | |
| verbose=1) |
NewerOlder