Skip to content

Instantly share code, notes, and snippets.

@alexminza
Created October 2, 2023 18:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save alexminza/8b954d057db25f51e8b4827f9e43418b to your computer and use it in GitHub Desktop.
Save alexminza/8b954d057db25f51e8b4827f9e43418b to your computer and use it in GitHub Desktop.
Semantic Kernel AlgoliaMemory
using System.Runtime.CompilerServices;
using Microsoft.SemanticKernel.Memory;
using Algolia.Search.Clients;
using Algolia.Search.Models.Search;
using System.Text.Json;
public class AlgoliaMemory : ISemanticTextMemory
{
private readonly ISearchClient _searchClient;
private readonly IEnumerable<string>? _attributesToRetrieve = null;
private readonly IDictionary<string, string>? _attributesMapping = null;
public AlgoliaMemory(ISearchClient searchClient, IEnumerable<string>? attributesToRetrieve = null, IDictionary<string, string>? attributesMapping = null)
{
this._searchClient = searchClient ?? throw new ArgumentNullException(nameof(searchClient));
this._attributesToRetrieve = attributesToRetrieve;
this._attributesMapping = attributesMapping;
}
public async Task<MemoryQueryResult?> GetAsync(string collection, string key, bool withEmbedding = false, CancellationToken cancellationToken = default)
{
var searchIndex = this._searchClient.InitIndex(collection);
var objectResult = await searchIndex.GetObjectAsync<Dictionary<string, object?>>(objectId: key);
var memoryResult = BuildMemoryQueryResult(objectResult);
return memoryResult;
}
public async Task<IList<string>> GetCollectionsAsync(CancellationToken cancellationToken = default)
{
var indices = await this._searchClient.ListIndicesAsync(ct: cancellationToken);
var collections = indices.Items
.Select(index => index.Name)
.ToList();
return collections;
}
public Task RemoveAsync(string collection, string key, CancellationToken cancellationToken = default)
{
throw new NotImplementedException();
}
public Task<string> SaveInformationAsync(string collection, string text, string id, string? description = null, string? additionalMetadata = null, CancellationToken cancellationToken = default)
{
throw new NotImplementedException();
}
public Task<string> SaveReferenceAsync(string collection, string text, string externalId, string externalSourceName, string? description = null, string? additionalMetadata = null, CancellationToken cancellationToken = default)
{
throw new NotImplementedException();
}
public async IAsyncEnumerable<MemoryQueryResult> SearchAsync(string collection, string query, int limit = 1, double minRelevanceScore = 0.7, bool withEmbeddings = false, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var searchIndex = this._searchClient.InitIndex(collection);
var algoliaQuery = new Query(query)
{
AttributesToRetrieve = this._attributesToRetrieve, //https://www.algolia.com/doc/api-reference/api-parameters/attributesToRetrieve/
RelevancyStrictness = (int)(minRelevanceScore * 100), //https://www.algolia.com/doc/api-reference/api-parameters/relevancyStrictness/
Length = limit, //https://www.algolia.com/doc/api-reference/api-parameters/length/
Offset = 0 //https://www.algolia.com/doc/api-reference/api-parameters/offset/
};
var searchResponse = await searchIndex.SearchAsync<Dictionary<string, object?>>(query: algoliaQuery, ct: cancellationToken);
foreach (var searchResult in searchResponse.Hits)
{
var memoryResult = BuildMemoryQueryResult(searchResult);
yield return memoryResult;
}
}
protected MemoryQueryResult BuildMemoryQueryResult(Dictionary<string, object?> searchResult)
{
bool isReference = Convert.ToBoolean(this.GetMappedAttributeValue(searchResult, "isReference") ?? false);
string id = Convert.ToString(this.GetMappedAttributeValue(searchResult, "objectID")) ?? string.Empty;
string text = Convert.ToString(this.GetMappedAttributeValue(searchResult, "text")) ?? string.Empty;
string description = Convert.ToString(this.GetMappedAttributeValue(searchResult, "description")) ?? string.Empty;
string externalSourceName = Convert.ToString(this.GetMappedAttributeValue(searchResult, "externalSourceName")) ?? string.Empty;
string additionalMetadata = SerializeMetadata(searchResult);
var memoryMetadata = new MemoryRecordMetadata(isReference: isReference, id: id, text: text, description: description, externalSourceName: externalSourceName, additionalMetadata: additionalMetadata);
var memoryResult = new MemoryQueryResult(metadata: memoryMetadata, relevance: 0, embedding: ReadOnlyMemory<float>.Empty);
return memoryResult;
}
protected object? GetMappedAttributeValue(Dictionary<string, object?> searchResult, string attribute)
{
string? attributeMapping = attribute;
this._attributesMapping?.TryGetValue(attribute, out attributeMapping);
if(string.IsNullOrWhiteSpace(attributeMapping))
return null;
return searchResult.GetValueOrDefault(attributeMapping);
}
protected static string SerializeMetadata(Dictionary<string, object?> searchResult)
{
var additionalMetadataDict = searchResult.ToDictionary(kv => kv.Key, kv => kv.Value is string stringValue
? stringValue
: JsonSerializer.Serialize(kv.Value));
string additionalMetadata = JsonSerializer.Serialize(additionalMetadataDict);
return additionalMetadata;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment