Skip to content

Instantly share code, notes, and snippets.

@iwalton3
Created May 31, 2023 02:32
Show Gist options
  • Save iwalton3/55a0dff6a53ccc0fa832d6df23c1cded to your computer and use it in GitHub Desktop.
Save iwalton3/55a0dff6a53ccc0fa832d6df23c1cded to your computer and use it in GitHub Desktop.
Discord Exllama Chatbot
#!/usr/bin/env python3
from model import ExLlama, ExLlamaCache, ExLlamaConfig
from tokenizer import ExLlamaTokenizer
from generator import ExLlamaGenerator
import argparse
import torch
from timeit import default_timer as timer
torch.set_grad_enabled(False)
torch.cuda._lazy_init()
import asyncio
import traceback
import discord
from discord import app_commands
import re
from threading import RLock
parser = argparse.ArgumentParser(description = "Simple chatbot example for ExLlama")
parser.add_argument("-t", "--tokenizer", type = str, help = "Tokenizer model path", required = True)
parser.add_argument("-c", "--config", type = str, help = "Model config path (config.json)", required = True)
parser.add_argument("-m", "--model", type = str, help = "Model weights path (.pt or .safetensors file)", required = True)
parser.add_argument("-a", "--attention", type = ExLlamaConfig.AttentionMethod.argparse, choices = list(ExLlamaConfig.AttentionMethod), help="Attention method", default = ExLlamaConfig.AttentionMethod.SWITCHED)
parser.add_argument("-mm", "--matmul", type = ExLlamaConfig.MatmulMethod.argparse, choices = list(ExLlamaConfig.MatmulMethod), help="Matmul method", default = ExLlamaConfig.MatmulMethod.SWITCHED)
parser.add_argument("-mlp", "--mlp", type = ExLlamaConfig.MLPMethod.argparse, choices = list(ExLlamaConfig.MLPMethod), help="Matmul method", default = ExLlamaConfig.MLPMethod.SWITCHED)
parser.add_argument("-s", "--stream", type = int, help = "Stream layer interval", default = 0)
parser.add_argument("-gs", "--gpu_split", type = str, help = "Comma-separated list of VRAM (in GB) to use per GPU device for model layers, e.g. -gs 20,7,7")
parser.add_argument("-dq", "--dequant", type = str, help = "Number of layers (per GPU) to de-quantize at load time")
parser.add_argument("-l", "--length", type = int, help = "Maximum sequence length", default = 2048)
parser.add_argument("-l-out", "--length-out", type = int, help = "Maximum output", default = 768)
parser.add_argument("-l-grace", "--length-grace", type = int, help = "Space to leave in token pool", default = 768)
parser.add_argument("-temp", "--temperature", type = float, help = "Temperature", default = 0.72)
parser.add_argument("-topk", "--top_k", type = int, help = "Top-K", default = 500)
parser.add_argument("-topp", "--top_p", type = float, help = "Top-P", default = 0.65)
parser.add_argument("-minp", "--min_p", type = float, help = "Min-P", default = 0.00)
parser.add_argument("-repp", "--repetition_penalty", type = float, help = "Repetition penalty", default = 1.1)
parser.add_argument("-repps", "--repetition_penalty_sustain", type = int, help = "Past length for repetition penalty", default = 256)
parser.add_argument("-beams", "--beams", type = int, help = "Number of beams for beam search", default = 1)
parser.add_argument("-beamlen", "--beam_length", type = int, help = "Number of future tokens to consider", default = 1)
args = parser.parse_args()
# It seems the shallow LoRA needs more prompting
system_prompt = (
'(The system prompt for your chatbot goes here. It is inserted before messages in a conversation.)'
)
use_system_prompt = True
last_bot_message = None
messages = []
# hex encoded since I am tired of seeing these every time I open the file
# put any bad words you don't want in bot output here
slurs = [
'\x66\x61\x67',
'\x66\x61\x67\x67\x6f\x74',
'\x74\x72\x61\x6e\x6e\x79',
'\x6e\x69\x67\x67\x65\x72',
'\x6e\x69\x67\x67\x61',
'\x72\x65\x74\x61\x72\x64'
]
user_regex = re.compile(r'<@!?(\d+)>')
emote_regex = re.compile(r'<a?:([a-zA-Z0-9_]+):\d+>')
emote_replace_regex = re.compile(r':([a-zA-Z0-9_]+):')
word_regex = re.compile(r'([a-zA-Z]+)')
promptLock = RLock()
# Instantiate model and generator
config = ExLlamaConfig(args.config)
config.model_path = args.model
config.attention_method = args.attention
config.matmul_method = args.matmul
config.stream_layer_interval = args.stream
config.mlp_method = args.mlp
if args.length is not None: config.max_seq_len = args.length
config.set_auto_map(args.gpu_split)
config.set_dequant(args.dequant)
model = ExLlama(config)
cache = ExLlamaCache(model)
tokenizer = ExLlamaTokenizer(args.tokenizer)
def get_generator():
generator = ExLlamaGenerator(model, tokenizer, cache)
generator.settings = ExLlamaGenerator.Settings()
generator.settings.temperature = args.temperature
generator.settings.top_k = args.top_k
generator.settings.top_p = args.top_p
generator.settings.min_p = args.min_p
generator.settings.token_repetition_penalty_max = args.repetition_penalty
generator.settings.token_repetition_penalty_sustain = args.repetition_penalty_sustain
generator.settings.token_repetition_penalty_decay = generator.settings.token_repetition_penalty_sustain // 2
generator.settings.beams = args.beams
generator.settings.beam_length = args.beam_length
return generator
def message_wordcount_ok(messages):
messages_word_count = 0
for message in messages:
# note this approach using the tokenizer isn't great...
# we may optimize it later
messages_word_count += tokenizer.encode(message[0]).shape[-1]
return messages_word_count <= args.length - args.length_grace
def ensure_wordcount_ok(messages):
while not message_wordcount_ok(messages):
messages.pop(0)
messages.pop(0)
class Bot:
def __init__(self, context=""):
self.messages = []
self.last_bot_message = None
self.context = context
self.lock = asyncio.Lock()
self.generator = get_generator()
self.gen_cache = ""
self.needs_regen = True
def sendPrompt(self, prompt: str) -> str:
with promptLock:
with torch.no_grad():
start = timer()
try:
if prompt.startswith(self.gen_cache) and not self.needs_regen:
self.generator.gen_feed_tokens(tokenizer.encode(prompt[len(self.gen_cache):]))
self.gen_cache = prompt
else:
self.gen_cache = prompt
self.generator.gen_begin(tokenizer.encode(prompt))
num_res_tokens = 0
res_line = ""
print(f'{self.context}bot: ', flush=True, end='')
self.generator.begin_beam_search()
for i in range(args.length_out):
gen_token = self.generator.beam_search()
if gen_token.item() == tokenizer.eos_token_id:
self.generator.replace_last_token(tokenizer.newline_token_id)
num_res_tokens += 1
text = tokenizer.decode(self.generator.sequence_actual[:, -num_res_tokens:][0])
new_text = text[len(res_line):]
res_line += new_text
print(new_text, end="", flush=True)
# <!end!> is the end string for the model between messages
if res_line.endswith("<!end!>"):
break
self.generator.end_beam_search()
self.gen_cache += res_line
end = timer()
print(f'\n[generated {num_res_tokens} tokens in {end - start:.2f} s at {num_res_tokens/(end - start):.2f} t/s]', flush=True)
return res_line.replace("<!end!>", "")
except:
# don't know state when it failed, so clear cache
self.needs_regen = True
raise
async def reset(self):
self.messages = []
async def send_next_message(self, channel: discord.TextChannel, interaction: discord.Interaction = None):
ensure_wordcount_ok(self.messages)
if interaction is not None:
await self._send_next_message(channel, interaction)
else:
async with channel.typing():
await self._send_next_message(channel)
async def _send_next_message(self, channel: discord.TextChannel, interaction: discord.Interaction = None):
prompt_messages = [[message, author] for message, author in self.messages]
last_message = None
for message in prompt_messages:
if last_message is not None and message[1] != last_message[1]:
last_message[0] += '<!end!>'
last_message = message
prompt = ''
if use_system_prompt:
while len(prompt_messages) > 0 and prompt_messages[0] == 'bot':
prompt_messages.pop(0)
prompt += system_prompt + '<!end!>\n'
prompt += '\n'.join(msg[0] for msg in prompt_messages)
tries = 0
while True:
try:
def task():
return self.sendPrompt(prompt + '<!end!>\n')
response = await asyncio.to_thread(task)
except Exception as e:
print("Generation failed!", flush=True)
traceback.print_exc()
response = None
if response is not None and not any(slur in response.lower() for slur in slurs):
break
tries += 1
if tries > 3:
if interaction is not None:
await interaction.followup.send('I\'m sorry, I\'m having trouble coming up with a response. Try saying something else!')
else:
await channel.send('I\'m sorry, I\'m having trouble coming up with a response. Try saying something else!')
return
self.messages.append((response, 'bot'))
if len(response) < 2000:
if interaction is not None:
await interaction.followup.send(response)
self.last_bot_message = await interaction.original_response()
else:
self.last_bot_message = await channel.send(response)
else:
acc_messages = []
acc_text = ""
for message in response.split('\n'):
if len(acc_text) + len(message) > 2000:
acc_messages.append(acc_text)
acc_text = ""
acc_text += message + '\n'
acc_messages.append(acc_text)
if interaction is not None:
await interaction.followup.send(acc_messages.pop(0))
self.last_bot_message = await interaction.original_response()
for message in acc_messages:
await asyncio.sleep(0.5)
self.last_bot_message = await channel.send(message)
async def no(self, channel: discord.TextChannel, interaction: discord.Interaction = None, should_defer=True):
self.messages = self.messages[:-1]
if self.last_bot_message:
await self.last_bot_message.edit(content='~~' + self.last_bot_message.content.replace('~~', '') + '~~')
self.last_bot_message = None
if should_defer:
await interaction.response.defer()
await self.send_next_message(channel, interaction)
async def get_response(self, message: discord.Message):
if message.author.bot:
self.last_bot_message = message
return
if message.content.startswith(';'):
return
if message.content.startswith('!reset'):
await message.channel.send('Context reset!')
await self.reset()
return
if message.content.startswith('!help'):
if message.channel.guild is not None:
await message.delete()
await message.channel.send('Chatbot Commands:\n\n !help - display this message\n !reset - forget current discussion\n !no - reject and send a new response')
return
message.content = message.content.replace('<!end!>', '')
async with self.lock:
def find_user(match):
user_id = int(match.group(1))
for user in message.mentions:
if user.id == user_id:
return f"@{user.name}"
return "@unknown"
if message.content == '!n' or message.content == '!no':
if message.channel.guild is not None:
await message.delete()
self.messages = self.messages[:-1]
if self.last_bot_message:
await self.last_bot_message.edit(content='~~' + self.last_bot_message.content.replace('~~', '') + '~~')
self.last_bot_message = None
else:
message_string = message.content
message_string = user_regex.sub(find_user, message_string)
message_string = emote_regex.sub(r':\1:', message_string)
currMsg = message_string
print(f'{self.context}{message.author.name}: {message_string}', flush=True)
self.messages.append((currMsg, message.author.name))
await self.send_next_message(message.channel)
server_bots = {}
dm_bots = {}
class MyClient(discord.Client):
async def on_ready(self):
await tree.sync()
self.idleTimer = None
print(f'Logged on as {self.user}!', flush=True)
async def on_message(self, message: discord.Message):
if message.guild is None:
if message.author.id not in dm_bots:
dm_bots[message.author.id] = Bot(f"({message.author.name}) ")
await dm_bots[message.author.id].get_response(message)
return
if message.channel.name == 'ai-friend':
if message.guild.id not in server_bots:
server_bots[message.guild.id] = Bot(f"({message.guild.name}) ")
await server_bots[message.guild.id].get_response(message)
intents = discord.Intents.default()
intents.message_content = True
intents.members = True
client = MyClient(intents=intents)
tree = app_commands.CommandTree(client)
@tree.command(name='reset', description='Forget the current discussion')
async def reset(interaction: discord.Interaction):
if interaction.channel.guild is None:
if interaction.user.id not in dm_bots:
dm_bots[interaction.user.id] = Bot(f"({interaction.user.name}) ")
await dm_bots[interaction.user.id].reset()
await interaction.response.send_message('Context reset!')
return
if interaction.channel.name == 'ai-friend':
if interaction.guild.id not in server_bots:
server_bots[interaction.guild.id] = Bot(f"({interaction.guild.name}) ")
await server_bots[interaction.guild.id].reset()
await interaction.response.send_message('Context reset!')
@tree.command(name='no', description='Reject the last response and send a new one')
async def no(interaction: discord.Interaction):
should_defer = True
if interaction.channel.guild is None:
if interaction.user.id not in dm_bots:
dm_bots[interaction.user.id] = Bot(f"({interaction.user.name}) ")
await dm_bots[interaction.user.id].no(interaction.channel, interaction, should_defer)
return
if interaction.channel.name == 'ai-friend':
if interaction.guild.id not in server_bots:
server_bots[interaction.guild.id] = Bot(f"({interaction.guild.name}) ")
await server_bots[interaction.guild.id].no(interaction.channel, interaction, should_defer)
client.run('(Discord token goes here!)')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment