Skip to content

Instantly share code, notes, and snippets.

@m-rath
Created April 26, 2021 01:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save m-rath/bf0c738c5abd1a6b7cca49b8118cb675 to your computer and use it in GitHub Desktop.
Save m-rath/bf0c738c5abd1a6b7cca49b8118cb675 to your computer and use it in GitHub Desktop.
def _bytes_feature(value):
if isinstance(value, type(tf.constant(0))):
value = value.numpy()
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
#---------------------------------------------------------------------------------------------------
def image_example(image_string, image_name, image_label):
image_shape = tf.image.decode_jpeg(image_string).shape # 2019 cassava images vary in height and width
feature = {
'image': _bytes_feature(image_string),
'height':_int64_feature(image_shape[0]),
'width': _int64_feature(image_shape[1]),
'image_name': _bytes_feature(image_name),
'target': _int64_feature(image_label)
}
return tf.train.Example(features=tf.train.Features(feature=feature))
#---------------------------------------------------------------------------------------------------
# TRAINING_FOLDS
image_id = Training_Folds[:,0]
label = Training_Folds[:,1]
num_Train_tfrecs = 13 #np.ceil(len(Training_Folds)/(len(Training_Folds)+len(Validation_Folds))*16)
num_recs_per = 3371 #len(Training_Folds)//num_Train_tfrecs
for z in range(num_Train_tfrecs):
record_file = 'ld_train'+str(z)+'-'+str(num_recs_per)+'.tfrec'
with tf.io.TFRecordWriter(record_file) as writer:
for i in range(num_recs_per):
img = image_id[i]
image_label = label[i]
image_string = open(dir+img, 'rb').read()
image_name = bytes(img, 'utf-8')
tf_example = image_example(image_string, image_name, image_label)
writer.write(tf_example.SerializeToString())
image_id = image_id[num_recs_per:]
label = label[num_recs_per:]
if z==num_Train_tfrecs-2: num_recs_per = len(image_id)
#---------------------------------------------------------------------------------------------------
# VALIDATION_FOLDS
image_id = Validation_Folds[:,0]
label = Validation_Folds[:,1]
num_Val_tfrecs = 3 #16-num_Train_tfrecs
num_recs_per = 3650 #len(Validation_Folds)//(16-num_Train_tfrecs)
for z in range(num_Val_tfrecs):
record_file = 'ld_val'+str(z)+'-'+str(num_recs_per)+'.tfrec'
with tf.io.TFRecordWriter(record_file) as writer:
for i in range(num_recs_per):
img = image_id[i]
image_label = label[i]
image_string = open(dir+img, 'rb').read()
image_name = bytes(img, 'utf-8')
tf_example = image_example(image_string, image_name, image_label)
writer.write(tf_example.SerializeToString())
image_id = image_id[num_recs_per:]
label = label[num_recs_per:]
if z==num_Val_tfrecs-2: num_recs_per = len(image_id)
#---------------------------------------------------------------------------------------------------
# TEST_SET
image_id = TEST[:,0]
label = TEST[:,1]
num_recs = len(image_id)
record_file = 'ld_test0-'+str(num_recs)+'.tfrec'
with tf.io.TFRecordWriter(record_file) as writer:
for i in range(num_recs):
img = image_id[i]
image_label = label[i]
image_string = open(dir+img, 'rb').read()
image_name = bytes(img, 'utf-8')
tf_example = image_example(image_string, image_name, image_label)
writer.write(tf_example.SerializeToString())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment