Skip to content

Instantly share code, notes, and snippets.

@anderzzz
Last active November 5, 2020 10:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save anderzzz/0568b03b45e481d6b46c4ec2ff4fa6c5 to your computer and use it in GitHub Desktop.
Save anderzzz/0568b03b45e481d6b46c4ec2ff4fa6c5 to your computer and use it in GitHub Desktop.
from torch.optim import SGD
from torch.utils.data import DataLoader
from sklearn.preprocessing import normalize
import fungidata
from ae_deep import EncoderVGGMerged
from cluster_utils import MemoryBank, LocalAggregationLoss
# Create fungi Dataset (details omitted)
dataset = fungidata.factory.create('grid basic idx', ...)
dataloader = DataLoader(dataset, ...)
# Instantiate custom-made model and criterion with initial memory bank from pre-trained VGG-Encoder
model = EncoderVGGMerged(merger_type='mean')
memory_bank = MemoryBank(n_vectors=5400, dim_vector=model.channels_code, memory_mixing_rate=0.5)
memory_bank.vectors = normalize(model.eval_codes_for_(dataloader), axis=1)
criterion = LocalAggregationLoss(memory_bank=memory_bank,
temperature=0.07, k_nearest_neighbours=500, clustering_repeats=6, number_of_centroids=100)
# Instantiate a stochastic-gradient descent optimizer
optimizer = SGD(model.parameters())
# Rudimentary outline of training loop
for epoch in range(20):
for inputs in dataloader:
optimizer.zero_grad()
output = model(inputs['image'])
loss = criterion(output, inputs['idx'])
loss.backward()
optimizer.step()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment