Skip to content

Instantly share code, notes, and snippets.

@wassname
Last active January 15, 2021 07:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wassname/298c5a7b4e4fb0974afbe48d9717dd14 to your computer and use it in GitHub Desktop.
Save wassname/298c5a7b4e4fb0974afbe48d9717dd14 to your computer and use it in GitHub Desktop.
Causal k nearest neighbors (KNN) that only looks back
# %%
import numpy as np
from functools import partial
from pykdtree.kdtree import KDTree
class LeftKDTree(KDTree):
"""
KNN that only looks left.
This way it respects causality when left is the past. Usefull for local anomoly factor (LOF).
url: https://gist.github.com/wassname/298c5a7b4e4fb0974afbe48d9717dd14
"""
def query(self, y, y_inds=None, *args, **kwargs):
x_inds = np.arange(self.n)
if y_inds is None:
y_inds = x_inds.copy()
# print(y.shape, y_inds.shape)
assert len(y) == len(y_inds), f'{y.shape}!={y_inds.shape}'
dists = []
inds = []
for i, y_ind in enumerate(y_inds):
d, ind = super().query(y[i:i + 1], mask=x_inds > y_ind - 1, *args, **kwargs)
dists.append(d)
inds.append(ind)
dists = np.concatenate(dists)
inds = np.concatenate(inds)
return dists, inds
# %%
X = np.array([np.arange(20)]*4).T
y = X[4:-4]
y_inds = np.arange(4, len(X) - 4)
# X, y, y_inds
# %%
tree = LeftKDTree(X)
d, ids = tree.query(y, y_inds)
id_is_left=((y_inds-ids)>0).all()
print(id_is_left, d, ids)
assert id_is_left.all()
# True [2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2.] [ 3 4 5 6 7 8 9 10 11 12 13 14]
# %%
tree = KDTree(X)
d, ids = tree.query(y)
id_is_left=(y_inds-ids>0).all()
print(id_is_left, d, ids)
assert not id_is_left.all()
# False [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] [ 4 5 6 7 8 9 10 11 12 13 14 15]
# %%
# %%
# Note
# d==4294967295 and dists==1.3407807929942596e+154 are np.nan, So you may wants to go
# dists[dists==1.3407807929942596e+154]=np.nan
y_inds = np.arange(len(X))
tree = LeftKDTree(X)
d, ids = tree.query(X, y_inds)
id_is_left=(y_inds-ids>0)[1:].all()
print(id_is_left, d, ids)
assert id_is_left.all()
# True [1.34078079e+154 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2.
# 2. 2.] [4294967295 0 1 2 3 4
# 5 6 7 8 9 10
# 11 12 13 14 15 16
# 17 18]
# %%
tree = KDTree(X)
d, ids = tree.query(X)
id_is_left = (y_inds - ids > 0).all()
print(id_is_left, d, ids)
assert not id_is_left.all()
# False [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# %%
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment