Last active
November 5, 2020 10:56
-
-
Save anderzzz/0568b03b45e481d6b46c4ec2ff4fa6c5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.optim import SGD | |
from torch.utils.data import DataLoader | |
from sklearn.preprocessing import normalize | |
import fungidata | |
from ae_deep import EncoderVGGMerged | |
from cluster_utils import MemoryBank, LocalAggregationLoss | |
# Create fungi Dataset (details omitted) | |
dataset = fungidata.factory.create('grid basic idx', ...) | |
dataloader = DataLoader(dataset, ...) | |
# Instantiate custom-made model and criterion with initial memory bank from pre-trained VGG-Encoder | |
model = EncoderVGGMerged(merger_type='mean') | |
memory_bank = MemoryBank(n_vectors=5400, dim_vector=model.channels_code, memory_mixing_rate=0.5) | |
memory_bank.vectors = normalize(model.eval_codes_for_(dataloader), axis=1) | |
criterion = LocalAggregationLoss(memory_bank=memory_bank, | |
temperature=0.07, k_nearest_neighbours=500, clustering_repeats=6, number_of_centroids=100) | |
# Instantiate a stochastic-gradient descent optimizer | |
optimizer = SGD(model.parameters()) | |
# Rudimentary outline of training loop | |
for epoch in range(20): | |
for inputs in dataloader: | |
optimizer.zero_grad() | |
output = model(inputs['image']) | |
loss = criterion(output, inputs['idx']) | |
loss.backward() | |
optimizer.step() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment