Skip to content

Instantly share code, notes, and snippets.

@bqqbarbhg
Created August 4, 2023 15:54
Show Gist options
  • Save bqqbarbhg/afe233132db83c2d57d36a0419fa8789 to your computer and use it in GitHub Desktop.
Save bqqbarbhg/afe233132db83c2d57d36a0419fa8789 to your computer and use it in GitHub Desktop.
import torch
from torch import nn
import torch.nn.functional as NNF
from torch.utils.data import DataLoader, Dataset, ConcatDataset
from torchvision import datasets
from torchvision.transforms import ToTensor, PILToTensor, RandomResizedCrop
import torchvision.transforms as TVT
from torchvision.transforms.functional import to_pil_image, center_crop
from torchvision.io import read_image, ImageReadMode
import os
from zipfile import ZipFile
from PIL import Image
import re
import io
import sys
from dataset_gen import gen_case_memory
import matplotlib.pyplot as plt
# Get cpu, gpu or mps device for training.
device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
def add_noise(im):
noise_up = torch.rand((1,)).item() <= 0.75
noise_down = torch.rand((1,)).item() <= 0.25
square_noise = torch.rand((1,)).item() <= 0.5
if noise_up or noise_down:
noise_max = 3 if square_noise else 5
noise_amount_up = 1 << torch.randint(1, noise_max+1, (1,)).item() if noise_up else 0
noise_amount_down = 1 << torch.randint(1, noise_max-1+1, (1,)).item() if noise_down else 0
for y in range(im.size[1]):
for x in range(im.size[0]):
v = im.getpixel((x ,y))
noise = torch.randint(-noise_amount_down, noise_amount_up+1, (1,)).item()
if square_noise:
if noise > 0:
noise *= torch.randint(1, noise_amount_up+1, (1,)).item()
elif noise < 0:
noise *= torch.randint(1, noise_amount_down+1, (1,)).item()
v = min(max(v + noise, 0), 255)
im.putpixel((x, y), v)
class ImageDataset(Dataset):
def __init__(self, path, limit=None, cache=True):
self.path = path
self.use_cache = cache
files = sorted(os.listdir(self.path))
files = [f for f in files if "_zp" not in f]
result_files = []
for n in range(len(files) // 2):
if limit is not None and len(result_files) >= limit:
break
result_files.append((files[n*2+0], files[n*2+1]))
self.files = result_files
self.cache = {}
def __len__(self):
return len(self.files)
def __getitem__(self, index):
cached = self.cache.get(index)
if not cached:
png, jpg = self.files[index]
png_path = os.path.join(self.path, png)
jpg_path = os.path.join(self.path, jpg)
png_tensor = read_image(png_path, ImageReadMode.GRAY)
jpg_tensor = read_image(jpg_path, ImageReadMode.GRAY)
cached = (png_tensor, jpg_tensor)
if self.use_cache:
self.cache[index] = cached
a, b = cached
scale = 1.0 / 255.0
return b.to(torch.float32) * scale, a.to(torch.float32) * scale
class MemoryFontDataset(Dataset):
def __init__(self, size):
self.size = size
def __len__(self):
return self.size
def __getitem__(self, index):
seed = bytes(torch.randint(0, 2**31, (16,), device="cpu").byte())
im_t, im_jpg = gen_case_memory(seed)
assert im_t.mode == "L"
assert im_jpg.mode == "L"
hi_tensor = PILToTensor()(im_t)
lo_tensor = PILToTensor()(im_jpg)
im_t.save("dbg_a.png", "PNG")
im_jpg.save("dbg_b.png", "PNG")
scale = 1.0 / 255.0
return lo_tensor.to(torch.float32) * scale, hi_tensor.to(torch.float32) * scale
class CbzDataset(Dataset):
def __init__(self, path, cache_path, max_file_cases, from_cache, resolution=256, repeat=1):
self.path = path
self.cache_path = cache_path
self.cached_zip = (None, None)
self.cases = []
self.repeat = repeat
self.resolution = resolution
if not from_cache:
re_key = re.compile(r"[^a-z0-9]+")
for cbz_name in os.listdir(path):
print(f"Scanning {cbz_name}")
with ZipFile(os.path.join(path, cbz_name)) as zf:
num_added = 0
num_total = 0
for name in zf.namelist():
if num_added > max_file_cases: break
with zf.open(name) as f:
try:
im = Image.open(f).convert("RGB")
except:
continue
tensor = PILToTensor()(im)
tn = tensor.to(torch.float32)
saturation = (tn[1] - tn[0]).abs() + (tn[1] - tn[2]).abs()
sat = saturation.sum().item() / (tn.shape[1] * tn.shape[2])
num_total += 1
# print(f"{cbz_name}:{name}: {sat} ({tn.shape})")
if sat < 5.0:
name_key = name
if "." in name_key:
name_key = name_key[:name_key.rindex(".")]
cbz_key = re_key.sub("-", cbz_name.lower())
name_key = re_key.sub("-", name_key.lower())
key = f"{cbz_key}-{name_key}"
cached = os.path.join(cache_path, f"{key}.png")
im.convert("L").save(cached, "PNG")
self.cases.append(cached)
num_added += 1
print(f"Added {num_added}/{num_total} cases")
else:
for f in os.listdir(self.cache_path):
self.cases.append(os.path.join(self.cache_path, f))
print(f"Added {len(self.cases)} cases from cache {cache_path}")
def __len__(self):
return len(self.cases) * self.repeat
def __getitem__(self, index):
case_path = self.cases[index % len(self.cases)]
hi_original = read_image(case_path)
hi_cropped = TVT.RandomCrop(self.resolution, pad_if_needed=True, padding_mode="reflect")(hi_original)
hi_flipped = TVT.RandomHorizontalFlip()(hi_cropped)
hi_tensor = hi_flipped
hi_blur = TVT.GaussianBlur(5, (0.01, 1.05))(hi_tensor)
lo_raw = TVT.Resize(self.resolution // 2, antialias=True)(hi_blur)
lo_raw = lo_raw.to(torch.int32)
if torch.rand((1,)).item() <= 0.15:
noise = torch.randint(-8, 32+1, lo_raw.shape)
lo_raw += noise
elif torch.rand((1,)).item() <= 0.25:
noise = torch.randint(-2, 8+1, lo_raw.shape) * torch.randint(1, 8+1, lo_raw.shape)
lo_raw += noise
lo_raw = torch.clamp(lo_raw, 0, 255).to(torch.uint8)
lo_im = to_pil_image(lo_raw)
qual = torch.randint(35, 80, (1,), device="cpu").item()
lo_mem = io.BytesIO()
# add_noise(lo_im)
lo_im.save(lo_mem, "JPEG", quality=qual)
lo_mem.seek(0)
lo_reload = Image.open(lo_mem).convert("L")
lo_tensor = PILToTensor()(lo_reload)
scale = 1.0 / 255.0
return lo_tensor.to(torch.float32) * scale, hi_tensor.to(torch.float32) * scale
class LabeledDataset(Dataset):
def __init__(self, dataset, label):
self.dataset = dataset
self.label = label
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
src, dst = self.dataset[index]
return src, (dst, self.label)
# Define model
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10)
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
class UpConv_7(nn.Module):
# https://github.com/nagadomi/waifu2x/blob/3c46906cb78895dbd5a25c3705994a1b2e873199/lib/srcnn.lua#L311
def __init__(self):
super().__init__()
self.offset = 2 # because of 0 padding
self.padding = 14
# self.pad = nn.ZeroPad2d(self.offset)
m = [nn.Conv2d(1, 16, 3, 1, 0),
nn.LeakyReLU(0.1, inplace=True),
nn.Conv2d(16, 32, 3, 1, 0),
nn.LeakyReLU(0.1, inplace=True),
nn.Conv2d(32, 64, 3, 1, 0),
nn.LeakyReLU(0.1, inplace=True),
nn.Conv2d(64, 128, 3, 1, 0),
nn.LeakyReLU(0.1, inplace=True),
nn.Conv2d(128, 128, 3, 1, 0),
nn.LeakyReLU(0.1, inplace=True),
nn.Conv2d(128, 256, 3, 1, 0),
nn.LeakyReLU(0.1, inplace=True),
# in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=
nn.ConvTranspose2d(256, 1, 4, 2, 3),
]
self.Sequential = nn.Sequential(*m)
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
# x = self.pad(x)
return self.Sequential.forward(x)
class UpConv_7_Kanji(nn.Module):
# https://github.com/nagadomi/waifu2x/blob/3c46906cb78895dbd5a25c3705994a1b2e873199/lib/srcnn.lua#L311
def __init__(self):
super().__init__()
self.offset = 2 # because of 0 padding
self.padding = 10+14
# self.pad = nn.ZeroPad2d(self.offset)
m = [nn.Conv2d(1, 16, 3, 1, 0),
nn.LeakyReLU(0.1, inplace=True),
nn.Conv2d(16, 32, 3, 1, 0),
nn.LeakyReLU(0.1, inplace=True),
nn.Conv2d(32, 128, 15, 1, 0),
nn.LeakyReLU(0.1, inplace=True),
nn.Conv2d(128, 128, 3, 1, 0),
nn.LeakyReLU(0.1, inplace=True),
nn.Conv2d(128, 256, 3, 1, 0),
nn.LeakyReLU(0.1, inplace=True),
# in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=
nn.ConvTranspose2d(256, 1, 4, 2, 3),
]
self.Sequential = nn.Sequential(*m)
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
# x = self.pad(x)
return self.Sequential.forward(x)
class UpConv_7_KanjiV2(nn.Module):
# https://github.com/nagadomi/waifu2x/blob/3c46906cb78895dbd5a25c3705994a1b2e873199/lib/srcnn.lua#L311
def __init__(self):
super().__init__()
self.offset = 2 # because of 0 padding
self.padding = 16+14
# self.pad = nn.ZeroPad2d(self.offset)
feature = [nn.Conv2d(1, 16, 3, 1, 0),
nn.LeakyReLU(0.1, inplace=True),
nn.Conv2d(16, 64, 4, 1, 0),
nn.LeakyReLU(0.1, inplace=True),
nn.MaxPool2d(4),
nn.Conv2d(64, 256, 5, 1, 0),
nn.LeakyReLU(0.1),
nn.Upsample(scale_factor=4, mode="bilinear"),
]
self.Feature = nn.Sequential(*feature)
main = [nn.Conv2d(3, 24, 3, 1, 0),
nn.LeakyReLU(0.1, inplace=True),
nn.Conv2d(24, 32, 3, 1, 0),
nn.LeakyReLU(0.1, inplace=True),
nn.Conv2d(32, 64, 3, 1, 0),
nn.LeakyReLU(0.1, inplace=True),
nn.Conv2d(64, 128, 3, 1, 0),
nn.LeakyReLU(0.1, inplace=True),
]
self.Main = nn.Sequential(*main)
upscale = [
nn.Conv2d(128+256, 256, 3, 1, 0),
nn.LeakyReLU(0.1, inplace=True),
nn.Conv2d(256, 256, 3, 1, 0),
nn.LeakyReLU(0.1, inplace=True),
# in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=
nn.ConvTranspose2d(256, 1, 4, 2, 3),
nn.Tanh(),
]
self.Upscale = nn.Sequential(*upscale)
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu', a=0.1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
# x = self.pad(x)
feat = self.Feature.forward(x)
x = torch.ones_like(x) - x
v = x[:, :, 8:-8, 8:-8]
dy = torch.diff(x[:, :, 8:-7, 8:-8], dim=2)
dx = torch.diff(x[:, :, 8:-8, 8:-7], dim=3)
vs = torch.cat([v, dy, dx], 1)
main = self.Main.forward(vs)
# print(main.shape)
# print(feat.shape)
up_in = torch.cat([main, feat], 1)
y = self.Upscale.forward(up_in)
y = torch.ones_like(y) * 0.5 - y * 0.55
return y
class UpSelect_7(nn.Module):
# https://github.com/nagadomi/waifu2x/blob/3c46906cb78895dbd5a25c3705994a1b2e873199/lib/srcnn.lua#L311
def __init__(self):
super().__init__()
self.offset = 2 # because of 0 padding
self.padding = 14
# self.pad = nn.ZeroPad2d(self.offset)
m = [nn.Conv2d(1, 16, 3, 1, 0),
nn.LeakyReLU(0.1, inplace=True),
nn.Conv2d(16, 32, 3, 1, 0),
nn.LeakyReLU(0.1, inplace=True),
nn.Conv2d(32, 64, 3, 1, 0),
nn.LeakyReLU(0.1, inplace=True),
nn.Conv2d(64, 128, 3, 1, 0),
nn.LeakyReLU(0.1, inplace=True),
nn.Conv2d(128, 128, 3, 1, 0),
nn.LeakyReLU(0.1, inplace=True),
nn.Conv2d(128, 256, 3, 1, 0),
nn.LeakyReLU(0.1, inplace=True),
# in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=
nn.ConvTranspose2d(256, 2, 4, 2, 3),
nn.Tanh(),
]
self.Sequential = nn.Sequential(*m)
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
# x = self.pad(x)
x = torch.ones_like(x) - x
y = self.Sequential.forward(x)
y = torch.ones_like(y) * 0.5 - y * 0.55
return y
class ConvBlock(nn.Module):
def __init__(self, num_layers, in_channels, out_channels):
super().__init__()
self.convs = nn.ModuleList(
[nn.Sequential(
nn.Conv2d(in_channels if i==0 else out_channels, out_channels, 3, padding=1),
nn.ReLU()
)
for i in range(num_layers)]
)
self.downsample = nn.MaxPool2d(kernel_size=2, stride=2)
def forward(self, x):
for conv in self.convs:
x = conv(x)
x = self.downsample(x)
return x
class CNN(nn.Module):
def __init__(self, in_channels, num_blocks, num_classes):
super().__init__()
first_channels = 64
self.blocks = nn.ModuleList(
[ConvBlock(
2 if i==0 else 3,
in_channels=(in_channels if i == 0 else first_channels*(2**(i-1))),
out_channels=first_channels*(2**i))
for i in range(num_blocks)]
)
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
self.cls = nn.Linear(first_channels*(2**(num_blocks-1)), num_classes)
def forward(self, x):
for block in self.blocks:
x = block(x)
x = self.global_pool(x)
x = x.flatten(1)
x = self.cls(x)
return x
class UpSelect_Ex(nn.Module):
def __init__(self):
super().__init__()
self.cnn = CNN(1, 4, 2)
def forward(self, x):
return self.cnn(x)
def train(dataloader, model, loss_fn, optimizer, report_freq):
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
pad = model.padding
y = center_crop(y, (X.shape[2]*2 - pad*2, X.shape[3]*2 - pad*2))
# print(X.shape)
# print(y.shape)
# Compute prediction error
pred = model(X)
# print(pred.shape)
loss = loss_fn(pred, y)
# Backpropagation
loss.backward()
optimizer.step()
optimizer.zero_grad()
if batch % report_freq == 0:
loss, current = loss.item(), (batch + 1) * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
pad = model.padding
y = center_crop(y, (X.shape[2]*2 - pad*2, X.shape[3]*2 - pad*2))
test_loss += loss_fn(pred, y).item()
test_loss /= num_batches
print(f"Test Error: \nAvg loss: {test_loss:>8f} \n")
return test_loss / num_batches
def train_select(dataloader, model_a, model_b, model, loss_fn, optimizer, report_freq):
size = len(dataloader.dataset)
model.train()
for batch, (X, (y, y_label)) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
y_label = y_label.to(device)
pred = model(X)
xh, xw = X.shape[2], X.shape[3]
Xs = TVT.Resize((xh, xw), TVT.InterpolationMode.BICUBIC, antialias=True)(X)
with torch.no_grad():
Xa = model_a(X)
Xb = model_b(X)
h = min(Xa.shape[2], Xb.shape[2], pred.shape[2])
w = min(Xa.shape[3], Xb.shape[3], pred.shape[3])
Xa = center_crop(Xa, (h, w))
Xb = center_crop(Xb, (h, w))
pred = center_crop(pred, (h, w))
y = center_crop(y, (h, w))
Xs = center_crop(Xs, (h, w))
pred_t = pred[:, 0:1, :, :]
pred_s = pred[:, 1:2, :, :]
# Xp = Xa * (1.0 - pred_t) + Xb * pred_t
# Xp = Xs * (1.0 - pred_s) + Xp * pred_s
pred_y = torch.ones_like(pred) * y_label.view(-1, 1, 1, 1)
# loss = # loss_fn(Xp, y)# + loss_fn(pred, pred_y)
loss = loss_fn(pred_t, pred_y) + loss_fn(torch.ones_like(pred_s), pred_s)
# Backpropagation
loss.backward()
optimizer.step()
optimizer.zero_grad()
if batch % report_freq == 0:
loss, current = loss.item(), (batch + 1) * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def test_select(dataloader, model_a, model_b, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, (y, y_label) in dataloader:
X, y = X.to(device), y.to(device)
y_label = y_label.to(device)
xh, xw = X.shape[2], X.shape[3]
Xs = TVT.Resize((xh, xw), TVT.InterpolationMode.BICUBIC, antialias=True)(X)
pred = model(X)
Xa = model_a(X)
Xb = model_b(X)
h = min(Xa.shape[2], Xb.shape[2], pred.shape[2])
w = min(Xa.shape[3], Xb.shape[3], pred.shape[3])
Xa = center_crop(Xa, (h, w))
Xb = center_crop(Xb, (h, w))
pred = center_crop(pred, (h, w))
y = center_crop(y, (h, w))
Xs = center_crop(Xs, (h, w))
pred_t = pred[:, 0:1, :, :]
pred_s = pred[:, 1:2, :, :]
Xp = Xa * (1.0 - pred_t) + Xb * pred_t
Xp = Xs * (1.0 - pred_s) + Xp * pred_s
# pred_y = torch.ones_like(pred) * y_label.view(-1, 1, 1, 1)
test_loss += loss_fn(Xp, y).item() # + loss_fn(pred, pred_y)
test_loss /= num_batches
print(f"Test Error: \nAvg loss: {test_loss:>8f} \n")
return test_loss / num_batches
def vis(data, model, epoch, limit=None):
with torch.no_grad():
for i,(x,y) in enumerate(data):
if limit is not None and i >= limit: break
x = x.to(device)
x = torch.unsqueeze(x, 0)
pred = model(x)
pred = torch.squeeze(pred, 0)
pred = torch.clamp(pred, 0, 1)
im = to_pil_image(pred, mode="L")
im.save(f"test/case_{i:06}_zp{epoch:03}.png")
if __name__ == "__main__":
cmd = sys.argv[1]
print(f"Using {device} device")
if cmd == "train-kanji":
model = UpConv_7_KanjiV2().to(device)
print(model)
internal_loss_fn = nn.L1Loss()
def loss_fn(a, b):
adx = torch.diff(a, dim=-1)
ady = torch.diff(a, dim=-2)
bdx = torch.diff(b, dim=-1)
bdy = torch.diff(b, dim=-2)
loss = 0
loss += internal_loss_fn(a, b)
loss += internal_loss_fn(adx, bdx)
loss += internal_loss_fn(ady, bdy)
return loss
optimizer = torch.optim.Adam(model.parameters(), lr=0.00025)
training_data = ImageDataset("training", limit=1_000_000)
# training_data = MemoryFontDataset(200_000)
test_data = ImageDataset("test")
batch_size = 64
# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
for X, y in test_dataloader:
print(f"Shape of X [N, C, H, W]: {X.shape}")
print(f"Shape of y: {y.shape} {y.dtype}")
break
epochs = 100
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train(train_dataloader, model, loss_fn, optimizer, 400)
test(test_dataloader, model, loss_fn)
vis(test_data, model, t+1, limit=1000)
torch.save(model.state_dict(), f"checkpoint/model_{t:04}.pth")
print("Done!")
torch.save(model.state_dict(), "model.pth")
print("Saved PyTorch Model State to model.pth")
if cmd == "train-image":
MAX_FILE_CASES = 10000
FROM_CACHE = True
training_data = CbzDataset(r"W:\MangaTraining\training", r"E:\MangaCache\training", MAX_FILE_CASES, FROM_CACHE, repeat=10)
test_data = CbzDataset(r"W:\MangaTraining\test", r"E:\MangaCache\test", MAX_FILE_CASES, FROM_CACHE)
model = UpConv_7().to(device)
print(model)
loss_fn = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.00025)
batch_size = 32
# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True, num_workers=16)
test_dataloader = DataLoader(test_data, batch_size=batch_size, num_workers=4)
if False:
cols, rows = 8, 8
figure = plt.figure(figsize=(8, 8))
for i in range(1, cols * rows + 1):
sample_idx = torch.randint(len(training_data), size=(1,)).item()
img, label = training_data[sample_idx]
figure.add_subplot(rows, cols, i)
plt.axis("off")
plt.imshow(img.squeeze(), cmap="gray")
plt.show()
exit(0)
for X, y in test_dataloader:
print(f"Shape of X [N, C, H, W]: {X.shape}")
print(f"Shape of y: {y.shape} {y.dtype}")
break
epochs = 100
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train(train_dataloader, model, loss_fn, optimizer, 50)
test(test_dataloader, model, loss_fn)
torch.save(model.state_dict(), f"checkpoint/image_{t:04}.pth")
print("Done!")
torch.save(model.state_dict(), "image.pth")
print("Saved PyTorch Model State to image.pth")
if cmd == "train-select":
model = UpSelect_7().to(device)
model_a = UpConv_7().to(device)
model_a.load_state_dict(torch.load("image.pth"))
model_a.eval()
model_b = UpConv_7_KanjiV2().to(device)
model_b.load_state_dict(torch.load("model.pth"))
model_b.eval()
print(model)
loss_fn = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.00025)
text_training_data = ImageDataset("training", limit=250_000, cache=False)
text_test_data = ImageDataset("test", limit=3_000, cache=False)
MAX_FILE_CASES = 1000
FROM_CACHE = True
REPEAT = 35
RESOLUTION = 112
image_training_data = CbzDataset(r"W:\MangaTraining\training", r"E:\MangaCache\training", MAX_FILE_CASES, FROM_CACHE, resolution=RESOLUTION, repeat=REPEAT)
image_test_data = CbzDataset(r"W:\MangaTraining\test", r"E:\MangaCache\test", MAX_FILE_CASES, FROM_CACHE, resolution=RESOLUTION, repeat=2)
training_data = ConcatDataset([LabeledDataset(text_training_data, 1), LabeledDataset(image_training_data, 0)])
test_data = ConcatDataset([LabeledDataset(text_test_data, 1), LabeledDataset(image_test_data, 0)])
# training_data = image_training_data
# test_data = image_test_data
if False:
cols, rows = 8, 8
figure = plt.figure(figsize=(16, 8))
for i in range(1, cols * rows + 1):
sample_idx = torch.randint(len(training_data), size=(1,)).item()
img, img2 = training_data[sample_idx]
figure.add_subplot(rows, cols*2, (i-1)*2+1)
plt.axis("off")
plt.imshow(img.squeeze(), cmap="gray", vmin=0.0, vmax=1.0)
figure.add_subplot(rows, cols*2, (i-1)*2+2)
plt.axis("off")
plt.imshow(img2.squeeze(), cmap="gray", vmin=0.0, vmax=1.0)
plt.show()
exit(0)
batch_size = 32
# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True, num_workers=20, persistent_workers=True, drop_last=True)
test_dataloader = DataLoader(test_data, batch_size=batch_size, num_workers=4, persistent_workers=True)
epochs = 100
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train_select(train_dataloader, model_a, model_b, model, loss_fn, optimizer, 400)
test_select(test_dataloader, model_a, model_b, model, loss_fn)
torch.save(model.state_dict(), f"checkpoint/select_{t:04}.pth")
print("Done!")
torch.save(model.state_dict(), "select.pth")
print("Saved PyTorch Model State to select.pth")
if cmd == "zp-text":
model = UpConv_7_KanjiV2().to(device)
model.load_state_dict(torch.load("model.pth"))
model.eval()
test_data = ImageDataset("test")
with torch.no_grad():
for i,(x,y) in enumerate(test_data):
x = x.to(device)
x = torch.unsqueeze(x, 0)
pred = model(x)
pred = torch.squeeze(pred, 0)
pred = torch.clamp(pred, 0, 1)
im = to_pil_image(pred, mode="L")
im.save(f"test/case_{i:06}_zp.png")
if cmd == "bench-image":
model = UpConv_7().to(device)
model.load_state_dict(torch.load("image.pth"))
model.eval()
with torch.no_grad():
for f in os.listdir("bench"):
if "_result" in f or "_old" in f or "_text" in f or "_select" in f or "_final" in f: continue
x = read_image(f"bench/{f}", ImageReadMode.GRAY)
scale = 1.0 / 255.0
x = x.to(torch.float32) * scale
x = x.to(device)
pred = model(x)
pred = torch.clamp(pred, 0, 1)
im = to_pil_image(pred, mode="L")
f_base, f_ext = f.split(".", maxsplit=1)
im.save(f"bench/{f_base}_result.{f_ext}")
if cmd == "bench-text":
model = UpConv_7_KanjiV2().to(device)
model.load_state_dict(torch.load("model.pth"))
model.eval()
with torch.no_grad():
for f in os.listdir("bench"):
if "_result" in f or "_old" in f or "_text" in f or "_select" in f or "_final" in f: continue
x = read_image(f"bench/{f}", ImageReadMode.GRAY)
scale = 1.0 / 255.0
x = x[:, :1024+128, :512+256]
x = x.to(torch.float32) * scale
x = x.to(device)
x = torch.unsqueeze(x, 0)
print(x.shape)
pred = model(x)
pred = torch.squeeze(pred, 0)
pred = torch.clamp(pred, 0, 1)
im = to_pil_image(pred, mode="L")
f_base, f_ext = f.split(".", maxsplit=1)
im.save(f"bench/{f_base}_text.png")
if cmd == "bench-final":
model = UpSelect_7().to(device)
model.load_state_dict(torch.load("select.pth"))
model.eval()
model_a = UpConv_7().to(device)
model_a.load_state_dict(torch.load("image.pth"))
model_a.eval()
model_b = UpConv_7_KanjiV2().to(device)
model_b.load_state_dict(torch.load("model.pth"))
model_b.eval()
pad = 128
for f in os.listdir("bench"):
if "_result" in f or "_old" in f or "_text" in f or "_select" in f or "_final" in f: continue
# if "003" not in f: continue
xs = read_image(f"bench/{f}", ImageReadMode.GRAY)
xs = NNF.pad(xs, (pad, pad, pad, pad), mode="constant")
_, sh, sw = xs.shape
y = torch.zeros(1, sh*2, sw*2)
ysel = torch.zeros(1, sh*2, sw*2)
with torch.no_grad():
for by in range(0, xs.shape[-2], 64):
for bx in range(0, xs.shape[-1], 64):
if bx + 128 > xs.shape[-1] or by + 128 > xs.shape[-2]:
continue
print(bx, by)
scale = 1.0 / 255.0
x = xs[:, by:by+128, bx:bx+128]
x = x.to(torch.float32) * scale
x = x.to(device)
x = torch.unsqueeze(x, 0)
Xa = model_a(x)
Xb = model_b(x)
pred = model(x)
h = min(Xa.shape[2], Xb.shape[2], pred.shape[2])
w = min(Xa.shape[3], Xb.shape[3], pred.shape[3])
Xa = center_crop(Xa, (h, w))
Xb = center_crop(Xb, (h, w))
pred = center_crop(pred, (h, w))
pred = torch.clamp(pred, 0, 1)
Xp = Xa * (1.0 - pred) + Xb * pred
pred = torch.squeeze(pred, 0)
Xp = torch.clamp(Xp, 0, 1)
oy = (256 - h) // 2
ox = (256 - w) // 2
dy = by*2 + oy
dx = bx*2 + ox
y[:, dy:dy+h, dx:dx+h] = Xp
ysel[:, dy:dy+h, dx:dx+h] = pred
y = y[:, pad*2:pad*-2, pad*2:pad*-2]
im = to_pil_image(y, mode="L")
f_base, f_ext = f.split(".", maxsplit=1)
im.save(f"bench/{f_base}_final.png")
ysel = ysel[:, pad*2:pad*-2, pad*2:pad*-2]
imsel = to_pil_image(ysel, mode="L")
f_base, f_ext = f.split(".", maxsplit=1)
imsel.save(f"bench/{f_base}_select.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment