Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
Last active February 1, 2021 00:33
Show Gist options
  • Save crowsonkb/d9cc697e6b8ff4b161217465249bfc59 to your computer and use it in GitHub Desktop.
Save crowsonkb/d9cc697e6b8ff4b161217465249bfc59 to your computer and use it in GitHub Desktop.
Modified Differential Multiplier Method
import torch
from torch import nn, optim
class Constraint(nn.Module):
def __init__(self, fn, maximum, damping=1e-2):
super().__init__()
self.fn = fn
self.register_buffer('maximum', torch.as_tensor(maximum))
self.register_buffer('damping', torch.as_tensor(damping))
def extra_repr(self):
return f'maximum={self.maximum:g}, damping={self.damping:g}'
class MDMM(nn.Module):
def __init__(self, constraints):
super().__init__()
self.constraints = nn.ModuleList()
self.slacks = nn.ParameterList()
self.lambdas = nn.ParameterList()
for c in constraints:
loss = c.fn().detach()
c = c.to(loss.device, loss.dtype)
self.constraints.append(c)
slack = (c.maximum - loss).relu().pow(1/2)
self.slacks.append(nn.Parameter(slack))
self.lambdas.append(nn.Parameter(slack.new_zeros([])))
def make_optimizer(self, params, lr=2e-3):
return optim.Adamax([{'params': params, 'lr': lr},
{'params': self.slacks, 'lr': lr},
{'params': self.lambdas, 'lr': -lr}])
def forward(self, loss):
out = loss.clone()
losses = []
for i, c in enumerate(self.constraints):
losses.append(c.fn())
c_value = c.maximum - losses[i] - self.slacks[i]**2
out -= self.lambdas[i] * c_value
out += c.damping * c_value**2 / 2
return out, losses
#!/usr/bin/env python3
import argparse
from PIL import Image
import torch
from torch import nn
from torch.nn import functional as F
from torchvision.transforms import functional as TF
import mdmm
class TVLoss(nn.Module):
def forward(self, input):
input = F.pad(input, (0, 1, 0, 1), 'replicate')
x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
diff = x_diff**2 + y_diff**2 + 1e-8
return diff.sum(dim=1).sqrt().sum()
def main():
p = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
p.add_argument('input_image', type=str,
help='the input image')
p.add_argument('output_image', type=str, nargs='?', default='out.png',
help='the output image')
p.add_argument('--max-tv', type=float, default=0.02,
help='the maximum allowable total variation per sample')
p.add_argument('--damping', type=float, default=1e-2,
help='the constraint damping strength')
p.add_argument('--lr', type=float, default=2e-3,
help='the learning rate')
args = p.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
pil_image = Image.open(args.input_image).resize((128, 128), Image.LANCZOS)
target = TF.to_tensor(pil_image)[None].to(device)
input = target.clone().requires_grad_()
# torch.manual_seed(0)
# target += torch.randn_like(target) / 10
# target.clamp_(0, 1)
crit_l2 = nn.MSELoss(reduction='sum')
crit_tv = TVLoss()
max_tv = args.max_tv * input.numel()
mdmm_mod = mdmm.MDMM([mdmm.Constraint(lambda: crit_tv(input), max_tv, args.damping)])
opt = mdmm_mod.make_optimizer([input], lr=args.lr)
try:
i = 0
while True:
i += 1
loss = crit_l2(input, target)
lagrangian, losses = mdmm_mod(loss)
msg = '{} l2={:g}, tv={:g}'
print(msg.format(i,
loss.item() / input.numel(),
losses[0].item() / input.numel()))
if not lagrangian.isfinite():
break
opt.zero_grad()
lagrangian.backward()
opt.step()
except KeyboardInterrupt:
pass
TF.to_pil_image(input[0].clamp(0, 1)).save(args.output_image)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment