Skip to content

Instantly share code, notes, and snippets.

@Tony363
Created May 15, 2022 15:53
Show Gist options
  • Save Tony363/e82b42f48423f00a0376888d66fc93e6 to your computer and use it in GitHub Desktop.
Save Tony363/e82b42f48423f00a0376888d66fc93e6 to your computer and use it in GitHub Desktop.
def load_embeddings(
embedding_list_train: list,
label_path: str,
mapping_path: str,
data_dir: str,
batch_size: int = 4,
) -> Tuple[np.ndarray, np.ndarray]:
encoding_filename_mapping = json.load(open(mapping_path, "r"))
raw_labels = json.load(open(label_path, "r"))
X_train, y_train = (), ()
for key in embedding_list_train:
real_filename = encoding_filename_mapping[key.replace(
".tfrecords", "")]
loaded = np.load("{}.npy".format(os.path.join(data_dir, key.replace(
".tfrecords", ""))))
# logging.info("Video - {}".format(key))
X_train += (*_get_chunk_array(loaded, batch_size),)
# get flicker frame indexes
flicker_idxs = np.array(raw_labels[real_filename]) - 1
# buffer zeros array frame video embedding
buf_label = np.zeros(loaded.shape[0], dtype=np.uint8)
# set indexes in zeros array based on flicker frame indexes
buf_label[flicker_idxs] = 1
y_train += tuple(
1 if sum(x) else 0
for x in _get_chunk_array(buf_label, batch_size)
) # consider using tf reduce sum for multiclass
return np.array(X_train), np.array(y_train)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment