Created
August 4, 2023 15:54
-
-
Save bqqbarbhg/afe233132db83c2d57d36a0419fa8789 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
import torch.nn.functional as NNF | |
from torch.utils.data import DataLoader, Dataset, ConcatDataset | |
from torchvision import datasets | |
from torchvision.transforms import ToTensor, PILToTensor, RandomResizedCrop | |
import torchvision.transforms as TVT | |
from torchvision.transforms.functional import to_pil_image, center_crop | |
from torchvision.io import read_image, ImageReadMode | |
import os | |
from zipfile import ZipFile | |
from PIL import Image | |
import re | |
import io | |
import sys | |
from dataset_gen import gen_case_memory | |
import matplotlib.pyplot as plt | |
# Get cpu, gpu or mps device for training. | |
device = ( | |
"cuda" | |
if torch.cuda.is_available() | |
else "mps" | |
if torch.backends.mps.is_available() | |
else "cpu" | |
) | |
def add_noise(im): | |
noise_up = torch.rand((1,)).item() <= 0.75 | |
noise_down = torch.rand((1,)).item() <= 0.25 | |
square_noise = torch.rand((1,)).item() <= 0.5 | |
if noise_up or noise_down: | |
noise_max = 3 if square_noise else 5 | |
noise_amount_up = 1 << torch.randint(1, noise_max+1, (1,)).item() if noise_up else 0 | |
noise_amount_down = 1 << torch.randint(1, noise_max-1+1, (1,)).item() if noise_down else 0 | |
for y in range(im.size[1]): | |
for x in range(im.size[0]): | |
v = im.getpixel((x ,y)) | |
noise = torch.randint(-noise_amount_down, noise_amount_up+1, (1,)).item() | |
if square_noise: | |
if noise > 0: | |
noise *= torch.randint(1, noise_amount_up+1, (1,)).item() | |
elif noise < 0: | |
noise *= torch.randint(1, noise_amount_down+1, (1,)).item() | |
v = min(max(v + noise, 0), 255) | |
im.putpixel((x, y), v) | |
class ImageDataset(Dataset): | |
def __init__(self, path, limit=None, cache=True): | |
self.path = path | |
self.use_cache = cache | |
files = sorted(os.listdir(self.path)) | |
files = [f for f in files if "_zp" not in f] | |
result_files = [] | |
for n in range(len(files) // 2): | |
if limit is not None and len(result_files) >= limit: | |
break | |
result_files.append((files[n*2+0], files[n*2+1])) | |
self.files = result_files | |
self.cache = {} | |
def __len__(self): | |
return len(self.files) | |
def __getitem__(self, index): | |
cached = self.cache.get(index) | |
if not cached: | |
png, jpg = self.files[index] | |
png_path = os.path.join(self.path, png) | |
jpg_path = os.path.join(self.path, jpg) | |
png_tensor = read_image(png_path, ImageReadMode.GRAY) | |
jpg_tensor = read_image(jpg_path, ImageReadMode.GRAY) | |
cached = (png_tensor, jpg_tensor) | |
if self.use_cache: | |
self.cache[index] = cached | |
a, b = cached | |
scale = 1.0 / 255.0 | |
return b.to(torch.float32) * scale, a.to(torch.float32) * scale | |
class MemoryFontDataset(Dataset): | |
def __init__(self, size): | |
self.size = size | |
def __len__(self): | |
return self.size | |
def __getitem__(self, index): | |
seed = bytes(torch.randint(0, 2**31, (16,), device="cpu").byte()) | |
im_t, im_jpg = gen_case_memory(seed) | |
assert im_t.mode == "L" | |
assert im_jpg.mode == "L" | |
hi_tensor = PILToTensor()(im_t) | |
lo_tensor = PILToTensor()(im_jpg) | |
im_t.save("dbg_a.png", "PNG") | |
im_jpg.save("dbg_b.png", "PNG") | |
scale = 1.0 / 255.0 | |
return lo_tensor.to(torch.float32) * scale, hi_tensor.to(torch.float32) * scale | |
class CbzDataset(Dataset): | |
def __init__(self, path, cache_path, max_file_cases, from_cache, resolution=256, repeat=1): | |
self.path = path | |
self.cache_path = cache_path | |
self.cached_zip = (None, None) | |
self.cases = [] | |
self.repeat = repeat | |
self.resolution = resolution | |
if not from_cache: | |
re_key = re.compile(r"[^a-z0-9]+") | |
for cbz_name in os.listdir(path): | |
print(f"Scanning {cbz_name}") | |
with ZipFile(os.path.join(path, cbz_name)) as zf: | |
num_added = 0 | |
num_total = 0 | |
for name in zf.namelist(): | |
if num_added > max_file_cases: break | |
with zf.open(name) as f: | |
try: | |
im = Image.open(f).convert("RGB") | |
except: | |
continue | |
tensor = PILToTensor()(im) | |
tn = tensor.to(torch.float32) | |
saturation = (tn[1] - tn[0]).abs() + (tn[1] - tn[2]).abs() | |
sat = saturation.sum().item() / (tn.shape[1] * tn.shape[2]) | |
num_total += 1 | |
# print(f"{cbz_name}:{name}: {sat} ({tn.shape})") | |
if sat < 5.0: | |
name_key = name | |
if "." in name_key: | |
name_key = name_key[:name_key.rindex(".")] | |
cbz_key = re_key.sub("-", cbz_name.lower()) | |
name_key = re_key.sub("-", name_key.lower()) | |
key = f"{cbz_key}-{name_key}" | |
cached = os.path.join(cache_path, f"{key}.png") | |
im.convert("L").save(cached, "PNG") | |
self.cases.append(cached) | |
num_added += 1 | |
print(f"Added {num_added}/{num_total} cases") | |
else: | |
for f in os.listdir(self.cache_path): | |
self.cases.append(os.path.join(self.cache_path, f)) | |
print(f"Added {len(self.cases)} cases from cache {cache_path}") | |
def __len__(self): | |
return len(self.cases) * self.repeat | |
def __getitem__(self, index): | |
case_path = self.cases[index % len(self.cases)] | |
hi_original = read_image(case_path) | |
hi_cropped = TVT.RandomCrop(self.resolution, pad_if_needed=True, padding_mode="reflect")(hi_original) | |
hi_flipped = TVT.RandomHorizontalFlip()(hi_cropped) | |
hi_tensor = hi_flipped | |
hi_blur = TVT.GaussianBlur(5, (0.01, 1.05))(hi_tensor) | |
lo_raw = TVT.Resize(self.resolution // 2, antialias=True)(hi_blur) | |
lo_raw = lo_raw.to(torch.int32) | |
if torch.rand((1,)).item() <= 0.15: | |
noise = torch.randint(-8, 32+1, lo_raw.shape) | |
lo_raw += noise | |
elif torch.rand((1,)).item() <= 0.25: | |
noise = torch.randint(-2, 8+1, lo_raw.shape) * torch.randint(1, 8+1, lo_raw.shape) | |
lo_raw += noise | |
lo_raw = torch.clamp(lo_raw, 0, 255).to(torch.uint8) | |
lo_im = to_pil_image(lo_raw) | |
qual = torch.randint(35, 80, (1,), device="cpu").item() | |
lo_mem = io.BytesIO() | |
# add_noise(lo_im) | |
lo_im.save(lo_mem, "JPEG", quality=qual) | |
lo_mem.seek(0) | |
lo_reload = Image.open(lo_mem).convert("L") | |
lo_tensor = PILToTensor()(lo_reload) | |
scale = 1.0 / 255.0 | |
return lo_tensor.to(torch.float32) * scale, hi_tensor.to(torch.float32) * scale | |
class LabeledDataset(Dataset): | |
def __init__(self, dataset, label): | |
self.dataset = dataset | |
self.label = label | |
def __len__(self): | |
return len(self.dataset) | |
def __getitem__(self, index): | |
src, dst = self.dataset[index] | |
return src, (dst, self.label) | |
# Define model | |
class NeuralNetwork(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.flatten = nn.Flatten() | |
self.linear_relu_stack = nn.Sequential( | |
nn.Linear(28*28, 512), | |
nn.ReLU(), | |
nn.Linear(512, 512), | |
nn.ReLU(), | |
nn.Linear(512, 10) | |
) | |
def forward(self, x): | |
x = self.flatten(x) | |
logits = self.linear_relu_stack(x) | |
return logits | |
class UpConv_7(nn.Module): | |
# https://github.com/nagadomi/waifu2x/blob/3c46906cb78895dbd5a25c3705994a1b2e873199/lib/srcnn.lua#L311 | |
def __init__(self): | |
super().__init__() | |
self.offset = 2 # because of 0 padding | |
self.padding = 14 | |
# self.pad = nn.ZeroPad2d(self.offset) | |
m = [nn.Conv2d(1, 16, 3, 1, 0), | |
nn.LeakyReLU(0.1, inplace=True), | |
nn.Conv2d(16, 32, 3, 1, 0), | |
nn.LeakyReLU(0.1, inplace=True), | |
nn.Conv2d(32, 64, 3, 1, 0), | |
nn.LeakyReLU(0.1, inplace=True), | |
nn.Conv2d(64, 128, 3, 1, 0), | |
nn.LeakyReLU(0.1, inplace=True), | |
nn.Conv2d(128, 128, 3, 1, 0), | |
nn.LeakyReLU(0.1, inplace=True), | |
nn.Conv2d(128, 256, 3, 1, 0), | |
nn.LeakyReLU(0.1, inplace=True), | |
# in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding= | |
nn.ConvTranspose2d(256, 1, 4, 2, 3), | |
] | |
self.Sequential = nn.Sequential(*m) | |
for m in self.modules(): | |
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): | |
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
def forward(self, x): | |
# x = self.pad(x) | |
return self.Sequential.forward(x) | |
class UpConv_7_Kanji(nn.Module): | |
# https://github.com/nagadomi/waifu2x/blob/3c46906cb78895dbd5a25c3705994a1b2e873199/lib/srcnn.lua#L311 | |
def __init__(self): | |
super().__init__() | |
self.offset = 2 # because of 0 padding | |
self.padding = 10+14 | |
# self.pad = nn.ZeroPad2d(self.offset) | |
m = [nn.Conv2d(1, 16, 3, 1, 0), | |
nn.LeakyReLU(0.1, inplace=True), | |
nn.Conv2d(16, 32, 3, 1, 0), | |
nn.LeakyReLU(0.1, inplace=True), | |
nn.Conv2d(32, 128, 15, 1, 0), | |
nn.LeakyReLU(0.1, inplace=True), | |
nn.Conv2d(128, 128, 3, 1, 0), | |
nn.LeakyReLU(0.1, inplace=True), | |
nn.Conv2d(128, 256, 3, 1, 0), | |
nn.LeakyReLU(0.1, inplace=True), | |
# in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding= | |
nn.ConvTranspose2d(256, 1, 4, 2, 3), | |
] | |
self.Sequential = nn.Sequential(*m) | |
for m in self.modules(): | |
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): | |
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
def forward(self, x): | |
# x = self.pad(x) | |
return self.Sequential.forward(x) | |
class UpConv_7_KanjiV2(nn.Module): | |
# https://github.com/nagadomi/waifu2x/blob/3c46906cb78895dbd5a25c3705994a1b2e873199/lib/srcnn.lua#L311 | |
def __init__(self): | |
super().__init__() | |
self.offset = 2 # because of 0 padding | |
self.padding = 16+14 | |
# self.pad = nn.ZeroPad2d(self.offset) | |
feature = [nn.Conv2d(1, 16, 3, 1, 0), | |
nn.LeakyReLU(0.1, inplace=True), | |
nn.Conv2d(16, 64, 4, 1, 0), | |
nn.LeakyReLU(0.1, inplace=True), | |
nn.MaxPool2d(4), | |
nn.Conv2d(64, 256, 5, 1, 0), | |
nn.LeakyReLU(0.1), | |
nn.Upsample(scale_factor=4, mode="bilinear"), | |
] | |
self.Feature = nn.Sequential(*feature) | |
main = [nn.Conv2d(3, 24, 3, 1, 0), | |
nn.LeakyReLU(0.1, inplace=True), | |
nn.Conv2d(24, 32, 3, 1, 0), | |
nn.LeakyReLU(0.1, inplace=True), | |
nn.Conv2d(32, 64, 3, 1, 0), | |
nn.LeakyReLU(0.1, inplace=True), | |
nn.Conv2d(64, 128, 3, 1, 0), | |
nn.LeakyReLU(0.1, inplace=True), | |
] | |
self.Main = nn.Sequential(*main) | |
upscale = [ | |
nn.Conv2d(128+256, 256, 3, 1, 0), | |
nn.LeakyReLU(0.1, inplace=True), | |
nn.Conv2d(256, 256, 3, 1, 0), | |
nn.LeakyReLU(0.1, inplace=True), | |
# in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding= | |
nn.ConvTranspose2d(256, 1, 4, 2, 3), | |
nn.Tanh(), | |
] | |
self.Upscale = nn.Sequential(*upscale) | |
for m in self.modules(): | |
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): | |
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu', a=0.1) | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
def forward(self, x): | |
# x = self.pad(x) | |
feat = self.Feature.forward(x) | |
x = torch.ones_like(x) - x | |
v = x[:, :, 8:-8, 8:-8] | |
dy = torch.diff(x[:, :, 8:-7, 8:-8], dim=2) | |
dx = torch.diff(x[:, :, 8:-8, 8:-7], dim=3) | |
vs = torch.cat([v, dy, dx], 1) | |
main = self.Main.forward(vs) | |
# print(main.shape) | |
# print(feat.shape) | |
up_in = torch.cat([main, feat], 1) | |
y = self.Upscale.forward(up_in) | |
y = torch.ones_like(y) * 0.5 - y * 0.55 | |
return y | |
class UpSelect_7(nn.Module): | |
# https://github.com/nagadomi/waifu2x/blob/3c46906cb78895dbd5a25c3705994a1b2e873199/lib/srcnn.lua#L311 | |
def __init__(self): | |
super().__init__() | |
self.offset = 2 # because of 0 padding | |
self.padding = 14 | |
# self.pad = nn.ZeroPad2d(self.offset) | |
m = [nn.Conv2d(1, 16, 3, 1, 0), | |
nn.LeakyReLU(0.1, inplace=True), | |
nn.Conv2d(16, 32, 3, 1, 0), | |
nn.LeakyReLU(0.1, inplace=True), | |
nn.Conv2d(32, 64, 3, 1, 0), | |
nn.LeakyReLU(0.1, inplace=True), | |
nn.Conv2d(64, 128, 3, 1, 0), | |
nn.LeakyReLU(0.1, inplace=True), | |
nn.Conv2d(128, 128, 3, 1, 0), | |
nn.LeakyReLU(0.1, inplace=True), | |
nn.Conv2d(128, 256, 3, 1, 0), | |
nn.LeakyReLU(0.1, inplace=True), | |
# in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding= | |
nn.ConvTranspose2d(256, 2, 4, 2, 3), | |
nn.Tanh(), | |
] | |
self.Sequential = nn.Sequential(*m) | |
for m in self.modules(): | |
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): | |
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
def forward(self, x): | |
# x = self.pad(x) | |
x = torch.ones_like(x) - x | |
y = self.Sequential.forward(x) | |
y = torch.ones_like(y) * 0.5 - y * 0.55 | |
return y | |
class ConvBlock(nn.Module): | |
def __init__(self, num_layers, in_channels, out_channels): | |
super().__init__() | |
self.convs = nn.ModuleList( | |
[nn.Sequential( | |
nn.Conv2d(in_channels if i==0 else out_channels, out_channels, 3, padding=1), | |
nn.ReLU() | |
) | |
for i in range(num_layers)] | |
) | |
self.downsample = nn.MaxPool2d(kernel_size=2, stride=2) | |
def forward(self, x): | |
for conv in self.convs: | |
x = conv(x) | |
x = self.downsample(x) | |
return x | |
class CNN(nn.Module): | |
def __init__(self, in_channels, num_blocks, num_classes): | |
super().__init__() | |
first_channels = 64 | |
self.blocks = nn.ModuleList( | |
[ConvBlock( | |
2 if i==0 else 3, | |
in_channels=(in_channels if i == 0 else first_channels*(2**(i-1))), | |
out_channels=first_channels*(2**i)) | |
for i in range(num_blocks)] | |
) | |
self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) | |
self.cls = nn.Linear(first_channels*(2**(num_blocks-1)), num_classes) | |
def forward(self, x): | |
for block in self.blocks: | |
x = block(x) | |
x = self.global_pool(x) | |
x = x.flatten(1) | |
x = self.cls(x) | |
return x | |
class UpSelect_Ex(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.cnn = CNN(1, 4, 2) | |
def forward(self, x): | |
return self.cnn(x) | |
def train(dataloader, model, loss_fn, optimizer, report_freq): | |
size = len(dataloader.dataset) | |
model.train() | |
for batch, (X, y) in enumerate(dataloader): | |
X, y = X.to(device), y.to(device) | |
pad = model.padding | |
y = center_crop(y, (X.shape[2]*2 - pad*2, X.shape[3]*2 - pad*2)) | |
# print(X.shape) | |
# print(y.shape) | |
# Compute prediction error | |
pred = model(X) | |
# print(pred.shape) | |
loss = loss_fn(pred, y) | |
# Backpropagation | |
loss.backward() | |
optimizer.step() | |
optimizer.zero_grad() | |
if batch % report_freq == 0: | |
loss, current = loss.item(), (batch + 1) * len(X) | |
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]") | |
def test(dataloader, model, loss_fn): | |
size = len(dataloader.dataset) | |
num_batches = len(dataloader) | |
model.eval() | |
test_loss, correct = 0, 0 | |
with torch.no_grad(): | |
for X, y in dataloader: | |
X, y = X.to(device), y.to(device) | |
pred = model(X) | |
pad = model.padding | |
y = center_crop(y, (X.shape[2]*2 - pad*2, X.shape[3]*2 - pad*2)) | |
test_loss += loss_fn(pred, y).item() | |
test_loss /= num_batches | |
print(f"Test Error: \nAvg loss: {test_loss:>8f} \n") | |
return test_loss / num_batches | |
def train_select(dataloader, model_a, model_b, model, loss_fn, optimizer, report_freq): | |
size = len(dataloader.dataset) | |
model.train() | |
for batch, (X, (y, y_label)) in enumerate(dataloader): | |
X, y = X.to(device), y.to(device) | |
y_label = y_label.to(device) | |
pred = model(X) | |
xh, xw = X.shape[2], X.shape[3] | |
Xs = TVT.Resize((xh, xw), TVT.InterpolationMode.BICUBIC, antialias=True)(X) | |
with torch.no_grad(): | |
Xa = model_a(X) | |
Xb = model_b(X) | |
h = min(Xa.shape[2], Xb.shape[2], pred.shape[2]) | |
w = min(Xa.shape[3], Xb.shape[3], pred.shape[3]) | |
Xa = center_crop(Xa, (h, w)) | |
Xb = center_crop(Xb, (h, w)) | |
pred = center_crop(pred, (h, w)) | |
y = center_crop(y, (h, w)) | |
Xs = center_crop(Xs, (h, w)) | |
pred_t = pred[:, 0:1, :, :] | |
pred_s = pred[:, 1:2, :, :] | |
# Xp = Xa * (1.0 - pred_t) + Xb * pred_t | |
# Xp = Xs * (1.0 - pred_s) + Xp * pred_s | |
pred_y = torch.ones_like(pred) * y_label.view(-1, 1, 1, 1) | |
# loss = # loss_fn(Xp, y)# + loss_fn(pred, pred_y) | |
loss = loss_fn(pred_t, pred_y) + loss_fn(torch.ones_like(pred_s), pred_s) | |
# Backpropagation | |
loss.backward() | |
optimizer.step() | |
optimizer.zero_grad() | |
if batch % report_freq == 0: | |
loss, current = loss.item(), (batch + 1) * len(X) | |
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]") | |
def test_select(dataloader, model_a, model_b, model, loss_fn): | |
size = len(dataloader.dataset) | |
num_batches = len(dataloader) | |
model.eval() | |
test_loss, correct = 0, 0 | |
with torch.no_grad(): | |
for X, (y, y_label) in dataloader: | |
X, y = X.to(device), y.to(device) | |
y_label = y_label.to(device) | |
xh, xw = X.shape[2], X.shape[3] | |
Xs = TVT.Resize((xh, xw), TVT.InterpolationMode.BICUBIC, antialias=True)(X) | |
pred = model(X) | |
Xa = model_a(X) | |
Xb = model_b(X) | |
h = min(Xa.shape[2], Xb.shape[2], pred.shape[2]) | |
w = min(Xa.shape[3], Xb.shape[3], pred.shape[3]) | |
Xa = center_crop(Xa, (h, w)) | |
Xb = center_crop(Xb, (h, w)) | |
pred = center_crop(pred, (h, w)) | |
y = center_crop(y, (h, w)) | |
Xs = center_crop(Xs, (h, w)) | |
pred_t = pred[:, 0:1, :, :] | |
pred_s = pred[:, 1:2, :, :] | |
Xp = Xa * (1.0 - pred_t) + Xb * pred_t | |
Xp = Xs * (1.0 - pred_s) + Xp * pred_s | |
# pred_y = torch.ones_like(pred) * y_label.view(-1, 1, 1, 1) | |
test_loss += loss_fn(Xp, y).item() # + loss_fn(pred, pred_y) | |
test_loss /= num_batches | |
print(f"Test Error: \nAvg loss: {test_loss:>8f} \n") | |
return test_loss / num_batches | |
def vis(data, model, epoch, limit=None): | |
with torch.no_grad(): | |
for i,(x,y) in enumerate(data): | |
if limit is not None and i >= limit: break | |
x = x.to(device) | |
x = torch.unsqueeze(x, 0) | |
pred = model(x) | |
pred = torch.squeeze(pred, 0) | |
pred = torch.clamp(pred, 0, 1) | |
im = to_pil_image(pred, mode="L") | |
im.save(f"test/case_{i:06}_zp{epoch:03}.png") | |
if __name__ == "__main__": | |
cmd = sys.argv[1] | |
print(f"Using {device} device") | |
if cmd == "train-kanji": | |
model = UpConv_7_KanjiV2().to(device) | |
print(model) | |
internal_loss_fn = nn.L1Loss() | |
def loss_fn(a, b): | |
adx = torch.diff(a, dim=-1) | |
ady = torch.diff(a, dim=-2) | |
bdx = torch.diff(b, dim=-1) | |
bdy = torch.diff(b, dim=-2) | |
loss = 0 | |
loss += internal_loss_fn(a, b) | |
loss += internal_loss_fn(adx, bdx) | |
loss += internal_loss_fn(ady, bdy) | |
return loss | |
optimizer = torch.optim.Adam(model.parameters(), lr=0.00025) | |
training_data = ImageDataset("training", limit=1_000_000) | |
# training_data = MemoryFontDataset(200_000) | |
test_data = ImageDataset("test") | |
batch_size = 64 | |
# Create data loaders. | |
train_dataloader = DataLoader(training_data, batch_size=batch_size) | |
test_dataloader = DataLoader(test_data, batch_size=batch_size) | |
for X, y in test_dataloader: | |
print(f"Shape of X [N, C, H, W]: {X.shape}") | |
print(f"Shape of y: {y.shape} {y.dtype}") | |
break | |
epochs = 100 | |
for t in range(epochs): | |
print(f"Epoch {t+1}\n-------------------------------") | |
train(train_dataloader, model, loss_fn, optimizer, 400) | |
test(test_dataloader, model, loss_fn) | |
vis(test_data, model, t+1, limit=1000) | |
torch.save(model.state_dict(), f"checkpoint/model_{t:04}.pth") | |
print("Done!") | |
torch.save(model.state_dict(), "model.pth") | |
print("Saved PyTorch Model State to model.pth") | |
if cmd == "train-image": | |
MAX_FILE_CASES = 10000 | |
FROM_CACHE = True | |
training_data = CbzDataset(r"W:\MangaTraining\training", r"E:\MangaCache\training", MAX_FILE_CASES, FROM_CACHE, repeat=10) | |
test_data = CbzDataset(r"W:\MangaTraining\test", r"E:\MangaCache\test", MAX_FILE_CASES, FROM_CACHE) | |
model = UpConv_7().to(device) | |
print(model) | |
loss_fn = nn.L1Loss() | |
optimizer = torch.optim.Adam(model.parameters(), lr=0.00025) | |
batch_size = 32 | |
# Create data loaders. | |
train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True, num_workers=16) | |
test_dataloader = DataLoader(test_data, batch_size=batch_size, num_workers=4) | |
if False: | |
cols, rows = 8, 8 | |
figure = plt.figure(figsize=(8, 8)) | |
for i in range(1, cols * rows + 1): | |
sample_idx = torch.randint(len(training_data), size=(1,)).item() | |
img, label = training_data[sample_idx] | |
figure.add_subplot(rows, cols, i) | |
plt.axis("off") | |
plt.imshow(img.squeeze(), cmap="gray") | |
plt.show() | |
exit(0) | |
for X, y in test_dataloader: | |
print(f"Shape of X [N, C, H, W]: {X.shape}") | |
print(f"Shape of y: {y.shape} {y.dtype}") | |
break | |
epochs = 100 | |
for t in range(epochs): | |
print(f"Epoch {t+1}\n-------------------------------") | |
train(train_dataloader, model, loss_fn, optimizer, 50) | |
test(test_dataloader, model, loss_fn) | |
torch.save(model.state_dict(), f"checkpoint/image_{t:04}.pth") | |
print("Done!") | |
torch.save(model.state_dict(), "image.pth") | |
print("Saved PyTorch Model State to image.pth") | |
if cmd == "train-select": | |
model = UpSelect_7().to(device) | |
model_a = UpConv_7().to(device) | |
model_a.load_state_dict(torch.load("image.pth")) | |
model_a.eval() | |
model_b = UpConv_7_KanjiV2().to(device) | |
model_b.load_state_dict(torch.load("model.pth")) | |
model_b.eval() | |
print(model) | |
loss_fn = nn.L1Loss() | |
optimizer = torch.optim.Adam(model.parameters(), lr=0.00025) | |
text_training_data = ImageDataset("training", limit=250_000, cache=False) | |
text_test_data = ImageDataset("test", limit=3_000, cache=False) | |
MAX_FILE_CASES = 1000 | |
FROM_CACHE = True | |
REPEAT = 35 | |
RESOLUTION = 112 | |
image_training_data = CbzDataset(r"W:\MangaTraining\training", r"E:\MangaCache\training", MAX_FILE_CASES, FROM_CACHE, resolution=RESOLUTION, repeat=REPEAT) | |
image_test_data = CbzDataset(r"W:\MangaTraining\test", r"E:\MangaCache\test", MAX_FILE_CASES, FROM_CACHE, resolution=RESOLUTION, repeat=2) | |
training_data = ConcatDataset([LabeledDataset(text_training_data, 1), LabeledDataset(image_training_data, 0)]) | |
test_data = ConcatDataset([LabeledDataset(text_test_data, 1), LabeledDataset(image_test_data, 0)]) | |
# training_data = image_training_data | |
# test_data = image_test_data | |
if False: | |
cols, rows = 8, 8 | |
figure = plt.figure(figsize=(16, 8)) | |
for i in range(1, cols * rows + 1): | |
sample_idx = torch.randint(len(training_data), size=(1,)).item() | |
img, img2 = training_data[sample_idx] | |
figure.add_subplot(rows, cols*2, (i-1)*2+1) | |
plt.axis("off") | |
plt.imshow(img.squeeze(), cmap="gray", vmin=0.0, vmax=1.0) | |
figure.add_subplot(rows, cols*2, (i-1)*2+2) | |
plt.axis("off") | |
plt.imshow(img2.squeeze(), cmap="gray", vmin=0.0, vmax=1.0) | |
plt.show() | |
exit(0) | |
batch_size = 32 | |
# Create data loaders. | |
train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True, num_workers=20, persistent_workers=True, drop_last=True) | |
test_dataloader = DataLoader(test_data, batch_size=batch_size, num_workers=4, persistent_workers=True) | |
epochs = 100 | |
for t in range(epochs): | |
print(f"Epoch {t+1}\n-------------------------------") | |
train_select(train_dataloader, model_a, model_b, model, loss_fn, optimizer, 400) | |
test_select(test_dataloader, model_a, model_b, model, loss_fn) | |
torch.save(model.state_dict(), f"checkpoint/select_{t:04}.pth") | |
print("Done!") | |
torch.save(model.state_dict(), "select.pth") | |
print("Saved PyTorch Model State to select.pth") | |
if cmd == "zp-text": | |
model = UpConv_7_KanjiV2().to(device) | |
model.load_state_dict(torch.load("model.pth")) | |
model.eval() | |
test_data = ImageDataset("test") | |
with torch.no_grad(): | |
for i,(x,y) in enumerate(test_data): | |
x = x.to(device) | |
x = torch.unsqueeze(x, 0) | |
pred = model(x) | |
pred = torch.squeeze(pred, 0) | |
pred = torch.clamp(pred, 0, 1) | |
im = to_pil_image(pred, mode="L") | |
im.save(f"test/case_{i:06}_zp.png") | |
if cmd == "bench-image": | |
model = UpConv_7().to(device) | |
model.load_state_dict(torch.load("image.pth")) | |
model.eval() | |
with torch.no_grad(): | |
for f in os.listdir("bench"): | |
if "_result" in f or "_old" in f or "_text" in f or "_select" in f or "_final" in f: continue | |
x = read_image(f"bench/{f}", ImageReadMode.GRAY) | |
scale = 1.0 / 255.0 | |
x = x.to(torch.float32) * scale | |
x = x.to(device) | |
pred = model(x) | |
pred = torch.clamp(pred, 0, 1) | |
im = to_pil_image(pred, mode="L") | |
f_base, f_ext = f.split(".", maxsplit=1) | |
im.save(f"bench/{f_base}_result.{f_ext}") | |
if cmd == "bench-text": | |
model = UpConv_7_KanjiV2().to(device) | |
model.load_state_dict(torch.load("model.pth")) | |
model.eval() | |
with torch.no_grad(): | |
for f in os.listdir("bench"): | |
if "_result" in f or "_old" in f or "_text" in f or "_select" in f or "_final" in f: continue | |
x = read_image(f"bench/{f}", ImageReadMode.GRAY) | |
scale = 1.0 / 255.0 | |
x = x[:, :1024+128, :512+256] | |
x = x.to(torch.float32) * scale | |
x = x.to(device) | |
x = torch.unsqueeze(x, 0) | |
print(x.shape) | |
pred = model(x) | |
pred = torch.squeeze(pred, 0) | |
pred = torch.clamp(pred, 0, 1) | |
im = to_pil_image(pred, mode="L") | |
f_base, f_ext = f.split(".", maxsplit=1) | |
im.save(f"bench/{f_base}_text.png") | |
if cmd == "bench-final": | |
model = UpSelect_7().to(device) | |
model.load_state_dict(torch.load("select.pth")) | |
model.eval() | |
model_a = UpConv_7().to(device) | |
model_a.load_state_dict(torch.load("image.pth")) | |
model_a.eval() | |
model_b = UpConv_7_KanjiV2().to(device) | |
model_b.load_state_dict(torch.load("model.pth")) | |
model_b.eval() | |
pad = 128 | |
for f in os.listdir("bench"): | |
if "_result" in f or "_old" in f or "_text" in f or "_select" in f or "_final" in f: continue | |
# if "003" not in f: continue | |
xs = read_image(f"bench/{f}", ImageReadMode.GRAY) | |
xs = NNF.pad(xs, (pad, pad, pad, pad), mode="constant") | |
_, sh, sw = xs.shape | |
y = torch.zeros(1, sh*2, sw*2) | |
ysel = torch.zeros(1, sh*2, sw*2) | |
with torch.no_grad(): | |
for by in range(0, xs.shape[-2], 64): | |
for bx in range(0, xs.shape[-1], 64): | |
if bx + 128 > xs.shape[-1] or by + 128 > xs.shape[-2]: | |
continue | |
print(bx, by) | |
scale = 1.0 / 255.0 | |
x = xs[:, by:by+128, bx:bx+128] | |
x = x.to(torch.float32) * scale | |
x = x.to(device) | |
x = torch.unsqueeze(x, 0) | |
Xa = model_a(x) | |
Xb = model_b(x) | |
pred = model(x) | |
h = min(Xa.shape[2], Xb.shape[2], pred.shape[2]) | |
w = min(Xa.shape[3], Xb.shape[3], pred.shape[3]) | |
Xa = center_crop(Xa, (h, w)) | |
Xb = center_crop(Xb, (h, w)) | |
pred = center_crop(pred, (h, w)) | |
pred = torch.clamp(pred, 0, 1) | |
Xp = Xa * (1.0 - pred) + Xb * pred | |
pred = torch.squeeze(pred, 0) | |
Xp = torch.clamp(Xp, 0, 1) | |
oy = (256 - h) // 2 | |
ox = (256 - w) // 2 | |
dy = by*2 + oy | |
dx = bx*2 + ox | |
y[:, dy:dy+h, dx:dx+h] = Xp | |
ysel[:, dy:dy+h, dx:dx+h] = pred | |
y = y[:, pad*2:pad*-2, pad*2:pad*-2] | |
im = to_pil_image(y, mode="L") | |
f_base, f_ext = f.split(".", maxsplit=1) | |
im.save(f"bench/{f_base}_final.png") | |
ysel = ysel[:, pad*2:pad*-2, pad*2:pad*-2] | |
imsel = to_pil_image(ysel, mode="L") | |
f_base, f_ext = f.split(".", maxsplit=1) | |
imsel.save(f"bench/{f_base}_select.png") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment