Skip to content

Instantly share code, notes, and snippets.

@hazdzz
Last active September 24, 2023 18:50
Show Gist options
  • Save hazdzz/2155de16f99e0fd5f5256dfad24a0a4a to your computer and use it in GitHub Desktop.
Save hazdzz/2155de16f99e0fd5f5256dfad24a0a4a to your computer and use it in GitHub Desktop.
from torch import Tensor
from typing import Callable, List, Optional, Tuple
import math
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
class CSigmoid(nn.Module):
def forward(self, input: Tensor) -> Tensor:
return c_sigmoid(input)
class CTanh(nn.Module):
def forward(self, input: Tensor) -> Tensor:
return c_tanh(input)
class modTanh(nn.Module):
__constants__ = ['rounding_mode']
rounding_mode: str
def __init__(self, rounding_mode: str = None):
super(modTanh, self).__init__()
self.rounding_mode = rounding_mode
def forward(self, input: Tensor, rounding_mode: str = None) -> Tensor:
return mod_tanh(input, rounding_mode=rounding_mode)
def extra_repr(self) -> str:
return 'rounding_mode={}'.format(self.rounding_mode)
class Hirose(nn.Module):
__constants__ = ['m', 'inplace']
m: float
inplace: bool
def __init__(self, m: float = 1., inplace: bool = False):
super(Hirose, self).__init__()
self.m = m
self.inplace = inplace
def forward(self, input: Tensor) -> Tensor:
return hirose(input, m=self.m, inplace=self.inplace)
def extra_repr(self) -> str:
inplace_str = 'inplace=True' if self.inplace else ''
return 'm={}{}'.format(self.m, inplace_str)
class Siglog(nn.Module):
def forward(self, input: Tensor) -> Tensor:
return siglog(input)
class CCardioid(nn.Module):
def forward(self, input: Tensor) -> Tensor:
return c_cardioid(input)
class CReLU(nn.Module):
__constants__ = ['inplace']
inplace: bool
def __init__(self, inplace: bool = False):
super(CReLU, self).__init__()
self.inplace = inplace
def forward(self, input: Tensor) -> Tensor:
return c_relu(input, inplace=self.inplace)
def extra_repr(self) -> str:
inplace_str = 'inplace=True' if self.inplace else ''
return inplace_str
class zReLU(nn.Module):
__constants__ = ['inplace']
inplace: bool
def __init__(self, inplace: bool = False):
super(zReLU, self).__init__()
self.inplace = inplace
def forward(self, input: Tensor) -> Tensor:
return z_relu(input, inplace=self.inplace)
def extra_repr(self) -> str:
inplace_str = 'inplace=True' if self.inplace else ''
return inplace_str
class modReLU(nn.Module):
__constants__ = ['bias', 'rounding_mode', 'inplace']
bias: float
rounding_mode: str
inplace: bool
def __init__(self, bias: float = -math.sqrt(2), rounding_mode: str = None, inplace: bool = False):
super(modReLU, self).__init__()
self.bias = bias
self.rounding_mode = rounding_mode
self.inplace = inplace
def forward(self, input: Tensor) -> Tensor:
return mod_relu(input, bias=self.bias, rounding_mode=self.rounding_mode, inplace=self.inplace)
def extra_repr(self) -> str:
inplace_str = 'inplace=True' if self.inplace else ''
return 'bias={}, rounding_mode={}, inplace_str={}'.format(self.bias, self.rounding_mode, inplace_str)
class CLeakyReLU(nn.Module):
__constants__ = ['negative_slope', 'inplace']
negative_slope: float
inplace: bool
def __init__(self, negative_slope: float = 1e-2, inplace: bool = False) -> None:
super(CLeakyReLU, self).__init__()
self.negative_slope = negative_slope
self.inplace = inplace
def forward(self, input: Tensor) -> Tensor:
return c_leaky_relu(input, self.negative_slope, self.inplace)
def extra_repr(self) -> str:
inplace_str = ', inplace=True' if self.inplace else ''
return 'negative_slope={}{}'.format(self.negative_slope, inplace_str)
class modSoftmax(nn.Module):
__constants__ = ['dim']
dim: Optional[int]
def __init__(self, dim: Optional[int] = None) -> None:
super(modSoftmax, self).__init__()
self.dim = dim
def __setstate__(self, state):
self.__dict__.update(state)
if not hasattr(self, 'dim'):
self.dim = None
def forward(self, input: Tensor) -> Tensor:
return mod_softmax(input, self.dim, _stacklevel=5)
def extra_repr(self) -> str:
return 'dim={dim}'.format(dim=self.dim)
class modLogSoftmax(nn.Module):
__constants__ = ['dim']
dim: Optional[int]
def __init__(self, dim: Optional[int] = None) -> None:
super(modLogSoftmax, self).__init__()
self.dim = dim
def __setstate__(self, state):
self.__dict__.update(state)
if not hasattr(self, 'dim'):
self.dim = None
def forward(self, input: Tensor) -> Tensor:
return mod_log_softmax(input, self.dim, _stacklevel=5)
def extra_repr(self) -> str:
return 'dim={dim}'.format(dim=self.dim)
class rSoftmax(nn.Module):
__constants__ = ['dim']
dim: Optional[int]
def __init__(self, dim: Optional[int] = None) -> None:
super(rSoftmax, self).__init__()
self.dim = dim
def __setstate__(self, state):
self.__dict__.update(state)
if not hasattr(self, 'dim'):
self.dim = None
def forward(self, input: Tensor) -> Tensor:
return r_softmax(input, self.dim, _stacklevel=5)
def extra_repr(self) -> str:
return 'dim={dim}'.format(dim=self.dim)
class rLogSoftmax(nn.Module):
__constants__ = ['dim']
dim: Optional[int]
def __init__(self, dim: Optional[int] = None) -> None:
super(rLogSoftmax, self).__init__()
self.dim = dim
def __setstate__(self, state):
self.__dict__.update(state)
if not hasattr(self, 'dim'):
self.dim = None
def forward(self, input: Tensor) -> Tensor:
return r_log_softmax(input, self.dim, _stacklevel=5)
def extra_repr(self) -> str:
return 'dim={dim}'.format(dim=self.dim)
def complex_fcaller(funtional_handle, *args):
return torch.complex(funtional_handle(args[0].real, *args[1:]), funtional_handle(args[0].imag, *args[1:]))
def c_sigmoid(input: Tensor):
if input.is_complex():
return torch.complex(torch.sigmoid(input.real), torch.sigmoid(input.imag))
else:
return torch.sigmoid(input)
def c_tanh(input: Tensor):
if input.is_complex():
return torch.complex(torch.tanh(input.real), torch.tanh(input.imag))
else:
return torch.tanh(input)
def mod_tanh(input: Tensor, rounding_mode: str = None) -> Tensor:
if input.is_complex():
magnitude = torch.abs(input)
euler_phase = torch.div(input=input, other=magnitude, rounding_mode=rounding_mode)
return torch.mul(torch.tanh(magnitude), euler_phase).type(input.type())
else:
return torch.tanh(input)
def hirose(input: Tensor, m: float = 1., rounding_mode: str = None, inplace: bool = False) -> Tensor:
if input.is_complex():
magnitude = torch.abs(input)
euler_phase = torch.div(input, magnitude)
if inplace:
input = torch.mul(torch.tanh(torch.div(input=magnitude, other=torch.pow(m, 2), rounding_mode=rounding_mode)), euler_phase).type(input.type())
return input
else:
hirose = torch.mul(torch.tanh(torch.div(input=magnitude, other=torch.pow(m, 2), rounding_mode=rounding_mode)), euler_phase).type(input.type())
return hirose
else:
if inplace:
input = torch.tanh(torch.div(input=input, other=torch.pow(m, 2), rounding_mode=rounding_mode)).type(input.type())
return input
else:
hirose = torch.tanh(torch.div(input=input, other=torch.pow(m, 2), rounding_mode=rounding_mode)).type(input.type())
return hirose
def siglog(input: Tensor):
return torch.div(input, 1 + torch.abs(input))
def c_cardioid(input: Tensor):
phase = torch.angle(input)
return 0.5 * torch.mul(1 + torch.cos(phase), input)
def c_relu(input: Tensor, inplace: bool = False) -> Tensor:
if input.is_complex():
return torch.complex(F.relu(input.real, inplace=inplace), F.relu(input.imag, inplace=inplace))
else:
return F.relu(input, inplace=inplace)
def mod_relu(input: Tensor, bias: float = -math.sqrt(2), rounding_mode: str = None, inplace: bool = False) -> Tensor:
if input.is_complex():
magnitude = torch.abs(input)
euler_phase = torch.div(input=input, other=magnitude, rounding_mode=rounding_mode)
if inplace:
input = torch.mul(F.relu(magnitude + bias, inplace=False), euler_phase).type(input.type())
return input
else:
mod_relu = torch.mul(F.relu(magnitude + bias, inplace=inplace), euler_phase).type(input.type())
return mod_relu
else:
return F.relu(input, inplace=inplace)
def z_relu(input: Tensor, inplace: bool = False) -> Tensor:
if input.is_complex():
if inplace:
mask = torch.zeros_like(input)
input = torch.where(torch.angle(input) < 0, mask, input)
input = torch.where(torch.angle(input) > (math.pi / 2), mask, input)
return input
else:
mask = torch.zeros_like(input)
z_relu = torch.where(torch.angle(input) < 0, mask, input)
z_relu = torch.where(torch.angle(z_relu) > (math.pi / 2), mask, z_relu)
return z_relu
else:
return F.relu(input, inplace=inplace)
def c_leaky_relu(input: Tensor, negative_slope: float = 0.01, inplace: bool = False) -> Tensor:
if input.is_complex():
return torch.complex(F.leaky_relu(input=input.real, negative_slope=negative_slope, inplace=inplace), \
F.leaky_relu(input=input.imag, negative_slope=negative_slope, inplace=inplace))
else:
return F.leaky_relu(input=input, negative_slope=negative_slope, inplace=inplace)
def mod_softmax(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtype: Optional[int] = None) -> Tensor:
if input.is_complex():
return F.softmax(torch.abs(input), dim=dim, _stacklevel=_stacklevel, dtype=dtype)
else:
return F.softmax(input, dim=dim, _stacklevel=_stacklevel, dtype=dtype)
def mod_log_softmax(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtype: Optional[int] = None) -> Tensor:
if input.is_complex():
return F.log_softmax(torch.abs(input), dim=dim, _stacklevel=_stacklevel, dtype=dtype)
else:
return F.log_softmax(input, dim=dim, _stacklevel=_stacklevel, dtype=dtype)
def r_softmax(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtype: Optional[int] = None) -> Tensor:
if input.is_complex():
return F.softmax(input.real, dim=dim, _stacklevel=_stacklevel, dtype=dtype)
else:
return F.softmax(input, dim=dim, _stacklevel=_stacklevel, dtype=dtype)
def r_log_softmax(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtype: Optional[int] = None) -> Tensor:
if input.is_complex():
return F.log_softmax(input.real, dim=dim, _stacklevel=_stacklevel, dtype=dtype)
else:
return F.log_softmax(input, dim=dim, _stacklevel=_stacklevel, dtype=dtype)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment