Skip to content

Instantly share code, notes, and snippets.

@doleron
Created April 9, 2023 00:55
Show Gist options
  • Save doleron/62f77567a53d5e33a8dfba64ba90e88a to your computer and use it in GitHub Desktop.
Save doleron/62f77567a53d5e33a8dfba64ba90e88a to your computer and use it in GitHub Desktop.
def create_pairs(x, digit_indices):
pairs = []
labels = []
n = min([len(digit_indices[d]) for d in range(CLASSES_SIZE)]) - 1
for d in range(CLASSES_SIZE):
for i in range(n):
z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
pairs += [[x[z1], x[z2]]]
inc = random.randrange(1, CLASSES_SIZE)
dn = (d + inc) % CLASSES_SIZE
z1, z2 = digit_indices[d][i], digit_indices[dn][i]
pairs += [[x[z1], x[z2]]]
labels += [1, 0]
return np.array(pairs), np.array(labels)
def create_pairs_on_set(images, labels):
digit_indices = [np.where(labels == i)[0] for i in range(CLASSES_SIZE)]
pairs, y = create_pairs(images, digit_indices)
y = y.astype('float32')
return pairs, y
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment