Skip to content

Instantly share code, notes, and snippets.

@Filarius
Created January 22, 2022 18:40
Show Gist options
  • Save Filarius/ff0531c6d98fbb17124d2ddc621b34ff to your computer and use it in GitHub Desktop.
Save Filarius/ff0531c6d98fbb17124d2ddc621b34ff to your computer and use it in GitHub Desktop.
pytorch (anti) astigmatism simple image pre-processing
import argparse
import random
import gym
import numpy as np
from itertools import count
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.a1 = nn.Conv2d(3, 3, (1, 41))
#self.a2 = nn.ReLU()
def forward(self, x):
x = self.a1(x)
#x = self.a2(x)
return x
def blur(mt):
if len(mt.shape) == 4:
mt = mt[0,:,:,:]
mt = mt[:, :, :-10] + \
+ mt[:, :, 1:-9] \
+ mt[:, :, 2:-8] \
+ mt[:, :, 3:-7] \
+ mt[:, :, 4:-6] \
+ mt[:, :, 5:-5] \
+ mt[:, :, 6:-4] \
+ mt[:, :, 7:-3] \
+ mt[:, :, 8:-2] \
+ mt[:, :, 9:-1]
mt = mt / 10
return mt[None,...]
def main():
loss_f = nn.L1Loss()
m = Image.open("path.jpg")
mt = transforms.PILToTensor()(m)[:3]
mt = mt/255.0
mt3 = mt[None, :, :, 25:-25]
mt2 = blur(mt)
# image = transforms.ToPILImage()(mt3[0,:,:,:])
# im = plt.imshow(image,animated=True)
#m2 = transforms.ToPILImage()(mt)
#plt.imshow(m2)
n = NeuralNetwork()
optimizer = optim.Adam(n.parameters(),lr = 1e-2)
for epoch_i in range(100000):
if epoch_i == 200:
optimizer.param_groups[0]['lr'] = 1e-3
print(":R")
if epoch_i == 50000:
optimizer.param_groups[0]['lr'] = 1e-4
print(":R")
optimizer.zero_grad()
im_fix = n(mt[None,...])
im_fix = torch.clamp(im_fix,0,1)
im_blur = blur(im_fix)
loss = loss_f(im_blur,mt3)
loss.backward()
optimizer.step()
dif = loss.item()
if epoch_i%10==0:
print(epoch_i, dif)
if epoch_i%100==0:
img = transforms.ToPILImage()(im_fix[0,:,:,:])
img.save("imgs/"+str(epoch_i)+'.jpg')
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment