Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
Created February 1, 2021 00:47
Show Gist options
  • Save crowsonkb/552dce1335afdccd49e4925f03b9d1a0 to your computer and use it in GitHub Desktop.
Save crowsonkb/552dce1335afdccd49e4925f03b9d1a0 to your computer and use it in GitHub Desktop.
Modified Differential Multiplier Method
import abc
import torch
from torch import nn, optim
class Constraint(nn.Module, metaclass=abc.ABCMeta):
def __init__(self, fn, damping):
super().__init__()
self.fn = fn
self.register_buffer('damping', torch.as_tensor(damping))
self.lmbda = nn.Parameter(torch.tensor(0.))
@abc.abstractmethod
def c_value(self, loss):
...
def forward(self):
loss = self.fn()
c_value = self.c_value(loss)
output = self.damping * c_value**2 / 2 - self.lmbda * c_value
return output, loss
class EqConstraint(Constraint):
def __init__(self, fn, value, damping=1e-2):
super().__init__(fn, damping)
self.register_buffer('value', torch.as_tensor(value))
def extra_repr(self):
return f'value={self.value:g}, damping={self.damping:g}'
def c_value(self, loss):
return self.value - loss
class MaxConstraint(Constraint):
def __init__(self, fn, max, damping=1e-2):
super().__init__(fn, damping)
loss = self.fn()
self.register_buffer('max', loss.new_tensor(max))
self.slack = nn.Parameter((self.max - loss).relu().pow(1/2))
def extra_repr(self):
return f'max={self.max:g}, damping={self.damping:g}'
def c_value(self, loss):
return self.max - loss - self.slack**2
class MaxConstraintHard(Constraint):
def __init__(self, fn, max, damping=1e-2):
super().__init__(fn, damping)
self.register_buffer('max', torch.as_tensor(max))
def extra_repr(self):
return f'max={self.max:g}, damping={self.damping:g}'
def c_value(self, loss):
return loss.clamp(max=self.max) - loss
class MinConstraint(Constraint):
def __init__(self, fn, min, damping=1e-2):
super().__init__(fn, damping)
loss = self.fn()
self.register_buffer('min', loss.new_tensor(min))
self.slack = nn.Parameter((loss - self.min).relu().pow(1/2))
def extra_repr(self):
return f'min={self.min:g}, damping={self.damping:g}'
def c_value(self, loss):
return loss - self.min - self.slack**2
class MinConstraintHard(Constraint):
def __init__(self, fn, min, damping=1e-2):
super().__init__(fn, damping)
self.register_buffer('min', torch.as_tensor(min))
def extra_repr(self):
return f'min={self.min:g}, damping={self.damping:g}'
def c_value(self, loss):
return loss.clamp(min=self.min) - loss
class BoundConstraintHard(Constraint):
def __init__(self, fn, min, max, damping=1e-2):
super().__init__(fn, damping)
self.register_buffer('min', torch.as_tensor(min))
self.register_buffer('max', torch.as_tensor(max))
def extra_repr(self):
return f'min={self.min:g}, max={self.max:g}, damping={self.damping:g}'
def c_value(self, loss):
return loss.clamp(self.min, self.max) - loss
class MDMM(nn.ModuleList):
def make_optimizer(self, params, *, optimizer=optim.Adamax, lr=2e-3):
lambdas = [c.lmbda for c in self]
slacks = [c.slack for c in self if hasattr(c, 'slack')]
return optimizer([{'params': params, 'lr': lr},
{'params': lambdas, 'lr': -lr},
{'params': slacks, 'lr': lr}])
def forward(self, loss):
output = loss.clone()
losses = []
for c in self:
c_value, c_loss = c()
output += c_value
losses.append(c_loss)
return output, losses
#!/usr/bin/env python3
import argparse
from PIL import Image
import torch
from torch import nn
from torch.nn import functional as F
from torchvision.transforms import functional as TF
import mdmm_2 as mdmm
class TVLoss(nn.Module):
def forward(self, input):
input = F.pad(input, (0, 1, 0, 1), 'replicate')
x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
diff = x_diff**2 + y_diff**2 + 1e-8
return diff.sum(dim=1).sqrt().sum()
def main():
p = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
p.add_argument('input_image', type=str,
help='the input image')
p.add_argument('output_image', type=str, nargs='?', default='out.png',
help='the output image')
p.add_argument('--max-tv', type=float, default=0.02,
help='the maximum allowable total variation per sample')
p.add_argument('--damping', type=float, default=1e-2,
help='the constraint damping strength')
p.add_argument('--lr', type=float, default=2e-3,
help='the learning rate')
args = p.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
pil_image = Image.open(args.input_image).resize((128, 128), Image.LANCZOS)
target = TF.to_tensor(pil_image)[None].to(device)
input = target.clone().requires_grad_()
# torch.manual_seed(0)
# target += torch.randn_like(target) / 10
# target.clamp_(0, 1)
crit_l2 = nn.MSELoss(reduction='sum')
crit_tv = TVLoss()
max_tv = args.max_tv * input.numel()
mdmm_mod = mdmm.MDMM([mdmm.MaxConstraint(lambda: crit_tv(input), max_tv, args.damping)])
opt = mdmm_mod.make_optimizer([input], lr=args.lr)
try:
i = 0
while True:
i += 1
loss = crit_l2(input, target)
lagrangian, losses = mdmm_mod(loss)
msg = '{} l2={:g}, tv={:g}'
print(msg.format(i,
loss.item() / input.numel(),
losses[0].item() / input.numel()))
if not lagrangian.isfinite():
break
opt.zero_grad()
lagrangian.backward()
opt.step()
except KeyboardInterrupt:
pass
TF.to_pil_image(input[0].clamp(0, 1)).save(args.output_image)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment