Skip to content

Instantly share code, notes, and snippets.

@smly
Created June 2, 2021 09:46
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save smly/ce2457841941789cc4b9dcc670765dca to your computer and use it in GitHub Desktop.
Save smly/ce2457841941789cc4b9dcc670765dca to your computer and use it in GitHub Desktop.
query expansion and database augmentation
import numpy as np
from .functions import l2norm_numpy
def qe_dba(
feats_test, feats_index, sims, topk_idx, alpha=3.0, qe=True, dba=True
):
# alpha-query expansion (DBA)
feats_concat = np.concatenate([feats_test, feats_index], axis=0)
weights = sims ** alpha
feats_tmp = np.zeros(feats_concat.shape)
for i in range(feats_concat.shape[0]):
feats_tmp[i, :] = weights[i].dot(feats_concat[topk_idx[i]])
del feats_concat
feats_concat = l2norm_numpy(feats_tmp).astype(np.float32)
split_at = [len(feats_test)]
if qe and dba:
return np.split(feats_concat, split_at, axis=0)
elif not qe and dba:
_, feats_index = np.split(feats_concat, split_at, axis=0)
return feats_test, feats_index
elif qe and not dba:
feats_test, _ = np.split(feats_concat, split_at, axis=0)
return feats_test, feats_index
else:
raise ValueError
def qe_dba_label_constrained(
feats_test, feats_index, label_index,
sims, topk_idx, alpha=3.0, qe=True, dba=True
):
labels_concat = np.concatenate([
# unlabeled data
np.ones(feats_test.shape[0]) * -1,
# labeled data
label_index
], axis=0)
feats_concat = np.concatenate([feats_test, feats_index], axis=0)
assert labels_concat.shape[0] == feats_concat.shape[0]
weights = sims ** alpha
feats_tmp = np.zeros(feats_concat.shape)
for i in range(feats_concat.shape[0]):
if feats_test.shape[0] > i:
# test images
feats_tmp[i, :] = weights[i].dot(feats_concat[topk_idx[i]])
else:
# train images
query_match = (labels_concat[topk_idx[i]] < 0) * 1.0
binary_label_match = (labels_concat[topk_idx[i]] == labels_concat[i]) * 1.0
weights_mask = (query_match + binary_label_match > 0.0) * 1.0
label_constrained_weights = weights[i] * weights_mask
feats_tmp[i, :] = label_constrained_weights.dot(feats_concat[topk_idx[i]])
del feats_concat
feats_concat = l2norm_numpy(feats_tmp).astype(np.float32)
split_at = [len(feats_test)]
if qe and dba:
return np.split(feats_concat, split_at, axis=0)
elif not qe and dba:
_, feats_index = np.split(feats_concat, split_at, axis=0)
return feats_test, feats_index
elif qe and not dba:
feats_test, _ = np.split(feats_concat, split_at, axis=0)
return feats_test, feats_index
else:
raise ValueError
@smly
Copy link
Author

smly commented Jun 2, 2021

import numpy as np

def l2norm_numpy(x):
    return x / np.linalg.norm(x, ord=2, axis=1, keepdims=True)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment