Skip to content

Instantly share code, notes, and snippets.

@redwrasse
Last active July 30, 2020 02:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save redwrasse/937ad45951158329b97d8203a50971b7 to your computer and use it in GitHub Desktop.
Save redwrasse/937ad45951158329b97d8203a50971b7 to your computer and use it in GitHub Desktop.
"""
Currently trains with decreasing loss
*** epoch: 0 epoch loss: 276.47448682785034
*** epoch: 1 epoch loss: 216.9058997631073
*** epoch: 2 epoch loss: 190.01888144016266
*** epoch: 3 epoch loss: 171.68642991781235
*** epoch: 4 epoch loss: 157.7317717075348
*** epoch: 5 epoch loss: 145.89844578504562
...
...
*** epoch: 90 epoch loss: 11.323879387229681
*** epoch: 91 epoch loss: 11.176103946752846
*** epoch: 92 epoch loss: 11.033554057590663
*** epoch: 93 epoch loss: 10.898204608820379
"""
import torchaudio
import torch
import torch.nn.functional as F
def train_ar_generative_model():
"""
Train an auto-regressive generative model
simplest model = single convolutional layer
with softmax per sample activation
line up inputs and outputs appropriately, and
optimize with a per-sample cross-entropy loss.
"""
yesno_data = torchaudio.datasets.YESNO('./',
download=True)
data_loader = torch.utils.data.DataLoader(yesno_data,
batch_size=1,
shuffle=True,
num_workers=1)
KERNEL_SIZE = 100
def left_pad(x):
# left pad x with KERNEL_SIZE - 1 zeros to the left
return F.pad(x,
pad=[KERNEL_SIZE - 1, 0],
mode='constant',
value=0)
conv_layer = torch.nn.Conv1d(
in_channels=256,
out_channels=256,
kernel_size=KERNEL_SIZE
)
def loss_criterion(output, input):
modified_output = output[:, :, KERNEL_SIZE - 1:-1]
modified_input = torch.squeeze(input[:, :, KERNEL_SIZE:],
dim=1)
loss_fn = torch.nn.CrossEntropyLoss()
return loss_fn(modified_output, modified_input)
optimizer = torch.optim.SGD(conv_layer.parameters(),
lr=1e-1)
encoding = torchaudio.transforms.MuLawEncoding(quantization_channels=256)
nepochs = 10**3
for epoch in range(nepochs):
epoch_loss = 0.
# break into chunks to spend less computation time on each iteration
for i, sample in enumerate(data_loader):
waveform, sample_rate, labels = sample
n = waveform.shape[-1]
n_sample = int(KERNEL_SIZE * 1.5)
for j in range(0, n_sample, n - n_sample): # may not be complete
waveform_chunk = waveform[:, :, j: j + n_sample]
categorical_input = encoding(waveform_chunk)
assert categorical_input.shape[:2] == (1, 1)
input = torch.squeeze(torch.nn.functional.one_hot(categorical_input, 256),
dim=0).permute(0, 2, 1).float()
assert input.shape[:2] == (1, 256)
optimizer.zero_grad()
lp_input = left_pad(input)
assert lp_input.shape[:2] == (1, 256)
assert lp_input.shape[-1] == (input.shape[-1] + KERNEL_SIZE - 1)
output = conv_layer(lp_input)
assert output.shape[:2] == (1, 256)
assert output.shape[-1] == input.shape[-1]
loss = loss_criterion(output, categorical_input)
loss.backward()
epoch_loss += loss.item()
optimizer.step()
#print(f'sample chunk loss: {loss.item()}')
print(f'*** epoch: {epoch} epoch loss: {epoch_loss}')
def download_process_data():
# following the torchaudio docs:
# https://pytorch.org/audio/datasets.html
yesno_data = torchaudio.datasets.YESNO('./',
download=True)
data_loader = torch.utils.data.DataLoader(yesno_data,
batch_size=1,
shuffle=True,
num_workers=1)
sample_ct = 0
for i, sample in enumerate(data_loader):
# from the torchaudio docs:
# Each item is a tuple of the form: (waveform, sample_rate, labels)
# each waveform of shape [1, 1, n] where n seems to vary between
# ~ 45000 and 55000 eg. single-channel waveform of variable length
# sample rate is 8000 for all samples
waveform, sample_rate, labels = sample
assert torch.equal(sample_rate, torch.LongTensor([8000,]))
# mu-quantization and reshaped
# assumed and verified signal already given between range -1 and 1,
# necessary for mu encoding
assert -1. < waveform.min() < 1.
assert -1. < waveform.max() < 1.
encoding = torchaudio.transforms.MuLawEncoding(quantization_channels=256)
quantized_waveform = torch.squeeze(torch.nn.functional.one_hot(encoding(waveform), 256),
dim=0).permute(0, 2, 1)
# shape is (1, 256, n)
assert quantized_waveform.shape[:2] == (1, 256)
sample_ct += 1
assert sample_ct == 60, 'expected 60 samples'
if __name__ == "__main__":
train_ar_generative_model()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment