Created
December 31, 2023 00:44
-
-
Save russcam/09cf54eff3bedf0df17d03af53057e36 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System.Globalization; | |
using System.IO.Compression; | |
using System.Text.Json; | |
using Azure.AI.OpenAI; | |
using CsvHelper; | |
using CsvHelper.Configuration; | |
using CsvHelper.Configuration.Attributes; | |
using CsvHelper.TypeConversion; | |
using Qdrant.Client; | |
using Qdrant.Client.Grpc; | |
var embeddingsUrl = | |
new Uri("https://cdn.openai.com/API/examples/data/vector_database_wikipedia_articles_embedded.zip"); | |
var name = @"C:\vector_database_wikipedia_articles_embedded"; | |
if (!Directory.Exists(name)) | |
{ | |
using var httpClient = new HttpClient(); | |
await using (var stream = await httpClient.GetStreamAsync(embeddingsUrl)) | |
await using (var destination = new FileStream($"{name}.zip", FileMode.CreateNew)) | |
await stream.CopyToAsync(destination); | |
ZipFile.ExtractToDirectory($"{name}.zip", name); | |
} | |
var records = ReadRecords(Path.Combine(name, "vector_database_wikipedia_articles_embedded.csv")); | |
var firstRecord = records.First(); | |
var size = (ulong)firstRecord.ContentVector.Length; | |
var collectionName = "Articles"; | |
var client = new QdrantClient("localhost"); | |
try | |
{ | |
await client.DeleteCollectionAsync(collectionName); | |
} | |
catch (QdrantException) | |
{ | |
// swallow | |
} | |
await client.CreateCollectionAsync(collectionName, | |
new VectorParamsMap | |
{ | |
Map = | |
{ | |
["title"] = new VectorParams { Distance = Distance.Cosine, Size = size }, | |
["content"] = new VectorParams { Distance = Distance.Cosine, Size = size }, | |
} | |
}); | |
var points = new List<PointStruct>(1000) { RecordToPointStruct(firstRecord) }; | |
foreach (var record in records) | |
{ | |
points.Add(RecordToPointStruct(record)); | |
if (points.Count == 1000) | |
{ | |
await client.UpsertAsync(collectionName, points); | |
points.Clear(); | |
} | |
} | |
if (points.Any()) | |
await client.UpsertAsync(collectionName, points); | |
var count = await client.CountAsync(collectionName); | |
Console.WriteLine($"Count of points: {count}"); | |
var openAiApiKey = "<insert your OpenAI API key>"; | |
var openAIClient = new OpenAIClient(openAiApiKey); | |
var results = await Query(client, openAIClient, "modern art in Europe", collectionName); | |
foreach (var (point, i) in results.Select((point, i) => (point, i))) | |
Console.WriteLine($"{i + 1}. {point.Payload["title"].StringValue} (Score: {Math.Round(point.Score, 3)})"); | |
Console.WriteLine(); | |
results = await Query(client, openAIClient, "Famous battles in Scottish history", collectionName, "content"); | |
foreach (var (point, i) in results.Select((point, i) => (point, i))) | |
Console.WriteLine($"{i + 1}. {point.Payload["title"].StringValue} (Score: {Math.Round(point.Score, 3)})"); | |
return; | |
static async Task<IReadOnlyList<ScoredPoint>> Query( | |
QdrantClient client, | |
OpenAIClient openAIClient, | |
string query, | |
string collectionName, | |
string vectorName = "title", | |
ulong topK = 20) | |
{ | |
var response = await openAIClient.GetEmbeddingsAsync(new EmbeddingsOptions | |
{ | |
Input = { query }, | |
DeploymentName = "text-embedding-ada-002" | |
}); | |
return await client.SearchAsync(collectionName, response.Value.Data[0].Embedding, vectorName: vectorName, limit: topK); | |
} | |
static PointStruct RecordToPointStruct(CsvRecord record) | |
{ | |
return new PointStruct | |
{ | |
Id = (ulong)record.Id, | |
Vectors = new Dictionary<string, float[]> | |
{ | |
["title"] = record.TitleVector, | |
["content"] = record.ContentVector | |
}, | |
Payload = | |
{ | |
["url"] = record.Url, | |
["title"] = record.Title, | |
["text"] = record.Text | |
} | |
}; | |
} | |
static IEnumerable<CsvRecord> ReadRecords(string name) | |
{ | |
using var reader = new StreamReader(name); | |
var config = new CsvConfiguration(CultureInfo.InvariantCulture) | |
{ | |
PrepareHeaderForMatch = args => string.Concat(args.Header.Select((x, i) => | |
i > 0 && char.IsUpper(x) ? "_" + x : x.ToString())).ToLowerInvariant() | |
}; | |
using var csv = new CsvReader(reader, config); | |
foreach (var record in csv.GetRecords<CsvRecord>()) | |
yield return record; | |
} | |
public record CsvRecord( | |
int Id, | |
string Url, | |
string Title, | |
string Text, | |
[TypeConverter(typeof(StringToEmbeddingConverter))] | |
float[] TitleVector, | |
[TypeConverter(typeof(StringToEmbeddingConverter))] | |
float[] ContentVector, | |
int VectorId); | |
public class StringToEmbeddingConverter : DefaultTypeConverter | |
{ | |
public override object ConvertFromString(string? text, IReaderRow row, MemberMapData memberMapData) => | |
JsonSerializer.Deserialize<float[]>(text ?? throw new ArgumentNullException(nameof(text)))!; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment