Last active
July 8, 2021 06:40
-
-
Save bmabir17/78cdf6766469efb1212a2fab71114c70 to your computer and use it in GitHub Desktop.
Logging Keras Validation Results with WandB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
name: 360_image_clf | |
channels: | |
- conda-forge | |
- defaults | |
dependencies: | |
- _libgcc_mutex=0.1=main | |
- _openmp_mutex=4.5=1_gnu | |
- _tflow_select=2.3.0=mkl | |
- absl-py=0.13.0=py38h06a4308_0 | |
- aiohttp=3.7.4=py38h27cfd23_1 | |
- albumentations=1.0.0=pyhd8ed1ab_0 | |
- alsa-lib=1.2.3=h516909a_0 | |
- astor=0.8.1=py38h06a4308_0 | |
- astunparse=1.6.3=py_0 | |
- async-timeout=3.0.1=py38h06a4308_0 | |
- attrs=21.2.0=pyhd3eb1b0_0 | |
- blas=1.0=mkl | |
- blinker=1.4=py38h06a4308_0 | |
- brotlipy=0.7.0=py38h27cfd23_1003 | |
- bzip2=1.0.8=h7b6447c_0 | |
- c-ares=1.17.1=h27cfd23_0 | |
- ca-certificates=2021.5.30=ha878542_0 | |
- cachetools=4.2.2=pyhd3eb1b0_0 | |
- cairo=1.16.0=hf32fb01_1 | |
- certifi=2021.5.30=py38h578d9bd_0 | |
- cffi=1.14.5=py38h261ae71_0 | |
- chardet=3.0.4=py38h06a4308_1003 | |
- click=8.0.1=pyhd3eb1b0_0 | |
- cloudpickle=1.6.0=py_0 | |
- coverage=5.5=py38h27cfd23_2 | |
- cryptography=3.4.7=py38hd23ed53_0 | |
- cycler=0.10.0=py38_0 | |
- cython=0.29.23=py38h2531618_0 | |
- cytoolz=0.11.0=py38h7b6447c_0 | |
- dask-core=2021.6.2=pyhd3eb1b0_0 | |
- dbus=1.13.18=hb2f20db_0 | |
- decorator=5.0.9=pyhd3eb1b0_0 | |
- expat=2.4.1=h2531618_2 | |
- ffmpeg=4.3.1=hca11adc_2 | |
- fontconfig=2.13.1=h6c09931_0 | |
- freetype=2.10.4=h5ab3b9f_0 | |
- fsspec=2021.6.0=pyhd3eb1b0_0 | |
- gast=0.4.0=py_0 | |
- geos=3.8.0=he6710b0_0 | |
- gettext=0.21.0=hf68c758_0 | |
- glib=2.68.3=h9c3ff4c_0 | |
- glib-tools=2.68.3=h9c3ff4c_0 | |
- gmp=6.2.1=h2531618_2 | |
- gnutls=3.6.15=he1e5248_0 | |
- google-auth=1.32.0=pyhd3eb1b0_0 | |
- google-auth-oauthlib=0.4.4=pyhd3eb1b0_0 | |
- google-pasta=0.2.0=py_0 | |
- graphite2=1.3.14=h23475e2_0 | |
- grpcio=1.36.1=py38h2157cd5_1 | |
- gst-plugins-base=1.18.4=hf529b03_2 | |
- gstreamer=1.18.4=h76c114f_2 | |
- h5py=2.10.0=py38hd6299e0_1 | |
- harfbuzz=2.8.1=h83ec7ef_0 | |
- hdf5=1.10.6=hb1b8bf9_0 | |
- icu=68.1=h2531618_0 | |
- idna=2.10=pyhd3eb1b0_0 | |
- imageio=2.9.0=pyhd3eb1b0_0 | |
- imgaug=0.4.0=pyhd3eb1b0_0 | |
- importlib-metadata=3.10.0=py38h06a4308_0 | |
- intel-openmp=2021.2.0=h06a4308_610 | |
- jasper=1.900.1=hd497a04_4 | |
- joblib=1.0.1=pyhd8ed1ab_0 | |
- jpeg=9d=h36c2ea0_0 | |
- keras=2.4.3=0 | |
- keras-base=2.4.3=py_0 | |
- keras-preprocessing=1.1.2=pyhd3eb1b0_0 | |
- kiwisolver=1.3.1=py38h2531618_0 | |
- krb5=1.19.1=h3535a68_0 | |
- lame=3.100=h7b6447c_0 | |
- lcms2=2.12=h3be6417_0 | |
- ld_impl_linux-64=2.35.1=h7274673_9 | |
- libblas=3.9.0=9_mkl | |
- libcblas=3.9.0=9_mkl | |
- libclang=11.1.0=default_ha53f305_1 | |
- libedit=3.1.20210216=h27cfd23_1 | |
- libevent=2.1.10=hcdb4288_3 | |
- libffi=3.3=he6710b0_2 | |
- libgcc-ng=9.3.0=h5101ec6_17 | |
- libgfortran-ng=7.5.0=ha8ba4b0_17 | |
- libgfortran4=7.5.0=ha8ba4b0_17 | |
- libglib=2.68.3=h3e27bee_0 | |
- libgomp=9.3.0=h5101ec6_17 | |
- libiconv=1.16=h516909a_0 | |
- libidn2=2.3.1=h27cfd23_0 | |
- liblapack=3.9.0=9_mkl | |
- liblapacke=3.9.0=9_mkl | |
- libllvm11=11.1.0=hf817b99_2 | |
- libogg=1.3.5=h27cfd23_1 | |
- libopencv=4.5.2=py38hcdf9bf1_0 | |
- libopus=1.3.1=h7b6447c_0 | |
- libpng=1.6.37=hbc83047_0 | |
- libpq=13.3=hd57d9b9_0 | |
- libprotobuf=3.15.8=h780b84a_0 | |
- libstdcxx-ng=9.3.0=hd4cf53a_17 | |
- libtasn1=4.16.0=h27cfd23_0 | |
- libtiff=4.2.0=h85742a9_0 | |
- libunistring=0.9.10=h27cfd23_0 | |
- libuuid=1.0.3=h1bed415_2 | |
- libvorbis=1.3.7=h7b6447c_0 | |
- libwebp-base=1.2.0=h27cfd23_0 | |
- libxcb=1.14=h7b6447c_0 | |
- libxkbcommon=1.0.3=he3ba5ed_0 | |
- libxml2=2.9.12=h72842e0_0 | |
- locket=0.2.1=py38h06a4308_1 | |
- lz4-c=1.9.3=h2531618_0 | |
- markdown=3.3.4=py38h06a4308_0 | |
- matplotlib=3.3.4=py38h06a4308_0 | |
- matplotlib-base=3.3.4=py38h62a2d02_0 | |
- mkl=2021.2.0=h06a4308_296 | |
- mkl-service=2.3.0=py38h27cfd23_1 | |
- mkl_fft=1.3.0=py38h42c9631_2 | |
- mkl_random=1.2.1=py38ha9443f7_2 | |
- multidict=5.1.0=py38h27cfd23_2 | |
- mysql-common=8.0.25=ha770c72_0 | |
- mysql-libs=8.0.25=h935591d_0 | |
- ncurses=6.2=he6710b0_1 | |
- nettle=3.7.3=hbbd107a_1 | |
- networkx=2.5=py_0 | |
- nspr=4.30=h9c3ff4c_0 | |
- nss=3.67=hb5efdd6_0 | |
- numpy=1.20.2=py38h2d18471_0 | |
- numpy-base=1.20.2=py38hfae3a4d_0 | |
- oauthlib=3.1.0=py_0 | |
- olefile=0.46=py_0 | |
- opencv=4.5.2=py38h578d9bd_0 | |
- openh264=2.1.1=h780b84a_0 | |
- openssl=1.1.1k=h7f98852_0 | |
- opt_einsum=3.3.0=pyhd3eb1b0_1 | |
- pandas=1.2.5=py38h295c915_0 | |
- partd=1.2.0=pyhd3eb1b0_0 | |
- pcre=8.45=h295c915_0 | |
- pillow=8.2.0=py38he98fc37_0 | |
- pip=21.1.3=py38h06a4308_0 | |
- pixman=0.40.0=h7b6447c_0 | |
- protobuf=3.15.8=py38h709712a_0 | |
- py-opencv=4.5.2=py38hd0cf306_0 | |
- pyasn1=0.4.8=py_0 | |
- pyasn1-modules=0.2.8=py_0 | |
- pycparser=2.20=py_2 | |
- pyjwt=1.7.1=py38_0 | |
- pyopenssl=20.0.1=pyhd3eb1b0_1 | |
- pyparsing=2.4.7=pyhd3eb1b0_0 | |
- pyqt=5.12.3=py38h578d9bd_7 | |
- pyqt-impl=5.12.3=py38h7400c14_7 | |
- pyqt5-sip=4.19.18=py38h709712a_7 | |
- pyqtchart=5.12=py38h7400c14_7 | |
- pyqtwebengine=5.12.1=py38h7400c14_7 | |
- pysocks=1.7.1=py38h06a4308_0 | |
- python=3.8.10=h12debd9_8 | |
- python-dateutil=2.8.1=pyhd3eb1b0_0 | |
- python-flatbuffers=1.12=pyhd3eb1b0_0 | |
- python_abi=3.8=2_cp38 | |
- pytz=2021.1=pyhd3eb1b0_0 | |
- pywavelets=1.1.1=py38h7b6447c_2 | |
- pyyaml=5.4.1=py38h27cfd23_1 | |
- qt=5.12.9=hda022c4_4 | |
- readline=8.1=h27cfd23_0 | |
- requests=2.25.1=pyhd3eb1b0_0 | |
- requests-oauthlib=1.3.0=py_0 | |
- rsa=4.7.2=pyhd3eb1b0_1 | |
- scikit-image=0.16.2=py38h0573a6f_0 | |
- scikit-learn=0.24.2=py38hdc147b9_0 | |
- scipy=1.6.2=py38had2a1c9_1 | |
- seaborn=0.11.1=pyhd3eb1b0_0 | |
- setuptools=52.0.0=py38h06a4308_0 | |
- shapely=1.7.1=py38h98ec03d_0 | |
- six=1.16.0=pyhd3eb1b0_0 | |
- sqlite=3.36.0=hc218d9a_0 | |
- tensorboard=2.4.0=pyhc547734_0 | |
- tensorboard-plugin-wit=1.6.0=py_0 | |
- tensorflow=2.4.1=mkl_py38hb2083e0_0 | |
- tensorflow-base=2.4.1=mkl_py38h43e0292_0 | |
- tensorflow-estimator=2.5.0=pyh7b7c402_0 | |
- termcolor=1.1.0=py38h06a4308_1 | |
- threadpoolctl=2.1.0=pyh5ca1d4c_0 | |
- tk=8.6.10=hbc83047_0 | |
- toolz=0.11.1=pyhd3eb1b0_0 | |
- tornado=6.1=py38h27cfd23_0 | |
- typing-extensions=3.10.0.0=hd3eb1b0_0 | |
- typing_extensions=3.10.0.0=pyh06a4308_0 | |
- urllib3=1.26.6=pyhd3eb1b0_1 | |
- werkzeug=1.0.1=pyhd3eb1b0_0 | |
- wheel=0.36.2=pyhd3eb1b0_0 | |
- wrapt=1.12.1=py38h7b6447c_1 | |
- x264=1!161.3030=h7f98852_1 | |
- xz=5.2.5=h7b6447c_0 | |
- yaml=0.2.5=h7b6447c_0 | |
- yarl=1.6.3=py38h27cfd23_0 | |
- zipp=3.4.1=pyhd3eb1b0_0 | |
- zlib=1.2.11=h7b6447c_3 | |
- zstd=1.4.9=haebb681_0 | |
- pip: | |
- configparser==5.0.2 | |
- docker-pycreds==0.4.0 | |
- gitdb==4.0.7 | |
- gitpython==3.1.18 | |
- opencv-contrib-python==4.5.2.54 | |
- opencv-python==4.5.2.54 | |
- pathtools==0.1.2 | |
- promise==2.3 | |
- psutil==5.8.0 | |
- sentry-sdk==1.1.0 | |
- shortuuid==1.0.1 | |
- smmap==4.0.0 | |
- subprocess32==3.5.4 | |
- wandb==0.10.33 | |
prefix: /home/bmabir/anaconda3/envs/360_image_clf |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from PIL import Image | |
from pathlib import Path | |
from collections import defaultdict | |
import keras | |
# import tensorflow as tf | |
# from tensorflow.python import keras | |
# from tensorflow.python.keras import backend as Keras | |
from tensorflow.keras.utils import Sequence | |
import logging | |
import matplotlib.pyplot as plt | |
import cv2 | |
from random import shuffle | |
from tensorflow.keras.utils import to_categorical | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(name)s %(levelname)s:%(message)s') | |
logger = logging.getLogger(__name__) | |
class SmartDataGenerator(Sequence): | |
def __init__(self, data, batch_size, target_resolution, n_channels, n_classes, balancing_augmentation, | |
class_balancing = True, shuffle=True, resize_technique = Image.BICUBIC): | |
"""initialization""" | |
self.target_resolution = target_resolution | |
self.n_channels = n_channels | |
self.batch_size = batch_size | |
self.list_IDs, self.labels = data | |
self.n_classes = n_classes | |
self.shuffle = shuffle | |
self.resize_technique = resize_technique | |
self.class_balancing = class_balancing | |
self.b_augment = balancing_augmentation | |
self.balance_class() | |
self.on_epoch_end() | |
def __len__(self): | |
ln = int(np.floor(len(self.list_IDs) / self.batch_size)) | |
return ln | |
def __getitem__(self, pivot_index): | |
"""Generate one batch of data""" | |
indexes = self.indexes[pivot_index * self.batch_size : (pivot_index+1)*self.batch_size] | |
list_IDs_tmp = [self.list_IDs[k] for k in indexes] | |
X, y = self.__data_generation(list_IDs_tmp) | |
return X, y | |
def read_img(self, path): | |
"""read, resize and convert into numpy array""" | |
img_obj = Image.open(path) | |
img_obj = img_obj.resize(self.target_resolution, Image.BICUBIC) | |
img = np.array(img_obj) | |
return img | |
def __data_generation(self, list_IDs_temp): | |
"""Generates data from disk files""" | |
X = np.empty((self.batch_size, *self.target_resolution, self.n_channels )) | |
y = np.empty((self.batch_size), dtype=int) | |
for i, img_ID in enumerate(list_IDs_temp): | |
if not Path(img_ID).exists() and Path(img_ID).name == '__aug__': | |
tmp_img_ID = Path(img_ID).parent | |
img = self.read_img(tmp_img_ID) | |
img = self.b_augment(image=img)["image"] | |
img = img.reshape((img.shape[0], img.shape[1], self.n_channels)) | |
img = img.astype('float32') / 255 | |
X[i,] = img | |
y[i] = self.labels[str(img_ID)] | |
elif Path(img_ID).exists(): | |
img = self.read_img(img_ID) | |
img = img.reshape((img.shape[0], img.shape[1], self.n_channels )) | |
img = img.astype('float32') / 255 | |
X[i,] = img | |
y[i] = self.labels[str(img_ID)] | |
else: | |
raise Exception('File not found at {}th iteration. Please check {}'.format(i, img_ID)) | |
y_cat = to_categorical(y, num_classes=self.n_classes) | |
return X, y_cat | |
def balance_class(self): | |
if self.class_balancing == False: | |
logging.info('Skipping data balancing') | |
return | |
else: | |
logging.info('-------------------\n\nStarting Datasets balancing using SMOT...\n') | |
# first track all file name in its belonging class | |
cls_filename_map = defaultdict(list) | |
for path in self.list_IDs: | |
file_name = Path(path) | |
cls = self.labels[str(path)] | |
cls_filename_map[cls].append(file_name) | |
# calculate how many synthetic file is needed in each class | |
total = 0 | |
max_cls_file_num = 0 | |
cls_file_count_map = {} | |
for cls in cls_filename_map.keys(): | |
file_num = len(cls_filename_map[cls]) | |
cls_file_count_map[cls] = file_num | |
total += file_num | |
if max_cls_file_num < file_num: | |
max_cls_file_num = file_num | |
dominate_pct = 4 # imbalce pct of the dominate class. dominate_pct = 0 means all other class are equal | |
new_estimated_total = int(len(cls_filename_map) * (max_cls_file_num - (max_cls_file_num * dominate_pct / 100.0))) | |
logging.info('Found {} data. More {} synthetic data will be generated (new total = {}).\t{}/class\n' | |
.format(total, new_estimated_total-total, new_estimated_total, (new_estimated_total/len(cls_filename_map)))) | |
for cls, cls_file_count in cls_file_count_map.items(): | |
current_cls_filelist = cls_filename_map[cls] | |
current_cls_file_count = len(current_cls_filelist) | |
if current_cls_file_count < max_cls_file_num: | |
more_needed = int((new_estimated_total / len(cls_filename_map)) - current_cls_file_count) | |
logging.info('class-{} will need more {}.\tcurrently have {}'.format(cls, more_needed, current_cls_file_count)) | |
for i in range(more_needed): | |
index = i % current_cls_file_count | |
tmp_file = current_cls_filelist[index] | |
new_filename = Path(tmp_file) / '__aug__' # __aug__ is our defined flag to distinguish from real | |
self.list_IDs.append(new_filename) | |
self.labels[str(new_filename)] = cls | |
else: | |
logging.info( | |
'class-{} need not any synthetic data. Already it has {}'.format(cls, current_cls_file_count)) | |
def on_epoch_end(self): | |
"""Updates indexes after each epoch""" | |
self.indexes = np.arange(len(self.list_IDs)) | |
if self.shuffle == True: | |
np.random.shuffle(self.indexes) | |
class TukiTaki: | |
def get_class_name_from_hotencoding(hot_encoding): | |
import json | |
i = np.argmax(hot_encoding, axis=0) | |
with open('Reverse_class_index_mapping.txt') as f: | |
data = json.load(f) | |
key = str(i) | |
return data[key] | |
class KfoldMaker: | |
def __init__(self, dataset_dir, image_extensions): | |
self.dataset_dir = dataset_dir | |
self.image_extensions = image_extensions | |
self.all_path_list, self.all_labels_dic = self.scan_dataset() | |
def generate_folds(self, K): | |
np.random.shuffle(self.all_path_list) | |
# filename_grpups = self.person_wise_grouping(K) | |
num_validation_samples = len(self.all_path_list) // K | |
fold_list = [] | |
logging.info('\n\nDividing dataset into {} folds...'.format(K)) | |
for fold in range(K): | |
# For general case | |
validation_data_x = self.all_path_list[num_validation_samples*fold : num_validation_samples*(fold+1)] | |
training_data_x = self.all_path_list[ : num_validation_samples*fold] \ | |
+ self.all_path_list[num_validation_samples*(fold+1) : ] | |
validation_data_y = {} | |
training_data_y = {} | |
for x in validation_data_x: | |
validation_data_y[str(x)] = self.all_labels_dic[str(x)] | |
for x in training_data_x: | |
training_data_y[str(x)] = self.all_labels_dic[str(x)] | |
fold_list.append(((training_data_x, training_data_y), (validation_data_x, validation_data_y))) | |
status_bar = ['_'for x in range(K)] | |
status_bar[fold] = '#' | |
logging.info('Generated K-fold metadata for fold: {} {}'.format(str(fold), ''.join(status_bar))) | |
# import pickle | |
# with open('tmp/k_fold_metadata_temporary.pickle', 'wb') as fp: | |
# pickle.dump(fold_list, fp) | |
return fold_list | |
def scan_dataset(self): | |
from pathlib import Path | |
root_path = Path(self.dataset_dir) | |
images_path_list = [] | |
for extension in self.image_extensions: | |
img_list = list(root_path.rglob('*.'+extension)) | |
images_path_list.extend(img_list) | |
logging.debug('{} images found with {} extension'.format(len(img_list), extension)) | |
logging.info('Found {} images with {} extensions\n'.format(len(images_path_list), self.image_extensions)) | |
label_dic = {} | |
cls_id_counter = -1 | |
cls_id_map = {} | |
reverse_cls_id_map = {} | |
for index, path in enumerate(images_path_list): | |
cls_name = path.parent.name | |
if cls_name not in cls_id_map: | |
cls_id_counter += 1 | |
cls_id_map[cls_name] = cls_id_counter | |
reverse_cls_id_map[cls_id_counter] = cls_name | |
label_dic[str(path)] = cls_id_counter | |
# save mapping between original class name and new assigned index | |
# it may be useful during testing/prediction | |
import json | |
with open('Reverse_class_index_mapping.txt', 'w') as file: | |
file.write(json.dumps(reverse_cls_id_map)) | |
logging.info('Detected {} correctly labeled images inside {}'.format(len(images_path_list), self.dataset_dir)) | |
logging.info('Total {} class found. {}\n'.format(len(cls_id_map), cls_id_map)) | |
return images_path_list, label_dic |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def build_simple_model(config): | |
# from keras import models, layers | |
import tensorflow as tf | |
from tensorflow import keras | |
from tensorflow.keras import layers | |
# create model | |
model = keras.Sequential() | |
# add model layers | |
model.add(layers.Conv2D(32, kernel_size=3, activation='relu', input_shape=(config.input_size, config.input_size, config.n_channels))) | |
model.add(layers.MaxPooling2D((2, 2))) | |
model.add(layers.Conv2D(64, kernel_size=3, activation='relu')) | |
model.add(layers.MaxPooling2D((2, 2))) | |
model.add(layers.Conv2D(64, (3, 3), activation='relu')) | |
model.add(layers.Flatten()) | |
model.add(layers.Dense(64, activation='relu')) | |
model.add(layers.Dense(config.n_classes, activation='softmax')) | |
model.compile( | |
optimizer='rmsprop', | |
loss='categorical_crossentropy', | |
metrics=[ | |
tf.keras.metrics.BinaryAccuracy(name='accuracy'), | |
tf.keras.metrics.TruePositives(name='true_pos'), | |
tf.keras.metrics.FalsePositives(name='false_pos'), | |
tf.keras.metrics.TrueNegatives(name='true_neg'), | |
tf.keras.metrics.FalseNegatives(name='false_neg'), | |
tf.keras.metrics.Precision(name='precision'), | |
tf.keras.metrics.Recall(name='recall'), | |
tf.keras.metrics.AUC(name='auc'), | |
tf.keras.metrics.AUC(name='prc', curve='PR') | |
]) | |
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/home/bmabir/anaconda3/envs/360_image_clf/bin/python | |
from DatasetTools import SmartDataGenerator, KfoldMaker | |
from PIL import Image | |
from albumentations import ( | |
Compose, HorizontalFlip, | |
RandomBrightness, RandomContrast, | |
) | |
from model import build_simple_model | |
import tensorflow as tf | |
from tensorflow.python.keras import backend as K | |
# adjust values to your needs | |
tf_config = tf.compat.v1.ConfigProto( device_count = {'GPU': 1 , 'CPU': 10} ) | |
sess = tf.compat.v1.Session(config=tf_config) | |
K.set_session(sess) | |
import wandb | |
from wandb.keras import WandbCallback | |
# 1. Start a new run | |
wandb.init(project='360_image_clf', entity='bmabir17') | |
# 2. Save model inputs and hyperparameters | |
config = wandb.config | |
config.epochs = 20 | |
config.num_validation=100 | |
config.batch_size=16 | |
config.n_channels=3 | |
config.n_classes=2 | |
config.input_size=256 | |
config.k_fold=5 | |
# ---------------------------------------------------------------------------------------------------------------------- | |
# ---------------------------------------------------------------------------------------------------------------------- | |
dataset_dir = 'data/train' | |
# Parameters | |
params = {'batch_size': config.batch_size, | |
'target_resolution': (config.input_size, config.input_size), | |
'n_channels': config.n_channels, | |
'n_classes': config.n_classes, | |
'balancing_augmentation': Compose([HorizontalFlip(p=0.5), RandomContrast(limit=0.2, p=0.5), RandomBrightness(limit=0.2, p=0.5)] ), | |
'shuffle': True, | |
'resize_technique': Image.BICUBIC} | |
# NOTE: here KfoldMaker is only used to scan files and generate file list | |
k_obj = KfoldMaker(dataset_dir=dataset_dir, image_extensions=['jpeg']) | |
num_validation = config.num_validation # number of validation set you want | |
X = k_obj.all_path_list | |
Y = k_obj.all_labels_dic | |
# print(Y) | |
train_x = X[ : -num_validation] | |
val_x = X[ num_validation : ] | |
training_gen = SmartDataGenerator(data=(train_x, Y), **params) | |
validation_gen = SmartDataGenerator(data=(val_x, Y), **params) | |
model = build_simple_model(config) | |
log = model.fit(x=training_gen, | |
validation_data=validation_gen, | |
workers=9, # realtime loading with parallel processing | |
epochs=config.epochs, verbose=1, | |
callbacks=[ | |
WandbCallback( | |
save_model=True, | |
mode="auto", | |
input_type="image", | |
labels=['bad','good'], | |
predictions=1)]) | |
model.save("saved_model/good_bad_model_e{config.epochs}_s256") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment