Skip to content

Instantly share code, notes, and snippets.

@purple4reina
Created November 15, 2021 04:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save purple4reina/500d042eeece2f8794da4b2c0e06131e to your computer and use it in GitHub Desktop.
Save purple4reina/500d042eeece2f8794da4b2c0e06131e to your computer and use it in GitHub Desktop.
# https://realpython.com/generative-adversarial-networks/
# https://salu133445.github.io/lakh-pianoroll-dataset/
import torch
from torch import nn
import time
import json
import math
import matplotlib.pyplot as plt
import pypianoroll as ppr
import numpy as np
import os
os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide"
import pygame
lr = 0.001
num_epochs = 10000
data_length = 1024
samples = 32
with open('final-project/datasets/lpd/lpd_cleansed/midi_info_v2.json') as f:
metadata = json.loads(f.read())
row_length = data_length * 128
train_data = []
for file, data in metadata.items():
if not data['constant_tempo']:
continue
if data['tempo'] != 96:
continue
filename = f'final-project/datasets/lpd/lpd_full/{file[0]}/{file}.npz'
mtrack = ppr.load(filename)
for track in mtrack.tracks:
if 'clarinet' not in track.name.lower():
continue
piece = track.pianoroll
if piece.shape[0] < data_length:
continue
data = piece[:data_length].reshape(row_length)
data = data.astype(np.float32)
train_data.append(data)
samples -= 1
if not samples:
break
assert not samples, samples
train_data_length = len(train_data)
train_labels = torch.zeros(train_data_length)
train_set = [
(train_data[i], train_labels[i]) for i in range(train_data_length)
]
batch_size = 32
train_loader = torch.utils.data.DataLoader(
train_set, batch_size=batch_size, shuffle=True
)
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(row_length, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, 128),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(128, 64),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(64, 1),
nn.Sigmoid(),
)
def forward(self, x):
output = self.model(x)
return output
discriminator = Discriminator()
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(1024, 512),
nn.ReLU(),
nn.Linear(512, 128),
nn.ReLU(),
nn.Linear(128, 32),
nn.ReLU(),
nn.Linear(32, row_length),
nn.Sigmoid(),
)
def forward(self, x):
output = self.model(x)
return output
generator = Generator()
loss_function = nn.BCELoss()
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=lr)
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=lr)
for epoch in range(num_epochs):
for real_samples, _ in train_loader:
# Data for training the discriminator
real_samples_labels = torch.ones((batch_size, 1))
latent_space_samples = torch.randn((batch_size, 1024))
generated_samples = generator(latent_space_samples)
generated_samples_labels = torch.zeros((batch_size, 1))
all_samples = torch.cat((real_samples, generated_samples))
all_samples_labels = torch.cat(
(real_samples_labels, generated_samples_labels)
)
# Training the discriminator
discriminator.zero_grad()
output_discriminator = discriminator(all_samples)
loss_discriminator = loss_function(
output_discriminator, all_samples_labels)
loss_discriminator.backward()
optimizer_discriminator.step()
# Data for training the generator
latent_space_samples = torch.randn((batch_size, 1024))
# Training the generator
generator.zero_grad()
generated_samples = generator(latent_space_samples)
output_discriminator_generated = discriminator(generated_samples)
loss_generator = loss_function(
output_discriminator_generated, real_samples_labels
)
loss_generator.backward()
optimizer_generator.step()
# Show loss
print(f"Epoch: {epoch} Loss D.: {loss_discriminator} Loss G.: {loss_generator}")
latent_space_samples = torch.randn(1, 1024)
generated_samples = generator(latent_space_samples)
generated_samples = generated_samples.detach()
piece = generated_samples.reshape((data_length, 128))
piece[piece>0.99] = 1
piece[piece<=0.99] = 0
track = ppr.BinaryTrack(pianoroll=piece, name='piano')
mtrack = ppr.Multitrack(tracks=[track])
fname = 'final-project/original.mid'
mtrack.write(fname)
pygame.init()
pygame.mixer.music.load(fname)
pygame.mixer.music.play()
mtrack.plot()
plt.show()
while pygame.mixer.music.get_busy():
time.sleep(0.5)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment