Skip to content

Instantly share code, notes, and snippets.

View J3698's full-sized avatar
🐛
Inch Worm

Anti J3698

🐛
Inch Worm
  • Working
  • Working
View GitHub Profile
@J3698
J3698 / manifest.json
Created August 31, 2021 18:05
Manifest
{
"name": "Detexify",
"description": "Symbol Recognition for Overleaf",
"version": "1.0",
"manifest_version": 2,
"content_scripts": [{
"js": ["content.js"],
"css": ["style.css"],
"matches": ["https://www.overleaf.com/project/*"]
def get_style_transfer_loss(encoder, decoder, content_image, style_image, lambda_content, lambda_style):
assert_shape(content_image, (g_batch_size, 3, 256, 256))
style_features = encoder(style_image)
content_features = encoder(content_image)
stylized_images, stylized_features = create_stylized_images(decoder, content_features, style_features)
features_of_stylized = encoder(stylized_images)
def train_epoch_style_loss(args, encoder, decoder, dataloader, val_dataloader,
optimizer, epoch_num, writer, run, device):
encoder.eval()
decoder.train()
total_loss = 0
num_batches = calc_num_batches(dataloader, args)
progress_bar = tqdm.tqdm(enumerate(dataloader), total = num_batches, dynamic_ncols = True)
for i, (content_image, style_image) in progress_bar:
# mvoe to gpu
def __next__(self):
if self.ilength <= 0:
raise StopIteration
self.ilength -= 1
coco_idx, wiki_idx = self.random_pair_of_indices()
content_image = self.coco[coco_idx][0]
if not self.exclude_style:
features_list = []
for layer in features:
features_list.append(layer)
if isinstance(layer, nn.Conv2d):
features_list.append(nn.BatchNorm2d(layer.out_channels))
del features_list[-1]
self.features = nn.Sequential(*features_list)
def main():
target = torch.randint(-20, 20, (8, 3, 4, 4)).float()
source = torch.randint(-20, 20, (8, 3, 4, 4)).float()
stylized_source = adain(source, target)
target = target.view(8, 3, -1)
stylized_source = stylized_source.view(8, 3, -1)
# check variances the same
# check shapes
assert len(target.shape) == 4, "expected 4 dimensions"
assert target.shape == source.shape, "source/target shape mismatch"
batch_size, channels, width, height = source.shape
# calculate target stats
target_reshaped = target.view(batch_size, channels, 1, 1, -1)
target_variances = target_reshaped.var(-1, unbiased = False)
target_means = target_reshaped.mean(-1)
def train_epoch_reconstruct(encoder, decoder, dataloader, optimizer, epoch_num, writer, run):
encoder.train()
decoder.train()
total_loss = 0
for i, content_image in tqdm.tqdm(enumerate(dataloader), total = len(dataloader), dynamic_ncols = True):
content_image = content_image.to(DEVICE)
optimizer.zero_grad()
reconstruction = decoder(encoder(content_image)[-1])
def main():
encoder = VGG19Encoder()
decoder = Decoder()
print(decoder)
sample_input = torch.ones((1, 3, 256, 256))
outputs = encoder(sample_input)
output = decoder(outputs[-1])
print(f"Input shape: {sample_input.shape}")
for i, layer in enumerate(features):
if isinstance(layer, nn.MaxPool2d):
features[i] = nn.Upsample(scale_factor = (2, 2), mode = 'nearest')
elif isinstance(layer, nn.Conv2d):
conv2d = nn.Conv2d(layer.out_channels, layer.in_channels, \
kernel_size = layer.kernel_size, stride = layer.stride, \
padding = layer.padding, padding_mode = 'reflect')
with torch.no_grad():
conv2d.weight[...] = layer.weight.transpose(0, 1)
features[i] = conv2d