Skip to content

Instantly share code, notes, and snippets.

@stes
Created October 2, 2023 21:03
Show Gist options
  • Save stes/8722195744040ba186f4820c9faee7a5 to your computer and use it in GitHub Desktop.
Save stes/8722195744040ba186f4820c9faee7a5 to your computer and use it in GitHub Desktop.
from cebra import CEBRA
import joblib as jl
import sklearn.linear_model
data = jl.load('data/synthetic/continuous_label_poisson.jl')
def reconstruction_score(x, y):
def _linear_fitting(x, y):
lin_model = sklearn.linear_model.LinearRegression()
lin_model.fit(x, y)
return lin_model.score(x, y), lin_model.predict(x)
return _linear_fitting(x, y)
for _ in range(100):
max_iterations = 1000
cebra_model = CEBRA(
model_architecture="offset1-model-mse",
batch_size=512,
learning_rate=1e-4,
max_iterations=max_iterations,
delta=0.1,
conditional='delta',
output_dimension=2,
distance='euclidean',
device="cuda",
verbose=False,
)
cebra_model.partial_fit(data['x'][:12000], data['u'][:12000])
cebra_output = cebra_model.transform(data['x'])
cebra_score, transformed_cebra_z = reconstruction_score(
cebra_output,
data['z'][:, :2]
)
print("R2", cebra_score)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment