Skip to content

Instantly share code, notes, and snippets.

@vikramsoni2
Created December 13, 2023 21:42
Show Gist options
  • Save vikramsoni2/94afa3bae03d35952ec9356721c62fe9 to your computer and use it in GitHub Desktop.
Save vikramsoni2/94afa3bae03d35952ec9356721c62fe9 to your computer and use it in GitHub Desktop.
Memory module for LLMs to maintain conversation history during inference
import numpy as np
from pydantic import BaseModel
from utils.config import Config
from utils.models import ChatMessage
from tokenizers import Tokenizer
from typing import Dict, List, Optional, Union, Literal
DEFAULT_TOKEN_LIMIT_RATIO = 0.75
DEFAULT_TOKENIZER = Tokenizer.from_pretrained(Config.LLAMA_MODEL)
class Message(BaseModel):
role: str
content: str
class Memory:
"""
A class to manage conversation context for language models. It handles storage of messages,
token counting, and prepares prompts for different language model providers.
Attributes:
context_window (int): The maximum number of tokens to consider in the conversation history.
tokenizer (Tokenizer): A tokenizer for encoding messages. Defaults to DEFAULT_TOKENIZER.
token_limit (int): The maximum number of tokens allowed. Defaults to 75% of context_window.
system_prompt (Message): An optional system-level prompt included at the start of the conversation.
"""
def __init__(self,
context_window: int,
tokenizer: Optional[Tokenizer] = None,
token_limit: Optional[int] = None,
system_prompt: Optional[str] = None):
"""
Initializes the Memory class with a context window, tokenizer, token limit, and system prompt.
Args:
context_window (int): The maximum number of tokens to consider in the conversation history.
tokenizer (Optional[Tokenizer]): A tokenizer for encoding messages. Defaults to DEFAULT_TOKENIZER.
token_limit (Optional[int]): The maximum number of tokens allowed. Defaults to 75% of context_window.
system_prompt (Optional[str]): An optional system-level prompt to include at the start of the conversation.
"""
self.context_window = context_window
self.tokenizer = tokenizer or DEFAULT_TOKENIZER
self.token_limit = token_limit or int(context_window * DEFAULT_TOKEN_LIMIT_RATIO)
self.system_prompt = Message(role='system', content=system_prompt) if system_prompt else None
self.system_prompt_size = len(self.tokenizer.encode(system_prompt)) if system_prompt else 0
self.conversation = [] # a list of Messages
self.token_counts = [] # a list of token count of Messages in conversation
def set_messages(self, messages: List[Union[Message, ChatMessage]]):
"""
Adds messages to the conversation memory and calculates the token count for each message.
Args:
messages (List[Union[Message, ChatMessage]]): A list of Message or ChatMessage objects to add to the memory.
"""
for message in messages:
# Check if message is of type ChatMessage, then convert it to Message
if isinstance(message, ChatMessage):
role = 'user' if message.type == 'user' else 'assistant'
message = Message(role=role, content=message.text)
self.conversation.append(message)
self.token_counts.append(len(self.tokenizer.encode(message.content)))
def get_messages(self, n:Optional[int]=None) -> list[Message]:
"""
Retrieves a specified number of recent messages from memory.
Args:
n (Optional[int]): The number of messages to retrieve. Defaults to all messages.
Returns:
list[Message]: A list of the most recent messages.
"""
if n is None:
n = len(self.conversation)
return [self.system_prompt] + self.conversation[-n:]
def get_context_size(self, n_messages: Optional[int] = None) -> int:
"""
Calculates the total size of a specified number of recent messages in terms of tokens.
Args:
n_messages (Optional[int]): The number of recent messages to consider. Defaults to all messages.
Returns:
int: Total token count of the specified messages.
"""
if n_messages is None:
n_messages = len(self.token_counts)
return sum( [self.system_prompt_size] + self.token_counts[-n_messages:] )
def _get_llama_prompt(self, n_messages:int) -> str:
"""
Internal function to format the prompt for the LLaMA model using the last n messages.
Args:
n_messages (int): Number of messages to include in the prompt.
Returns:
str: A string formatted as a prompt for LLaMA.
"""
history_dict = ""
sys_prompt = f"<<sys>>{self.system_prompt.content}<</sys>>" if self.system_prompt else ""
for i, message in enumerate(self.conversation[-n_messages:]):
if message.role == 'user':
prefix = f"<s>[INST]{sys_prompt}" if i == 0 else "<s>[INST]"
history_dict += f"{prefix}{message.content}[/INST]"
else:
history_dict += f"{message.content}</s>"
return history_dict
def get_prompt(self, query: str, provider: Literal["openai", "llama"] = "openai") -> Union[List[Dict[str, str]], str]:
"""
Generates a prompt for a given query and provider.
Args:
query (str): The user's query to add to the conversation.
provider (Literal["openai", "llama"]): The provider for which the prompt is being generated. Defaults to "openai".
Returns:
Union[List[Dict[str, str]], str]: A prompt formatted for the specified provider.
"""
new_message = Message(role='user', content=query)
self.set_messages([new_message])
# find count of last N items from array whose sum is less than token limit
array_np = np.array(self.token_counts)
cumsum_reversed = np.cumsum(array_np[::-1])
message_count = np.sum(cumsum_reversed < self.token_limit)
# if the first item in the subset if assistant response, remove it too
if self.conversation[-message_count].role == 'assistant':
message_count -= 1
if provider=='llama':
history_dict = self._get_llama_prompt(message_count)
else:
history_dict = [msg.model_dump() for msg in self.get_messages(message_count)]
return history_dict
def reset(self):
"""
Resets the conversation memory.
Returns:
Memory: The Memory instance itself after resetting.
"""
self.conversation = []
self.token_counts = []
return self
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment