Created
December 27, 2023 01:06
-
-
Save wassname/0e3b8a88074c74927c03e4e821186640 to your computer and use it in GitHub Desktop.
wrap sklearn scalars for torch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
how to wrap a scikit-learn scalar like RobustScaler for pytorch | |
""" | |
import torch | |
import numpy as np | |
from einops import rearrange | |
from sklearn.preprocessing import StandardScaler, RobustScaler | |
class TorchRobustScaler(RobustScaler): | |
def wrap(self, X, method: str): | |
b, l, h, v = X.shape | |
X = rearrange(X, "b l h v -> b (l h v)") | |
X = getattr(super(), method)(X) | |
if isinstance(X, np.ndarray): | |
X = torch.from_numpy(rearrange(X, "b (l h v) -> b l h v", l=l, h=h, v=v)) | |
return X | |
def fit(self, X): | |
return self.wrap(X, "fit") | |
def transform(self, X): | |
return self.wrap(X, "transform") | |
def inverse_transform(self, X): | |
return self.wrap(X, "inverse_transform") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment