Skip to content

Instantly share code, notes, and snippets.

@Abhishek-Shaw-Kolkata
Created March 13, 2021 17:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Abhishek-Shaw-Kolkata/f2638ab72ef1d18159571625b518489c to your computer and use it in GitHub Desktop.
Save Abhishek-Shaw-Kolkata/f2638ab72ef1d18159571625b518489c to your computer and use it in GitHub Desktop.
import import_ipynb
from utils import mask2rle,combined_loss,dice_coef
import cv2
import tensorflow as tf
from tensorflow import keras
from tqdm import tqdm
def predict_results(test_files_png):
'''
Given a list of chest Xray image files it generates prediction results for Pneumothorax disease
Args:
test_files_png : List of files for which we need to generate predictions
Returns:
None
'''
model_seg = tf.keras.models.load_model('models/Uefficientnetb4',custom_objects={'combined_loss' :combined_loss, \
'dice_coef' : dice_coef})
d = dict()
for file in tqdm(test_files_png):
img = tf.io.read_file(file)
img = tf.image.decode_png(img, channels= N_CHANNELS)
img = tf.image.convert_image_dtype(img, tf.float32)
img = tf.image.resize(img, [IMG_WIDTH, IMG_HEIGHT])
img.set_shape((IMG_HEIGHT,IMG_WIDTH,3))
pred_mask= model_seg.predict(tf.expand_dims(img,axis=0)).reshape((IMG_HEIGHT,IMG_WIDTH))
if (pred_mask > .5).astype(int).sum() >0 :
pred_mask = cv2.resize(pred_mask,(1024,1024))
pred_mask = (pred_mask > .5).astype(int)
#plt.imshow(pred_mask.squeeze())
d[os.path.splitext(file.split('/')[-1])[0]] = mask2rle(pred_mask.T * 255, 1024,1024)
else:
d[os.path.splitext(file.split('/')[-1])[0]] = '-1'
#Convert the dictionary into a DataFrame
sub = pd.DataFrame.from_dict(d, orient='index')
#Reset index
sub.reset_index(inplace=True)
#Set column names
sub.columns = ['ImageId', 'EncodedPixels']
sub.head()
sub.to_csv('submission.csv', index=False, header=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment