[python] TLSH Vantage Point Trees
import os
import json
import numpy as np
import tlsh # from `py-tlsh`
from joblib import Parallel, delayed, parallel_backend
from sklearn.metrics import pairwise_distances_chunked
from tqdm import tqdm
from vpt import VantagePointTree
# ----------------------------------------------------------------------------
# demo build hashes function
def build_hashes(files):
tlsh_lookup = dict()
tlsh_hashers = list()
duplicates = dict()
failures = list()
for file in files:
hasher = tlsh.Tlsh()
with open(file, "rb") as fp:
for line in fp:
except ValueError:
if not hasher.is_valid:
digest = hasher.hexdigest()
if digest in tlsh_lookup:
entry = tlsh_lookup[digest]
if digest not in duplicates:
duplicates[digest] = [entry]
tlsh_lookup[digest] = file
return tlsh_lookup, tlsh_hashers, duplicates, failures
# ----------------------------------------------------------------------------
# sklearn `pairwise_distances_chunked` brute-force distance calculation
# can be parallel with joblib backend
def find_near_duplicates(tlsh_hashers, max_dist=100):
Xlen = len(tlsh_hashers)
X = np.arange(Xlen)[:, np.newaxis]
def tlsh_sim(row, col):
row = int(row)
col = int(col)
# skip computation if left-lower half or same
if row == col:
return 0
# print(row, col) # DEBUG
# add this check for slice computation
# (if whole pairwise matrix can be computed, sklearn assumes symmetry?)
if row > col:
return -1
# return (np.arange(len(tlsh_hashers)) + 1)[row] + ((np.arange(len(tlsh_hashers)) + 1) / 10)[col] # DEBUG
h_row = tlsh_hashers[row]
h_col = tlsh_hashers[col]
return h_row.diff(h_col)
# memory for a single row (instead of warning for `=0`)
# 1 MB --> 1 B --> int8 (8 byte?) * row-length
working_memory = 1 / 2**20 * 8 * Xlen
# numbers of rows per chunk
num_rows = 1
working_memory *= num_rows
# max_dist: maximum distance (if None or negative, compute all)
# reducer function (to only return col indexes per row that are below max distance)
def reduce_func(D_chunk, start_row):
# filter at least positive, less than max_dist and not same
idxs = [
np.flatnonzero((d >= 0) & (d <= max_dist) & (np.arange(Xlen) != pos))
for pos, d in enumerate(D_chunk, start_row)
vals = [d[i] for d, i in zip(D_chunk, idxs)]
return idxs, vals
# NOTE: if whole NxN matrix fits into working_memory, sklearn assumes symmetry and only computes top-right half and mirrors it on -1
pwdst_iter = pairwise_distances_chunked(
below_max_dist_idxs = []
below_max_dist_vals = []
for chunk_idxs, chunk_vals in tqdm(
pwdst_iter, total=round(Xlen / num_rows), desc="Filter max_dist"
below_max_dist_idxs = [
(row, col) for row, cols in enumerate(below_max_dist_idxs) for col in cols
below_max_dist_vals = np.concatenate(below_max_dist_vals)
return below_max_dist_idxs, below_max_dist_vals
# ----------------------------------------------------------------------------
# iterator to build hashes matrix (will compute it for every pair row, no shortcuts)
def compute_near_duplicates_matrix_iter(tlsh_hashers):
X = np.arange(len(tlsh_hashers))[:, np.newaxis]
def tlsh_sim(row, col):
row = int(row)
col = int(col)
# skip computation if left-lower half or same
if row == col:
return 0
# add this check for slice computation
# (if whole pairwise matrix can be computed, sklearn assumes symmetry?)
if row > col:
return -1
h_row = tlsh_hashers[row]
h_col = tlsh_hashers[col]
return h_row.diff(h_col)
# memory for a single row (instead of warning for `=0`)
# 1 MB --> 1 B --> int8 (8 byte?) * row-length
working_memory = 1 / 2**20 * 8 * len(tlsh_hashers)
# numbers of rows per chunk
num_rows = 1
working_memory *= num_rows
# max_dist: maximum distance (if None or negative, compute all)
# NOTE: if whole NxN matrix fits into working_memory, sklearn assumes symmetry and only computes top-right half and mirrors it on -1
pwdst_iter = pairwise_distances_chunked(
yield from pwdst_iter
# ----------------------------------------------------------------------------
# parallel search (using global globaltree VPT tree instance)
def _find_nearest_vpt_one(digest, i, max_dist):
global globaltree
tlsh_hash = tlsh.Tlsh()
result =
# no result
if not result:
return None
# to distant
dist, idx, th_other = result
if dist > max_dist:
return None
# precompute for later
result_key = tuple([dist] + sorted([i, idx]))
drow = tlsh_hash.hexdigest()
dcol = th_other.hexdigest()
return result_key, (drow, dcol)
def find_nearest_with_VPT_parallel(tlsh_hashers, tlsh_lookup, tree, max_dist=100):
global globaltree
globaltree = tree
results = dict()
seen = set()
tlsh_hashers = tqdm(tlsh_hashers, desc="Search for most similar")
# Parallel(return_generator=True) ?
# n_jobs=-3 here will overwrite parallel_backend() settings
for result in Parallel(prefer="processes")(
delayed(_find_nearest_vpt_one)(th.hexdigest(), i, max_dist)
for i, th in enumerate(tlsh_hashers)
if not result:
result_key, (drow, dcol) = result
dist = result_key[0]
if result_key in seen:
if dist not in results:
results[dist] = list()
erow = tlsh_lookup[drow]
ecol = tlsh_lookup[dcol]
results[dist].append(((drow, dcol), (erow, ecol)))
del globaltree
return results
# ----------------------------------------------------------------------------
# non-parallel variant of `find_nearest_with_VPT_parallel`
def find_nearest_with_VPT(tlsh_hashers, tlsh_lookup, tree, max_dist=100):
results = dict()
seen = set()
for i, th in enumerate(tqdm(tlsh_hashers, desc="Search for most similar")):
result =
if not result:
dist, idx, th_other = result
if dist > max_dist:
result_key = tuple([dist] + sorted([i, idx]))
if result_key in seen:
if dist not in results:
results[dist] = list()
drow = tlsh_hashers[i].hexdigest()
dcol = tlsh_hashers[idx].hexdigest()
erow = tlsh_lookup[drow] # file
ecol = tlsh_lookup[dcol]
results[dist].append(((drow, dcol), (erow, ecol)))
return results
# ----------------------------------------------------------------------------
# json de/encoder for serializing binary strings
class ByteJSONEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, bytes):
return o.decode()
return o
# ----------------------------------------------------------------------------
# caching of TLSH hashes
def _load_tlsh_file(filename, encoding="utf-8"):
tlsh_lookup = dict()
tlsh_hashers = list()
with open(filename, "rb") as fp:
for line in tqdm(fp, desc="Load TLSH hashes"):
digest, sourcefile = line.rstrip().split(b"\t")
digest = digest.decode(encoding)
sourcefile = sourcefile.decode(encoding)
tlsh_lookup[digest] = sourcefile
# regenerate TLSH hash object
tlsh_hash = tlsh.Tlsh()
return tlsh_lookup, tlsh_hashers
def _save_tlsh_file(filename, tlsh_lookup, encoding="utf-8"):
if not tlsh_lookup:
with open(filename, "wb") as fp:
for digest, sourcefile in tlsh_lookup.items():
# ----------------------------------------------------------------------------
if __name__ == "__main__":
import argparse
# fmt: off
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("file_or_folder", type=str, help="Input file or folder with source files")
parser.add_argument("-t", "--tlshfile", dest="tlshfilename", type=str, help="Input/Output file for TLSH hashes")
parser.add_argument("--errorfile", dest="errorfilename", type=str, help="Output file for errors")
parser.add_argument("--dupsfile", dest="duplicatefilename", type=str, help="Output file for duplicates (same TLSH)")
subparsers = parser.add_subparsers(dest="action", required=True, help="Actions")
parser_maxdist = subparsers.add_parser("maxdist", help="Compute TLSH distances below 'max_dist'")
parser_matrix = subparsers.add_parser("matrix", help="Compute full NxN TLSH distance matrix")
parser_mostsimilar = subparsers.add_parser("mostsimilar", help="Compute most similar TLSH document pairs")
parser_hash = subparsers.add_parser("hash", help="Precompute TLSH hashes only (requires -t/--tlshfile to be set)")
parser_maxdist.add_argument("-d", "--maxdist", dest="max_dist", default=100, type=int)
parser_maxdist.add_argument("-r", "--resultfile", dest="resultfilename", default="result.json", type=str)
parser_matrix.add_argument("-r", "--resultfile", dest="resultfilename", default="result.txt", type=str)
parser_mostsimilar.add_argument("-r", "--resultfile", dest="resultfilename", default="result.json", type=str)
parser_mostsimilar.add_argument("-d", "--maxdist", dest="max_dist", default=100, type=int)
# fmt: on
args = parser.parse_args()
if not os.path.exists(args.file_or_folder):
raise FileNotFoundError(args.file_or_folder)
if args.action == "hash" and not args.tlshfilename:
parser.error("Action 'hash' requires '-t/--tlshfile' to be set!")
if args.action == "maxdist" and args.max_dist < 0:
"Action 'maxdist' requires '-d/--maxdist' to be at least 0 or larger!"
# ------------------------------------------
if (
args.action != "hash"
and args.tlshfilename
and os.path.exists(args.tlshfilename)
print("[!] Found existing TLSH hash file. Loading pre-computed hashes ...")
tlsh_lookup, tlsh_hashers = _load_tlsh_file(args.tlshfilename)
print("[*] Number of hashes: {}".format(len(tlsh_lookup)))
files = [args.file_or_folder]
if os.path.isdir(args.file_or_folder):
files = [os.path.join(args.file_or_folder, f) for f in os.listdir(args.file_or_folder)]
tlsh_lookup, tlsh_hashers, duplicates, failures = build_hashes(files)
print("[*] Number of hashes: {}".format(len(tlsh_lookup)))
"[*] Number of duplicates: {} ({} total)".format(
len(duplicates), sum(map(len, duplicates.values()))
print("[*] Number of failures (for TLSH hash computation): {}".format(len(failures))) # fmt: skip
if args.tlshfilename:
_save_tlsh_file(args.tlshfilename, tlsh_lookup)
print("[*] TLSH hashes stored in '{}'".format(args.tlshfilename))
if args.errorfilename:
# store failures as [entry1..N]
with open(args.errorfilename, "w") as fp:
json.dump(failures, fp, indent=2, cls=ByteJSONEncoder)
"[*] Errors (TLSH computation) stored in '{}'".format(
del failures
if args.duplicatefilename:
# store duplicates as digest -> [entry1..N]
with open(args.duplicatefilename, "w") as fp:
json.dump(duplicates, fp, indent=2, cls=ByteJSONEncoder)
print("[*] Duplicates stored in '{}'".format(args.duplicatefilename))
del duplicates
# ------------------------------------------
if args.action == "maxdist":
print("[*] Find near duplicates with max_dist={} ...".format(args.max_dist))
below_max_dist_idxs, below_max_dist_vals = find_near_duplicates(
tlsh_hashers, max_dist=args.max_dist
print("[*] Number of near duplicates: {}".format(len(below_max_dist_vals)))
if len(below_max_dist_vals):
vals, cnts = np.unique(below_max_dist_vals, return_counts=True)
"[*] Near duplicates: {}".format(
", ".join("[{}]: {}".format(int(v), c) for v, c in zip(vals, cnts))
# store near duplicates as D -> ((digestA, digestB), (entryA, entryB))
result = dict()
for val in np.unique(below_max_dist_vals):
result[int(val)] = list()
for (row, col), val in zip(below_max_dist_idxs, below_max_dist_vals):
val = int(val)
if val not in result:
result[val] = list()
drow = tlsh_hashers[row].hexdigest()
dcol = tlsh_hashers[col].hexdigest()
erow = tlsh_lookup[drow] # file
ecol = tlsh_lookup[dcol]
result[val].append(((drow, dcol), (erow, ecol)))
with open(args.resultfilename, "w") as fp:
json.dump(result, fp, indent=2, cls=ByteJSONEncoder)
print("[*] Results stored in '{}'".format(args.resultfilename))
# ------------------------------------------
elif args.action == "matrix":
print("[*] Compute complete near document duplicate matrix ...")
matrix_iter = compute_near_duplicates_matrix_iter(tlsh_hashers)
with open(args.resultfilename, "w") as fp:
for chunks in tqdm(matrix_iter, desc="Compute dist matrix"):
np.savetxt(fp, chunks, fmt="%u")
print("[*] Results stored in '{}'".format(args.resultfilename))
# ------------------------------------------
elif args.action == "mostsimilar":
"[*] Compute most similar near document duplicates with max_dist={} ...".format(
print("[*] Build Vantage Point Tree ...")
tree =
with parallel_backend("multiprocessing", n_jobs=8):
results = find_nearest_with_VPT_parallel(
tlsh_hashers, tlsh_lookup, tree, max_dist=args.max_dist
results = {key: results[key] for key in sorted(results.keys())}
with open(args.resultfilename, "w") as fp:
json.dump(results, fp, indent=2, cls=ByteJSONEncoder)
print("[*] Results stored in '{}'".format(args.resultfilename))

Timing experiments

Setup code

import tlsh
import dedup
import vpt
import vpt_simple

def load_hashes(fn="hashes.tlsh.tsv"):
  with open(fn, "r") as fp:
    return [line.split("\t", 1)[0] for line in fp]

hashes = load_hashes()[:10]

def h2t(h):
  t = tlsh.Tlsh()
  return t

data = list(map(h2t, hashes))

tree = vpt_simple.VPTSimple(data).build()

tree2 =

forrest = vpt_simple.VPTSimpleForrest().build(data, 3)

for row in dedup.compute_near_duplicates_matrix_iter(data):

Timing code (ignoring accuracy)

dists_naive = [dedup.linear_search(data, xi)[0] for xi in data]
dists_vpt = [[0] for xi in data]
dists_vptpaper = [r[1] if r else r for r in [ for xi in data]]
dists_vptforrestpaper = [r[1] if r else r for r in [ for xi in data]]

python3 -m timeit

SETUP_PRE='import tlsh,vpt,vpt_simple,dedup;fp=open(FN_HASHES);hashes=[line.split("\t", 1)[0] for line in fp];fp.close();hashes=hashes[:N_HASHES];data=[tlsh.Tlsh() for _ in range(N_HASHES)];[t.fromTlshStr(h) for t,h in zip(data,hashes)];'

TEST_N='dists_gold=[dedup.linear_search(data,xi)[0] for xi in data]'
TEST_VPT='dists=[[0] for xi in data]'
TEST_PT='dists=[r[1] if r else r for r in ( for xi in data)]'
TEST_PF='dists=[r[1] if r else r for r in ( for xi in data)]'
python3 -m timeit -s "${SETUP_CONF}${SETUP_PRE}${SETUP_N}" "${TEST_N}"
python3 -m timeit -s "${SETUP_CONF}${SETUP_PRE}${SETUP_VPT}" "${TEST_VPT}"
python3 -m timeit -s "${SETUP_CONF}${SETUP_PRE}${SETUP_PT}" "${TEST_PT}"
python3 -m timeit -s "${SETUP_CONF}${SETUP_PRE}${SETUP_PF}" "${TEST_PF}"


NOTE: ignore accuracy: PT will be really bad


  • N: 500 loops, best of 5: 401 usec per loop
  • VPT: 2000 loops, best of 5: 193 usec per loop
  • PT: 5000 loops, best of 5: 43.2 usec per loop
  • PF: 500 loops, best of 5: 232 usec per loop


  • N: 50 loops, best of 5: 9 msec per loop
  • VPT: 20 loops, best of 5: 14.1 msec per loop
  • PT: 500 loops, best of 5: 629 usec per loop
  • PF: 100 loops, best of 5: 3.37 msec per loop


  • N: 1 loop, best of 5: 551 msec per loop
  • VPT: 1 loop, best of 5: 983 msec per loop
  • PT: 50 loops, best of 5: 9.1 msec per loop
  • PF: 5 loops, best of 5: 47.4 msec per loop

N=1000 with NTREES=10

  • PF: 5 loops, best of 5: 93.8 msec per loop

N=10000, NTREES=10

  • N: 1 loop, best of 5: 53.5 sec per loop
  • VPT: 1 loop, best of 5: 64.1 sec per loop
  • PT: 2 loops, best of 5: 111 msec per loop
  • PF: 1 loop, best of 5: 1.28 sec per loop


TODO: how to combine timing with accuracy? Run until the results are correct? But how to do this with PT (might never be correct, could abort with failure state)?

import sys
from functools import partial
import numpy as np
import tlsh # from `py-tlsh`
# ----------------------------------------------------------------------------
# vantage point tree search
def _median(arr):
if len(arr) % 2 == 0:
arr += [-1]
return int(np.median(arr))
def _nestlevel():
frame = sys._getframe(1)
name = frame.f_code.co_name
level = 1
while frame.f_back:
frame = frame.f_back
if frame.f_code.co_name != name:
level += 1
return level
class VantagePointTree:
def __init__(self, tlsh_hasher, index, threshold=-1):
self.tlsh = tlsh_hasher
self.index = index
self.inner = None
self.outer = None
self.threshold = threshold
def __getstate__(self):
state = self.__dict__.copy()
tlsh = state.pop("tlsh")
if tlsh:
state["digest"] = tlsh.hexdigest()
return state
def __setstate__(self, state):
digest = state.pop("digest")
if digest:
tlsh_hasher = tlsh.Tlsh()
state["tlsh"] = tlsh_hasher
def num_subtree(self):
return 1 + self.num_inner + self.num_outer
def num_inner(self):
if not self.inner:
return 0
return self.inner.num_subtree
def num_outer(self):
if not self.outer:
return 0
return self.outer.num_subtree
def build(cls, tlsh_hashers, indexes=None):
num = len(tlsh_hashers)
if not num:
return None
if not indexes:
indexes = np.arange(num)
tlsh_hasher = tlsh_hashers[0]
idx = indexes[0]
if num == 1:
return cls(tlsh_hasher, idx)
tlsh_hashers = tlsh_hashers[1:]
indexes = indexes[1:]
distances = [tlsh_hasher.diff(th) for th in tlsh_hashers]
dist_median = _median(distances)
tree = cls(tlsh_hasher, idx, threshold=dist_median)
inner_tlsh_hashers = list()
inner_indexes = list()
outer_tlsh_hashers = list()
outer_indexes = list()
for dist, tlsh_hasher, idx in zip(distances, tlsh_hashers, indexes):
if dist < dist_median:
tree.inner =, inner_indexes)
tree.outer =, outer_indexes)
return tree
def dump(self, maxdepth=None, _curdepth=None):
if not _curdepth:
_curdepth = 0
if maxdepth is not None and _curdepth > maxdepth:
" " * _curdepth
+ "... {} hidden (<:{} >:{})".format(
self.num_subtree, self.num_inner, self.num_outer
if self.inner:
self.inner.dump(maxdepth, _curdepth + 1)
print(" " * _curdepth, end="")
if self.threshold == -1:
print("LEAF: idx={}".format(self.index))
print("SPLIT: idx={} T={}".format(self.index, self.threshold)),
if self.outer:
self.outer.dump(maxdepth, _curdepth + 1)
def _search(
if not tlsh_hasher:
return _cur_best
if not tree:
if _debug:
print("[D]", "~" * _nestlevel() + " " + "no subtree --> {}".format(_cur_best)) # fmt: skip
return _cur_best
search = partial(
dist = tlsh_hasher.diff(tree.tlsh)
if _debug and _debug_stats:
_debug_stats[0] += 1
if dist == 0 and _skip_same:
if _debug:
print("[D]", "~" * _nestlevel() + " " + "Found same TLSH at index={}, with _cur_best={}".format(tree.index, _cur_best)) # fmt: skip
# TODO: not sure what best to set here
if _cur_best is not None:
dist = _cur_best[0]
# return _cur_best
dist = tree.threshold
elif _cur_best is None or dist < _cur_best[0]:
_cur_best = (dist, tree.index, tree.tlsh)
if _debug:
print("[D]", "~" * _nestlevel() + " " + "check@{}: {}".format(tree.index, _cur_best)) # fmt: skip
if dist <= tree.threshold:
if _debug:
print("[D]", "~" * _nestlevel() + " " + "check inner: {} <= {} : i={} o={}".format(dist, tree.threshold, tree.num_inner, tree.num_outer)) # fmt: skip
new_best = search(tree.inner, tlsh_hasher, _cur_best=_cur_best)
if _debug and new_best != _cur_best:
print("[D]", "~" * _nestlevel() + " " + "new inner best --> {}".format(new_best)) # fmt: skip
new_dist = new_best[0]
if dist + new_dist + margin >= tree.threshold:
if _debug:
print("[D]", "~" * _nestlevel() + " " + "over threshold -> check outer: {} + {} + {} >= {} : i={} o={}".format(dist, new_dist, margin, tree.threshold, tree.num_inner, tree.num_outer)) # fmt: skip
return search(tree.outer, tlsh_hasher, _cur_best=new_best)
elif _metriccheck and tree.outer:
new_best_outer = search(tree.outer, tlsh_hasher, _cur_best=new_best)
if new_best_outer and new_best_outer[0] != new_best[0]:
raise RuntimeError(
"Metric problem for outer={} with inner={}, dist={}, threshold={}".format(
new_best_outer, new_best, dist, tree.threshold
return new_best
if _debug:
print("[D]", "~" * _nestlevel() + " " + "check outer: {} > {} : i={} o={}".format(dist, tree.threshold, tree.num_inner, tree.num_outer)) # fmt: skip
new_best = search(tree.outer, tlsh_hasher, _cur_best=_cur_best)
if _debug and new_best != _cur_best:
print("[D]", "~" * _nestlevel() + " " + "new outer best --> {}".format(new_best)) # fmt: skip
new_dist = new_best[0]
if dist - new_dist - margin <= tree.threshold:
if _debug:
print("[D]", "~" * _nestlevel() + " " + "below threshold -> check outer: {} - {} - {} < {} : i={} o={}".format(dist, new_dist, margin, tree.threshold, tree.num_inner, tree.num_outer)) # fmt: skip
return search(tree.inner, tlsh_hasher, _cur_best=new_best)
elif _metriccheck and tree.inner:
new_best_inner = search(tree.inner, tlsh_hasher, _cur_best=new_best)
if new_best_inner and new_best_inner[0] != new_best[0]:
raise RuntimeError(
"Metric problem for inner={} with outer={}, dist={}, threshold={}".format(
new_best_inner, new_best, dist, tree.threshold
return new_best
def search(
if not tlsh_hasher:
return None
# number of TLSH comparisons
_debug_stats = [0]
result = VantagePointTree._search(
if _debug:
print("[D]", "~ {} TLSH comparisons".format(_debug_stats[0]))
return result
# ----------------------------------------------------------------------------
# brute force approach / linear search
def linear_search(tlsh_hashers, tlsh_hasher, skip_same=True):
distances = np.array([tlsh_hasher.diff(th) for th in tlsh_hashers])
if skip_same:
distances = [dist if dist != 0 else np.NaN for dist in distances]
idx = np.nanargmin(distances)
return (distances[idx], idx, tlsh_hashers[idx])
# ----------------------------------------------------------------------------
# NOTE: do not use this. It does not find the nearest stuff and can not skip itself
# but based on TLSH paper code
# -
import random
import tlsh # from `py-tlsh`
# ----------------------------------------------------------------------------
class VPTSimple:
# based on
def __init__(self, data): = data
self.Split = None
self.Threshold = None
self.LC = None
self.RC = None
# --------------------------------
def Dist(A, B):
return A.diff(B)
def FindThresholdAndSplits(data, Y):
# Find threshold (T ) s.t. size(X1) ≈ size(X2) where
# X1 = {xi ∈ : Dist(xi, Y ) ≤ T }
# X2 = {xi ∈ : Dist(xi, Y ) > T }
distances = sorted([Y.diff(xi) for xi in data])
T = distances[int((len(data) - 1) / 2)]
X1 = [xi for xi in data if VPTSimple.Dist(xi, Y) <= T]
X2 = [xi for xi in data if VPTSimple.Dist(xi, Y) > T]
# edge case due to duplicate distances near the median
if len(X1) == 0 or len(X2) == 0:
# collapse duplicate distances, and try median again
distances = sorted(set(distances))
# NOTE: this here can still be a single element if all distances are the same
T = distances[int((len(distances) - 1) / 2)]
X1 = [xi for xi in data if VPTSimple.Dist(xi, Y) <= T]
X2 = [xi for xi in data if VPTSimple.Dist(xi, Y) > T]
return T, X1, X2
def SplitMethod(N, nitemsInLeaf):
nitems = len(
if nitems < nitemsInLeaf + 1:
return None
i = random.randint(0, nitems - 1)
Y =[i]
# change: we need to skip the randomly selected element
T, X1, X2 = VPTSimple.FindThresholdAndSplits([:i] +[i + 1:], Y)
return Y, T, X1, X2
def TreeBuild(N, nitemsInLeaf):
# change: if less than 3 items, we can't split into left, middle, right
if len( < 3:
Res = VPTSimple.SplitMethod(N, nitemsInLeaf=nitemsInLeaf)
if Res is not None:
Y, T, X1, X2 = Res
# guard against empty branches (e.g. FindThresholdAndSplits() fails somehow)
if len(X1) == 0 or len(X2) == 0:
N.Split = Y
N.Threshold = T
N.LC = VPTSimple(data=X1)
VPTSimple.TreeBuild(N.LC, nitemsInLeaf=nitemsInLeaf)
N.RC = VPTSimple(data=X2)
VPTSimple.TreeBuild(N.RC, nitemsInLeaf=nitemsInLeaf)
# change: clean up to reduce memory since it now is contained in its branches = None
def isLeaf(N):
# if leaf then only `` is assigned, everything else is None
return N.Split is None
def closestItem(N, S):
# compute distances
distances = [VPTSimple.Dist(xi, S) for xi in]
# sort by shortest, and return first element
return sorted(zip(, distances), key=lambda t: t[1])[0][0]
def Search(N, S):
if VPTSimple.isLeaf(N):
X = VPTSimple.closestItem(N, S)
d = VPTSimple.Dist(X, S)
return X, d
thisDist = VPTSimple.Dist(N.Split, S)
if thisDist <= N.Threshold:
return VPTSimple.Search(N.LC, S)
return VPTSimple.Search(N.RC, S)
def closestItemNotSelf(N, S):
# compute distances
distances = [VPTSimple.Dist(xi, S) for xi in]
# sort by shortest
distsNdata = sorted(zip(, distances), key=lambda t: t[1])
# skip self if contained
if distsNdata[0][1] == 0:
# if we have multiple hashes in the leaf, return the second best
if len( > 1:
return distsNdata[1][0]
# else None to continue search elsewhere
return None
# and return first element
return distsNdata[0][0]
def SearchNotSelf(N, S):
if VPTSimple.isLeaf(N):
X = VPTSimple.closestItemNotSelf(N, S)
# if self found, then abort branch
if X is None:
return None
d = VPTSimple.Dist(X, S)
return X, d
thisDist = VPTSimple.Dist(N.Split, S)
if thisDist <= N.Threshold:
Res = VPTSimple.SearchNotSelf(N.LC, S)
# if we found something (not self) return
if Res is not None:
return Res
# otherwise search other branch
return VPTSimple.SearchNotSelf(N.RC, S)
# --------------------------------
def build(self, nitemsInLeaf=1):
assert nitemsInLeaf >= 1
VPTSimple.TreeBuild(self, nitemsInLeaf=nitemsInLeaf)
return self
def search(self, S, not_self=True):
if not_self:
return VPTSimple.SearchNotSelf(self, S)
return VPTSimple.Search(self, S)
class VPTSimpleForrest:
def __init__(self):
self.trees = []
def build(self, data, nTrees=None, nitemsInLeaf=1):
if nTrees is None:
nTrees = int(max(3, min(100, len(data) ** (1 / 3))))
self.trees = []
for _ in range(nTrees):
# copy and shuffle lists
data = list(data)
# build trees
tree = VPTSimple(data)
return self
def search(self, S, not_self=True):
results = []
for ti in self.trees:
Res =, not_self=not_self)
if Res is not None:
# find best result or None if None
if len(results) >= 1:
Res = sorted(results, key=lambda ri: ri[1])[0]
return Res
return None
# ----------------------------------------------------------------------------
