Created
July 19, 2024 20:02
-
-
Save benwtrent/bfda4ddb6eecaf6b2cd4bfc6b63b8425 to your computer and use it in GitHub Desktop.
simple notebook for vector testing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import pyarrow.parquet as pq\n", | |
"\n", | |
"tbls = []\n", | |
"for i in range(10):\n", | |
" tbls.append(pq.read_table('data/%d-en.parquet' % i, columns=['emb']))\n", | |
"np_total = np.concatenate([tbl[0].to_numpy() for tbl in tbls])\n", | |
"flat_ds = list()\n", | |
"for vec in np_total:\n", | |
" flat_ds.append(vec)\n", | |
"np_flat_ds = np.array(flat_ds)\n", | |
"print(np_flat_ds.shape)\n", | |
"np.random.shuffle(np_flat_ds)\n", | |
"query_vectors = np_flat_ds[:50]\n", | |
"doc_vectors = np_flat_ds[50:]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def simple_binary_quantize(vec):\n", | |
" bits = vec > 0\n", | |
" return np.packbits(bits)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# binary quantize the vectors\n", | |
"binary_vectors = np.apply_along_axis(simple_binary_quantize, 1, doc_vectors)\n", | |
"binary_queries = np.apply_along_axis(simple_binary_quantize, 1, query_vectors)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"all_kth5, all_kth10, all_kth50, all_kth100 = [], [], [], []\n", | |
"best_kth5, best_kth10, best_kth50, best_kth100 = None, None, None, None \n", | |
"best_hamming_distance = np.zeros(doc_vectors.shape[0])\n", | |
"best_distance = np.zeros(doc_vectors.shape[0])\n", | |
"best_sorted_hamming_distances = np.zeros(doc_vectors.shape[0])\n", | |
"\n", | |
"worst_kth5, worst_kth10, worst_kth50, worst_kth100 = None, None, None, None \n", | |
"worst_hamming_distance = np.zeros(doc_vectors.shape[0])\n", | |
"worst_distance = np.zeros(doc_vectors.shape[0])\n", | |
"worst_sorted_hamming_distances = np.zeros(doc_vectors.shape[0])\n", | |
"# loop through the query vectors, finding the worst & best case hamming distance for the top 100 nearest neighbors\n", | |
"for i in range(50):\n", | |
" query_vector = query_vectors[i]\n", | |
" binary_query = binary_queries[i]\n", | |
" # find the nearest neighbors for the query vector\n", | |
" distances = np.dot(doc_vectors, query_vector)\n", | |
" # get the nearest 100 neighbors\n", | |
" nearest_indices = np.argsort(distances)[::-1][:100]\n", | |
" # now to bitwise_xor the query vector with all the doc vectors\n", | |
" # and then count the number of bits that are set\n", | |
" # this is the hamming distance \n", | |
" xor_bits = np.bitwise_xor(binary_vectors, binary_query)\n", | |
" # gotta unpack the bits\n", | |
" hamming_distances = np.unpackbits(xor_bits, axis=1).sum(axis=1)\n", | |
" hamming_distances = (doc_vectors.shape[1] - hamming_distances) / doc_vectors.shape[1] \n", | |
" # sort the hamming distances, descending\n", | |
" sorted_hamming_distances = np.sort(hamming_distances)[::-1]\n", | |
" # for k = 5, 10, 50, 100 we determine the hamming distance calculuated given the true nearest_indices\n", | |
" cut = np.min(hamming_distances[nearest_indices[:5]])\n", | |
" kth5 = len(np.argwhere(hamming_distances >= cut))\n", | |
" cut = np.min(hamming_distances[nearest_indices[:10]])\n", | |
" kth10 = len(np.argwhere(hamming_distances >= cut))\n", | |
" cut = np.min(hamming_distances[nearest_indices[:50]])\n", | |
" kth50 = len(np.argwhere(hamming_distances >= cut))\n", | |
" cut = np.min(hamming_distances[nearest_indices[:100]])\n", | |
" kth100 = len(np.argwhere(hamming_distances >= cut))\n", | |
" all_kth5.append(kth5)\n", | |
" all_kth10.append(kth10)\n", | |
" all_kth50.append(kth50)\n", | |
" all_kth100.append(kth100)\n", | |
" # update the best case\n", | |
" if best_kth100 is None or kth100 < best_kth100:\n", | |
" best_kth5 = kth5\n", | |
" best_kth10 = kth10\n", | |
" best_kth50 = kth50\n", | |
" best_kth100 = kth100\n", | |
" best_hamming_distance = hamming_distances\n", | |
" best_distance = distances\n", | |
" best_sorted_hamming_distances = sorted_hamming_distances\n", | |
" # update the worst case\n", | |
" if worst_kth100 is None or kth100 > worst_kth100:\n", | |
" worst_kth5 = kth5\n", | |
" worst_kth10 = kth10\n", | |
" worst_kth50 = kth50\n", | |
" worst_kth100 = kth100\n", | |
" worst_hamming_distance = hamming_distances\n", | |
" worst_distance = distances\n", | |
" worst_sorted_hamming_distances = sorted_hamming_distances\n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import matplotlib.pyplot as plt\n", | |
"def plot_kth_vs_distances(title, distances, hamming_distances, sorted_hamming_distances, kth5, kth10, kth50, kth100):\n", | |
" plt.scatter(distances, hamming_distances)\n", | |
" # label x & y\n", | |
" plt.xlabel(\"Dot Distance\")\n", | |
" plt.ylabel(\"Hamming Distance\")\n", | |
" # add lines for the kth values given the sorted hamming distances\n", | |
" plt.axhline(y=sorted_hamming_distances[kth5], color='tab:green', linestyle='--', label=f'kth5:{kth5}')\n", | |
" plt.axhline(y=sorted_hamming_distances[kth10], color='tab:blue', linestyle='--', label=f'kth10:{kth10}')\n", | |
" plt.axhline(y=sorted_hamming_distances[kth50], color='tab:orange', linestyle='--', label=f'kth50:{kth50}')\n", | |
" plt.axhline(y=sorted_hamming_distances[kth100], color='tab:red', linestyle='--', label=f'kth100:{kth100}')\n", | |
" # add the title\n", | |
" plt.title(title)\n", | |
" plt.legend()\n", | |
" plt.show()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# plot the best case\n", | |
"plot_kth_vs_distances(\"e5-small best\", best_distance, best_hamming_distance, best_sorted_hamming_distances, best_kth5, best_kth10, best_kth50, best_kth100)\n", | |
"# plot the worst case\n", | |
"plot_kth_vs_distances(\"e5-small worst\", worst_distance, worst_hamming_distance, worst_sorted_hamming_distances, worst_kth5, worst_kth10, worst_kth50, worst_kth100)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "venv", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"name": "python", | |
"version": "3.10.12" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment