Skip to content

Instantly share code, notes, and snippets.

@russcam
Created December 31, 2023 00:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save russcam/09cf54eff3bedf0df17d03af53057e36 to your computer and use it in GitHub Desktop.
Save russcam/09cf54eff3bedf0df17d03af53057e36 to your computer and use it in GitHub Desktop.
using System.Globalization;
using System.IO.Compression;
using System.Text.Json;
using Azure.AI.OpenAI;
using CsvHelper;
using CsvHelper.Configuration;
using CsvHelper.Configuration.Attributes;
using CsvHelper.TypeConversion;
using Qdrant.Client;
using Qdrant.Client.Grpc;
var embeddingsUrl =
new Uri("https://cdn.openai.com/API/examples/data/vector_database_wikipedia_articles_embedded.zip");
var name = @"C:\vector_database_wikipedia_articles_embedded";
if (!Directory.Exists(name))
{
using var httpClient = new HttpClient();
await using (var stream = await httpClient.GetStreamAsync(embeddingsUrl))
await using (var destination = new FileStream($"{name}.zip", FileMode.CreateNew))
await stream.CopyToAsync(destination);
ZipFile.ExtractToDirectory($"{name}.zip", name);
}
var records = ReadRecords(Path.Combine(name, "vector_database_wikipedia_articles_embedded.csv"));
var firstRecord = records.First();
var size = (ulong)firstRecord.ContentVector.Length;
var collectionName = "Articles";
var client = new QdrantClient("localhost");
try
{
await client.DeleteCollectionAsync(collectionName);
}
catch (QdrantException)
{
// swallow
}
await client.CreateCollectionAsync(collectionName,
new VectorParamsMap
{
Map =
{
["title"] = new VectorParams { Distance = Distance.Cosine, Size = size },
["content"] = new VectorParams { Distance = Distance.Cosine, Size = size },
}
});
var points = new List<PointStruct>(1000) { RecordToPointStruct(firstRecord) };
foreach (var record in records)
{
points.Add(RecordToPointStruct(record));
if (points.Count == 1000)
{
await client.UpsertAsync(collectionName, points);
points.Clear();
}
}
if (points.Any())
await client.UpsertAsync(collectionName, points);
var count = await client.CountAsync(collectionName);
Console.WriteLine($"Count of points: {count}");
var openAiApiKey = "<insert your OpenAI API key>";
var openAIClient = new OpenAIClient(openAiApiKey);
var results = await Query(client, openAIClient, "modern art in Europe", collectionName);
foreach (var (point, i) in results.Select((point, i) => (point, i)))
Console.WriteLine($"{i + 1}. {point.Payload["title"].StringValue} (Score: {Math.Round(point.Score, 3)})");
Console.WriteLine();
results = await Query(client, openAIClient, "Famous battles in Scottish history", collectionName, "content");
foreach (var (point, i) in results.Select((point, i) => (point, i)))
Console.WriteLine($"{i + 1}. {point.Payload["title"].StringValue} (Score: {Math.Round(point.Score, 3)})");
return;
static async Task<IReadOnlyList<ScoredPoint>> Query(
QdrantClient client,
OpenAIClient openAIClient,
string query,
string collectionName,
string vectorName = "title",
ulong topK = 20)
{
var response = await openAIClient.GetEmbeddingsAsync(new EmbeddingsOptions
{
Input = { query },
DeploymentName = "text-embedding-ada-002"
});
return await client.SearchAsync(collectionName, response.Value.Data[0].Embedding, vectorName: vectorName, limit: topK);
}
static PointStruct RecordToPointStruct(CsvRecord record)
{
return new PointStruct
{
Id = (ulong)record.Id,
Vectors = new Dictionary<string, float[]>
{
["title"] = record.TitleVector,
["content"] = record.ContentVector
},
Payload =
{
["url"] = record.Url,
["title"] = record.Title,
["text"] = record.Text
}
};
}
static IEnumerable<CsvRecord> ReadRecords(string name)
{
using var reader = new StreamReader(name);
var config = new CsvConfiguration(CultureInfo.InvariantCulture)
{
PrepareHeaderForMatch = args => string.Concat(args.Header.Select((x, i) =>
i > 0 && char.IsUpper(x) ? "_" + x : x.ToString())).ToLowerInvariant()
};
using var csv = new CsvReader(reader, config);
foreach (var record in csv.GetRecords<CsvRecord>())
yield return record;
}
public record CsvRecord(
int Id,
string Url,
string Title,
string Text,
[TypeConverter(typeof(StringToEmbeddingConverter))]
float[] TitleVector,
[TypeConverter(typeof(StringToEmbeddingConverter))]
float[] ContentVector,
int VectorId);
public class StringToEmbeddingConverter : DefaultTypeConverter
{
public override object ConvertFromString(string? text, IReaderRow row, MemberMapData memberMapData) =>
JsonSerializer.Deserialize<float[]>(text ?? throw new ArgumentNullException(nameof(text)))!;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment