Skip to content

Instantly share code, notes, and snippets.

@gabrielgarza
Last active May 29, 2019 06:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gabrielgarza/33a37a92c683eb11c8ee088ec5e3dcc8 to your computer and use it in GitHub Desktop.
Save gabrielgarza/33a37a92c683eb11c8ee088ec5e3dcc8 to your computer and use it in GitHub Desktop.
class InferenceConfig(config.__class__):
# Run detection on one image at a time
GPU_COUNT = 1
IMAGES_PER_GPU = 1
DETECTION_MIN_CONFIDENCE = 0.95
DETECTION_NMS_THRESHOLD = 0.0
IMAGE_MIN_DIM = 768
IMAGE_MAX_DIM = 768
RPN_ANCHOR_SCALES = (64, 96, 128, 256, 512)
DETECTION_MAX_INSTANCES = 20
# Create model object in inference mode.
config = InferenceConfig()
model = modellib.MaskRCNN(mode="inference", model_dir=MODEL_DIR, config=config)
# Instantiate dataset
dataset = ship.ShipDataset()
# Load weights
model.load_weights(os.path.join(ROOT_DIR, SHIP_WEIGHTS_PATH), by_name=True)
class_names = ['BG', 'ship']
# Run detection
# Load image ids (filenames) and run length encoded pixels
images_path = "datasets/test"
sample_sub_csv = "sample_submission.csv"
# images_path = "datasets/val"
# sample_sub_csv = "val_ship_segmentations.csv"
sample_submission_df = pd.read_csv(os.path.join(images_path,sample_sub_csv))
unique_image_ids = sample_submission_df.ImageId.unique()
out_pred_rows = []
count = 0
for image_id in unique_image_ids:
image_path = os.path.join(images_path, image_id)
if os.path.isfile(image_path):
count += 1
print("Step: ", count)
# Start counting prediction time
tic = time.clock()
image = skimage.io.imread(image_path)
results = model.detect([image], verbose=1)
r = results[0]
# First Image
re_encoded_to_rle_list = []
for i in np.arange(np.array(r['masks']).shape[-1]):
boolean_mask = r['masks'][:,:,i]
re_encoded_to_rle = dataset.rle_encode(boolean_mask)
re_encoded_to_rle_list.append(re_encoded_to_rle)
if len(re_encoded_to_rle_list) == 0:
out_pred_rows += [{'ImageId': image_id, 'EncodedPixels': None}]
print("Found Ship: ", "NO")
else:
for rle_mask in re_encoded_to_rle_list:
out_pred_rows += [{'ImageId': image_id, 'EncodedPixels': rle_mask}]
print("Found Ship: ", rle_mask)
toc = time.clock()
print("Prediction time: ",toc-tic)
submission_df = pd.DataFrame(out_pred_rows)[['ImageId', 'EncodedPixels']]
filename = "{}{:%Y%m%dT%H%M}.csv".format("./submissions/submission_", datetime.datetime.now())
submission_df.to_csv(filename, index=False)
@atlurip
Copy link

atlurip commented May 29, 2019

Hi ,I am getting error in rle_encode

rle_encode() takes 1 positional argument but 2 were given.

could you please please provide the code for rle_encode()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment