Skip to content

Instantly share code, notes, and snippets.

@ergo70
Last active April 5, 2024 09:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ergo70/18bb47f4d6b43d51b7049f2f1b82dd31 to your computer and use it in GitHub Desktop.
Save ergo70/18bb47f4d6b43d51b7049f2f1b82dd31 to your computer and use it in GitHub Desktop.
cosine similarity function on float4 vectors, stored as PostgreSQL bytea
#include <math.h>
#include "postgres.h"
#include "fmgr.h"
#include "utils/array.h"
#include "access/htup.h"
#include "catalog/pg_type.h"
#include "utils/lsyscache.h" // Required for building with PGXS (at least on macOS)
/*
CREATE FUNCTION cast_bytea_to_float4_array(bytea) RETURNS float4[]
AS '$libdir/bytea2float4vec', 'cast_bytea_to_float4_array'
LANGUAGE C strict immutable parallel safe;
CREATE FUNCTION cast_float4_array_to_bytea(float4[]) RETURNS bytea
AS '$libdir/bytea2float4vec', 'cast_float4_array_to_bytea'
LANGUAGE C strict immutable parallel safe;
CREATE FUNCTION cosine_similarity_bytea(bytea, bytea) RETURNS float8
AS '$libdir/bytea2float4vec', 'cosine_similarity_bytea'
LANGUAGE C strict immutable parallel safe;
CREATE CAST (float4[] AS bytea) WITH FUNCTION cast_float4_array_to_bytea(float4[]) AS assignment;
CREATE CAST (bytea AS float4[]) WITH FUNCTION cast_bytea_to_float4_array(bytea) AS assignment;
*/
PG_MODULE_MAGIC;
PGDLLEXPORT PG_FUNCTION_INFO_V1(cast_bytea_to_float4_array);
Datum cast_bytea_to_float4_array(PG_FUNCTION_ARGS)
{
bytea *a = PG_GETARG_BYTEA_PP(0);
Oid elemtype = FLOAT4OID;
uint32 data_length_a = VARSIZE_ANY(a) - VARHDRSZ;
float *readptr = (float *)VARDATA_ANY(a);
ArrayType *retval = NULL;
Datum *elements = NULL;
int16 typlen = 0;
bool typbyval;
char typalign;
int ndims = 1;
int dims[MAXDIM];
int lbs[MAXDIM];
int num_elements = data_length_a / sizeof(float);
dims[0] = num_elements;
lbs[0] = 1;
elements = (Datum *)palloc0(num_elements * sizeof(Datum));
for (int i = 0; i < num_elements; i++)
{
elements[i] = Float4GetDatum(*readptr);
readptr++;
}
get_typlenbyvalalign(elemtype, &typlen, &typbyval, &typalign);
retval = construct_md_array(elements, NULL, ndims, dims, lbs,
elemtype, typlen, typbyval, typalign);
pfree(elements);
PG_RETURN_ARRAYTYPE_P(retval);
}
PGDLLEXPORT PG_FUNCTION_INFO_V1(cast_float4_array_to_bytea);
Datum cast_float4_array_to_bytea(PG_FUNCTION_ARGS)
{
ArrayType *a = PG_GETARG_ARRAYTYPE_P(0);
Oid elemtypeA = ARR_ELEMTYPE(a);
Datum *datumsA = NULL;
int countA = 0;
int16 elemWidthA;
bool elemTypeByValA;
char elemAlignmentCodeA;
bytea *retval = NULL;
char *writeptr = NULL;
float fieldA = 0.0;
if (elemtypeA != FLOAT4OID)
ereport(ERROR, (errcode(ERRCODE_DATA_EXCEPTION), errmsg("float4 OID array needed. Got %d", elemtypeA)));
if (ARR_NDIM(a) != 1)
ereport(ERROR, (errcode(ERRCODE_DATA_EXCEPTION), errmsg("One-dimensional array needed. Got %d", ARR_NDIM(a))));
get_typlenbyvalalign(elemtypeA, &elemWidthA, &elemTypeByValA, &elemAlignmentCodeA);
deconstruct_array(a, elemtypeA, elemWidthA, elemTypeByValA, elemAlignmentCodeA, &datumsA, NULL, &countA);
retval = palloc(VARHDRSZ + (countA * elemWidthA));
writeptr = (char *)VARDATA(retval);
for (int i = 0; i < countA; i++)
{
fieldA = DatumGetFloat4(datumsA[i]);
memcpy(writeptr, &fieldA, elemWidthA);
writeptr += sizeof(float);
}
SET_VARSIZE(retval, VARHDRSZ + (countA * elemWidthA));
PG_RETURN_BYTEA_P(retval);
}
PGDLLEXPORT PG_FUNCTION_INFO_V1(cosine_similarity_bytea);
Datum cosine_similarity_bytea(PG_FUNCTION_ARGS)
{
bytea *a = PG_GETARG_BYTEA_PP(0);
bytea *b = PG_GETARG_BYTEA_PP(1);
uint32 data_length_a = VARSIZE_ANY(a) - VARHDRSZ;
uint32 data_length_b = VARSIZE_ANY(b) - VARHDRSZ;
float *fa = (float *)VARDATA_ANY(a);
float *fb = (float *)VARDATA_ANY(b);
float distance = 0.0f;
float norma = 0.0f;
float normb = 0.0f;
float8 similarity = -666.0;
if ((data_length_a % sizeof(float) != 0) || (data_length_b % sizeof(float) != 0))
ereport(ERROR,
(errcode(ERRCODE_DATA_EXCEPTION),
errmsg("Vector size does not match sizeof(float)")));
if (data_length_a != data_length_b)
ereport(ERROR,
(errcode(ERRCODE_DATA_EXCEPTION),
errmsg("Different vector dimensions %d and %d", (data_length_a / sizeof(float)), (data_length_b / sizeof(float)))));
for (int i = 0; i < data_length_a; i += sizeof(float))
{
distance += *fa * *fb;
norma += *fa * *fa;
normb += *fb * *fb;
fa++;
fb++;
}
similarity = (double)distance / sqrt((double)norma * (double)normb);
PG_RETURN_FLOAT8(similarity);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment