Created
April 4, 2020 06:02
-
-
Save cccntu/88e13255a4d06244c21db99225faf63d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
讓encoder能disentangleword跟style才會學得好 像fontgan那樣 | |
相同的word/style的embedding要接近,用F.pairwise_distance | |
然後要能decode出來,用l1_loss, | |
""" | |
class Encoder(nn.Module): | |
def __init__(): | |
self.encoder = Encoder() | |
self.style = nn.Linear() | |
self.word = nn.Linear() | |
def forward(img): | |
emb = self.encoder(img) | |
style_emb = self.style(emb) | |
word_emb = self.word(emb) | |
return style_emb, word_emb | |
class Decoder(nn.Module): | |
def __init__(): | |
self.decoder = Decoder() | |
self.linear = nn.Linear() | |
def forward(style_emb, word_emb): | |
emb = self.linear(torch.cat([style_emb, word_emb])) | |
out = self.decoder(emb) | |
return out | |
class Model(pl.LightningModule): | |
def __init__(): | |
self.encoder = Encoder() | |
self.decoder = Decoder() | |
def training_step(self, batch, batch_idx): | |
img, same_word_diff_style_as_img, diff_word_same_style_as_img = batch | |
style0, word0 = self.encoder(img) | |
_, word1 = self.encoder(same_word_diff_style_as_img) | |
style1 , _ = self.encoder(diff_word_same_style_as_img) | |
img_out = self.decoder(style0, word0) | |
loss_word = F.pairwise_distance(word0, word1) | |
loss_style= F.pairwise_distance(style0, style1) | |
loss_img = F.l1_loss(img, img_out) | |
loss = loss_word + loss_style + loss_img | |
return {'loss': loss} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment