Skip to content

Instantly share code, notes, and snippets.

@jeremyfix
Last active September 19, 2022 07:08
Show Gist options
  • Save jeremyfix/ee183491af00d536006c4531ec3e536b to your computer and use it in GitHub Desktop.
Save jeremyfix/ee183491af00d536006c4531ec3e536b to your computer and use it in GitHub Desktop.
Experimentation of pytorch 1.12 gradient descent with complex tensors
name: torchcomplex
channels:
- pytorch
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- blas=1.0=mkl
- brotli=1.0.9=h5eee18b_7
- brotli-bin=1.0.9=h5eee18b_7
- brotlipy=0.7.0=py39h27cfd23_1003
- bzip2=1.0.8=h7b6447c_0
- ca-certificates=2022.07.19=h06a4308_0
- certifi=2022.6.15=py39h06a4308_0
- cffi=1.15.1=py39h74dc2b5_0
- charset-normalizer=2.0.4=pyhd3eb1b0_0
- cpuonly=2.0=0
- cryptography=37.0.1=py39h9ce1e76_0
- cycler=0.11.0=pyhd3eb1b0_0
- dbus=1.13.18=hb2f20db_0
- expat=2.4.4=h295c915_0
- ffmpeg=4.3=hf484d3e_0
- fontconfig=2.13.1=h6c09931_0
- fonttools=4.25.0=pyhd3eb1b0_0
- freetype=2.11.0=h70c0345_0
- giflib=5.2.1=h7b6447c_0
- glib=2.69.1=h4ff587b_1
- gmp=6.2.1=h295c915_3
- gnutls=3.6.15=he1e5248_0
- gst-plugins-base=1.14.0=h8213a91_2
- gstreamer=1.14.0=h28cd5cc_2
- icu=58.2=he6710b0_3
- idna=3.3=pyhd3eb1b0_0
- intel-openmp=2021.4.0=h06a4308_3561
- jpeg=9e=h7f8727e_0
- kiwisolver=1.4.2=py39h295c915_0
- krb5=1.19.2=hac12032_0
- lame=3.100=h7b6447c_0
- lcms2=2.12=h3be6417_0
- ld_impl_linux-64=2.38=h1181459_1
- lerc=3.0=h295c915_0
- libbrotlicommon=1.0.9=h5eee18b_7
- libbrotlidec=1.0.9=h5eee18b_7
- libbrotlienc=1.0.9=h5eee18b_7
- libclang=10.0.1=default_hb85057a_2
- libdeflate=1.8=h7f8727e_5
- libedit=3.1.20210910=h7f8727e_0
- libevent=2.1.12=h8f2d780_0
- libffi=3.3=he6710b0_2
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libiconv=1.16=h7f8727e_2
- libidn2=2.3.2=h7f8727e_0
- libllvm10=10.0.1=hbcb73fb_5
- libpng=1.6.37=hbc83047_0
- libpq=12.9=h16c4e8d_3
- libstdcxx-ng=11.2.0=h1234567_1
- libtasn1=4.16.0=h27cfd23_0
- libtiff=4.4.0=hecacb30_0
- libunistring=0.9.10=h27cfd23_0
- libuuid=1.0.3=h7f8727e_2
- libwebp=1.2.2=h55f646e_0
- libwebp-base=1.2.2=h7f8727e_0
- libxcb=1.15=h7f8727e_0
- libxkbcommon=1.0.1=hfa300c1_0
- libxml2=2.9.14=h74e7548_0
- libxslt=1.1.35=h4e12654_0
- lz4-c=1.9.3=h295c915_1
- matplotlib=3.5.2=py39h06a4308_0
- matplotlib-base=3.5.2=py39hf590b9c_0
- mkl=2021.4.0=h06a4308_640
- mkl-service=2.4.0=py39h7f8727e_0
- mkl_fft=1.3.1=py39hd3c417c_0
- mkl_random=1.2.2=py39h51133e4_0
- munkres=1.1.4=py_0
- ncurses=6.3=h5eee18b_3
- nettle=3.7.3=hbbd107a_1
- nspr=4.33=h295c915_0
- nss=3.74=h0370c37_0
- numpy=1.23.1=py39h6c91a56_0
- numpy-base=1.23.1=py39ha15fc14_0
- openh264=2.1.1=h4ff587b_0
- openssl=1.1.1q=h7f8727e_0
- packaging=21.3=pyhd3eb1b0_0
- pcre=8.45=h295c915_0
- pillow=9.2.0=py39hace64e9_1
- pip=22.1.2=py39h06a4308_0
- ply=3.11=py39h06a4308_0
- pycparser=2.21=pyhd3eb1b0_0
- pyopenssl=22.0.0=pyhd3eb1b0_0
- pyparsing=3.0.9=py39h06a4308_0
- pyqt=5.15.7=py39h6a678d5_1
- pyqt5-sip=12.11.0=py39h6a678d5_1
- pysocks=1.7.1=py39h06a4308_0
- python=3.9.13=haa1d7c7_1
- python-dateutil=2.8.2=pyhd3eb1b0_0
- pytorch=1.12.1=py3.9_cpu_0
- pytorch-mutex=1.0=cpu
- qt-main=5.15.2=h327a75a_7
- qt-webengine=5.15.9=hd2b0992_4
- qtwebkit=5.212=h4eab89a_4
- readline=8.1.2=h7f8727e_1
- requests=2.28.1=py39h06a4308_0
- setuptools=63.4.1=py39h06a4308_0
- sip=6.6.2=py39h6a678d5_0
- six=1.16.0=pyhd3eb1b0_1
- sqlite=3.39.2=h5082296_0
- tk=8.6.12=h1ccaba5_0
- toml=0.10.2=pyhd3eb1b0_0
- torchvision=0.13.1=py39_cpu
- tornado=6.2=py39h5eee18b_0
- tqdm=4.64.0=py39h06a4308_0
- typing_extensions=4.3.0=py39h06a4308_0
- tzdata=2022c=h04d1e81_0
- urllib3=1.26.11=py39h06a4308_0
- wheel=0.37.1=pyhd3eb1b0_0
- xz=5.2.5=h7f8727e_1
- zlib=1.2.12=h5eee18b_3
- zstd=1.5.2=ha4553b6_0
prefix: /home/fix_jer/.local/conda/envs/torchcomplex
#!/usr/bin/env python
"""
Script for demoing complex data in pytorch
"""
import sys
import argparse
import itertools
import torch
import torch.utils.data
import torch.nn as nn
import tqdm
import numpy as np
import matplotlib.pyplot as plt
class Dataset(torch.utils.data.IterableDataset):
def __init__(self):
super().__init__()
def __next__(self):
x = (2 * torch.rand(1) - 1.0) + (2 * torch.rand(1) - 1) * 1j
y = x.abs()
return x, y
def __iter__(self):
return self
class CReLU(nn.Module):
def __init__(self):
super().__init__()
self.relu = nn.ReLU()
def forward(self, z):
return self.relu(z.real) + self.relu(z.imag) * 1j
class zReLU(nn.Module):
def forward(self, z):
pos_real = z.real > 0
pos_img = z.imag > 0
return z * pos_real * pos_img
class Mod(nn.Module):
def __init__(self):
super().__init__()
def forward(self, z):
return torch.abs(z)
class Dropout(nn.Module):
def __init__(self, p=0.5):
super().__init__()
self.p = p
def forward(self, z):
mask = torch.nn.functional.dropout(
torch.ones(z.shape), self.p, training=self.training
)
return mask * z
class Dropout2d(nn.Module):
def __init__(self, p=0.5):
super().__init__()
self.p = p
def forward(self, z):
mask = torch.nn.functional.dropout2d(
torch.ones(z.shape), self.p, training=self.training
)
return mask * z
def test_data():
dataloader = torch.utils.data.DataLoader(Dataset(), batch_size=32, num_workers=2)
X, Y = next(iter(dataloader))
print(f"Got an input tensor of shape {X.shape} and type {X.dtype}")
def test_linear(args):
dataset = Dataset()
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, num_workers=2)
dtype = torch.complex64
device = torch.device(args.device)
model = nn.Sequential(
nn.Linear(1, 128, dtype=dtype),
Dropout(0.5),
zReLU(),
# nn.BatchNorm1d(32, dtype=dtype),
# nn.Dropout(),
nn.Linear(128, 128, dtype=dtype),
Dropout(0.5),
zReLU(),
nn.Linear(128, 1, dtype=dtype),
Mod(),
)
optim = torch.optim.Adam(model.parameters(), lr=3e-4)
# Train the network
it_train = iter(dataloader)
n_steps = 5000
model.train()
for i in tqdm.tqdm(range(n_steps)):
x, y = next(it_train)
x, y = x.to(device), y.to(device)
# Forward
y_pred = model(x)
loss = ((y_pred - y) ** 2).sum()
sys.stdout.write(f"\r {loss}")
sys.stdout.flush()
optim.zero_grad()
loss.backward()
optim.step()
# Evaluate
model.eval()
test_loss = 0
n_samples = 1000
with torch.no_grad():
for (x, y) in itertools.islice(dataset, n_samples):
x, y = x.to(device), y.to(device)
# Forward
y_pred = model(x)
test_loss += ((y_pred - y) ** 2).sum().item()
test_loss /= n_samples
test_loss = np.sqrt(test_loss)
print(f"The loss evaluated on {n_samples} samples is {test_loss}")
# Display the learned function
x = np.linspace(-1, 1)
y = np.linspace(-1, 1)
X, Y = np.meshgrid(x, y)
inputs = torch.tensor(X * 1j + Y, dtype=dtype).reshape(-1, 1)
expected = inputs.abs()
model.eval()
with torch.no_grad():
outputs = model(inputs)
print(((outputs - expected) ** 2).mean())
Z = outputs.reshape(X.shape).numpy()
plt.figure()
plt.pcolormesh(X, Y, Z)
plt.colorbar()
plt.tight_layout()
plt.savefig("abs.png")
plt.show()
def test_conv(args):
# Dummy example where conv are actually doing the same operations as fc layers
dataset = Dataset()
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, num_workers=2)
dtype = torch.complex64
device = torch.device(args.device)
model = nn.Sequential(
nn.Conv1d(1, 128, kernel_size=1, dtype=dtype),
zReLU(),
# nn.BatchNorm1d(32, dtype=dtype),
Dropout2d(),
nn.Conv1d(128, 128, kernel_size=1, dtype=dtype),
zReLU(),
Dropout2d(),
nn.Flatten(),
nn.Linear(128, 1, dtype=dtype),
Mod(),
)
optim = torch.optim.Adam(model.parameters(), lr=3e-4)
# Train the network
it_train = iter(dataloader)
n_steps = 1000
model.train()
for i in tqdm.tqdm(range(n_steps)):
x, y = next(it_train)
x, y = x.to(device), y.to(device)
# Forward
x = x.reshape(-1, 1, 1)
y_pred = model(x)
loss = ((y_pred - y) ** 2).sum()
sys.stdout.write(f"\r {loss}")
sys.stdout.flush()
optim.zero_grad()
loss.backward()
optim.step()
# Evaluate
model.eval()
test_loss = 0
n_samples = 1000
with torch.no_grad():
for (x, y) in itertools.islice(dataset, n_samples):
x, y = x.to(device), y.to(device)
x = x.reshape(-1, 1, 1)
# Forward
y_pred = model(x)
test_loss += ((y_pred - y) ** 2).sum().item()
test_loss /= n_samples
test_loss = np.sqrt(test_loss)
print(f"The loss evaluated on {n_samples} samples is {test_loss}")
# Display the learned function
x = np.linspace(-1, 1)
y = np.linspace(-1, 1)
X, Y = np.meshgrid(x, y)
inputs = torch.tensor(X * 1j + Y, dtype=dtype).reshape(-1, 1, 1)
expected = inputs.abs()
model.eval()
with torch.no_grad():
outputs = model(inputs)
print(((outputs - expected) ** 2).mean())
Z = outputs.reshape(X.shape).numpy()
plt.figure()
plt.pcolormesh(X, Y, Z)
plt.colorbar()
plt.show()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--device", choices=["cpu", "cuda"], default="cpu")
args = parser.parse_args()
test_data()
test_linear(args)
test_conv(args)
@jeremyfix
Copy link
Author

Example output with the dense layers

abs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment