Skip to content

Instantly share code, notes, and snippets.

model = MildNet()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
criterion = nn.TripletMarginLoss(reduction="none", margin=0.1)
optimizer = torch.optim.Adam(model.parameters())
n_epochs = 161
print_every = 20
eval_losses = []
class MildNet(nn.Module):
'''
Reference:
https://github.com/gofynd/mildnet/blob/master/trainer/model.py
'''
def __init__(self):
super(MildNet, self).__init__()
# VGG16 part
self.convblock1 = nn.Sequential(
class PokemonDataset(Dataset):
def __init__(self, images, root_dir, imageset=None):
"""
pokemon dataset: loads image and target
"""
self.imageset = np.load(imageset, mmap_mode="r+") if not imageset is None else None
self.root_dir = root_dir
self.images = images
self.anchor_transform = transforms.Compose([
This file has been truncated, but you can view the full file.
{
"metadata": {
"name": "",
"signature": "sha256:a335db0b8a53af13275b4349230c1da2cfc2cfb3656ff35165ac0953e5d11441"
},
{
"metadata": {
"name": "",
"signature": "sha256:54166f6914b9ffaf3770b33c61adc8ba2fb19fcb83dc41d343e3f243da1c2350"
},
This file has been truncated, but you can view the full file.
{
"metadata": {
"name": "",
"signature": "sha256:7dcbc3494d103bb64638a59d37b2f07316e47138a6a03f43f38c153b5ea5e2b1"
},
This file has been truncated, but you can view the full file.
{
"metadata": {
"name": "",
"signature": "sha256:0d6cec8428c42be839832b08cbffae47bbb34cabb8643b9e4f0dbac014f577de"
},
This file has been truncated, but you can view the full file.
{
"metadata": {
"name": "",
"signature": "sha256:bc9e1652c8a802dca603f12cdc6f292033e9bc0ca0c93fb1c07808754617a29a"
},
This file has been truncated, but you can view the full file.
{
"metadata": {
"name": "",
"signature": "sha256:a8398ce40087121501e924c32e900fecbc5db8dd5d2713c651e50d79a75f3012"
},
This file has been truncated, but you can view the full file.
{
"metadata": {
"name": "",
"signature": "sha256:3f377aaa4530cb2b7f98231fc80000d4ce68686f1805ceaa487a9269a324a17d"
},