Skip to content

Instantly share code, notes, and snippets.

@Querela
Last active June 15, 2023 12:03
Show Gist options
  • Save Querela/d34d76bf090863418168527bc5aba3ff to your computer and use it in GitHub Desktop.
Save Querela/d34d76bf090863418168527bc5aba3ff to your computer and use it in GitHub Desktop.
[python] TLSH Vantage Point Trees
import os
import json
import numpy as np
import tlsh # from `py-tlsh`
from joblib import Parallel, delayed, parallel_backend
from sklearn.metrics import pairwise_distances_chunked
from tqdm import tqdm
from vpt import VantagePointTree
# ----------------------------------------------------------------------------
# demo build hashes function
def build_hashes(files):
tlsh_lookup = dict()
tlsh_hashers = list()
duplicates = dict()
failures = list()
for file in files:
hasher = tlsh.Tlsh()
with open(file, "rb") as fp:
for line in fp:
hasher.update(line)
try:
hasher.final()
except ValueError:
pass
if not hasher.is_valid:
failures.append(file)
continue
digest = hasher.hexdigest()
if digest in tlsh_lookup:
entry = tlsh_lookup[digest]
if digest not in duplicates:
duplicates[digest] = [entry]
duplicates[digest].append(file)
continue
tlsh_lookup[digest] = file
tlsh_hashers.append(hasher)
return tlsh_lookup, tlsh_hashers, duplicates, failures
# ----------------------------------------------------------------------------
# sklearn `pairwise_distances_chunked` brute-force distance calculation
# can be parallel with joblib backend
def find_near_duplicates(tlsh_hashers, max_dist=100):
Xlen = len(tlsh_hashers)
X = np.arange(Xlen)[:, np.newaxis]
def tlsh_sim(row, col):
row = int(row)
col = int(col)
# skip computation if left-lower half or same
if row == col:
return 0
# print(row, col) # DEBUG
# add this check for slice computation
# (if whole pairwise matrix can be computed, sklearn assumes symmetry?)
if row > col:
return -1
# return (np.arange(len(tlsh_hashers)) + 1)[row] + ((np.arange(len(tlsh_hashers)) + 1) / 10)[col] # DEBUG
h_row = tlsh_hashers[row]
h_col = tlsh_hashers[col]
return h_row.diff(h_col)
# memory for a single row (instead of warning for `=0`)
# 1 MB --> 1 B --> int8 (8 byte?) * row-length
working_memory = 1 / 2**20 * 8 * Xlen
# numbers of rows per chunk
num_rows = 1
working_memory *= num_rows
# max_dist: maximum distance (if None or negative, compute all)
# reducer function (to only return col indexes per row that are below max distance)
def reduce_func(D_chunk, start_row):
# filter at least positive, less than max_dist and not same
idxs = [
np.flatnonzero((d >= 0) & (d <= max_dist) & (np.arange(Xlen) != pos))
for pos, d in enumerate(D_chunk, start_row)
]
vals = [d[i] for d, i in zip(D_chunk, idxs)]
return idxs, vals
# NOTE: if whole NxN matrix fits into working_memory, sklearn assumes symmetry and only computes top-right half and mirrors it on -1
pwdst_iter = pairwise_distances_chunked(
X,
X,
metric=tlsh_sim,
reduce_func=reduce_func,
n_jobs=-1,
working_memory=working_memory,
)
below_max_dist_idxs = []
below_max_dist_vals = []
for chunk_idxs, chunk_vals in tqdm(
pwdst_iter, total=round(Xlen / num_rows), desc="Filter max_dist"
):
below_max_dist_idxs.extend(chunk_idxs)
below_max_dist_vals.extend(chunk_vals)
below_max_dist_idxs = [
(row, col) for row, cols in enumerate(below_max_dist_idxs) for col in cols
]
below_max_dist_vals = np.concatenate(below_max_dist_vals)
return below_max_dist_idxs, below_max_dist_vals
# ----------------------------------------------------------------------------
# iterator to build hashes matrix (will compute it for every pair row, no shortcuts)
def compute_near_duplicates_matrix_iter(tlsh_hashers):
X = np.arange(len(tlsh_hashers))[:, np.newaxis]
def tlsh_sim(row, col):
row = int(row)
col = int(col)
# skip computation if left-lower half or same
if row == col:
return 0
# add this check for slice computation
# (if whole pairwise matrix can be computed, sklearn assumes symmetry?)
if row > col:
return -1
h_row = tlsh_hashers[row]
h_col = tlsh_hashers[col]
return h_row.diff(h_col)
# memory for a single row (instead of warning for `=0`)
# 1 MB --> 1 B --> int8 (8 byte?) * row-length
working_memory = 1 / 2**20 * 8 * len(tlsh_hashers)
# numbers of rows per chunk
num_rows = 1
working_memory *= num_rows
# max_dist: maximum distance (if None or negative, compute all)
# NOTE: if whole NxN matrix fits into working_memory, sklearn assumes symmetry and only computes top-right half and mirrors it on -1
pwdst_iter = pairwise_distances_chunked(
X,
X,
metric=tlsh_sim,
reduce_func=None,
n_jobs=-1,
working_memory=working_memory,
)
yield from pwdst_iter
# ----------------------------------------------------------------------------
# parallel search (using global globaltree VPT tree instance)
def _find_nearest_vpt_one(digest, i, max_dist):
global globaltree
tlsh_hash = tlsh.Tlsh()
tlsh_hash.fromTlshStr(digest)
result = globaltree.search(tlsh_hash)
# no result
if not result:
return None
# to distant
dist, idx, th_other = result
if dist > max_dist:
return None
# precompute for later
result_key = tuple([dist] + sorted([i, idx]))
drow = tlsh_hash.hexdigest()
dcol = th_other.hexdigest()
return result_key, (drow, dcol)
def find_nearest_with_VPT_parallel(tlsh_hashers, tlsh_lookup, tree, max_dist=100):
global globaltree
globaltree = tree
results = dict()
seen = set()
tlsh_hashers = tqdm(tlsh_hashers, desc="Search for most similar")
# Parallel(return_generator=True) ?
# n_jobs=-3 here will overwrite parallel_backend() settings
for result in Parallel(prefer="processes")(
delayed(_find_nearest_vpt_one)(th.hexdigest(), i, max_dist)
for i, th in enumerate(tlsh_hashers)
):
if not result:
continue
result_key, (drow, dcol) = result
dist = result_key[0]
if result_key in seen:
continue
if dist not in results:
results[dist] = list()
erow = tlsh_lookup[drow]
ecol = tlsh_lookup[dcol]
results[dist].append(((drow, dcol), (erow, ecol)))
seen.add(result_key)
del globaltree
return results
# ----------------------------------------------------------------------------
# non-parallel variant of `find_nearest_with_VPT_parallel`
def find_nearest_with_VPT(tlsh_hashers, tlsh_lookup, tree, max_dist=100):
results = dict()
seen = set()
for i, th in enumerate(tqdm(tlsh_hashers, desc="Search for most similar")):
result = tree.search(th)
if not result:
continue
dist, idx, th_other = result
if dist > max_dist:
continue
result_key = tuple([dist] + sorted([i, idx]))
if result_key in seen:
continue
if dist not in results:
results[dist] = list()
drow = tlsh_hashers[i].hexdigest()
dcol = tlsh_hashers[idx].hexdigest()
erow = tlsh_lookup[drow] # file
ecol = tlsh_lookup[dcol]
results[dist].append(((drow, dcol), (erow, ecol)))
seen.add(result_key)
return results
# ----------------------------------------------------------------------------
# json de/encoder for serializing binary strings
class ByteJSONEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, bytes):
return o.decode()
return o
# ----------------------------------------------------------------------------
# caching of TLSH hashes
def _load_tlsh_file(filename, encoding="utf-8"):
tlsh_lookup = dict()
tlsh_hashers = list()
with open(filename, "rb") as fp:
for line in tqdm(fp, desc="Load TLSH hashes"):
digest, sourcefile = line.rstrip().split(b"\t")
digest = digest.decode(encoding)
sourcefile = sourcefile.decode(encoding)
tlsh_lookup[digest] = sourcefile
# regenerate TLSH hash object
tlsh_hash = tlsh.Tlsh()
tlsh_hash.fromTlshStr(digest)
tlsh_hashers.append(tlsh_hash)
return tlsh_lookup, tlsh_hashers
def _save_tlsh_file(filename, tlsh_lookup, encoding="utf-8"):
if not tlsh_lookup:
return
with open(filename, "wb") as fp:
for digest, sourcefile in tlsh_lookup.items():
fp.write(digest.encode(encoding))
fp.write(b"\t")
fp.write(sourcefile.encode(encoding))
fp.write(b"\n")
# ----------------------------------------------------------------------------
if __name__ == "__main__":
import argparse
# fmt: off
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("file_or_folder", type=str, help="Input file or folder with source files")
parser.add_argument("-t", "--tlshfile", dest="tlshfilename", type=str, help="Input/Output file for TLSH hashes")
parser.add_argument("--errorfile", dest="errorfilename", type=str, help="Output file for errors")
parser.add_argument("--dupsfile", dest="duplicatefilename", type=str, help="Output file for duplicates (same TLSH)")
subparsers = parser.add_subparsers(dest="action", required=True, help="Actions")
parser_maxdist = subparsers.add_parser("maxdist", help="Compute TLSH distances below 'max_dist'")
parser_matrix = subparsers.add_parser("matrix", help="Compute full NxN TLSH distance matrix")
parser_mostsimilar = subparsers.add_parser("mostsimilar", help="Compute most similar TLSH document pairs")
parser_hash = subparsers.add_parser("hash", help="Precompute TLSH hashes only (requires -t/--tlshfile to be set)")
parser_maxdist.add_argument("-d", "--maxdist", dest="max_dist", default=100, type=int)
parser_maxdist.add_argument("-r", "--resultfile", dest="resultfilename", default="result.json", type=str)
parser_matrix.add_argument("-r", "--resultfile", dest="resultfilename", default="result.txt", type=str)
parser_mostsimilar.add_argument("-r", "--resultfile", dest="resultfilename", default="result.json", type=str)
parser_mostsimilar.add_argument("-d", "--maxdist", dest="max_dist", default=100, type=int)
# fmt: on
args = parser.parse_args()
if not os.path.exists(args.file_or_folder):
raise FileNotFoundError(args.file_or_folder)
if args.action == "hash" and not args.tlshfilename:
parser.error("Action 'hash' requires '-t/--tlshfile' to be set!")
if args.action == "maxdist" and args.max_dist < 0:
parser.error(
"Action 'maxdist' requires '-d/--maxdist' to be at least 0 or larger!"
)
# ------------------------------------------
if (
args.action != "hash"
and args.tlshfilename
and os.path.exists(args.tlshfilename)
):
print("[!] Found existing TLSH hash file. Loading pre-computed hashes ...")
tlsh_lookup, tlsh_hashers = _load_tlsh_file(args.tlshfilename)
print("[*] Number of hashes: {}".format(len(tlsh_lookup)))
else:
files = [args.file_or_folder]
if os.path.isdir(args.file_or_folder):
files = [os.path.join(args.file_or_folder, f) for f in os.listdir(args.file_or_folder)]
tlsh_lookup, tlsh_hashers, duplicates, failures = build_hashes(files)
print("[*] Number of hashes: {}".format(len(tlsh_lookup)))
print(
"[*] Number of duplicates: {} ({} total)".format(
len(duplicates), sum(map(len, duplicates.values()))
)
)
print("[*] Number of failures (for TLSH hash computation): {}".format(len(failures))) # fmt: skip
if args.tlshfilename:
_save_tlsh_file(args.tlshfilename, tlsh_lookup)
print("[*] TLSH hashes stored in '{}'".format(args.tlshfilename))
if args.errorfilename:
# store failures as [entry1..N]
with open(args.errorfilename, "w") as fp:
json.dump(failures, fp, indent=2, cls=ByteJSONEncoder)
print(
"[*] Errors (TLSH computation) stored in '{}'".format(
args.errorfilename
)
)
del failures
if args.duplicatefilename:
# store duplicates as digest -> [entry1..N]
with open(args.duplicatefilename, "w") as fp:
json.dump(duplicates, fp, indent=2, cls=ByteJSONEncoder)
print("[*] Duplicates stored in '{}'".format(args.duplicatefilename))
del duplicates
# ------------------------------------------
if args.action == "maxdist":
print("[*] Find near duplicates with max_dist={} ...".format(args.max_dist))
below_max_dist_idxs, below_max_dist_vals = find_near_duplicates(
tlsh_hashers, max_dist=args.max_dist
)
print("[*] Number of near duplicates: {}".format(len(below_max_dist_vals)))
if len(below_max_dist_vals):
vals, cnts = np.unique(below_max_dist_vals, return_counts=True)
print(
"[*] Near duplicates: {}".format(
", ".join("[{}]: {}".format(int(v), c) for v, c in zip(vals, cnts))
)
)
# store near duplicates as D -> ((digestA, digestB), (entryA, entryB))
result = dict()
for val in np.unique(below_max_dist_vals):
result[int(val)] = list()
for (row, col), val in zip(below_max_dist_idxs, below_max_dist_vals):
val = int(val)
if val not in result:
result[val] = list()
drow = tlsh_hashers[row].hexdigest()
dcol = tlsh_hashers[col].hexdigest()
erow = tlsh_lookup[drow] # file
ecol = tlsh_lookup[dcol]
result[val].append(((drow, dcol), (erow, ecol)))
with open(args.resultfilename, "w") as fp:
json.dump(result, fp, indent=2, cls=ByteJSONEncoder)
print("[*] Results stored in '{}'".format(args.resultfilename))
# ------------------------------------------
elif args.action == "matrix":
print("[*] Compute complete near document duplicate matrix ...")
matrix_iter = compute_near_duplicates_matrix_iter(tlsh_hashers)
with open(args.resultfilename, "w") as fp:
for chunks in tqdm(matrix_iter, desc="Compute dist matrix"):
np.savetxt(fp, chunks, fmt="%u")
print("[*] Results stored in '{}'".format(args.resultfilename))
# ------------------------------------------
elif args.action == "mostsimilar":
print(
"[*] Compute most similar near document duplicates with max_dist={} ...".format(
args.max_dist
)
)
print("[*] Build Vantage Point Tree ...")
tree = VantagePointTree.build(tlsh_hashers)
with parallel_backend("multiprocessing", n_jobs=8):
results = find_nearest_with_VPT_parallel(
tlsh_hashers, tlsh_lookup, tree, max_dist=args.max_dist
)
results = {key: results[key] for key in sorted(results.keys())}
with open(args.resultfilename, "w") as fp:
json.dump(results, fp, indent=2, cls=ByteJSONEncoder)
print("[*] Results stored in '{}'".format(args.resultfilename))

Timing experiments

Setup code

import tlsh
import dedup
import vpt
import vpt_simple

def load_hashes(fn="hashes.tlsh.tsv"):
  with open(fn, "r") as fp:
    return [line.split("\t", 1)[0] for line in fp]


hashes = load_hashes()[:10]

def h2t(h):
  t = tlsh.Tlsh()
  t.fromTlshStr(h)
  return t


data = list(map(h2t, hashes))

tree = vpt_simple.VPTSimple(data).build()

tree2 = vpt.VantagePointTree.build(data)

forrest = vpt_simple.VPTSimpleForrest().build(data, 3)

for row in dedup.compute_near_duplicates_matrix_iter(data):
  print(row)

Timing code (ignoring accuracy)

dists_naive = [dedup.linear_search(data, xi)[0] for xi in data]
dists_vpt = [tree2.search(xi)[0] for xi in data]
dists_vptpaper = [r[1] if r else r for r in [tree.search(xi) for xi in data]]
dists_vptforrestpaper = [r[1] if r else r for r in [forrest.search(xi) for xi in data]]

python3 -m timeit

SETUP_CONF='FN_HASHES="hashes.tlsh.tsv";N_HASHES=100000;'
SETUP_PRE='import tlsh,vpt,vpt_simple,dedup;fp=open(FN_HASHES);hashes=[line.split("\t", 1)[0] for line in fp];fp.close();hashes=hashes[:N_HASHES];data=[tlsh.Tlsh() for _ in range(N_HASHES)];[t.fromTlshStr(h) for t,h in zip(data,hashes)];'
SETUP_N=''
SETUP_VPT='tree=vpt.VantagePointTree.build(data)'
SETUP_PT='tree=vpt_simple.VPTSimple(data).build()'
SETUP_PF='tree=vpt_simple.VPTSimpleForrest().build(data,5)'
SETUP_PF='tree=vpt_simple.VPTSimpleForrest().build(data,10)'

TEST_N='dists_gold=[dedup.linear_search(data,xi)[0] for xi in data]'
TEST_VPT='dists=[tree.search(xi)[0] for xi in data]'
TEST_PT='dists=[r[1] if r else r for r in (tree.search(xi) for xi in data)]'
TEST_PF='dists=[r[1] if r else r for r in (tree.search(xi) for xi in data)]'
python3 -m timeit -s "${SETUP_CONF}${SETUP_PRE}${SETUP_N}" "${TEST_N}"
python3 -m timeit -s "${SETUP_CONF}${SETUP_PRE}${SETUP_VPT}" "${TEST_VPT}"
python3 -m timeit -s "${SETUP_CONF}${SETUP_PRE}${SETUP_PT}" "${TEST_PT}"
python3 -m timeit -s "${SETUP_CONF}${SETUP_PRE}${SETUP_PF}" "${TEST_PF}"

Results

NOTE: ignore accuracy: PT will be really bad

N=10

  • N: 500 loops, best of 5: 401 usec per loop
  • VPT: 2000 loops, best of 5: 193 usec per loop
  • PT: 5000 loops, best of 5: 43.2 usec per loop
  • PF: 500 loops, best of 5: 232 usec per loop

N=100

  • N: 50 loops, best of 5: 9 msec per loop
  • VPT: 20 loops, best of 5: 14.1 msec per loop
  • PT: 500 loops, best of 5: 629 usec per loop
  • PF: 100 loops, best of 5: 3.37 msec per loop

N=1000

  • N: 1 loop, best of 5: 551 msec per loop
  • VPT: 1 loop, best of 5: 983 msec per loop
  • PT: 50 loops, best of 5: 9.1 msec per loop
  • PF: 5 loops, best of 5: 47.4 msec per loop

N=1000 with NTREES=10

  • PF: 5 loops, best of 5: 93.8 msec per loop

N=10000, NTREES=10

  • N: 1 loop, best of 5: 53.5 sec per loop
  • VPT: 1 loop, best of 5: 64.1 sec per loop
  • PT: 2 loops, best of 5: 111 msec per loop
  • PF: 1 loop, best of 5: 1.28 sec per loop

...

TODO: how to combine timing with accuracy? Run until the results are correct? But how to do this with PT (might never be correct, could abort with failure state)?

# https://github.com/trendmicro/tlsh/blob/master/tlshCluster/pylib/hac_lib.py
# https://tlsh.org/papers.html
# https://tlsh.org/papersDir/COINS_2020_camera_ready.pdf
# https://fribbels.github.io/vptree/writeup
import sys
from functools import partial
import numpy as np
import tlsh # from `py-tlsh`
# ----------------------------------------------------------------------------
# vantage point tree search
def _median(arr):
if len(arr) % 2 == 0:
arr += [-1]
return int(np.median(arr))
def _nestlevel():
frame = sys._getframe(1)
name = frame.f_code.co_name
level = 1
while frame.f_back:
frame = frame.f_back
if frame.f_code.co_name != name:
break
level += 1
return level
class VantagePointTree:
def __init__(self, tlsh_hasher, index, threshold=-1):
self.tlsh = tlsh_hasher
self.index = index
self.inner = None
self.outer = None
self.threshold = threshold
def __getstate__(self):
state = self.__dict__.copy()
tlsh = state.pop("tlsh")
if tlsh:
state["digest"] = tlsh.hexdigest()
return state
def __setstate__(self, state):
digest = state.pop("digest")
if digest:
tlsh_hasher = tlsh.Tlsh()
tlsh_hasher.fromTlshStr(digest)
state["tlsh"] = tlsh_hasher
self.__dict__.update(state)
@property
def num_subtree(self):
return 1 + self.num_inner + self.num_outer
@property
def num_inner(self):
if not self.inner:
return 0
return self.inner.num_subtree
@property
def num_outer(self):
if not self.outer:
return 0
return self.outer.num_subtree
@classmethod
def build(cls, tlsh_hashers, indexes=None):
num = len(tlsh_hashers)
if not num:
return None
if not indexes:
indexes = np.arange(num)
tlsh_hasher = tlsh_hashers[0]
idx = indexes[0]
if num == 1:
return cls(tlsh_hasher, idx)
tlsh_hashers = tlsh_hashers[1:]
indexes = indexes[1:]
distances = [tlsh_hasher.diff(th) for th in tlsh_hashers]
dist_median = _median(distances)
tree = cls(tlsh_hasher, idx, threshold=dist_median)
inner_tlsh_hashers = list()
inner_indexes = list()
outer_tlsh_hashers = list()
outer_indexes = list()
for dist, tlsh_hasher, idx in zip(distances, tlsh_hashers, indexes):
if dist < dist_median:
inner_tlsh_hashers.append(tlsh_hasher)
inner_indexes.append(idx)
else:
outer_tlsh_hashers.append(tlsh_hasher)
outer_indexes.append(idx)
tree.inner = cls.build(inner_tlsh_hashers, inner_indexes)
tree.outer = cls.build(outer_tlsh_hashers, outer_indexes)
return tree
def dump(self, maxdepth=None, _curdepth=None):
if not _curdepth:
_curdepth = 0
if maxdepth is not None and _curdepth > maxdepth:
print(
" " * _curdepth
+ "... {} hidden (<:{} >:{})".format(
self.num_subtree, self.num_inner, self.num_outer
)
)
return
if self.inner:
self.inner.dump(maxdepth, _curdepth + 1)
print(" " * _curdepth, end="")
if self.threshold == -1:
print("LEAF: idx={}".format(self.index))
else:
print("SPLIT: idx={} T={}".format(self.index, self.threshold)),
if self.outer:
self.outer.dump(maxdepth, _curdepth + 1)
@staticmethod
def _search(
tree,
tlsh_hasher,
margin=20,
_cur_best=None,
_skip_same=True,
_metriccheck=False,
_debug=False,
_debug_stats=None,
):
if not tlsh_hasher:
return _cur_best
if not tree:
if _debug:
print("[D]", "~" * _nestlevel() + " " + "no subtree --> {}".format(_cur_best)) # fmt: skip
return _cur_best
search = partial(
VantagePointTree._search,
margin=margin,
_skip_same=_skip_same,
_metriccheck=_metriccheck,
_debug=_debug,
_debug_stats=_debug_stats,
)
dist = tlsh_hasher.diff(tree.tlsh)
if _debug and _debug_stats:
_debug_stats[0] += 1
if dist == 0 and _skip_same:
if _debug:
print("[D]", "~" * _nestlevel() + " " + "Found same TLSH at index={}, with _cur_best={}".format(tree.index, _cur_best)) # fmt: skip
# TODO: not sure what best to set here
if _cur_best is not None:
dist = _cur_best[0]
# return _cur_best
else:
dist = tree.threshold
elif _cur_best is None or dist < _cur_best[0]:
_cur_best = (dist, tree.index, tree.tlsh)
if _debug:
print("[D]", "~" * _nestlevel() + " " + "check@{}: {}".format(tree.index, _cur_best)) # fmt: skip
if dist <= tree.threshold:
if _debug:
print("[D]", "~" * _nestlevel() + " " + "check inner: {} <= {} : i={} o={}".format(dist, tree.threshold, tree.num_inner, tree.num_outer)) # fmt: skip
new_best = search(tree.inner, tlsh_hasher, _cur_best=_cur_best)
if _debug and new_best != _cur_best:
print("[D]", "~" * _nestlevel() + " " + "new inner best --> {}".format(new_best)) # fmt: skip
new_dist = new_best[0]
if dist + new_dist + margin >= tree.threshold:
if _debug:
print("[D]", "~" * _nestlevel() + " " + "over threshold -> check outer: {} + {} + {} >= {} : i={} o={}".format(dist, new_dist, margin, tree.threshold, tree.num_inner, tree.num_outer)) # fmt: skip
return search(tree.outer, tlsh_hasher, _cur_best=new_best)
elif _metriccheck and tree.outer:
new_best_outer = search(tree.outer, tlsh_hasher, _cur_best=new_best)
if new_best_outer and new_best_outer[0] != new_best[0]:
raise RuntimeError(
"Metric problem for outer={} with inner={}, dist={}, threshold={}".format(
new_best_outer, new_best, dist, tree.threshold
)
)
return new_best
else:
if _debug:
print("[D]", "~" * _nestlevel() + " " + "check outer: {} > {} : i={} o={}".format(dist, tree.threshold, tree.num_inner, tree.num_outer)) # fmt: skip
new_best = search(tree.outer, tlsh_hasher, _cur_best=_cur_best)
if _debug and new_best != _cur_best:
print("[D]", "~" * _nestlevel() + " " + "new outer best --> {}".format(new_best)) # fmt: skip
new_dist = new_best[0]
if dist - new_dist - margin <= tree.threshold:
if _debug:
print("[D]", "~" * _nestlevel() + " " + "below threshold -> check outer: {} - {} - {} < {} : i={} o={}".format(dist, new_dist, margin, tree.threshold, tree.num_inner, tree.num_outer)) # fmt: skip
return search(tree.inner, tlsh_hasher, _cur_best=new_best)
elif _metriccheck and tree.inner:
new_best_inner = search(tree.inner, tlsh_hasher, _cur_best=new_best)
if new_best_inner and new_best_inner[0] != new_best[0]:
raise RuntimeError(
"Metric problem for inner={} with outer={}, dist={}, threshold={}".format(
new_best_inner, new_best, dist, tree.threshold
)
)
return new_best
def search(
self,
tlsh_hasher,
margin=20,
_skip_same=True,
_metriccheck=False,
_debug=False,
):
if not tlsh_hasher:
return None
# number of TLSH comparisons
_debug_stats = [0]
result = VantagePointTree._search(
self,
tlsh_hasher,
margin=margin,
_skip_same=_skip_same,
_metriccheck=_metriccheck,
_debug=_debug,
_debug_stats=_debug_stats,
)
if _debug:
print("[D]", "~ {} TLSH comparisons".format(_debug_stats[0]))
return result
# ----------------------------------------------------------------------------
# brute force approach / linear search
def linear_search(tlsh_hashers, tlsh_hasher, skip_same=True):
distances = np.array([tlsh_hasher.diff(th) for th in tlsh_hashers])
if skip_same:
distances = [dist if dist != 0 else np.NaN for dist in distances]
idx = np.nanargmin(distances)
return (distances[idx], idx, tlsh_hashers[idx])
# ----------------------------------------------------------------------------
# NOTE: do not use this. It does not find the nearest stuff and can not skip itself
# but based on TLSH paper code
# - https://tlsh.org/papersDir/COINS_2020_camera_ready.pdf
import random
import tlsh # from `py-tlsh`
# ----------------------------------------------------------------------------
class VPTSimple:
# based on https://tlsh.org/papersDir/COINS_2020_camera_ready.pdf
def __init__(self, data):
self.data = data
self.Split = None
self.Threshold = None
self.LC = None
self.RC = None
# --------------------------------
@staticmethod
def Dist(A, B):
return A.diff(B)
@staticmethod
def FindThresholdAndSplits(data, Y):
# Find threshold (T ) s.t. size(X1) ≈ size(X2) where
# X1 = {xi ∈ N.data : Dist(xi, Y ) ≤ T }
# X2 = {xi ∈ N.data : Dist(xi, Y ) > T }
distances = sorted([Y.diff(xi) for xi in data])
T = distances[int((len(data) - 1) / 2)]
X1 = [xi for xi in data if VPTSimple.Dist(xi, Y) <= T]
X2 = [xi for xi in data if VPTSimple.Dist(xi, Y) > T]
# edge case due to duplicate distances near the median
if len(X1) == 0 or len(X2) == 0:
# collapse duplicate distances, and try median again
distances = sorted(set(distances))
# NOTE: this here can still be a single element if all distances are the same
T = distances[int((len(distances) - 1) / 2)]
X1 = [xi for xi in data if VPTSimple.Dist(xi, Y) <= T]
X2 = [xi for xi in data if VPTSimple.Dist(xi, Y) > T]
return T, X1, X2
@staticmethod
def SplitMethod(N, nitemsInLeaf):
nitems = len(N.data)
if nitems < nitemsInLeaf + 1:
return None
else:
i = random.randint(0, nitems - 1)
Y = N.data[i]
# change: we need to skip the randomly selected element
T, X1, X2 = VPTSimple.FindThresholdAndSplits(N.data[:i] + N.data[i + 1:], Y)
return Y, T, X1, X2
@staticmethod
def TreeBuild(N, nitemsInLeaf):
# change: if less than 3 items, we can't split into left, middle, right
if len(N.data) < 3:
return
Res = VPTSimple.SplitMethod(N, nitemsInLeaf=nitemsInLeaf)
if Res is not None:
Y, T, X1, X2 = Res
# guard against empty branches (e.g. FindThresholdAndSplits() fails somehow)
if len(X1) == 0 or len(X2) == 0:
return
N.Split = Y
N.Threshold = T
N.LC = VPTSimple(data=X1)
VPTSimple.TreeBuild(N.LC, nitemsInLeaf=nitemsInLeaf)
N.RC = VPTSimple(data=X2)
VPTSimple.TreeBuild(N.RC, nitemsInLeaf=nitemsInLeaf)
# change: clean up N.data to reduce memory since it now is contained in its branches
N.data = None
@staticmethod
def isLeaf(N):
# if leaf then only `N.data` is assigned, everything else is None
return N.Split is None
@staticmethod
def closestItem(N, S):
# compute distances
distances = [VPTSimple.Dist(xi, S) for xi in N.data]
# sort by shortest, and return first element
return sorted(zip(N.data, distances), key=lambda t: t[1])[0][0]
@staticmethod
def Search(N, S):
if VPTSimple.isLeaf(N):
X = VPTSimple.closestItem(N, S)
d = VPTSimple.Dist(X, S)
return X, d
else:
thisDist = VPTSimple.Dist(N.Split, S)
if thisDist <= N.Threshold:
return VPTSimple.Search(N.LC, S)
else:
return VPTSimple.Search(N.RC, S)
@staticmethod
def closestItemNotSelf(N, S):
# compute distances
distances = [VPTSimple.Dist(xi, S) for xi in N.data]
# sort by shortest
distsNdata = sorted(zip(N.data, distances), key=lambda t: t[1])
# skip self if contained
if distsNdata[0][1] == 0:
# if we have multiple hashes in the leaf, return the second best
if len(N.data) > 1:
return distsNdata[1][0]
# else None to continue search elsewhere
return None
# and return first element
return distsNdata[0][0]
@staticmethod
def SearchNotSelf(N, S):
if VPTSimple.isLeaf(N):
X = VPTSimple.closestItemNotSelf(N, S)
# if self found, then abort branch
if X is None:
return None
d = VPTSimple.Dist(X, S)
return X, d
thisDist = VPTSimple.Dist(N.Split, S)
if thisDist <= N.Threshold:
Res = VPTSimple.SearchNotSelf(N.LC, S)
# if we found something (not self) return
if Res is not None:
return Res
# otherwise search other branch
return VPTSimple.SearchNotSelf(N.RC, S)
# --------------------------------
def build(self, nitemsInLeaf=1):
assert nitemsInLeaf >= 1
VPTSimple.TreeBuild(self, nitemsInLeaf=nitemsInLeaf)
return self
def search(self, S, not_self=True):
if not_self:
return VPTSimple.SearchNotSelf(self, S)
return VPTSimple.Search(self, S)
class VPTSimpleForrest:
def __init__(self):
self.trees = []
def build(self, data, nTrees=None, nitemsInLeaf=1):
if nTrees is None:
nTrees = int(max(3, min(100, len(data) ** (1 / 3))))
self.trees = []
for _ in range(nTrees):
# copy and shuffle lists
data = list(data)
random.shuffle(data)
# build trees
tree = VPTSimple(data)
tree.build(nitemsInLeaf=nitemsInLeaf)
self.trees.append(tree)
return self
def search(self, S, not_self=True):
results = []
for ti in self.trees:
Res = ti.search(S, not_self=not_self)
if Res is not None:
results.append(Res)
# find best result or None if None
if len(results) >= 1:
Res = sorted(results, key=lambda ri: ri[1])[0]
return Res
else:
return None
# ----------------------------------------------------------------------------
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment