Skip to content

Instantly share code, notes, and snippets.

@smcveigh-phunware
Created March 28, 2018 22:03
Show Gist options
  • Save smcveigh-phunware/9e274942908a7cd366ca866f06505be7 to your computer and use it in GitHub Desktop.
Save smcveigh-phunware/9e274942908a7cd366ca866f06505be7 to your computer and use it in GitHub Desktop.
cdist(..., "sqeuclidean") vs. broadcasting
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import logging as log\n",
"import numpy as np\n",
"from typing import Optional\n",
"from scipy.spatial.distance import cdist"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"np.set_printoptions(edgeitems=10)\n",
"np.core.arrayprint._line_width = 180"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"FORMAT = \"{asctime} {name} {levelname:8s} {message}\"\n",
"log.basicConfig(format=FORMAT, style='{', level=log.DEBUG)\n",
"logger = log.getLogger(\"notebook\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def euclidean_square(x: np.ndarray, xp: Optional[np.ndarray]=None) -> np.ndarray:\n",
" \n",
" assert x.ndim == 2, \"x must be 2D\"\n",
" logger.debug(\"x.shape = {}\".format(x.shape))\n",
" \n",
" xs = np.sum(np.square(x), axis=1)\n",
" \n",
" if xp is None:\n",
" return np.reshape(xs, (-1, 1)) + np.reshape(xs, (1, -1)) - 2 * np.matmul(x, x.T) \n",
" \n",
" assert xp.ndim == 2, \"xp must be 2D\"\n",
" assert x.shape[1] == xp.shape[1], \"x and xp must have the same number of columns\"\n",
" logger.debug(\"xp.shape = {}\".format(xp.shape))\n",
" \n",
" xps = np.sum(np.square(xp), axis=1)\n",
" \n",
" return np.reshape(xs, (-1, 1)) + np.reshape(xps, (1, -1)) - 2 * np.matmul(x, xp.T) \n",
"\n",
"\n",
"def euclidean(x: np.ndarray, xp: Optional[np.ndarray]=None) -> np.ndarray:\n",
" return np.sqrt(euclidean_square(x, xp)) \n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2018-03-28 15:01:30,413 notebook DEBUG x.shape = (7, 10)\n",
"2018-03-28 15:01:30,414 notebook DEBUG xp.shape = (3, 10)\n"
]
},
{
"data": {
"text/plain": [
"array([[5.4832288 , 3.96775002, 4.59193113],\n",
" [5.44005569, 6.46171414, 3.67261701],\n",
" [5.6897614 , 5.77721062, 6.12629625],\n",
" [2.4223504 , 6.31955375, 7.90340939],\n",
" [2.83818613, 5.30016691, 4.98072687],\n",
" [2.30973936, 4.85547768, 8.65337153],\n",
" [3.08661838, 2.52295335, 4.26490056]])"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"array([[5.4832288 , 3.96775002, 4.59193113],\n",
" [5.44005569, 6.46171414, 3.67261701],\n",
" [5.6897614 , 5.77721062, 6.12629625],\n",
" [2.4223504 , 6.31955375, 7.90340939],\n",
" [2.83818613, 5.30016691, 4.98072687],\n",
" [2.30973936, 4.85547768, 8.65337153],\n",
" [3.08661838, 2.52295335, 4.26490056]])"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"array([[False, False, True],\n",
" [ True, False, False],\n",
" [ True, True, True],\n",
" [False, False, False],\n",
" [False, False, False],\n",
" [False, True, False],\n",
" [ True, True, False]])"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"True"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"cols, x_rows, y_rows = np.random.randint(low=3, high=12, size=3)\n",
"\n",
"x = np.random.uniform(low=-1, high=1, size=(x_rows, cols))\n",
"y = np.random.uniform(low=-1, high=1, size=(y_rows, cols))\n",
" \n",
"a = cdist(x, y, 'sqeuclidean')\n",
"b = euclidean_square(x, y)\n",
"\n",
"display(a,b)\n",
"\n",
"display(np.equal(a, b))\n",
"display(np.allclose(a, b))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment