Skip to content

Instantly share code, notes, and snippets.

@cccntu
Created April 4, 2020 06:02
Show Gist options
  • Save cccntu/88e13255a4d06244c21db99225faf63d to your computer and use it in GitHub Desktop.
Save cccntu/88e13255a4d06244c21db99225faf63d to your computer and use it in GitHub Desktop.
"""
讓encoder能disentangleword跟style才會學得好 像fontgan那樣
相同的word/style的embedding要接近,用F.pairwise_distance
然後要能decode出來,用l1_loss,
"""
class Encoder(nn.Module):
def __init__():
self.encoder = Encoder()
self.style = nn.Linear()
self.word = nn.Linear()
def forward(img):
emb = self.encoder(img)
style_emb = self.style(emb)
word_emb = self.word(emb)
return style_emb, word_emb
class Decoder(nn.Module):
def __init__():
self.decoder = Decoder()
self.linear = nn.Linear()
def forward(style_emb, word_emb):
emb = self.linear(torch.cat([style_emb, word_emb]))
out = self.decoder(emb)
return out
class Model(pl.LightningModule):
def __init__():
self.encoder = Encoder()
self.decoder = Decoder()
def training_step(self, batch, batch_idx):
img, same_word_diff_style_as_img, diff_word_same_style_as_img = batch
style0, word0 = self.encoder(img)
_, word1 = self.encoder(same_word_diff_style_as_img)
style1 , _ = self.encoder(diff_word_same_style_as_img)
img_out = self.decoder(style0, word0)
loss_word = F.pairwise_distance(word0, word1)
loss_style= F.pairwise_distance(style0, style1)
loss_img = F.l1_loss(img, img_out)
loss = loss_word + loss_style + loss_img
return {'loss': loss}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment