Skip to content

Instantly share code, notes, and snippets.

@alexminza
Created August 11, 2023 09:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save alexminza/f05fb189de8d8eeb53ab62b35fbc4660 to your computer and use it in GitHub Desktop.
Save alexminza/f05fb189de8d8eeb53ab62b35fbc4660 to your computer and use it in GitHub Desktop.
Semantic Kernel ConversationalRetrievalPlugin
using System.ComponentModel;
using System.Text;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Orchestration;
using Microsoft.SemanticKernel.SkillDefinition;
using Microsoft.SemanticKernel.AI.ChatCompletion;
using Microsoft.SemanticKernel.Memory;
class ConversationalRetrievalPlugin
{
private readonly IKernel _kernel;
private readonly ChatHistory _chatHistory;
private readonly string _memoryCollectionName;
private readonly bool _rephraseQuestion;
private readonly ISKFunction _condenseQuestionFunction;
private readonly ISKFunction _answerContextQuestionFunction;
private readonly ISKFunction _searchMemoryFunction;
private readonly ISKFunction _rephraseQuestionFunction;
//https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/chains/conversational_retrieval/prompts.py
//private const string instructions = "You are a helpful friendly assistant.";
private const string searchResultsKey = "context";
private const string chatHistoryKey = "chat_history";
private const string questionKey = "question";
public ConversationalRetrievalPlugin(IKernel kernel, string memoryCollectionName, bool rephraseQuestion = true, string resourcesPath = "./Resources")
{
this._kernel = kernel;
this._chatHistory = new ChatHistory();
this._memoryCollectionName = memoryCollectionName;
this._rephraseQuestion = rephraseQuestion;
#region Define plugin semantic functions
var semanticFunctions = kernel.ImportSemanticSkillFromDirectory(resourcesPath, "ConversationalRetrievalPlugin");
this._condenseQuestionFunction = semanticFunctions["CondenseQuestion"];
this._answerContextQuestionFunction = semanticFunctions["AnswerContextQuestion"];
var nativeFunctions = kernel.ImportSkill(this, nameof(ConversationalRetrievalPlugin));
this._searchMemoryFunction = nativeFunctions[nameof(SearchMemory)];
this._rephraseQuestionFunction = nativeFunctions[nameof(RephraseQuestion)];
#endregion
}
[SKFunction, Description("QnA function for having a conversation based on retrieved documents.")]
[SKParameter("input", "User question")]
public async Task<SKContext> QnA(SKContext context, ILogger? logger, CancellationToken cancellationToken = default)
{
//https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/chains/conversational_retrieval/base.py
string question = context.Variables.Input;
context.Variables.Set(questionKey, question);
this._chatHistory.AddUserMessage(question);
context.Variables.Set(chatHistoryKey, GetChatHistoryText(this._chatHistory));
var response = await this._kernel.RunAsync(
variables: context.Variables,
cancellationToken: cancellationToken,
this._condenseQuestionFunction,
this._searchMemoryFunction,
this._rephraseQuestionFunction,
this._answerContextQuestionFunction);
this._chatHistory.AddAssistantMessage(response.Result);
return response;
}
[SKFunction, Description("Use original user question or rephrased question.")]
[SKParameter("input", "Rephrased user question")]
[SKParameter(questionKey, "Original user question")]
public SKContext RephraseQuestion(SKContext context, ILogger? logger)
{
if (!this._rephraseQuestion)
{
string oriqinalQuestion = context.Variables[questionKey];
context.Variables.Update(oriqinalQuestion);
}
return context;
}
[SKFunction, Description("Search memory documents for fragments relevant to the user question.")]
[SKParameter("input", "User question")]
public async Task<SKContext> SearchMemory(SKContext context, ILogger? logger, CancellationToken cancellationToken = default)
{
//https://github.com/microsoft/semantic-kernel/blob/main/dotnet/src/Skills/Skills.Core/TextMemorySkill.cs
//https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/KernelSyntaxExamples/Example15_MemorySkill.cs#L63
var searchResults = await this._kernel.Memory.SearchAsync(
collection: this._memoryCollectionName,
query: context.Variables.Input,
limit: 5,
minRelevanceScore: 0.5,
cancellationToken: cancellationToken)
.ToListAsync(cancellationToken: cancellationToken);
context.Variables.Set(searchResultsKey, GetSearchResultsText(searchResults));
return context;
}
private static string GetChatHistoryText(ChatHistory chatHistory)
{
var chatHistorySB = new StringBuilder();
foreach (var chatMessage in chatHistory.Messages)
if (chatMessage.Role != AuthorRole.System)
chatHistorySB.AppendLine($"{chatMessage.Role}: {chatMessage.Content}");
return chatHistorySB.ToString();
}
private static string GetSearchResultsText(IEnumerable<MemoryQueryResult> searchResults)
{
var searchResultsSB = new StringBuilder();
foreach (var searchResult in searchResults)
{
searchResultsSB.AppendLine(searchResult.Metadata.Id);
searchResultsSB.AppendLine(searchResult.Metadata.Text);
searchResultsSB.AppendLine("");
}
return searchResultsSB.ToString();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment