Skip to content

Instantly share code, notes, and snippets.

@alexlyzhov
Created July 7, 2018 09:53
Show Gist options
  • Save alexlyzhov/045d4b1cc348872a570d45e7238c11ad to your computer and use it in GitHub Desktop.
Save alexlyzhov/045d4b1cc348872a570d45e7238c11ad to your computer and use it in GitHub Desktop.
Baseline
import json
import os
from collections import defaultdict
import numpy as np
from PIL import Image
class DSModel:
MODEL_REL_FILENAME = 'class.txt'
@staticmethod
def get_class(img_data):
classes_info = img_data['aabb']
for class_, class_bboxes in classes_info.items():
if len(class_bboxes) > 0:
return class_
raise Exception
@staticmethod
def get_img_data(class_):
classes = ['ring', 'Other', 'Brooch', 'earring', 'pendant', 'necklace']
img_data = {'aabb': {cur_class_: [] for cur_class_ in classes}}
img_data['aabb'][class_] = [[[0] * 2] * 5]
return img_data
def __init__(self, path_to_assets_dir: str):
pass
def load_model(self, path_to_model_dir: str):
"""Load name of the most popular class
"""
class_filename = os.path.join(path_to_model_dir, self.MODEL_REL_FILENAME)
with open(class_filename, 'r') as f:
self.best_class = f.read()
def train(self, path_to_training_data: str, path_to_model_dir: str):
"""Choose the most popular class and save its name
"""
markup_filename = os.path.join(path_to_training_data, 'markup.json')
with open(markup_filename) as f:
markup = json.load(f)
# # This is how you iterate over training images:
# for filename in markup.keys():
# full_path = os.path.join(path_to_training_data, filename)
# img = np.asarray(Image.open(full_path))
# # do something with the image here
class_occurences = defaultdict(int)
for filename, img_data in markup.items():
class_ = self.get_class(img_data)
class_occurences[class_] += 1
best_class = max(class_occurences.keys(), key=(lambda key: class_occurences[key]))
class_filename = os.path.join(path_to_model_dir, self.MODEL_REL_FILENAME)
with open(class_filename, 'a') as f:
f.write(best_class)
def predict(self, batch: [bytes]) -> list:
"""Predict the most popular class for all objects
"""
# # This is how you iterate over batch of test images:
# for cur_data in batch:
# bytesio = BytesIO(cur_data)
# img = np.asarray(Image.open(bytesio))
# # do something with the image here
best_img_data = self.get_img_data(self.best_class)
predictions = [best_img_data] * len(batch)
return predictions
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment