Created
August 11, 2023 09:49
-
-
Save alexminza/f05fb189de8d8eeb53ab62b35fbc4660 to your computer and use it in GitHub Desktop.
Semantic Kernel ConversationalRetrievalPlugin
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System.ComponentModel; | |
using System.Text; | |
using Microsoft.Extensions.Logging; | |
using Microsoft.SemanticKernel; | |
using Microsoft.SemanticKernel.Orchestration; | |
using Microsoft.SemanticKernel.SkillDefinition; | |
using Microsoft.SemanticKernel.AI.ChatCompletion; | |
using Microsoft.SemanticKernel.Memory; | |
class ConversationalRetrievalPlugin | |
{ | |
private readonly IKernel _kernel; | |
private readonly ChatHistory _chatHistory; | |
private readonly string _memoryCollectionName; | |
private readonly bool _rephraseQuestion; | |
private readonly ISKFunction _condenseQuestionFunction; | |
private readonly ISKFunction _answerContextQuestionFunction; | |
private readonly ISKFunction _searchMemoryFunction; | |
private readonly ISKFunction _rephraseQuestionFunction; | |
//https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/chains/conversational_retrieval/prompts.py | |
//private const string instructions = "You are a helpful friendly assistant."; | |
private const string searchResultsKey = "context"; | |
private const string chatHistoryKey = "chat_history"; | |
private const string questionKey = "question"; | |
public ConversationalRetrievalPlugin(IKernel kernel, string memoryCollectionName, bool rephraseQuestion = true, string resourcesPath = "./Resources") | |
{ | |
this._kernel = kernel; | |
this._chatHistory = new ChatHistory(); | |
this._memoryCollectionName = memoryCollectionName; | |
this._rephraseQuestion = rephraseQuestion; | |
#region Define plugin semantic functions | |
var semanticFunctions = kernel.ImportSemanticSkillFromDirectory(resourcesPath, "ConversationalRetrievalPlugin"); | |
this._condenseQuestionFunction = semanticFunctions["CondenseQuestion"]; | |
this._answerContextQuestionFunction = semanticFunctions["AnswerContextQuestion"]; | |
var nativeFunctions = kernel.ImportSkill(this, nameof(ConversationalRetrievalPlugin)); | |
this._searchMemoryFunction = nativeFunctions[nameof(SearchMemory)]; | |
this._rephraseQuestionFunction = nativeFunctions[nameof(RephraseQuestion)]; | |
#endregion | |
} | |
[SKFunction, Description("QnA function for having a conversation based on retrieved documents.")] | |
[SKParameter("input", "User question")] | |
public async Task<SKContext> QnA(SKContext context, ILogger? logger, CancellationToken cancellationToken = default) | |
{ | |
//https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/chains/conversational_retrieval/base.py | |
string question = context.Variables.Input; | |
context.Variables.Set(questionKey, question); | |
this._chatHistory.AddUserMessage(question); | |
context.Variables.Set(chatHistoryKey, GetChatHistoryText(this._chatHistory)); | |
var response = await this._kernel.RunAsync( | |
variables: context.Variables, | |
cancellationToken: cancellationToken, | |
this._condenseQuestionFunction, | |
this._searchMemoryFunction, | |
this._rephraseQuestionFunction, | |
this._answerContextQuestionFunction); | |
this._chatHistory.AddAssistantMessage(response.Result); | |
return response; | |
} | |
[SKFunction, Description("Use original user question or rephrased question.")] | |
[SKParameter("input", "Rephrased user question")] | |
[SKParameter(questionKey, "Original user question")] | |
public SKContext RephraseQuestion(SKContext context, ILogger? logger) | |
{ | |
if (!this._rephraseQuestion) | |
{ | |
string oriqinalQuestion = context.Variables[questionKey]; | |
context.Variables.Update(oriqinalQuestion); | |
} | |
return context; | |
} | |
[SKFunction, Description("Search memory documents for fragments relevant to the user question.")] | |
[SKParameter("input", "User question")] | |
public async Task<SKContext> SearchMemory(SKContext context, ILogger? logger, CancellationToken cancellationToken = default) | |
{ | |
//https://github.com/microsoft/semantic-kernel/blob/main/dotnet/src/Skills/Skills.Core/TextMemorySkill.cs | |
//https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/KernelSyntaxExamples/Example15_MemorySkill.cs#L63 | |
var searchResults = await this._kernel.Memory.SearchAsync( | |
collection: this._memoryCollectionName, | |
query: context.Variables.Input, | |
limit: 5, | |
minRelevanceScore: 0.5, | |
cancellationToken: cancellationToken) | |
.ToListAsync(cancellationToken: cancellationToken); | |
context.Variables.Set(searchResultsKey, GetSearchResultsText(searchResults)); | |
return context; | |
} | |
private static string GetChatHistoryText(ChatHistory chatHistory) | |
{ | |
var chatHistorySB = new StringBuilder(); | |
foreach (var chatMessage in chatHistory.Messages) | |
if (chatMessage.Role != AuthorRole.System) | |
chatHistorySB.AppendLine($"{chatMessage.Role}: {chatMessage.Content}"); | |
return chatHistorySB.ToString(); | |
} | |
private static string GetSearchResultsText(IEnumerable<MemoryQueryResult> searchResults) | |
{ | |
var searchResultsSB = new StringBuilder(); | |
foreach (var searchResult in searchResults) | |
{ | |
searchResultsSB.AppendLine(searchResult.Metadata.Id); | |
searchResultsSB.AppendLine(searchResult.Metadata.Text); | |
searchResultsSB.AppendLine(""); | |
} | |
return searchResultsSB.ToString(); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment