Skip to content

Instantly share code, notes, and snippets.

@taras-sereda
Created January 23, 2019 10:26
Show Gist options
  • Save taras-sereda/130ef385382577ee7f8c6e9b2ce9dd90 to your computer and use it in GitHub Desktop.
Save taras-sereda/130ef385382577ee7f8c6e9b2ce9dd90 to your computer and use it in GitHub Desktop.
waveglow CPU inference
diff --git a/glow_old.py b/glow_old.py
index 0de2375..5895300 100644
--- a/glow_old.py
+++ b/glow_old.py
@@ -183,7 +183,7 @@ class WaveGlow(torch.nn.Module):
self.n_remaining_channels,
spect.size(2)).normal_()
else:
- audio = torch.cuda.FloatTensor(spect.size(0),
+ audio = torch.FloatTensor(spect.size(0),
self.n_remaining_channels,
spect.size(2)).normal_()
@@ -215,7 +215,7 @@ class WaveGlow(torch.nn.Module):
self.n_early_size,
spect.size(2)).normal_()
else:
- z = torch.cuda.FloatTensor(spect.size(0),
+ z = torch.FloatTensor(spect.size(0),
diff --git a/glow_old.py b/glow_old.py
index 0de2375..5895300 100644
--- a/glow_old.py
+++ b/glow_old.py
@@ -183,7 +183,7 @@ class WaveGlow(torch.nn.Module):
self.n_remaining_channels,
spect.size(2)).normal_()
else:
- audio = torch.cuda.FloatTensor(spect.size(0),
+ audio = torch.FloatTensor(spect.size(0),
self.n_remaining_channels,
spect.size(2)).normal_()
@@ -215,7 +215,7 @@ class WaveGlow(torch.nn.Module):
self.n_early_size,
spect.size(2)).normal_()
else:
- z = torch.cuda.FloatTensor(spect.size(0),
+ z = torch.FloatTensor(spect.size(0),
self.n_early_size,
spect.size(2)).normal_()
audio = torch.cat((sigma*z, audio),1)
diff --git a/inference.py b/inference.py
index 2c67605..61cf6f2 100644
--- a/inference.py
+++ b/inference.py
@@ -32,9 +32,9 @@ from mel2samp import files_to_list, MAX_WAV_VALUE
def main(mel_files, waveglow_path, sigma, output_dir, sampling_rate, is_fp16):
mel_files = files_to_list(mel_files)
- waveglow = torch.load(waveglow_path)['model']
+ waveglow = torch.load(waveglow_path, map_location=lambda storage, loc: storage)['model']
waveglow = waveglow.remove_weightnorm(waveglow)
- waveglow.cuda().eval()
+ waveglow.eval()
if is_fp16:
waveglow.half()
for k in waveglow.convinv:
@@ -43,7 +43,7 @@ def main(mel_files, waveglow_path, sigma, output_dir, sampling_rate, is_fp16):
for i, file_path in enumerate(mel_files):
file_name = os.path.splitext(os.path.basename(file_path))[0]
mel = torch.load(file_path)
- mel = torch.autograd.Variable(mel.cuda())
+ mel = torch.autograd.Variable(mel)
mel = torch.unsqueeze(mel, 0)
mel = mel.half() if is_fp16 else mel
with torch.no_grad():
diff --git a/tacotron2 b/tacotron2
--- a/tacotron2
+++ b/tacotron2
@@ -1 +1 @@
-Subproject commit fc0cf6a89a47166350b65daa1beaa06979e4cddf
+Subproject commit fc0cf6a89a47166350b65daa1beaa06979e4cddf-dirty
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment