This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def train(vae, data_dataloader, epochs): | |
| loss_evolution = [] | |
| opt = torch.optim.Adam(autoencoder.parameters()) | |
| for _ in tqdm(range(epochs)): | |
| for i, (x, y) in enumerate(data_dataloader): | |
| opt.zero_grad() | |
| x_hat = vae(x) | |
| # Loss = MSE + divergence KL | |
| loss = ((x - x_hat) ** 2).sum() + autoencoder.encoder.kl | |
| if i == (len(x) - 1): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| generated_ECG = [] | |
| for i in range(n): | |
| latent_vector_shape = (1, 25) | |
| new_latent_ECG = vae_trained.encoder.N.sample(latent_vector_shape) | |
| generated_ECG.append(vae_trained.decoder(new_latent_ECG)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class VariationalAutoencoder(nn.Module): | |
| def __init__(self): | |
| super(VariationalAutoencoder, self).__init__() | |
| self.encoder = Encoder() | |
| self.decoder = Decoder() | |
| def forward(self, x): | |
| z = self.encoder(x) | |
| return self.decoder(z) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class Decoder(nn.Module): | |
| """ | |
| Décode le vecteur latent et renvoie les données reconstruites | |
| """ | |
| def __init__(self): | |
| super(Decoder, self).__init__() | |
| # Couches convolutives | |
| self.conv1 = nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2) | |
| self.convbatchNorm1 = nn.BatchNorm2d(16) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class Encoder(nn.Module): | |
| """ | |
| Encode les données et renvoie le vecteur latent | |
| """ | |
| def __init__(self): | |
| super(Encoder, self).__init__() | |
| # Couches convolutives | |
| self.conv1 = nn.Conv2d(1, 8, stride=1, padding=2, kernel_size=5) | |
| self.convbatchNorm1 = nn.BatchNorm2d(8) |