Skip to content

Instantly share code, notes, and snippets.

@NotSoSuper
Created December 1, 2016 20:22
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save NotSoSuper/59f0bf08cbc2b511df27805d4500743e to your computer and use it in GitHub Desktop.
Save NotSoSuper/59f0bf08cbc2b511df27805d4500743e to your computer and use it in GitHub Desktop.
some utility functions for discord bots
import asyncio
import discord
import aiohttp
import re
import aiohttp
import random
import sys
import os
import aiosocks
import time
import traceback
import inspect
import io
import math
import async_timeout
from pymysql.converters import escape_item, escape_string, encoders
from contextlib import redirect_stdout
import linecache
from discord.ext import commands
from discord.ext.commands.errors import CommandNotFound, CommandError
from discord.ext.commands.context import Context
from io import BytesIO
class DataProtocol(asyncio.SubprocessProtocol):
def __init__(self, exit_future):
self.exit_future = exit_future
self.output = bytearray()
def pipe_data_received(self, fd, data):
self.output.extend(data)
def process_exited(self):
self.exit_future.set_result(True)
def pipe_connection_lost(self, fd, exc):
self.exit_future.set_result(True)
def connection_lost(self, exc):
self.exit_future.set_result(True)
class Funcs():
def __init__(self, bot, cursor):
self.bot = bot
self.cursor = cursor
self.bot.google_api_keys = open(self.discord_path('utils/keys.txt')).read().split('\n')
self.bot.google_count = 0
self.image_mimes = ['image/png', 'image/pjpeg', 'image/jpeg', 'image/x-icon']
def discord_path(self, path):
return os.path.join(os.path.dirname(os.path.realpath(sys.argv[0])), path)
def files_path(self, path):
return self.discord_path('files/'+path)
async def prefix_check(self, s, prefix, prefix_set):
if prefix_set:
return True
count = 0
for x in s:
if count == 2:
break
elif count == 1:
if x != prefix:
break
if x == prefix:
count += 1
if count == 1:
return True
return False
async def get_prefix(self, message):
if self.bot.dev_mode:
prefix = ','
else:
prefix = '.'
prefix_set = False
if message.channel.is_private is False and message.content.startswith(prefix+"prefix") is False:
sql = "SELECT prefix FROM `prefix` WHERE server={0}"
sql = sql.format(message.server.id)
sql_channel = "SELECT prefix,channel FROM `prefix_channel` WHERE server={0} AND channel={1}"
sql_channel = sql_channel.format(message.server.id, message.channel.id)
result = self.cursor.execute(sql_channel).fetchall()
if result:
for s in result:
if s['channel'] == message.channel.id:
prefix = s['prefix']
prefix_set = True
break
elif not prefix_set:
result = self.cursor.execute(sql).fetchall()
if len(result) != 0:
prefix = result[0]['prefix']
prefix_set = True
if prefix_set:
prefix = prefix.lower()
mention = commands.bot.when_mentioned(self.bot, message)
if message.content.startswith(mention):
check = True
else:
check = await self.prefix_check(message.content, prefix, prefix_set)
return [prefix, mention], check
async def is_blacklisted(self, message):
try:
perms = message.channel.permissions_for(message.server.me)
if perms.send_messages is False or perms.read_messages is False:
return True
except:
pass
if message.author.id == self.bot.owner.id:
return False
global_blacklist_result = self.cursor.execute('SELECT * FROM `global_blacklist` WHERE user={0}'.format(message.author.id)).fetchall()
if message.channel.is_private:
if len(global_blacklist_result) != 0 and message.author.id != bot.owner.id:
return True
return False
muted_check_result = self.cursor.execute('SELECT * FROM `muted` WHERE server={0} AND id={1}'.format(message.server.id, message.author.id)).fetchall()
if len(muted_check_result) != 0 and message.server.owner != message.author:
return True
server_blacklist_result = self.cursor.execute('SELECT * FROM `blacklist` WHERE server={0} AND user={1}'.format(message.server.id, message.author.id)).fetchall()
channel_blacklist_result = self.cursor.execute('SELECT * FROM `channel_blacklist` WHERE server={0} AND channel={1}'.format(message.server.id, message.channel.id)).fetchall()
if len(global_blacklist_result) != 0:
return True
elif len(server_blacklist_result) != 0:
return True
elif len(channel_blacklist_result) != 0:
if 'blacklist' in message.content:
return False
return True
return False
async def command_check(self, message, command, prefix):
if message.author.id == self.bot.owner.id:
return False
sql = 'SELECT * FROM `command_blacklist` WHERE type="global" AND command={0}'
sql = sql.format(self.escape(command))
result = self.cursor.execute(sql).fetchall()
if len(result) != 0:
return True
if message.channel.is_private:
return False
sql = 'SELECT * FROM `command_blacklist` WHERE server={0}'
sql = sql.format(message.server.id)
result = self.cursor.execute(sql).fetchall()
if message.channel.topic != None:
command_escape = re.escape(command)
topic_regex = re.compile(r"((\[|\{)"+command_escape+"(\]|\}))", re.I|re.S)
topic_match = True if topic_regex.findall(message.channel.topic.lower()) else False
else:
topic_match = False
is_admin = False
try:
perms = message.channel.permissions_for(message.author)
if perms.administrator or perms.manage_server or perms.manage_roles:
is_admin = True
except:
pass
for s in result:
if s['command'] != command:
continue
if s['type'] == 'server':
if topic_match:
return False
else:
await self.bot.send_message(message.channel, ':no_entry: **That command is disabled on this server**{0}'.format("\n`{0}command enable {1}` to enable the command.\n**Alternatively** place `[{1}]` in the channel topic or name.".format(prefix, command) if is_admin else ''))
return True
elif s['type'] == 'channel':
if str(s['channel']) == str(message.channel.id):
await self.bot.send_message(message.channel, ':no_entry: **That command is disabled in this channel**{0}'.format("\n`{0}command enable channel {1}` {2} to enable the command.".format(prefix, command, message.channel.mention) if is_admin else ''))
return True
elif s['type'] == 'role':
for role in message.author.roles:
if str(role.id) == str(s['role']):
await self.bot.send_message(message.channel, ':no_entry: **That command is disabled for role: {1}**{0}'.format("\n`{0}command enable channel {1}` {2} to enable the command.".format(prefix, command, role.mention) if is_admin else ''), role.mention)
return True
elif s['type'] == 'user':
if str(s['user']) == str(message.author.id):
return True
return False
async def process_commands(self, message, command, prefix):
_internal_channel = message.channel
_internal_author = message.author
view = commands.view.StringView(message.content)
view.skip_string(prefix)
invoker = view.get_word()
tmp = {
'bot': self.bot,
'invoked_with': invoker,
'message': message,
'view': view,
'prefix': prefix
}
ctx = Context(**tmp)
del tmp
try:
command = self.bot.commands[command]
except:
raise Exception('wot')
self.bot.dispatch('command', command, ctx)
try:
with async_timeout.timeout(60):
await command.invoke(ctx)
except CommandError as e:
ctx.command.dispatch_error(e, ctx)
except asyncio.TimeoutError:
await self.bot.send_message(message.channel, ':warning: **Command timed out, don\'t be an asshole.**')
return
else:
self.bot.dispatch('command_completion', command, ctx)
async def queue_message(self, channel_id:str, msg:str):
message_id = random.randint(0, 1000000)
payload = {'key':'keeee', 'id': message_id, 'channel_id': channel_id, 'message': msg}
try:
with aiohttp.ClientSession() as session:
with aiohttp.Timeout(15):
async with session.post('http://no:2221/queue', data=payload) as r:
pass
except (asyncio.TimeoutError, aiohttp.errors.ClientConnectionError, aiohttp.errors.ClientError):
await asyncio.sleep(5)
return
async def isimage(self, url:str):
try:
with aiohttp.ClientSession() as session:
with aiohttp.Timeout(5):
async with session.get(url) as resp:
if resp.status == 200:
mime = resp.headers.get('Content-type', '').lower()
if any([mime == x for x in self.image_mimes]):
return True
else:
return False
except:
return False
async def isgif(self, url:str):
try:
with aiohttp.ClientSession() as session:
with aiohttp.Timeout(5):
async with session.get(url) as resp:
if resp.status == 200:
mime = resp.headers.get('Content-type', '').lower()
if mime == "image/gif":
return True
else:
return False
except:
return False
async def download(self, url:str, path:str):
try:
with aiohttp.ClientSession() as session:
with aiohttp.Timeout(5):
async with session.get(url) as resp:
data = await resp.read()
with open(path, "wb") as f:
f.write(data)
except asyncio.TimeoutError:
return False
async def bytes_download(self, url:str):
try:
with aiohttp.ClientSession() as session:
with aiohttp.Timeout(5):
async with session.get(url) as resp:
data = await resp.read()
b = BytesIO(data)
b.seek(0)
return b
except asyncio.TimeoutError:
return False
async def get_json(self, url:str):
try:
with aiohttp.ClientSession() as session:
with aiohttp.Timeout(5):
async with session.get(url) as resp:
try:
load = await resp.json()
return load
except:
return {}
except asyncio.TimeoutError:
return {}
async def run_process(self, code, response=False):
try:
loop = self.bot.loop
exit_future = asyncio.Future(loop=loop)
create = loop.subprocess_exec(lambda: DataProtocol(exit_future),
*code, stdin=None, stderr=None)
transport, protocol = await asyncio.wait_for(create, timeout=30)
await exit_future
transport.close()
if response:
data = bytes(protocol.output)
return data.decode('ascii').rstrip()
return True
except asyncio.TimeoutError:
return False
async def proxy_request(self, url, **kwargs):
post = kwargs.get('post')
post = True if post != {} else False
post_data = kwargs.get('post_data')
headers = kwargs.get('headers')
j = kwargs.get('j')
j = True if j != {} else False
proxy_addr = aiosocks.Socks5Addr('proxy-nl.privateinternetaccess.com', 1080)
proxy_auth = aiosocks.Socks5Auth('', password='')
proxy_connection = aiosocks.connector.SocksConnector(proxy=proxy_addr, proxy_auth=proxy_auth, remote_resolve=True)
with aiohttp.ClientSession(connector=proxy_connection) as session:
async with session.post(url, data=post_data if post else None, headers=headers) as resp:
if j:
return await resp.json()
else:
return await resp.text()
async def truncate(self, channel, msg):
if len(msg) == 0:
return
split = [msg[i:i + 1999] for i in range(0, len(msg), 1999)]
try:
for s in split:
await self.bot.send_message(channel, s)
await asyncio.sleep(0.21)
except Exception as e:
await self.bot.send_message(channel, e)
async def get_images(self, ctx, **kwargs):
try:
message = ctx.message
channel = ctx.message.channel
attachments = ctx.message.attachments
mentions = ctx.message.mentions
limit = kwargs.pop('limit', None)
urls = kwargs.pop('urls', [])
gif = kwargs.pop('gif', False)
if gif:
check_func = self.isgif
else:
check_func = self.isimage
if urls is None:
urls = []
elif type(urls) != tuple:
urls = [urls]
else:
urls = list(urls)
scale = kwargs.pop('scale', None)
scale_msg = None
int_scale = None
if gif is False:
for mention in mentions:
urls.append(mention.avatar_url)
if limit:
limit += 1
for attachment in attachments:
urls.append(attachment['url'])
if scale:
scale_limit = scale
if limit:
limit += 1
if limit and urls and len(urls) > limit:
await self.bot.send_message(channel, ':no_entry: `Max image limit (<= {0})`'.format(limit))
ctx.command.reset_cooldown(ctx)
return False
img_urls = []
count = 1
for url in urls:
if url.startswith('<@'):
continue
try:
if scale:
if str(math.floor(float(url))).isdigit():
int_scale = int(math.floor(float(url)))
scale_msg = '`Scale: {0}`\n'.format(int_scale)
if int_scale > scale_limit and ctx.message.author.id != self.bot.owner.id:
int_scale = scale_limit
scale_msg = '`Scale: {0} (Limit: <= {1})`\n'.format(int_scale, scale_limit)
continue
except:
pass
check = await check_func(url)
if check is False and gif is False:
check = await self.isgif(url)
if check:
await self.bot.send_message(channel, ":warning: This command is for images, not gifs (use `gmagik` or `gascii`)!")
ctx.command.reset_cooldown(ctx)
return False
elif len(img_urls) == 0:
await self.bot.send_message(channel, 'Invalid or Non-Image(s)!')
ctx.command.reset_cooldown(ctx)
return False
else:
await self.bot.send_message(channel, ':warning: Image `{0}` is Invalid!'.format(count))
continue
elif gif and check is False:
check = await self.isimage(url)
if check:
await self.bot.send_message(channel, ":warning: This command is for gifs, not images (use `magik`)!")
ctx.command.reset_cooldown(ctx)
return False
elif len(img_urls) == 0:
await self.bot.send_message(channel, 'Invalid or Non-Gifs(s)!')
ctx.command.reset_cooldown(ctx)
return False
else:
await self.bot.send_message(channel, ':warning: Gif `{0}` is Invalid!'.format(count))
continue
img_urls.append(url)
count += 1
else:
if len(img_urls) == 0:
last_attachment = None
async for m in self.bot.logs_from(channel, before=message, limit=25):
check = False
if m.attachments:
last_attachment = m.attachments[0]['url']
check = await check_func(last_attachment)
elif m.embeds:
last_attachment = m.embeds[0]['url']
check = await check_func(last_attachment)
if check:
img_urls.append(last_attachment)
break
else:
continue
if len(img_urls) == 0:
await self.bot.send_message(channel, ":no_entry: Please input url(s){0}or attachment(s).".format(', mention(s) ' if not gif else ' '))
ctx.command.reset_cooldown(ctx)
return False
if scale:
return img_urls, int_scale, scale_msg
return img_urls
except Exception as e:
print(e)
async def google_keys(self):
keys = self.bot.google_api_keys
if self.bot.google_count >= len(keys):
self.bot.google_count = 0
key = keys[self.bot.google_count]
self.bot.google_count += 1
return str(key)
def write_last_time(self):
path = self.files_path('last_time_{0}.txt'.format(self.bot.shard_id))
utc = str(int(time.time()))
with open(path, 'wb') as f:
f.write(utc.encode())
f.close()
def get_last_time(self):
path = self.files_path('last_time_{0}.txt'.format(self.bot.shard_id))
try:
return int(open(path, 'r').read())
except:
return False
def restart_program(self):
python = sys.executable
os.execl(python, python, * sys.argv)
async def cleanup_code(self, content):
"""Automatically removes code blocks from the code."""
if content.startswith('```') and content.endswith('```'):
clean = '\n'.join(content.split('\n')[1:-1])
else:
clean = content.strip('` \n')
if clean.startswith('http'):
with aiohttp.ClientSession() as session:
async with session.get(clean) as r:
code = await r.text()
clean = code
return clean
def get_syntax_error(self, e):
return '```py\n{0.text}{1:>{0.offset}}\n{2}: {0}```'.format(e, '^', type(e).__name__)
async def repl(self, ctx, code):
msg = ctx.message
variables = {
'ctx': ctx,
'bot': self.bot,
'message': msg,
'server': msg.server,
'channel': msg.channel,
'author': msg.author,
'last': None,
'commands': commands,
'discord': discord,
'asyncio': asyncio,
'cursor': self.cursor
}
cleaned = await self.cleanup_code(code)
if cleaned in ('quit', 'exit', 'exit()'):
await self.bot.say('Exiting.')
return 'exit'
executor = exec
if cleaned.count('\n') == 0:
try:
code = compile(cleaned, '<repl session>', 'eval')
except SyntaxError:
pass
else:
executor = eval
if executor is exec:
try:
code = compile(cleaned, '<repl session>', 'exec')
except SyntaxError as e:
await self.bot.say(self.get_syntax_error(e))
return False
fmt = None
stdout = io.StringIO()
try:
with redirect_stdout(stdout):
result = executor(code, variables)
if inspect.isawaitable(result):
result = await result
except Exception as e:
value = stdout.getvalue()
fmt = '```py\n{}{}\n```'.format(value, traceback.format_exc())
else:
value = stdout.getvalue()
if result is not None:
fmt = '```py\n{}{}\n```'.format(value, result)
variables['last'] = result
elif value:
fmt = '```py\n{}\n```'.format(value)
return fmt
async def command_help(self, ctx):
if ctx.invoked_subcommand:
cmd = ctx.invoked_subcommand
else:
cmd = ctx.command
pages = self.bot.formatter.format_help_for(ctx, cmd)
for page in pages:
await self.bot.send_message(ctx.message.channel, page.replace("\n", "fix\n", 1))
def escape(self, obj, mapping=encoders):
if isinstance(obj, str):
return "'" + escape_string(obj) + "'"
return escape_item(obj, 'utf8mb4', mapping=mapping)
# async def is_above(ctx, user):
# u1 = ctx.author
# if u1 == bot.owner:
# return True
# u2 = user
# server = ctx.message.server
# channel = ctx.message.channel
# if server.owner == u1:
# return True
# elif server.owner == u2:
# return '`User is the server owner.`'
# if channel.permissions_for(u1).administrator and channel.permissions_for(u2).administrator and u1.top_role.position > u2.top_role.position:
# return True
# elif u1.top_role == u2.top_role:
# return '`Same role.`'
# if u1.top_role.position > u2.top_role.position:
# return True
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment