Skip to content

Instantly share code, notes, and snippets.

View shuuchen's full-sized avatar
🎯
Focusing

Shuchen Du shuuchen

🎯
Focusing
View GitHub Profile
@shuuchen
shuuchen / TripletNetwork.py
Created March 4, 2020 10:43
Triplet Network
class TripletNetwork(SiameseNetwork):
def __init__(self):
super().__init__()
def forward(self, input1, input2, input3):
output1 = self.forward_once(input1)
output2 = self.forward_once(input2)
output3 = self.forward_once(input3)
return output1, output2, output3
@shuuchen
shuuchen / ContrastiveLoss.py
Last active June 8, 2020 13:08
Contrastive Loss
import torch
from torch import nn
from torch.nn import functional as F
class ContrastiveLoss(nn.Module):
def __init__(self, margin=5.0):
super(ContrastiveLoss, self).__init__()
self.margin = margin
@shuuchen
shuuchen / TripletLoss.py
Last active June 8, 2020 13:09
Triplet Loss
import torch
from torch import nn
from torch.nn.modules.distance import PairwiseDistance
class TripletLoss(nn.Module):
def __init__(self, margin=5.0):
super(TripletLoss, self).__init__()
self.margin = margin
nn.Conv2d(ch, ch * 2, 3, 2, 1)
nn.Conv2d(ch, ch * 2, 3, 2, 1)
nn.BatchNorm2d(ch * 2)
class Conv(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size=3, stride=1, relued=True):
super(Conv, self).__init__()
padding = (kernel_size - 1) // 2
self.conv_bn = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding, bias=False),
nn.BatchNorm2d(out_ch, momentum=BN_MOMENTUM))
self.relu = nn.ReLU()
self.relued = relued
head = nn.Sequential(
Conv(ch, ch, 1), # fusion with a 1x1 conv module
nn.Conv2d(ch, out_ch, 1)) # final prediction a 1x1 conv class
# inside a class __init__ function
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight, std=0.001)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
import torch
from torch import nn
class SelfAttnBottleneck(nn.Module):
expansion = 8
def __init__(self, in_channel, out_channel):
super().__init__()
@shuuchen
shuuchen / test_sync_batch_ddp.py
Last active December 1, 2020 03:27
Multi-GPU sync-batch-norm test
import os
import argparse
import torch
import shutil
import torch.optim as optim
import torch.nn as nn
import numpy as np
import pandas as pd