Last active
September 24, 2023 18:50
-
-
Save hazdzz/2155de16f99e0fd5f5256dfad24a0a4a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch import Tensor | |
from typing import Callable, List, Optional, Tuple | |
import math | |
import warnings | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class CSigmoid(nn.Module): | |
def forward(self, input: Tensor) -> Tensor: | |
return c_sigmoid(input) | |
class CTanh(nn.Module): | |
def forward(self, input: Tensor) -> Tensor: | |
return c_tanh(input) | |
class modTanh(nn.Module): | |
__constants__ = ['rounding_mode'] | |
rounding_mode: str | |
def __init__(self, rounding_mode: str = None): | |
super(modTanh, self).__init__() | |
self.rounding_mode = rounding_mode | |
def forward(self, input: Tensor, rounding_mode: str = None) -> Tensor: | |
return mod_tanh(input, rounding_mode=rounding_mode) | |
def extra_repr(self) -> str: | |
return 'rounding_mode={}'.format(self.rounding_mode) | |
class Hirose(nn.Module): | |
__constants__ = ['m', 'inplace'] | |
m: float | |
inplace: bool | |
def __init__(self, m: float = 1., inplace: bool = False): | |
super(Hirose, self).__init__() | |
self.m = m | |
self.inplace = inplace | |
def forward(self, input: Tensor) -> Tensor: | |
return hirose(input, m=self.m, inplace=self.inplace) | |
def extra_repr(self) -> str: | |
inplace_str = 'inplace=True' if self.inplace else '' | |
return 'm={}{}'.format(self.m, inplace_str) | |
class Siglog(nn.Module): | |
def forward(self, input: Tensor) -> Tensor: | |
return siglog(input) | |
class CCardioid(nn.Module): | |
def forward(self, input: Tensor) -> Tensor: | |
return c_cardioid(input) | |
class CReLU(nn.Module): | |
__constants__ = ['inplace'] | |
inplace: bool | |
def __init__(self, inplace: bool = False): | |
super(CReLU, self).__init__() | |
self.inplace = inplace | |
def forward(self, input: Tensor) -> Tensor: | |
return c_relu(input, inplace=self.inplace) | |
def extra_repr(self) -> str: | |
inplace_str = 'inplace=True' if self.inplace else '' | |
return inplace_str | |
class zReLU(nn.Module): | |
__constants__ = ['inplace'] | |
inplace: bool | |
def __init__(self, inplace: bool = False): | |
super(zReLU, self).__init__() | |
self.inplace = inplace | |
def forward(self, input: Tensor) -> Tensor: | |
return z_relu(input, inplace=self.inplace) | |
def extra_repr(self) -> str: | |
inplace_str = 'inplace=True' if self.inplace else '' | |
return inplace_str | |
class modReLU(nn.Module): | |
__constants__ = ['bias', 'rounding_mode', 'inplace'] | |
bias: float | |
rounding_mode: str | |
inplace: bool | |
def __init__(self, bias: float = -math.sqrt(2), rounding_mode: str = None, inplace: bool = False): | |
super(modReLU, self).__init__() | |
self.bias = bias | |
self.rounding_mode = rounding_mode | |
self.inplace = inplace | |
def forward(self, input: Tensor) -> Tensor: | |
return mod_relu(input, bias=self.bias, rounding_mode=self.rounding_mode, inplace=self.inplace) | |
def extra_repr(self) -> str: | |
inplace_str = 'inplace=True' if self.inplace else '' | |
return 'bias={}, rounding_mode={}, inplace_str={}'.format(self.bias, self.rounding_mode, inplace_str) | |
class CLeakyReLU(nn.Module): | |
__constants__ = ['negative_slope', 'inplace'] | |
negative_slope: float | |
inplace: bool | |
def __init__(self, negative_slope: float = 1e-2, inplace: bool = False) -> None: | |
super(CLeakyReLU, self).__init__() | |
self.negative_slope = negative_slope | |
self.inplace = inplace | |
def forward(self, input: Tensor) -> Tensor: | |
return c_leaky_relu(input, self.negative_slope, self.inplace) | |
def extra_repr(self) -> str: | |
inplace_str = ', inplace=True' if self.inplace else '' | |
return 'negative_slope={}{}'.format(self.negative_slope, inplace_str) | |
class modSoftmax(nn.Module): | |
__constants__ = ['dim'] | |
dim: Optional[int] | |
def __init__(self, dim: Optional[int] = None) -> None: | |
super(modSoftmax, self).__init__() | |
self.dim = dim | |
def __setstate__(self, state): | |
self.__dict__.update(state) | |
if not hasattr(self, 'dim'): | |
self.dim = None | |
def forward(self, input: Tensor) -> Tensor: | |
return mod_softmax(input, self.dim, _stacklevel=5) | |
def extra_repr(self) -> str: | |
return 'dim={dim}'.format(dim=self.dim) | |
class modLogSoftmax(nn.Module): | |
__constants__ = ['dim'] | |
dim: Optional[int] | |
def __init__(self, dim: Optional[int] = None) -> None: | |
super(modLogSoftmax, self).__init__() | |
self.dim = dim | |
def __setstate__(self, state): | |
self.__dict__.update(state) | |
if not hasattr(self, 'dim'): | |
self.dim = None | |
def forward(self, input: Tensor) -> Tensor: | |
return mod_log_softmax(input, self.dim, _stacklevel=5) | |
def extra_repr(self) -> str: | |
return 'dim={dim}'.format(dim=self.dim) | |
class rSoftmax(nn.Module): | |
__constants__ = ['dim'] | |
dim: Optional[int] | |
def __init__(self, dim: Optional[int] = None) -> None: | |
super(rSoftmax, self).__init__() | |
self.dim = dim | |
def __setstate__(self, state): | |
self.__dict__.update(state) | |
if not hasattr(self, 'dim'): | |
self.dim = None | |
def forward(self, input: Tensor) -> Tensor: | |
return r_softmax(input, self.dim, _stacklevel=5) | |
def extra_repr(self) -> str: | |
return 'dim={dim}'.format(dim=self.dim) | |
class rLogSoftmax(nn.Module): | |
__constants__ = ['dim'] | |
dim: Optional[int] | |
def __init__(self, dim: Optional[int] = None) -> None: | |
super(rLogSoftmax, self).__init__() | |
self.dim = dim | |
def __setstate__(self, state): | |
self.__dict__.update(state) | |
if not hasattr(self, 'dim'): | |
self.dim = None | |
def forward(self, input: Tensor) -> Tensor: | |
return r_log_softmax(input, self.dim, _stacklevel=5) | |
def extra_repr(self) -> str: | |
return 'dim={dim}'.format(dim=self.dim) | |
def complex_fcaller(funtional_handle, *args): | |
return torch.complex(funtional_handle(args[0].real, *args[1:]), funtional_handle(args[0].imag, *args[1:])) | |
def c_sigmoid(input: Tensor): | |
if input.is_complex(): | |
return torch.complex(torch.sigmoid(input.real), torch.sigmoid(input.imag)) | |
else: | |
return torch.sigmoid(input) | |
def c_tanh(input: Tensor): | |
if input.is_complex(): | |
return torch.complex(torch.tanh(input.real), torch.tanh(input.imag)) | |
else: | |
return torch.tanh(input) | |
def mod_tanh(input: Tensor, rounding_mode: str = None) -> Tensor: | |
if input.is_complex(): | |
magnitude = torch.abs(input) | |
euler_phase = torch.div(input=input, other=magnitude, rounding_mode=rounding_mode) | |
return torch.mul(torch.tanh(magnitude), euler_phase).type(input.type()) | |
else: | |
return torch.tanh(input) | |
def hirose(input: Tensor, m: float = 1., rounding_mode: str = None, inplace: bool = False) -> Tensor: | |
if input.is_complex(): | |
magnitude = torch.abs(input) | |
euler_phase = torch.div(input, magnitude) | |
if inplace: | |
input = torch.mul(torch.tanh(torch.div(input=magnitude, other=torch.pow(m, 2), rounding_mode=rounding_mode)), euler_phase).type(input.type()) | |
return input | |
else: | |
hirose = torch.mul(torch.tanh(torch.div(input=magnitude, other=torch.pow(m, 2), rounding_mode=rounding_mode)), euler_phase).type(input.type()) | |
return hirose | |
else: | |
if inplace: | |
input = torch.tanh(torch.div(input=input, other=torch.pow(m, 2), rounding_mode=rounding_mode)).type(input.type()) | |
return input | |
else: | |
hirose = torch.tanh(torch.div(input=input, other=torch.pow(m, 2), rounding_mode=rounding_mode)).type(input.type()) | |
return hirose | |
def siglog(input: Tensor): | |
return torch.div(input, 1 + torch.abs(input)) | |
def c_cardioid(input: Tensor): | |
phase = torch.angle(input) | |
return 0.5 * torch.mul(1 + torch.cos(phase), input) | |
def c_relu(input: Tensor, inplace: bool = False) -> Tensor: | |
if input.is_complex(): | |
return torch.complex(F.relu(input.real, inplace=inplace), F.relu(input.imag, inplace=inplace)) | |
else: | |
return F.relu(input, inplace=inplace) | |
def mod_relu(input: Tensor, bias: float = -math.sqrt(2), rounding_mode: str = None, inplace: bool = False) -> Tensor: | |
if input.is_complex(): | |
magnitude = torch.abs(input) | |
euler_phase = torch.div(input=input, other=magnitude, rounding_mode=rounding_mode) | |
if inplace: | |
input = torch.mul(F.relu(magnitude + bias, inplace=False), euler_phase).type(input.type()) | |
return input | |
else: | |
mod_relu = torch.mul(F.relu(magnitude + bias, inplace=inplace), euler_phase).type(input.type()) | |
return mod_relu | |
else: | |
return F.relu(input, inplace=inplace) | |
def z_relu(input: Tensor, inplace: bool = False) -> Tensor: | |
if input.is_complex(): | |
if inplace: | |
mask = torch.zeros_like(input) | |
input = torch.where(torch.angle(input) < 0, mask, input) | |
input = torch.where(torch.angle(input) > (math.pi / 2), mask, input) | |
return input | |
else: | |
mask = torch.zeros_like(input) | |
z_relu = torch.where(torch.angle(input) < 0, mask, input) | |
z_relu = torch.where(torch.angle(z_relu) > (math.pi / 2), mask, z_relu) | |
return z_relu | |
else: | |
return F.relu(input, inplace=inplace) | |
def c_leaky_relu(input: Tensor, negative_slope: float = 0.01, inplace: bool = False) -> Tensor: | |
if input.is_complex(): | |
return torch.complex(F.leaky_relu(input=input.real, negative_slope=negative_slope, inplace=inplace), \ | |
F.leaky_relu(input=input.imag, negative_slope=negative_slope, inplace=inplace)) | |
else: | |
return F.leaky_relu(input=input, negative_slope=negative_slope, inplace=inplace) | |
def mod_softmax(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtype: Optional[int] = None) -> Tensor: | |
if input.is_complex(): | |
return F.softmax(torch.abs(input), dim=dim, _stacklevel=_stacklevel, dtype=dtype) | |
else: | |
return F.softmax(input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) | |
def mod_log_softmax(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtype: Optional[int] = None) -> Tensor: | |
if input.is_complex(): | |
return F.log_softmax(torch.abs(input), dim=dim, _stacklevel=_stacklevel, dtype=dtype) | |
else: | |
return F.log_softmax(input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) | |
def r_softmax(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtype: Optional[int] = None) -> Tensor: | |
if input.is_complex(): | |
return F.softmax(input.real, dim=dim, _stacklevel=_stacklevel, dtype=dtype) | |
else: | |
return F.softmax(input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) | |
def r_log_softmax(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtype: Optional[int] = None) -> Tensor: | |
if input.is_complex(): | |
return F.log_softmax(input.real, dim=dim, _stacklevel=_stacklevel, dtype=dtype) | |
else: | |
return F.log_softmax(input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment