Skip to content

Instantly share code, notes, and snippets.

@Ab1992ao
Created May 17, 2021 09:16
Show Gist options
  • Save Ab1992ao/328bfb3a44ce9923134a37d85fae6780 to your computer and use it in GitHub Desktop.
Save Ab1992ao/328bfb3a44ce9923134a37d85fae6780 to your computer and use it in GitHub Desktop.
generate triplet data for multitask learning pipe
class TripletGenerator:
def __init__(self, datadict, hard_frac = 0.5, batch_size=256):
self.datadict = datadict
self._anchor_idx = np.array(list(self.datadict.keys()))
self._hard_frac = hard_frac
self.generator = self.generate_batch(batch_size)
def generate_batch(self, size):
while True:
hards = int(size*self._hard_frac)
anchor_ids = np.array(np.random.choice(self._anchor_idx, size, replace=False))
anchors = self.get_anchors(anchor_ids)
positives = self.get_positives(anchor_ids)
negatives = np.hstack([self.get_hard_negatives(anchor_ids[:hards]),
self.get_random_negatives(anchor_ids[hards:])])
labels = np.ones(size)
assert len(anchors) == len(positives) == len(negatives) == len(labels) == size
yield [anchors, positives, negatives], labels
def get_anchors(self, anchor_ids):
classes = ['anchor']
samples = self.get_samples_from_ids(anchor_ids, classes)
return samples
def get_positives(self, anchor_ids):
classes = ['entailment']
samples = self.get_samples_from_ids(anchor_ids, classes)
return samples
def get_hard_negatives(self, anchor_ids):
classes = ['contradiction']
samples = self.get_samples_from_ids(anchor_ids, classes)
return samples
def get_random_negatives(self, anchor_ids):
samples = []
classes = ['contradiction', 'neutral','entailment']
for anchor_id in anchor_ids:
other_anchor_id = self.get_random(self._anchor_idx, anchor_id)
avail_classes = list(set(self.datadict[other_anchor_id].keys()) & set(classes))
sample_class = self.get_random(avail_classes)
sample = self.get_random(self.datadict[other_anchor_id][sample_class])
samples.append(sample)
samples = np.array(samples)
return samples
def get_samples_from_ids(self, anchor_ids, classes):
samples = []
for anchor_id in anchor_ids:
sample_class = self.get_random(classes)
sample = self.get_random(self.datadict[anchor_id][sample_class])
samples.append(sample)
samples = np.array(samples)
return samples
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment