Skip to content

Instantly share code, notes, and snippets.

@prakhar21
Created March 5, 2018 10:51
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save prakhar21/cc275b6f321e85bfcdc0f68af2fab895 to your computer and use it in GitHub Desktop.
Save prakhar21/cc275b6f321e85bfcdc0f68af2fab895 to your computer and use it in GitHub Desktop.
import os
import time
import random
import re
from slackclient import SlackClient
from keras.applications.vgg16 import VGG16
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array
from keras.applications.vgg16 import preprocess_input
from keras.applications.vgg16 import decode_predictions
from PIL import Image
import requests
starterbot_id = None
# instantiate Slack client
slack_client = SlackClient(os.environ.get('SLACK_BOT_TOKEN'))
RTM_READ_DELAY = 1
EXAMPLE_COMMAND = "do"
MENTION_REGEX = "^<@(|[WU].+?)>(.*)"
def load_model():
"""Loading model"""
model = VGG16()
return model
def load_image(image):
"""Loading image"""
im = load_img(image, target_size=(224, 224))
return im
def preprocess(img):
image = load_image(img)
image = img_to_array(image)
image = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))
image = preprocess_input(image)
return image
def parse_direct_mention(message_text):
matches = re.search(MENTION_REGEX, message_text)
# the first group contains the username, the second group contains the remaining message
return (matches.group(1), matches.group(2).strip()) if matches else (None, None)
def parse_bot_commands(slack_events):
"""
example {u'source_team': u'T90HNMX8B', u'text': u'how are you ?', u'ts': u'1520229941.000201', u'user': u'U90L44L0L', u'team': u'T90HNMX8B', u'type': u'message ', u'channel': u'D9KB0A5UP'}
"""
for event in slack_events:
if event["type"] == "message" and not "subtype" in event:
user_id, message = parse_direct_mention(event["text"])
if user_id == starterbot_id:
return message, event["channel"]
elif event['type'] == 'message' and 'file' in event:
uri = event['file']['url_private_download']
return uri, event['channel']
else:
pass
return None, None
def predict(image, model):
try:
fname = image.split("/")[-1]
u = "wget -d --header=\"Authorization: Bearer <your token>\" " + str(image)
os.system(u)
image = os.path.abspath(fname)
#im = Image.open(image)
#rgb_im = im.convert('RGB')
#rgb_im.save(image.split("/")[-1].split(".")[0]+'.jpg', "JPEG")
except:pass
image = preprocess(image)
y_pred = model.predict(image)
label = decode_predictions(y_pred)
return label
def handle_command(command, channel, model):
"""
Executes bot command if the command is known
"""
urls = re.findall('http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', command)
if len(urls):
# url found
# download and send to the model
response = predict(urls[0], model)
# [[(u'n02099601', u'golden_retriever', 0.43866554), (u'n02099712', u'Labrador_retriever', 0.37232184), (u'n02090379', u'redbone', 0.026343988), (u'n02087394', u'Rhodesian_ridgeback', 0.024411326), (u'n02099849', u'Chesapeake_Bay_retriever', 0.022567078)]]
candidate_fillers = ["It looks like a <FILL> to me.", \
"It is a <FILL>", \
"I think it's a <FILL>", \
"Well, that's a <FILL>"]
try:
for i in response[0]:
response = random.choice(candidate_fillers).replace("<FILL>", i[1].replace("_"," ").title())
break
except: response = 'Not sure what you mean'
else:
# Default response is help text for the user
response = "Not sure what you mean."
# Sends the response back to the channel
slack_client.api_call(
"chat.postMessage",
channel=channel,
text=response
)
if __name__ == '__main__':
model = load_model()
if slack_client.rtm_connect(with_team_state=False):
print("Starter Bot connected and running!")
# Read bot's user ID by calling Web API method `auth.test`
starterbot_id = slack_client.api_call("auth.test")["user_id"]
while True:
command, channel = parse_bot_commands(slack_client.rtm_read())
if command:
handle_command(command, channel, model)
time.sleep(RTM_READ_DELAY)
else:
print("Connection failed. Exception traceback printed above.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment