Skip to content

Instantly share code, notes, and snippets.

@yuq-1s
Last active February 1, 2022 09:27
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save yuq-1s/8bf91eaac76bbb5d6997eb36043ea1f8 to your computer and use it in GitHub Desktop.
Save yuq-1s/8bf91eaac76bbb5d6997eb36043ea1f8 to your computer and use it in GitHub Desktop.
[Implicit Maximum Likelihood Estimation](https://arxiv.org/abs/1809.09087) in 100 lines
import mxnet as mx
from mxnet import nd, autograd
from mxnet.gluon import nn
from mxnet.gluon.contrib.nn import Identity, Concurrent
from mxnet import gluon
import logging
def d(a, b):
return (a - b).norm()
def R(a, b):
return [min([d(a0, b0) for b0 in b]) for a0 in a]
def loss(fake, x):
R_value = R(x, fake)
return sum(R_value) / len(R_value)
def visualize(x):
'''
x: [n, dim_x]
'''
assert x.shape[1] == 2
import matplotlib.pyplot as plt
x = x.asnumpy()
plt.scatter(x[:, 0], x[:, 1])
plt.show(block=False)
class Model(nn.Block):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.blks = []
self.input = nn.Sequential()
with self.input.name_scope():
self.input.add(nn.Dense(5), nn.BatchNorm(), nn.LeakyReLU(0.1))
self.input.initialize('orthogonal')
for i in range(10):
blk = nn.Sequential()
with blk.name_scope():
blk.add(nn.Dense(5),
nn.BatchNorm(),
nn.LeakyReLU(0.1))
self.register_child(blk)
blk.initialize('orthogonal')
self.blks.append(blk)
self.output = nn.Dense(2)
self.output.initialize('orthogonal')
def forward(self, z):
z = self.input(z)
for blk in self.blks:
z = blk(z) + z
return self.output(z)
def get_x(n=30):
'''
return [3*n, 2]
sample data of dimension 2
'''
functions = [nd.sin, lambda x: nd.abs(1-x*x), lambda x: -nd.cos(x)]
def gen():
for f in functions:
x1 = nd.random.uniform(-1, 1, n)
x2 = f(x1) + 0.05 * nd.random.uniform(-1, 1, n)
yield nd.transpose(nd.stack(x1, x2))
# x2 = nd.stack(*[f(x1) for f in choice])
return nd.shuffle(nd.reshape(nd.stack(*list(gen())), [len(functions)*n, 2]))
if __name__ == '__main__':
log = logging.getLogger()
log.setLevel(logging.DEBUG)
n = 30
x_dim = 2
z_dim = 1
steps = 20000
model = Model()
schedule = mx.lr_scheduler.MultiFactorScheduler(step=[2000, 7500, 10000], factor=0.5)
sgd_optimizer = mx.optimizer.Adam(learning_rate=0.03, lr_scheduler=schedule)
trainer = mx.gluon.Trainer(params=model.collect_params(), optimizer=sgd_optimizer)
x = get_x(n)
train_loss = 0.;
logging.debug("start training")
for step in range(steps):
with autograd.record():
z = nd.random.uniform(-1, 1, (n, z_dim))
fake = model(z)
train_loss = loss(fake=fake, x=x)
train_loss.backward()
trainer.step(n)
print("step: {}, train_loss: {}".format(step, train_loss))
if step % 100 == 0:
visualize(fake)
x = get_x(n)
visualize(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment