-
-
Save grovduck/d8f9df6da62be204d62435fb33ffed8a to your computer and use it in GitHub Desktop.
WIP on predict
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import pytest | |
class Dataset: | |
def __init__(self, project, method, k=5): | |
ref_distances_df = pd.read_csv( | |
f"./tests/data/{method}_{project}_ref_distances_k{k}.csv" | |
) | |
ref_neighbors_df = pd.read_csv( | |
f"./tests/data/{method}_{project}_ref_neighbors_k{k}.csv" | |
) | |
trg_distances_df = pd.read_csv( | |
f"./tests/data/{method}_{project}_trg_distances_k{k}.csv" | |
) | |
trg_neighbors_df = pd.read_csv( | |
f"./tests/data/{method}_{project}_trg_neighbors_k{k}.csv" | |
) | |
trg_predicted_weighted_df = pd.read_csv( | |
f"./tests/data/{method}_{project}_trg_predicted_weighted_k{k}.csv" | |
) | |
trg_predicted_unweighted_df = pd.read_csv( | |
f"./tests/data/{method}_{project}_trg_predicted_unweighted_k{k}.csv" | |
) | |
env_df = pd.read_csv(f"./tests/data/{project}_env.csv") | |
spp_df = pd.read_csv(f"./tests/data/{project}_spp.csv") | |
cols = [f"K{i+1}" for i in range(k)] | |
self.ref_distances = ref_distances_df.loc[:, cols].values | |
self.ref_neighbors = ref_neighbors_df.loc[:, cols].values | |
self.trg_distances = trg_distances_df.loc[:, cols].values | |
self.trg_neighbors = trg_neighbors_df.loc[:, cols].values | |
self.trg_predicted_weighted = trg_predicted_weighted_df.iloc[:, 1:].values | |
self.trg_predicted_unweighted = trg_predicted_unweighted_df.iloc[:, 1:].values | |
self.X = env_df.iloc[:, 1:].values | |
self.y = spp_df.iloc[:, 1:].values | |
self.ids = env_df.iloc[:, 0].values | |
@pytest.fixture | |
def moscow_raw(): | |
return Dataset(project="moscow", method="raw", k=5) | |
@pytest.fixture | |
def moscow_euclidean(): | |
return Dataset(project="moscow", method="euclidean", k=5) | |
@pytest.fixture | |
def moscow_mahalanobis(): | |
return Dataset(project="moscow", method="mahalanobis", k=5) | |
@pytest.fixture | |
def moscow_msn(): | |
return Dataset(project="moscow", method="msn", k=5) | |
@pytest.fixture | |
def moscow_gnn(): | |
return Dataset(project="moscow", method="gnn", k=5) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from numpy.testing import assert_array_almost_equal | |
from sklearn.model_selection import train_test_split | |
from sknnr import ( | |
EuclideanKNNRegressor, | |
GNNRegressor, | |
MahalanobisKNNRegressor, | |
MSNRegressor, | |
RawKNNRegressor, | |
) | |
def yaimpute_weights(d): | |
return 1.0 / (1.0 + d) | |
def test_moscow_raw(moscow_raw): | |
# X_train, X_test, y_train, _ = train_test_split( | |
# moscow_raw.X, moscow_raw.ids, train_size=0.8, shuffle=False | |
# ) | |
X_train, X_test, y_train, _ = train_test_split( | |
moscow_raw.X, moscow_raw.y, train_size=0.8, shuffle=False | |
) | |
clf = RawKNNRegressor(n_neighbors=5).fit(X_train, y_train) | |
dist, nn = clf.kneighbors() | |
# assert_array_equal(nn, moscow_raw.ref_neighbors) | |
assert_array_almost_equal(dist, moscow_raw.ref_distances, decimal=3) | |
dist, nn = clf.kneighbors(X_test) | |
# assert_array_equal(nn, moscow_raw.trg_neighbors) | |
assert_array_almost_equal(dist, moscow_raw.trg_distances, decimal=3) | |
prd = clf.predict(X_test) | |
assert_array_almost_equal(prd, moscow_raw.trg_predicted_unweighted, decimal=3) | |
clf = RawKNNRegressor(n_neighbors=5, weights=yaimpute_weights).fit(X_train, y_train) | |
prd = clf.predict(X_test) | |
assert_array_almost_equal(prd, moscow_raw.trg_predicted_weighted, decimal=3) | |
def test_moscow_euclidean(moscow_euclidean): | |
# X_train, X_test, y_train, _ = train_test_split( | |
# moscow_euclidean.X, moscow_euclidean.ids, train_size=0.8, shuffle=False | |
# ) | |
X_train, X_test, y_train, _ = train_test_split( | |
moscow_euclidean.X, moscow_euclidean.y, train_size=0.8, shuffle=False | |
) | |
clf = EuclideanKNNRegressor(n_neighbors=5).fit(X_train, y_train) | |
dist, nn = clf.kneighbors() | |
# assert_array_equal(nn, moscow_euclidean.ref_neighbors) | |
assert_array_almost_equal(dist, moscow_euclidean.ref_distances, decimal=3) | |
dist, nn = clf.kneighbors(X_test) | |
# assert_array_equal(nn, moscow_euclidean.trg_neighbors) | |
assert_array_almost_equal(dist, moscow_euclidean.trg_distances, decimal=3) | |
prd = clf.predict(X_test) | |
assert_array_almost_equal(prd, moscow_euclidean.trg_predicted_unweighted, decimal=3) | |
clf = EuclideanKNNRegressor(n_neighbors=5, weights=yaimpute_weights).fit( | |
X_train, y_train | |
) | |
prd = clf.predict(X_test) | |
assert_array_almost_equal(prd, moscow_euclidean.trg_predicted_weighted, decimal=3) | |
def test_moscow_mahalanobis(moscow_mahalanobis): | |
# X_train, X_test, y_train, _ = train_test_split( | |
# moscow_mahalanobis.X, moscow_mahalanobis.ids, train_size=0.8, shuffle=False | |
# ) | |
X_train, X_test, y_train, _ = train_test_split( | |
moscow_mahalanobis.X, moscow_mahalanobis.y, train_size=0.8, shuffle=False | |
) | |
clf = MahalanobisKNNRegressor(n_neighbors=5).fit(X_train, y_train) | |
dist, nn = clf.kneighbors() | |
# assert_array_equal(nn, moscow_mahalanobis.ref_neighbors) | |
assert_array_almost_equal(dist, moscow_mahalanobis.ref_distances, decimal=3) | |
dist, nn = clf.kneighbors(X_test) | |
# assert_array_equal(nn, moscow_mahalanobis.trg_neighbors) | |
assert_array_almost_equal(dist, moscow_mahalanobis.trg_distances, decimal=3) | |
prd = clf.predict(X_test) | |
assert_array_almost_equal( | |
prd, moscow_mahalanobis.trg_predicted_unweighted, decimal=3 | |
) | |
clf = MahalanobisKNNRegressor(n_neighbors=5, weights=yaimpute_weights).fit( | |
X_train, y_train | |
) | |
prd = clf.predict(X_test) | |
assert_array_almost_equal(prd, moscow_mahalanobis.trg_predicted_weighted, decimal=3) | |
def test_moscow_msn(moscow_msn): | |
# X_train, X_test, y_train, _, y_spp, _ = train_test_split( | |
# moscow_msn.X, moscow_msn.ids, moscow_msn.y, train_size=0.8, shuffle=False | |
# ) | |
X_train, X_test, y_train, _, y_spp, _ = train_test_split( | |
moscow_msn.X, moscow_msn.y, moscow_msn.y, train_size=0.8, shuffle=False | |
) | |
clf = MSNRegressor(n_neighbors=5).fit(X_train, y_train, spp=y_spp) | |
dist, nn = clf.kneighbors() | |
# assert_array_equal(nn, moscow_msn.ref_neighbors) | |
assert_array_almost_equal(dist, moscow_msn.ref_distances, decimal=3) | |
dist, nn = clf.kneighbors(X_test) | |
# assert_array_equal(nn, moscow_msn.trg_neighbors) | |
assert_array_almost_equal(dist, moscow_msn.trg_distances, decimal=3) | |
prd = clf.predict(X_test) | |
assert_array_almost_equal(prd, moscow_msn.trg_predicted_unweighted, decimal=3) | |
clf = MSNRegressor(n_neighbors=5, weights=yaimpute_weights).fit(X_train, y_train) | |
prd = clf.predict(X_test) | |
assert_array_almost_equal(prd, moscow_msn.trg_predicted_weighted, decimal=3) | |
def test_moscow_gnn(moscow_gnn): | |
# X_train, X_test, y_train, _, y_spp, _ = train_test_split( | |
# moscow_gnn.X, moscow_gnn.ids, moscow_gnn.y, train_size=0.8, shuffle=False | |
# ) | |
X_train, X_test, y_train, _, y_spp, _ = train_test_split( | |
moscow_gnn.X, moscow_gnn.y, moscow_gnn.y, train_size=0.8, shuffle=False | |
) | |
clf = GNNRegressor(n_neighbors=5).fit(X_train, y_train, spp=y_spp) | |
dist, nn = clf.kneighbors() | |
# assert_array_equal(nn, moscow_gnn.ref_neighbors) | |
assert_array_almost_equal(dist, moscow_gnn.ref_distances, decimal=3) | |
dist, nn = clf.kneighbors(X_test) | |
# assert_array_equal(nn, moscow_gnn.trg_neighbors) | |
assert_array_almost_equal(dist, moscow_gnn.trg_distances, decimal=3) | |
prd = clf.predict(X_test) | |
assert_array_almost_equal(prd, moscow_gnn.trg_predicted_unweighted, decimal=3) | |
clf = GNNRegressor(n_neighbors=5, weights=yaimpute_weights).fit(X_train, y_train) | |
prd = clf.predict(X_test) | |
assert_array_almost_equal(prd, moscow_gnn.trg_predicted_weighted, decimal=3) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment