Skip to content

Instantly share code, notes, and snippets.

@oxyflour
Last active May 3, 2023 12:47
Show Gist options
  • Save oxyflour/06f0747f4e99cbf6cfffc1bf5e0d198e to your computer and use it in GitHub Desktop.
Save oxyflour/06f0747f4e99cbf6cfffc1bf5e0d198e to your computer and use it in GitHub Desktop.
80x accel for scikit-rf with pytorch
from typing import List, Tuple
from skrf import Network, Circuit, DefinedGammaZ0
from time import perf_counter
import numpy as np
import torch
n = Network('test/touchstone/dipole-x2.s2p')
f = n.frequency
z = DefinedGammaZ0(f, 50)
class TorchCircuit(Circuit):
def __init__(self, connections: List[List[Tuple]], dev=torch.device('cuda:0')) -> None:
self.dev = dev
super().__init__(connections)
@property
def s(self) -> np.array:
x = torch.tensor(self.X, device=self.dev)
c = torch.tensor(self.C, device=self.dev)
mat = torch.eye(self.dim, device=self.dev) - torch.matmul(c, x)
inv = torch.linalg.inv(mat)
ret = torch.matmul(x, inv).permute(0, 2, 1)
ret = ret.to('cpu').detach().numpy()
return ret
rs = [z.resistor(5, name=f'res-{i}') for i in range(10)]
p = Circuit.Port(f, name='port-0')
g = Circuit.Ground(f, name='gnd-0')
cx = [
[(n, 0), (p, 0)],
[(n, 1), (rs[0], 0)],
*[[(ri, 1), (rj, 0)] for ri, rj in zip(rs[0:-1], rs[1:])],
[(rs[-1], 1), (g, 0)],
]
def run(dev=None):
c = TorchCircuit(cx, dev) if dev else Circuit(cx)
return c.s
dev = torch.device('cuda:0')
print('mse', np.average(np.abs(run(dev) - run()) ** 2))
start = perf_counter()
for i in range(10):
run(dev)
torch_time = perf_counter() - start
print('PREF: run torch in', torch_time, 'seconds')
start = perf_counter()
for i in range(10):
run()
origin_time = perf_counter() - start
print('PREF: run skrf in', origin_time, 'seconds')
print('PERF: torch is', int(origin_time / torch_time), 'times faster')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment