Skip to content

Instantly share code, notes, and snippets.

@mcmonkey4eva
Last active June 22, 2023 12:04
Show Gist options
  • Save mcmonkey4eva/8b0e2a9ebf04f41b16011e0ebbd5fc9c to your computer and use it in GitHub Desktop.
Save mcmonkey4eva/8b0e2a9ebf04f41b16011e0ebbd5fc9c to your computer and use it in GitHub Desktop.
TestDiscordAIBot

This is a bare minimum 10-minute slap together. It doesn't work well and I don't intend to maintain it, it's just posted for if you want to look at it or reference it for a slap-together of your own.

I absolutely did not do things properly when making this.

using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using System.Net;
using System.Net.Http.Json;
using System.Text;
using System.Web;
using Discord;
using Discord.WebSocket;
using FreneticUtilities.FreneticToolkit;
using Microsoft.Extensions.DependencyInjection;
using System.Xml.Linq;
namespace TestDiscordAIBot;
public class LLMParams
{
public int max_new_tokens = 500;
public bool do_sample = true;
public float temperature = 0.7f;
public float top_p = 0.1f;
public float typical_p = 1;
public float repetition_penalty = 1.18f;
public float encoder_repetition_penalty = 1.0f;
public int top_k = 40;
public int min_length = 0;
public int no_repeat_ngram_size = 0;
public int num_beams = 1;
public float penalty_alpha = 0;
public int length_penalty = 1;
public bool early_stopping = false;
public int seed = -1;
public bool add_bos_token = false;
public bool skip_special_tokens = true;
public string[] stopping_strings = Array.Empty<string>();
}
public static class TextGenAPI
{
public static HttpClient Client = new();
public static string URLBase = "YOUR TEXT GEN WEBUI HERE"; // !!!!!!!!!!!!!! FILL ME IN
public static UTF8Encoding Encoding = new(false);
static TextGenAPI()
{
Client.DefaultRequestHeaders.Add("user-agent", "TestDiscordAIBot/1.0");
}
public static string SendRequest(string prompt, LLMParams llmParam)
{
JObject jData = new()
{
["prompt"] = prompt,
["max_new_tokens"] = llmParam.max_new_tokens,
["do_sample"] = llmParam.do_sample,
["temperature"] = llmParam.temperature,
["top_p"] = llmParam.top_p,
["typical_p"] = llmParam.typical_p,
["repetition_penalty"] = llmParam.repetition_penalty,
["encoder_repetition_penalty"] = llmParam.encoder_repetition_penalty,
["top_k"] = llmParam.top_k,
["min_length"] = llmParam.min_length,
["no_repeat_ngram_size"] = llmParam.no_repeat_ngram_size,
["num_beams"] = llmParam.num_beams,
["penalty_alpha"] = llmParam.penalty_alpha,
["length_penalty"] = llmParam.length_penalty,
["early_stopping"] = llmParam.early_stopping,
["seed"] = llmParam.seed,
["add_bos_token"] = llmParam.add_bos_token,
["skip_special_tokens"] = llmParam.skip_special_tokens,
["stopping_strings"] = JToken.FromObject(llmParam.stopping_strings)
};
string serialized = JsonConvert.SerializeObject(jData);
Console.WriteLine($"will send: {serialized}");
HttpResponseMessage response = Client.PostAsync($"{URLBase}/api/v1/generate", new StringContent(serialized, Encoding, "application/json")).Result;
Console.WriteLine($"Response type: {(int)response.StatusCode} {response.StatusCode}, {response.Content}");
string responseText = response.Content.ReadAsStringAsync().Result;
Console.WriteLine($"Response text: {responseText}");
string result = JObject.Parse(responseText)["results"][0]["text"].ToString();
Console.WriteLine($"Result text: {result}");
return result;
}
}
public static class Program
{
public static string PrePrompt = "YOUR PRE PROMPT HERE"; // !!!!!!!!!!!!!! FILL ME IN
public static DiscordSocketClient Client;
public static AsciiMatcher AlphanumericMatcher = new(AsciiMatcher.BothCaseLetters + AsciiMatcher.Digits);
public record class CachedMessage(string Content, ulong RefId, ulong Author, string AuthorName);
public static Dictionary<ulong, CachedMessage> MessageCache = new();
public static CachedMessage GetMessageCached(ulong channel, ulong id)
{
if (MessageCache.TryGetValue(id, out CachedMessage res))
{
return res;
}
IMessage message = (Client.GetChannel(channel) as SocketTextChannel).GetMessageAsync(id).Result;
Console.WriteLine($"Must fill cache on message {message.Id}");
if (message is null)
{
MessageCache[id] = null;
return null;
}
CachedMessage cache = new(message.Content, message.Reference?.MessageId.GetValueOrDefault(0) ?? 0, message.Author?.Id ?? 0, message.Author?.Username ?? "");
MessageCache[id] = cache;
return cache;
}
public static void Main()
{
Console.WriteLine("Starting...");
DiscordSocketConfig config = new()
{
MessageCacheSize = 50,
AlwaysDownloadUsers = true,
GatewayIntents = GatewayIntents.AllUnprivileged | GatewayIntents.MessageContent
};
Client = new DiscordSocketClient(config);
Client.Ready += () =>
{
Console.WriteLine("Bot ready.");
return Task.CompletedTask;
};
LLMParams llmParams = new() { stopping_strings = new[] { "\n###" } };
Client.MessageReceived += (message) =>
{
if (message.Content is null || message.Author.IsBot || message.Author.IsWebhook || message is not IUserMessage userMessage || message.Channel is not IGuildChannel guildChannel)
{
return Task.CompletedTask;
}
string prePrompt, user, helper;
string rawUser = AlphanumericMatcher.TrimToMatches(message.Author.Username);
user = rawUser;
if (user.Length < 3)
{
user = "User";
}
ulong guild = guildChannel.GuildId;
prePrompt = PrePrompt;
/* // !!!!!!!!!!!!!! FILL ME IN - OPTIONAL PROMPT SWAPPER PER GUILD
if (guild == 123ul)
{
prePrompt = PrePromptA;
user = "User";
helper = "Helper";
}
else if (guild == 456ul)
{
prePrompt = PrePromptB;
helper = "Llama";
}
else if (guild == 789ul)
{
prePrompt = PrePromptC;
}
else
{
Console.WriteLine("Bad guild");
return Task.CompletedTask;
}*/
user = "### Human";
helper = "### Assistant";
string prior = "";
bool isSelfRef = message.Content.Contains($"<@{Client.CurrentUser.Id}>") || message.Content.Contains($"<@!{Client.CurrentUser.Id}>");
if (message.Reference is not null && message.Reference.ChannelId == message.Channel.Id)
{
CachedMessage cache = GetMessageCached(message.Channel.Id, message.Reference.MessageId.Value);
if (cache is null)
{
return Task.CompletedTask;
}
CachedMessage refMessage = GetMessageCached(message.Channel.Id, message.Reference.MessageId.Value);
while (refMessage is not null)
{
isSelfRef = true;
if (refMessage.Author != Client.CurrentUser.Id || refMessage.RefId == 0)
{
return Task.CompletedTask;
}
CachedMessage ref2 = GetMessageCached(message.Channel.Id, refMessage.RefId);
if (ref2 is null)
{
return Task.CompletedTask;
}
string aname = AlphanumericMatcher.TrimToMatches(ref2.AuthorName);
if (aname.Length < 3)
{
aname = "User";
}
prior = $"{aname}: {ref2.Content}\n{helper}: {refMessage.Content}\n{prior}";
refMessage = ref2.RefId == 0 ? null : GetMessageCached(message.Channel.Id, ref2.RefId);
}
}
prePrompt = prePrompt.Replace("{{user}}", user).Replace("{{username}}", rawUser).Replace("{{helper}}", helper).Replace("{{date}}", DateTimeOffset.Now.ToString("yyyy-MM-dd HH:mm"));
if (!isSelfRef)
{
return Task.CompletedTask;
}
string input = message.Content.Replace($"<@{Client.CurrentUser.Id}>", "").Replace($"<@!{Client.CurrentUser.Id}>", "").Trim();
Console.WriteLine($"Got input: {prior} {input}");
if (input.StartsWith("[nopreprompt]"))
{
prePrompt = "";
input = input["[nopreprompt]".Length..].Trim();
}
else
{
input = input.Replace("\n", " ");
}
using (message.Channel.EnterTypingState())
{
string res = TextGenAPI.SendRequest($"{prePrompt}{prior}{user}: {input}\n{helper}:", llmParams);
int line = res.IndexOf("\n###");
if (line != -1)
{
res = res[..line];
}
Console.WriteLine($"\n\nUser: {input}\n{helper}:{res}\n\n");
res = res.Replace("\\", "\\\\").Replace("<", "\\<").Replace(">", "\\>").Replace("@", "\\@ ")
.Replace("http://", "").Replace("https://", "").Trim();
if (string.IsNullOrWhiteSpace(res))
{
res = "[Error]";
}
(message as IUserMessage).ReplyAsync(res, allowedMentions: AllowedMentions.None).Wait();
}
return Task.CompletedTask;
};
Console.WriteLine("Logging in to Discord...");
Client.LoginAsync(TokenType.Bot, "YOUR TOKEN HERE").Wait(); // !!!!!!!!!!!!!! FILL ME IN
Console.WriteLine("Connecting to Discord...");
Client.StartAsync().Wait();
Console.WriteLine("Running Discord!");
while (true)
{
string input = Console.ReadLine();
if (input is null)
{
return;
}
input = input.Replace("\n", " ");
string fullPrompt = $"User: {input}\nHelper: ";
string res = TextGenAPI.SendRequest(fullPrompt, llmParams);
Console.WriteLine($"AI says back: {res}");
int line = res.IndexOf('\n');
if (line != -1)
{
res = res[..line];
}
Console.WriteLine($"\n\nUser: {input}\nHelper: {res}\n\n");
}
}
}
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>net7.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Discord.Net" Version="3.10.0" />
<PackageReference Include="FreneticLLC.FreneticUtilities" Version="1.0.4" />
<PackageReference Include="Newtonsoft.Json" Version="13.0.3" />
</ItemGroup>
</Project>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment