Skip to content

Instantly share code, notes, and snippets.

@benwtrent
Created July 19, 2024 20:02
Show Gist options
  • Save benwtrent/bfda4ddb6eecaf6b2cd4bfc6b63b8425 to your computer and use it in GitHub Desktop.
Save benwtrent/bfda4ddb6eecaf6b2cd4bfc6b63b8425 to your computer and use it in GitHub Desktop.
simple notebook for vector testing
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pyarrow.parquet as pq\n",
"\n",
"tbls = []\n",
"for i in range(10):\n",
" tbls.append(pq.read_table('data/%d-en.parquet' % i, columns=['emb']))\n",
"np_total = np.concatenate([tbl[0].to_numpy() for tbl in tbls])\n",
"flat_ds = list()\n",
"for vec in np_total:\n",
" flat_ds.append(vec)\n",
"np_flat_ds = np.array(flat_ds)\n",
"print(np_flat_ds.shape)\n",
"np.random.shuffle(np_flat_ds)\n",
"query_vectors = np_flat_ds[:50]\n",
"doc_vectors = np_flat_ds[50:]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def simple_binary_quantize(vec):\n",
" bits = vec > 0\n",
" return np.packbits(bits)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# binary quantize the vectors\n",
"binary_vectors = np.apply_along_axis(simple_binary_quantize, 1, doc_vectors)\n",
"binary_queries = np.apply_along_axis(simple_binary_quantize, 1, query_vectors)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"all_kth5, all_kth10, all_kth50, all_kth100 = [], [], [], []\n",
"best_kth5, best_kth10, best_kth50, best_kth100 = None, None, None, None \n",
"best_hamming_distance = np.zeros(doc_vectors.shape[0])\n",
"best_distance = np.zeros(doc_vectors.shape[0])\n",
"best_sorted_hamming_distances = np.zeros(doc_vectors.shape[0])\n",
"\n",
"worst_kth5, worst_kth10, worst_kth50, worst_kth100 = None, None, None, None \n",
"worst_hamming_distance = np.zeros(doc_vectors.shape[0])\n",
"worst_distance = np.zeros(doc_vectors.shape[0])\n",
"worst_sorted_hamming_distances = np.zeros(doc_vectors.shape[0])\n",
"# loop through the query vectors, finding the worst & best case hamming distance for the top 100 nearest neighbors\n",
"for i in range(50):\n",
" query_vector = query_vectors[i]\n",
" binary_query = binary_queries[i]\n",
" # find the nearest neighbors for the query vector\n",
" distances = np.dot(doc_vectors, query_vector)\n",
" # get the nearest 100 neighbors\n",
" nearest_indices = np.argsort(distances)[::-1][:100]\n",
" # now to bitwise_xor the query vector with all the doc vectors\n",
" # and then count the number of bits that are set\n",
" # this is the hamming distance \n",
" xor_bits = np.bitwise_xor(binary_vectors, binary_query)\n",
" # gotta unpack the bits\n",
" hamming_distances = np.unpackbits(xor_bits, axis=1).sum(axis=1)\n",
" hamming_distances = (doc_vectors.shape[1] - hamming_distances) / doc_vectors.shape[1] \n",
" # sort the hamming distances, descending\n",
" sorted_hamming_distances = np.sort(hamming_distances)[::-1]\n",
" # for k = 5, 10, 50, 100 we determine the hamming distance calculuated given the true nearest_indices\n",
" cut = np.min(hamming_distances[nearest_indices[:5]])\n",
" kth5 = len(np.argwhere(hamming_distances >= cut))\n",
" cut = np.min(hamming_distances[nearest_indices[:10]])\n",
" kth10 = len(np.argwhere(hamming_distances >= cut))\n",
" cut = np.min(hamming_distances[nearest_indices[:50]])\n",
" kth50 = len(np.argwhere(hamming_distances >= cut))\n",
" cut = np.min(hamming_distances[nearest_indices[:100]])\n",
" kth100 = len(np.argwhere(hamming_distances >= cut))\n",
" all_kth5.append(kth5)\n",
" all_kth10.append(kth10)\n",
" all_kth50.append(kth50)\n",
" all_kth100.append(kth100)\n",
" # update the best case\n",
" if best_kth100 is None or kth100 < best_kth100:\n",
" best_kth5 = kth5\n",
" best_kth10 = kth10\n",
" best_kth50 = kth50\n",
" best_kth100 = kth100\n",
" best_hamming_distance = hamming_distances\n",
" best_distance = distances\n",
" best_sorted_hamming_distances = sorted_hamming_distances\n",
" # update the worst case\n",
" if worst_kth100 is None or kth100 > worst_kth100:\n",
" worst_kth5 = kth5\n",
" worst_kth10 = kth10\n",
" worst_kth50 = kth50\n",
" worst_kth100 = kth100\n",
" worst_hamming_distance = hamming_distances\n",
" worst_distance = distances\n",
" worst_sorted_hamming_distances = sorted_hamming_distances\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"def plot_kth_vs_distances(title, distances, hamming_distances, sorted_hamming_distances, kth5, kth10, kth50, kth100):\n",
" plt.scatter(distances, hamming_distances)\n",
" # label x & y\n",
" plt.xlabel(\"Dot Distance\")\n",
" plt.ylabel(\"Hamming Distance\")\n",
" # add lines for the kth values given the sorted hamming distances\n",
" plt.axhline(y=sorted_hamming_distances[kth5], color='tab:green', linestyle='--', label=f'kth5:{kth5}')\n",
" plt.axhline(y=sorted_hamming_distances[kth10], color='tab:blue', linestyle='--', label=f'kth10:{kth10}')\n",
" plt.axhline(y=sorted_hamming_distances[kth50], color='tab:orange', linestyle='--', label=f'kth50:{kth50}')\n",
" plt.axhline(y=sorted_hamming_distances[kth100], color='tab:red', linestyle='--', label=f'kth100:{kth100}')\n",
" # add the title\n",
" plt.title(title)\n",
" plt.legend()\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# plot the best case\n",
"plot_kth_vs_distances(\"e5-small best\", best_distance, best_hamming_distance, best_sorted_hamming_distances, best_kth5, best_kth10, best_kth50, best_kth100)\n",
"# plot the worst case\n",
"plot_kth_vs_distances(\"e5-small worst\", worst_distance, worst_hamming_distance, worst_sorted_hamming_distances, worst_kth5, worst_kth10, worst_kth50, worst_kth100)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment