Skip to content

Instantly share code, notes, and snippets.

@kingjr
Created October 29, 2020 15:02
Show Gist options
  • Save kingjr/5132a4b510bc4e4b596cabd36ecc7313 to your computer and use it in GitHub Desktop.
Save kingjr/5132a4b510bc4e4b596cabd36ecc7313 to your computer and use it in GitHub Desktop.
tanh_scaler
import numpy as np
from sklearn.preprocessing import RobustScaler
from sklearn.base import TransformerMixin
class TanhScaler(TransformerMixin):
def __init__(self,
factor=2.,
decim=1,
n_max=None,
shuffle=False,
with_centering=True,
with_scaling=True,
quantile_range=(25.0, 75.0),):
self.factor = factor
self.decim = decim
self.n_max = n_max
self.shuffle = shuffle
self.quantile_range = quantile_range
def fit(self, X, y=None):
self.scaler = RobustScaler(
with_centering=True,
with_scaling=True,
quantile_range=self.quantile_range,
copy=True,
unit_variance=True,
)
if self.decim:
X = X[::decim]
if self.n_max:
if self.shuffle:
select = np.random.permutation(len(X))[:self.n_max]
else:
select = slice(0, self.n_max)
X = X[select]
self.scaler.fit(X)
return self
def transform(self, X, y=None):
X -= self.scaler.center_
X /= self.scaler.scale_
return np.tanh(X / self.factor)
x = np.random.randn(10000, 1)/2
x[:100] += 200
plt.subplot(131).plot(np.sort(x.ravel()))
y = RobustScaler(unit_variance=False).fit_transform(x)
plt.subplot(132).plot(np.sort(y.ravel()))
z = TanhScaler(10).fit_transform(x)
plt.subplot(133).plot(np.sort(z.ravel()))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment