Skip to content

Instantly share code, notes, and snippets.

@jjerphan
Created June 23, 2021 11:39
Show Gist options
  • Save jjerphan/8bc10072e02637d318653aaa493430fd to your computer and use it in GitHub Desktop.
Save jjerphan/8bc10072e02637d318653aaa493430fd to your computer and use it in GitHub Desktop.
import numpy as np
from sklearn.neighbors import DistanceMetric
from .common import Benchmark
class DistanceMetricBenchmark(Benchmark):
param_names = ["n", "d"]
params = ([100, 1000, 10_000], [5, 10, 100])
def setup(self, n, d):
self.rng = np.random.RandomState(0)
self.X = self.rng.random_sample((n, d))
self.Y = self.rng.random_sample((n, d))
self.V = self.rng.random_sample((d, d))
self.VI = np.dot(self.V, self.V.T)
class EuclideanDistanceBenchmark(DistanceMetricBenchmark):
def setup(self, n, d):
super().setup(n, d)
self.dist_metric = DistanceMetric.get_metric("euclidean")
def time_pairwise(self, n, d):
self.dist_metric.pairwise(self.X, self.Y)
class ManhattanDistanceBenchmark(DistanceMetricBenchmark):
def setup(self, n, d):
super().setup(n, d)
self.dist_metric = DistanceMetric.get_metric("manhattan")
def time_pairwise(self, n, d):
self.dist_metric.pairwise(self.X, self.Y)
class ChebyshevDistanceBenchmark(DistanceMetricBenchmark):
def setup(self, n, d):
super().setup(n, d)
self.dist_metric = DistanceMetric.get_metric("chebyshev")
def time_pairwise(self, n, d):
self.dist_metric.pairwise(self.X, self.Y)
class MinkowskiDistanceBenchmark(DistanceMetricBenchmark):
def setup(self, n, d):
super().setup(n, d)
self.dist_metric = DistanceMetric.get_metric("minkowski", p=1.5)
def time_pairwise(self, n, d):
self.dist_metric.pairwise(self.X, self.Y)
class WMinkowskiDistanceBenchmark(DistanceMetricBenchmark):
def setup(self, n, d):
super().setup(n, d)
self.dist_metric = DistanceMetric.get_metric("wminkowski", p=1.5,
w=self.rng.random_sample(d))
def time_pairwise(self, n, d):
self.dist_metric.pairwise(self.X, self.Y)
class SEuclideanDistanceBenchmark(DistanceMetricBenchmark):
def setup(self, n, d):
super().setup(n, d)
self.dist_metric = DistanceMetric.get_metric("seuclidean", V=self.V)
def time_pairwise(self, n, d):
self.dist_metric.pairwise(self.X, self.Y)
class MahalanobisDistanceBenchmark(DistanceMetricBenchmark):
def setup(self, n, d):
super().setup(n, d)
self.dist_metric = DistanceMetric.get_metric("mahalanobis", VI=self.VI)
def time_pairwise(self, n, d):
self.dist_metric.pairwise(self.X, self.Y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment