Created
January 12, 2019 05:38
-
-
Save Erotemic/02eec53a78b3134509264ecaed37e9c1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
This module can be used as both a script and an importable module. | |
Run `python ggr_matching.py --help` for more details. | |
See docstring in fit for more details on the importable module. | |
conda install opencv | |
conda install pytorch torchvision -c pytorch | |
TestMe: | |
xdoctest ~/code/netharn/netharn/examples/ggr_matching.py all | |
""" | |
from os.path import join | |
import os | |
import ubelt as ub | |
import numpy as np | |
import netharn as nh | |
import torch | |
import torchvision | |
import ndsampler | |
import itertools as it | |
class MatchingHarness(nh.FitHarn): | |
""" | |
Define how to process a batch, compute loss, and evaluate validation | |
metrics. | |
Example: | |
>>> from ggr_matching import * | |
>>> harn = setup_harn() | |
>>> harn.initialize() | |
>>> batch = harn._demo_batch(0, 'train') | |
>>> batch = harn._demo_batch(0, 'vali') | |
""" | |
def __init__(harn, *args, **kw): | |
super(MatchingHarness, harn).__init__(*args, **kw) | |
def after_initialize(harn, **kw): | |
harn.confusion_vectors = [] | |
harn._has_preselected = False | |
def before_epochs(harn): | |
verbose = 0 | |
for tag, dset in harn.datasets.items(): | |
if isinstance(dset, torch.utils.data.Subset): | |
dset = dset.dataset | |
# Presample negatives before each epoch. | |
if tag == 'train': | |
if not harn._has_preselected: | |
harn.log('Presample {} dataset'.format(tag)) | |
# Randomly resample training negatives every epoch | |
harn.datasets['train'].preselect(verbose=verbose) | |
else: | |
if not harn._has_preselected: | |
harn.log('Presample {} dataset'.format(tag)) | |
dset.preselect(verbose=verbose) | |
harn._has_preselected = True | |
def prepare_batch(harn, raw_batch): | |
""" | |
ensure batch is in a standardized structure | |
""" | |
if harn.datasets['train'].triple: | |
batch = { | |
'img1': harn.xpu.move(raw_batch['img1']), | |
'img2': harn.xpu.move(raw_batch['img2']), | |
'img3': harn.xpu.move(raw_batch['img3']), | |
} | |
else: | |
batch = { | |
'img1': harn.xpu.move(raw_batch['img1']), | |
'img2': harn.xpu.move(raw_batch['img2']), | |
'label': harn.xpu.move(raw_batch['label']) | |
} | |
return batch | |
def run_batch(harn, batch): | |
""" | |
Connect data -> network -> loss | |
Args: | |
batch: item returned by the loader | |
""" | |
if harn.datasets['train'].triple: | |
inputs = [batch['img1'], batch['img2'], batch['img3']] | |
outputs = harn.model(*inputs) | |
d1, d2, d3 = outputs['dvecs'] | |
loss = harn.criterion(d1, d2, d3).sum() | |
else: | |
inputs = [batch['img1'], batch['img2']] | |
label = batch['label'] | |
outputs = harn.model(*inputs) | |
loss = harn.criterion(outputs['dist12'], label).sum() | |
return outputs, loss | |
def _decode(harn, outputs): | |
dvecs = [d for d in outputs['dvecs']] | |
decoded = { | |
'dvecs': [d.data.cpu().numpy() for d in dvecs], | |
} | |
if len(dvecs) > 1: | |
dist12 = torch.nn.functional.pairwise_distance(dvecs[0], dvecs[1]) | |
decoded['dist12'] = dist12.data.cpu().numpy() | |
if len(dvecs) > 2: | |
dist13 = torch.nn.functional.pairwise_distance(dvecs[0], dvecs[2]) | |
decoded['dist13'] = dist13.data.cpu().numpy() | |
dist23 = torch.nn.functional.pairwise_distance(dvecs[1], dvecs[2]) | |
decoded['dist23'] = dist23.data.cpu().numpy() | |
return decoded | |
def on_batch(harn, batch, outputs, loss): | |
""" custom callback | |
Example: | |
>>> from ggr_matching import * | |
>>> harn = setup_harn().initialize() | |
>>> batch = harn._demo_batch(0, tag='vali') | |
>>> outputs, loss = harn.run_batch(batch) | |
>>> decoded = harn._decode(outputs) | |
>>> stacked = harn._draw_batch(batch, decoded, limit=42) | |
""" | |
bx = harn.bxs[harn.current_tag] | |
decoded = harn._decode(outputs) | |
if bx < 8: | |
stacked = harn._draw_batch(batch, decoded) | |
dpath = ub.ensuredir((harn.train_dpath, 'monitor', harn.current_tag)) | |
fpath = join(dpath, 'batch_{}_epoch_{}.jpg'.format(bx, harn.epoch)) | |
nh.util.imwrite(fpath, stacked) | |
if harn.datasets['train'].triple: | |
POS_LABEL = 1 # NOQA | |
NEG_LABEL = 0 # NOQA | |
n = len(decoded['dist12']) | |
harn.confusion_vectors.append(([POS_LABEL] * n, decoded['dist12'])) | |
harn.confusion_vectors.append(([NEG_LABEL] * n, decoded['dist13'])) | |
harn.confusion_vectors.append(([NEG_LABEL] * n, decoded['dist23'])) | |
else: | |
label = batch['label'] | |
l2_dist_tensor = torch.squeeze(outputs['dist12'].data.cpu()) | |
label_tensor = torch.squeeze(label.data.cpu()) | |
# Record metrics for epoch scores | |
y_true = label_tensor.cpu().numpy() | |
y_dist = l2_dist_tensor.cpu().numpy() | |
harn.confusion_vectors.append((y_true, y_dist)) | |
if False: | |
# Distance | |
POS_LABEL = 1 # NOQA | |
NEG_LABEL = 0 # NOQA | |
is_pos = (label_tensor == POS_LABEL) | |
pos_dists = l2_dist_tensor[is_pos] | |
neg_dists = l2_dist_tensor[~is_pos] | |
# Average positive / negative distances | |
pos_dist = pos_dists.sum() / max(1, len(pos_dists)) | |
neg_dist = neg_dists.sum() / max(1, len(neg_dists)) | |
# accuracy | |
margin = harn.hyper.criterion_params['margin'] | |
pred_pos_flags = (l2_dist_tensor <= margin).long() | |
pred = pred_pos_flags | |
n_correct = (pred == label_tensor).sum() | |
fraction_correct = n_correct / len(label_tensor) | |
metrics = { | |
'accuracy': float(fraction_correct), | |
'pos_dist': float(pos_dist), | |
'neg_dist': float(neg_dist), | |
} | |
return metrics | |
def on_epoch(harn): | |
""" custom callback """ | |
from sklearn import metrics | |
margin = harn.hyper.criterion_params['margin'] | |
epoch_metrics = {} | |
if harn.confusion_vectors: | |
y_true = np.hstack([r for r, p in harn.confusion_vectors]) | |
y_dist = np.hstack([p for r, p in harn.confusion_vectors]) | |
POS_LABEL = 1 # NOQA | |
NEG_LABEL = 0 # NOQA | |
pos_dist = np.nanmean(y_dist[y_true == POS_LABEL]) | |
neg_dist = np.nanmean(y_dist[y_true == NEG_LABEL]) | |
# Transform distance into a probability-like space | |
y_probs = torch.sigmoid(torch.Tensor(-(y_dist - margin))).numpy() | |
brier = y_probs - y_true | |
y_pred = (y_dist <= margin).astype(y_true.dtype) | |
accuracy = (y_true == y_pred).mean() | |
mcc = metrics.matthews_corrcoef(y_true, y_pred) | |
brier = ((y_probs - y_true) ** 2).mean() | |
epoch_metrics = { | |
'mcc': mcc, | |
'brier': brier, | |
'accuracy': accuracy, | |
'pos_dist': pos_dist, | |
'neg_dist': neg_dist, | |
} | |
# Clear scores for next epoch | |
harn.confusion_vectors.clear() | |
return epoch_metrics | |
def _draw_batch(harn, batch, decoded, limit=32): | |
""" | |
Example: | |
>>> from ggr_matching import * | |
>>> harn = setup_harn().initialize() | |
>>> batch = harn._demo_batch(0, tag='vali') | |
>>> outputs, loss = harn.run_batch(batch) | |
>>> decoded = harn._decode(outputs) | |
>>> stacked = harn._draw_batch(batch, decoded, limit=42) | |
>>> # xdoctest: +REQUIRES(--show) | |
>>> import netharn as nh | |
>>> nh.util.autompl() | |
>>> nh.util.imshow(stacked, colorspace='rgb', doclf=True) | |
>>> nh.util.show_if_requested() | |
""" | |
if 'img3' in batch: | |
imgs = [batch[k].data.cpu().numpy() for k in ['img1', 'img2', 'img3']] | |
else: | |
imgs = [batch[k].data.cpu().numpy() for k in ['img1', 'img2']] | |
labels = batch['label'].data.cpu().numpy() | |
tostack = [] | |
fontkw = { | |
'fontScale': 1.0, | |
'thickness': 2 | |
} | |
n = min(limit, len(imgs[0])) | |
for i in range(n): | |
ims = [g[i].transpose(1, 2, 0) for g in imgs] | |
img = nh.util.stack_images(ims, overlap=-2, axis=1) | |
if 'dist13' in decoded: | |
text = 'dist12={:.2f} --- dist13={:.2f} --- dist23={:.2f}'.format( | |
decoded['dist12'][i], | |
decoded['dist13'][i], | |
decoded['dist23'][i], | |
) | |
else: | |
dist = decoded['dist12'][i] | |
label = labels[i] | |
text = 'dist={:.2f}, label={}'.format(dist, label) | |
img = (img * 255).astype(np.uint8) | |
img = nh.util.draw_text_on_image(img, text, | |
org=(2, img.shape[0] - 2), | |
color='blue', **fontkw) | |
tostack.append(img) | |
stacked = nh.util.stack_images_grid(tostack, overlap=-10, | |
bg_value=(10, 40, 30), | |
axis=1, chunksize=3) | |
return stacked | |
class MatchingNetworkLP(torch.nn.Module): | |
""" | |
Siamese pairwise distance | |
Example: | |
>>> self = MatchingNetworkLP() | |
>>> input_shapes = [(4, 3, 244, 244), (4, 3, 244, 244)] | |
>>> self.output_shape_for(*input_shapes) # todo pdist module | |
""" | |
def __init__(self, p=2, branch=None, input_shape=(1, 3, 416, 416)): | |
super(MatchingNetworkLP, self).__init__() | |
if branch is None: | |
self.branch = torchvision.models.resnet50(pretrained=True) | |
else: | |
self.branch = branch | |
assert isinstance(self.branch, torchvision.models.ResNet) | |
# Note the advanced usage of output-shape-for | |
branch_shape = nh.OutputShapeFor(self.branch)(input_shape) | |
prepool_shape = branch_shape.hidden.shallow(1)['layer4'] | |
# replace the last layer of resnet with a linear embedding to learn the | |
# LP distance between pairs of images. | |
# Also need to replace the pooling layer in case the input has a | |
# different size. | |
self.prepool_shape = prepool_shape | |
pool_channels = prepool_shape[1] | |
pool_dims = prepool_shape[2:] | |
self.branch.avgpool = torch.nn.AvgPool2d(pool_dims, stride=1) | |
self.branch.fc = torch.nn.Linear(pool_channels, 1024) | |
self.pdist = torch.nn.PairwiseDistance(p=p) | |
def forward(self, *inputs): | |
""" | |
Compute a resnet50 vector for each input and look at the LP-distance | |
between the vectors. | |
Example: | |
>>> input1 = nh.XPU(None).variable(torch.rand(4, 3, 224, 224)) | |
>>> input2 = nh.XPU(None).variable(torch.rand(4, 3, 224, 224)) | |
>>> self = MatchingNetworkLP(input_shape=input2.shape[1:]) | |
>>> output = self(input1, input2) | |
Ignore: | |
>>> input1 = nh.XPU(None).variable(torch.rand(1, 3, 416, 416)) | |
>>> input2 = nh.XPU(None).variable(torch.rand(1, 3, 416, 416)) | |
>>> input_shape1 = input1.shape | |
>>> self = MatchingNetworkLP(input_shape=input2.shape[1:]) | |
>>> self(input1, input2) | |
""" | |
dvecs = [self.branch(i) for i in inputs] | |
# L2 normalize the vectors | |
# dvecs = [torch.nn.functional.normalize(d) for d in dvecs] | |
if len(inputs) == 2: | |
dist = self.pdist(*dvecs) | |
output = { | |
'dvecs': dvecs, | |
'dist12': dist, | |
} | |
else: | |
output = { | |
'dvecs': dvecs, | |
} | |
return output | |
def output_shape_for(self, *input_shapes): | |
inputs = input_shapes | |
dvecs = [nh.OutputShapeFor(self.branch)(i) for i in inputs] | |
if len(inputs) == 2: | |
dist = nh.OutputShapeFor(self.pdist)(*dvecs) | |
output = { | |
'dvecs': dvecs, | |
'dist': dist, | |
} | |
else: | |
output = { | |
'dvecs': dvecs, | |
} | |
return output | |
def extract_ggr_pccs(coco_dset): | |
import graphid | |
graph = graphid.api.GraphID() | |
graph.add_annots_from(coco_dset.annots().aids) | |
infr = graph.infr | |
infr.params['inference.enabled'] = False | |
all_aids = list(coco_dset.annots().aids) | |
aids_set = set(all_aids) | |
for aid1 in ub.ProgIter(all_aids, desc='construct graph'): | |
annot = coco_dset.anns[aid1] | |
for review in annot['review_ids']: | |
aid2, decision = review | |
if aid2 not in aids_set: | |
# hack because data is setup wrong | |
continue | |
edge = (aid1, aid2) | |
if decision == 'positive': | |
infr.add_feedback(edge, evidence_decision=graphid.core.POSTV) | |
elif decision == 'negative': | |
infr.add_feedback(edge, evidence_decision=graphid.core.NEGTV) | |
elif decision == 'incomparable': | |
infr.add_feedback(edge, evidence_decision=graphid.core.INCMP) | |
else: | |
raise KeyError(decision) | |
infr.params['inference.enabled'] = True | |
infr.apply_nondynamic_update() | |
print('status = {}' + ub.repr2(infr.status(True))) | |
pccs = list(map(frozenset, infr.positive_components())) | |
for pcc in pccs: | |
for aid in pcc: | |
print('aid = {!r}'.format(aid)) | |
assert aid in coco_dset.anns | |
return pccs | |
class MatchingCocoDataset(torch.utils.data.Dataset): | |
""" | |
Example: | |
>>> harn = setup_harn(dbname='ggr2', xpu='cpu').initialize() | |
>>> self = harn.datasets['train'] | |
""" | |
def __init__(self, sampler, coco_dset, workdir=None, augment=False, | |
dim=416, triple=True): | |
print('make MatchingCocoDataset') | |
cacher = ub.Cacher('pccs', cfgstr=coco_dset.tag, verbose=True) | |
pccs = cacher.tryload() | |
if pccs is None: | |
pccs = extract_ggr_pccs(coco_dset) | |
cacher.save(pccs) | |
self.pccs = pccs | |
self.sampler = sampler | |
print('target index') | |
self.aid_to_tx = {aid: tx for tx, aid in | |
enumerate(sampler.regions.targets['aid'])} | |
self.coco_dset = coco_dset | |
self.triple = triple | |
self.max_num = int(1e5) | |
self.pos_ceil = sum((n * (n - 1)) // 2 for n in map(len, self.pccs)) | |
self.max_num = min(self.pos_ceil, self.max_num) | |
if self.triple: | |
self.sample_gen = sample_triples(pccs, rng=0) | |
self.samples = [next(self.sample_gen) for _ in range(self.max_num)] | |
nh.util.shuffle(self.samples, rng=0) | |
else: | |
print('Find Samples') | |
self.sample_gen = sample_edges_inf(self.pccs) | |
self.samples = sample_edges_finite(self.pccs, max_num=self.max_num, | |
pos_neg_ratio=1.0) | |
nh.util.shuffle(self.samples, rng=0) | |
print('Finished sampling') | |
window_dim = dim | |
self.dim = window_dim | |
self.window_dim = window_dim | |
self.rng = nh.util.ensure_rng(0) | |
if augment: | |
import imgaug.augmenters as iaa | |
# NOTE: we are only using `self.augmenter` to make a hyper hashid | |
# in __getitem__ we invoke transform explicitly for fine control | |
self.hue = nh.data.transforms.HSVShift(hue=0.1, sat=1.5, val=1.5) | |
self.crop = iaa.Crop(percent=(0, .2)) | |
self.flip = iaa.Fliplr(p=.5) | |
self.dependent = self.flip | |
self.independent = iaa.Sequential([ | |
self.hue, | |
self.crop, | |
# self.flip | |
]) | |
self.augmenter = iaa.Sequential([ | |
# self.hue, | |
self.crop, | |
self.flip | |
]) | |
else: | |
self.augmenter = None | |
target_size = (window_dim, window_dim) | |
self.letterbox = nh.data.transforms.Resize(target_size=target_size, | |
mode='letterbox') | |
self.preselect() | |
def preselect(self, verbose=0): | |
if self.augmenter: | |
n = len(self) | |
self.samples = [next(self.sample_gen) for _ in it.repeat(None, n)] | |
def __len__(self): | |
return len(self.samples) | |
def __getitem__(self, index): | |
import graphid | |
if self.triple: | |
aids = self.samples[index] | |
else: | |
aid1, aid2, label = self.samples[index] | |
aids = [aid1, aid2] | |
if label == graphid.core.POSTV: | |
label = 1 | |
elif label == graphid.core.NEGTV: | |
label = 0 | |
else: | |
raise KeyError(label) | |
txs = [self.aid_to_tx[aid] for aid in aids] | |
samples = [self.sampler.load_positive(index=tx) for tx in txs] | |
chips = [s['im'] for s in samples] | |
if self.augmenter: | |
# Ensure the same augmentor is used for bboxes and iamges | |
# if False: | |
# deps = [self.dependent.to_deterministic() for _ in chips] | |
# chips = [d(c) for d, c in zip(deps, chips)] | |
# dependent2 = self.dependent.to_deterministic() | |
if self.rng.rand() > .5: | |
chips = [np.fliplr(c) for c in chips] | |
chips = self.independent.augment_images(chips) | |
chips = self.letterbox.augment_images(chips) | |
chips = [ | |
torch.FloatTensor(c.transpose(2, 0, 1).astype(np.float32) / 255) | |
for c in chips | |
] | |
if self.triple: | |
item = { | |
'img1': chips[0], | |
'img2': chips[1], | |
'img3': chips[2], | |
} | |
else: | |
item = { | |
'img1': chips[0], | |
'img2': chips[1], | |
'label': label, | |
} | |
return item | |
def e_(e): | |
u, v = e | |
return (u, v) if u < v else (v, u) | |
def generate_positives(pccs, rng=None): | |
rng = nh.util.ensure_rng(rng) | |
generators = {i: nh.util.random_combinations(pcc, 2, rng=rng) | |
for i, pcc in enumerate(pccs)} | |
while generators: | |
to_remove = set() | |
for i, gen in generators.items(): | |
try: | |
yield e_(next(gen)) | |
except StopIteration: | |
to_remove.add(i) | |
for i in to_remove: | |
generators.pop(i) | |
def generate_negatives(pccs, rng=None): | |
rng = nh.util.ensure_rng(rng, api='python') | |
generators = None | |
unfinished = True | |
generators = {} | |
while unfinished: | |
finished = set() | |
if unfinished is True: | |
combos = nh.util.random_combinations(pccs, 2, rng=rng) | |
else: | |
combos = unfinished | |
for pcc1, pcc2 in combos: | |
key = (pcc1, pcc2) | |
if key not in generators: | |
generators[key] = nh.util.random_product([pcc1, pcc2], rng=rng) | |
gen = generators[key] | |
try: | |
edge = e_(next(gen)) | |
yield edge | |
except StopIteration: | |
finished.add(key) | |
if unfinished is True: | |
unfinished = set(generators.keys()) | |
unfinished.difference_update(finished) | |
def sample_triples(pccs, rng=None): | |
""" | |
Note: does not take into account incomparable | |
Example: | |
>>> pccs = list(map(frozenset, [{1, 2, 3}, {4, 5}, {6}])) | |
>>> gen = (sample_triples(pccs)) | |
>>> next(gen) | |
""" | |
assert len(pccs) > 1 | |
rng = nh.util.ensure_rng(rng, api='python') | |
aid_to_pcc = {aid: pcc for pcc in pccs for aid in pcc} | |
all_aids = sorted(aid_to_pcc.keys()) | |
pos_gen = generate_positives(pccs, rng=rng) | |
pos_inf = it.cycle(pos_gen) | |
for u, v in pos_inf: | |
this_pcc = aid_to_pcc[u] | |
aid3 = rng.choice(all_aids) | |
while aid3 in this_pcc: | |
aid3 = rng.choice(all_aids) | |
yield u, v, aid3 | |
def sample_edges_inf(pccs, rng=None): | |
""" | |
Note: does not take into account incomparable | |
Example: | |
>>> pccs = list(map(frozenset, [{1, 2, 3}, {4, 5}, {6}])) | |
>>> gen = (sample_edges_inf(pccs)) | |
>>> next(gen) | |
""" | |
import graphid | |
rng = nh.util.ensure_rng(rng, api='python') | |
pos_gen = generate_positives(pccs, rng=rng) | |
neg_gen = generate_negatives(pccs, rng=rng) | |
pos_inf = it.cycle((u, v, graphid.core.POSTV) for u, v in pos_gen) | |
neg_inf = it.cycle((u, v, graphid.core.NEGTV) for u, v in neg_gen) | |
for u, v, label in it.chain.from_iterable(zip(pos_inf, neg_inf)): | |
yield u, v, label | |
def sample_edges_finite(pccs, max_num=1000, pos_neg_ratio=None, rng=None): | |
""" | |
Note: does not take into account incomparable | |
>>> pccs = list(map(frozenset, [{1, 2, 3}, {4, 5}, {6}])) | |
>>> list(sample_edges_finite(pccs)) | |
""" | |
import graphid | |
rng = nh.util.ensure_rng(rng, api='python') | |
# Simpler very randomized sample strategy | |
max_pos = int(max_num) // 2 | |
max_neg = int(max_num) - max_pos | |
pos_pairs = [edge for i, edge in zip(range(max_pos), generate_positives(pccs, rng=rng))] | |
if pos_neg_ratio is not None: | |
max_neg = min(int(pos_neg_ratio * len(pos_pairs)), max_neg) | |
neg_pairs = [edge for i, edge in zip(range(max_neg), generate_negatives(pccs, rng=rng))] | |
labeled_pairs = [ | |
(graphid.core.POSTV, pos_pairs), | |
(graphid.core.NEGTV, neg_pairs), | |
] | |
samples = [(aid1, aid2, label) | |
for label, pairs in labeled_pairs | |
for aid1, aid2 in pairs] | |
return samples | |
def setup_harn(**kwargs): | |
""" | |
CommandLine: | |
python ~/code/netharn/netharn/examples/ggr_matching.py setup_harn | |
Example: | |
>>> harn = setup_harn(dbname='PZ_MTEST') | |
>>> harn.initialize() | |
""" | |
nice = kwargs.get('nice', 'untitled') | |
batch_size = int(kwargs.get('batch_size', 6)) | |
bstep = int(kwargs.get('bstep', 1)) | |
workers = int(kwargs.get('workers', 0)) | |
decay = float(kwargs.get('decay', 1e-5)) | |
lr = float(kwargs.get('lr', 0.019)) | |
dim = int(kwargs.get('dim', 416)) | |
xpu = kwargs.get('xpu', 'argv') | |
workdir = kwargs.get('workdir', None) | |
triple = kwargs.get('triple', False) | |
dbname = kwargs.get('dbname', 'ggr2') | |
margin = float(kwargs.get('margin', 4)) | |
if workdir is None: | |
workdir = ub.truepath(os.path.join('~/work/siam-ibeis2', dbname)) | |
ub.ensuredir(workdir) | |
if dbname == 'ggr2': | |
print('Creating torch CocoDataset') | |
train_dset = ndsampler.CocoDataset( | |
data='/media/joncrall/raid/data/ggr2-coco/annotations/instances_train2018.json', | |
img_root='/media/joncrall/raid/data/ggr2-coco/images/train2018', | |
) | |
train_dset.hashid = 'ggr2-coco-train2018' | |
vali_dset = ndsampler.CocoDataset( | |
data='/media/joncrall/raid/data/ggr2-coco/annotations/instances_val2018.json', | |
img_root='/media/joncrall/raid/data/ggr2-coco/images/val2018', | |
) | |
vali_dset.hashid = 'ggr2-coco-val2018' | |
print('Creating samplers') | |
train_sampler = ndsampler.CocoSampler(train_dset, workdir=workdir) | |
vali_sampler = ndsampler.CocoSampler(vali_dset, workdir=workdir) | |
print('Creating torch Datasets') | |
datasets = { | |
'train': MatchingCocoDataset(train_sampler, train_dset, workdir, | |
dim=dim, augment=True, triple=triple), | |
'vali': MatchingCocoDataset(vali_sampler, vali_dset, workdir, | |
dim=dim, triple=triple), | |
} | |
else: | |
from ibeis_utils import randomized_ibeis_dset | |
datasets = randomized_ibeis_dset(dbname, dim=dim) | |
for k, v in datasets.items(): | |
print('* len({}) = {}'.format(k, len(v))) | |
if workers > 0: | |
import cv2 | |
cv2.setNumThreads(0) | |
loaders = { | |
key: torch.utils.data.DataLoader( | |
dset, batch_size=batch_size, num_workers=workers, | |
shuffle=(key == 'train'), pin_memory=True) | |
for key, dset in datasets.items() | |
} | |
xpu = nh.XPU.cast(xpu) | |
if triple: | |
criterion_ = (torch.nn.TripletMarginLoss, { | |
'margin': margin, | |
}) | |
else: | |
criterion_ = (nh.criterions.ContrastiveLoss, { | |
'margin': margin, | |
'weight': None, | |
}) | |
hyper = nh.HyperParams(**{ | |
'nice': nice, | |
'workdir': workdir, | |
'datasets': datasets, | |
'loaders': loaders, | |
'xpu': xpu, | |
'model': (MatchingNetworkLP, { | |
'p': 2, | |
'input_shape': (1, 3, dim, dim), | |
}), | |
'criterion': criterion_, | |
'optimizer': (torch.optim.SGD, { | |
'lr': lr, | |
'weight_decay': decay, | |
'momentum': 0.9, | |
'nesterov': True, | |
}), | |
'initializer': (nh.initializers.NoOp, {}), | |
# 'scheduler': (nh.schedulers.Exponential, { | |
# 'gamma': 0.99, | |
# 'stepsize': 2, | |
# }), | |
'scheduler': (nh.schedulers.ListedScheduler, { | |
'points': { | |
'lr': { | |
0 : lr * 0.1, | |
35 : lr * 1.0, | |
70 : lr * 0.001, | |
}, | |
'momentum': { | |
0 : 0.95, | |
35 : 0.85, | |
70 : 0.95, | |
71 : 0.999, | |
}, | |
}, | |
}), | |
'monitor': (nh.Monitor, { | |
'minimize': ['loss', 'pos_dist', 'brier'], | |
'maximize': ['accuracy', 'neg_dist', 'mcc'], | |
'patience': 40, | |
'max_epoch': 100, | |
}), | |
# 'augment': datasets['train'].augmenter, | |
'dynamics': { | |
# Controls how many batches to process before taking a step in the | |
# gradient direction. Effectively simulates a batch_size that is | |
# `bstep` times bigger. | |
'batch_step': bstep, | |
}, | |
'other': { | |
'n_classes': 2, | |
}, | |
}) | |
harn = MatchingHarness(hyper=hyper) | |
harn.config['prog_backend'] = 'progiter' | |
harn.intervals['log_iter_train'] = 1 | |
harn.intervals['log_iter_test'] = None | |
harn.intervals['log_iter_vali'] = None | |
return harn | |
def fit(dbname='PZ_MTEST', nice='untitled', dim=416, batch_size=6, bstep=1, | |
lr=0.001, decay=0.0005, workers=0, xpu='argv'): | |
""" | |
Train a siamese chip descriptor for animal identification. | |
Args: | |
dbname (str): Name of IBEIS database to use | |
nice (str): Custom tag for this run | |
dim (int): Width and height of the network input | |
batch_size (int): Base batch size. Number of examples in GPU at any time. | |
bstep (int): Multiply by batch_size to simulate a larger batches. | |
lr (float): Base learning rate | |
decay (float): Weight decay (L2 regularization) | |
workers (int): Number of parallel data loader workers | |
xpu (str): Device to train on. Can be either `'cpu'`, `'gpu'`, a number | |
indicating a GPU (e.g. `0`), or a list of numbers (e.g. `[0,1,2]`) | |
indicating multiple GPUs | |
""" | |
# There has to be a good way to use argparse and specify params only once. | |
# Pass all args down to setup_harn | |
print('RUNNING FIT') | |
import inspect | |
kw = ub.dict_subset(locals(), inspect.getargspec(fit).args) | |
print('SETUP HARNESS') | |
harn = setup_harn(**kw) | |
print('INIT') | |
harn.initialize() | |
print('RUN') | |
harn.run() | |
def main(): | |
""" | |
CommandLine: | |
python examples/ggr_matching.py --help | |
# Test Runs: | |
# use a very small input dimension to test things out | |
python examples/ggr_matching.py --dbname PZ_MTEST --workers=0 --dim=32 --xpu=cpu | |
# test that GPU works | |
python examples/ggr_matching.py --dbname PZ_MTEST --workers=0 --dim=32 --xpu=gpu0 | |
# test that running at a large size works | |
python examples/ggr_matching.py --dbname PZ_MTEST --workers=6 --dim=416 --xpu=gpu0 | |
# Real Run: | |
python examples/ggr_matching.py --dbname GZ_Master1 --workers=6 --dim=512 --xpu=gpu0 --batch_size=10 --lr=0.00001 --nice=gzrun | |
python examples/ggr_matching.py --dbname GZ_Master1 --workers=6 --dim=512 --xpu=gpu0 --batch_size=6 --lr=0.001 --nice=gzrun | |
Notes: | |
# Some database names | |
PZ_Master1 | |
GZ_Master1 | |
RotanTurtles | |
humpbacks_fb | |
""" | |
import xinspect | |
parser = xinspect.auto_argparse(fit) | |
args, unknown = parser.parse_known_args() | |
ns = args.__dict__.copy() | |
fit(**ns) | |
if __name__ == '__main__': | |
""" | |
CommandLine: | |
python ~/code/netharn/examples/ggr_matching.py --help | |
python ~/code/netharn/examples/ggr_matching.py --workers=6 --xpu=0 --dbname=ggr2 --batch_size=8 --nice=ggr2-test | |
python ~/code/netharn/examples/ggr_matching.py --workers=6 --xpu=0 --dbname=ggr2 --batch_size=8 --nice=ggr2-test-v2 | |
python ~/code/netharn/examples/ggr_matching.py --workers=6 --xpu=0 --dbname=ggr2 --batch_size=8 --nice=ggr2-test-v3 | |
python ~/code/netharn/examples/ggr_matching.py --workers=6 --xpu=0 --dbname=ggr2 --batch_size=8 --nice=ggr2-triple-test-v4 | |
python ~/code/netharn/examples/ggr_matching.py --workers=6 --xpu=0 --dbname=ggr2 --batch_size=8 --nice=ggr2-triple-test-v4 | |
python ~/code/netharn/examples/ggr_matching.py --workers=6 --xpu=0 --dbname=ggr2 --batch_size=8 --nice=ggr2-triple-test-v5-newlr | |
python ~/code/netharn/examples/ggr_matching.py --workers=6 --xpu=0 --dbname=ggr2 --batch_size=8 --nice=ggr2-triple-test-v5-newlr_m3 --margin=2 | |
python ~/code/netharn/examples/ggr_matching.py --workers=6 --xpu=0 --dbname=ggr2 --batch_size=8 --nice=ggr2-pair-test-v6-newlr_m4 --margin=4 | |
python ~/code/netharn/examples/ggr_matching.py --workers=6 --xpu=0 --dbname=ggr2 --batch_size=8 --nice=ggr2-pair-test-v7 --margin=4 --lr=0.019 --bstep=8 --triple=False | |
""" | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment