Skip to content

Instantly share code, notes, and snippets.

@wassname
Created December 27, 2023 01:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wassname/0e3b8a88074c74927c03e4e821186640 to your computer and use it in GitHub Desktop.
Save wassname/0e3b8a88074c74927c03e4e821186640 to your computer and use it in GitHub Desktop.
wrap sklearn scalars for torch
"""
how to wrap a scikit-learn scalar like RobustScaler for pytorch
"""
import torch
import numpy as np
from einops import rearrange
from sklearn.preprocessing import StandardScaler, RobustScaler
class TorchRobustScaler(RobustScaler):
def wrap(self, X, method: str):
b, l, h, v = X.shape
X = rearrange(X, "b l h v -> b (l h v)")
X = getattr(super(), method)(X)
if isinstance(X, np.ndarray):
X = torch.from_numpy(rearrange(X, "b (l h v) -> b l h v", l=l, h=h, v=v))
return X
def fit(self, X):
return self.wrap(X, "fit")
def transform(self, X):
return self.wrap(X, "transform")
def inverse_transform(self, X):
return self.wrap(X, "inverse_transform")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment