Created
March 5, 2018 10:51
-
-
Save prakhar21/cc275b6f321e85bfcdc0f68af2fab895 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import time | |
import random | |
import re | |
from slackclient import SlackClient | |
from keras.applications.vgg16 import VGG16 | |
from keras.preprocessing.image import load_img | |
from keras.preprocessing.image import img_to_array | |
from keras.applications.vgg16 import preprocess_input | |
from keras.applications.vgg16 import decode_predictions | |
from PIL import Image | |
import requests | |
starterbot_id = None | |
# instantiate Slack client | |
slack_client = SlackClient(os.environ.get('SLACK_BOT_TOKEN')) | |
RTM_READ_DELAY = 1 | |
EXAMPLE_COMMAND = "do" | |
MENTION_REGEX = "^<@(|[WU].+?)>(.*)" | |
def load_model(): | |
"""Loading model""" | |
model = VGG16() | |
return model | |
def load_image(image): | |
"""Loading image""" | |
im = load_img(image, target_size=(224, 224)) | |
return im | |
def preprocess(img): | |
image = load_image(img) | |
image = img_to_array(image) | |
image = image.reshape((1, image.shape[0], image.shape[1], image.shape[2])) | |
image = preprocess_input(image) | |
return image | |
def parse_direct_mention(message_text): | |
matches = re.search(MENTION_REGEX, message_text) | |
# the first group contains the username, the second group contains the remaining message | |
return (matches.group(1), matches.group(2).strip()) if matches else (None, None) | |
def parse_bot_commands(slack_events): | |
""" | |
example {u'source_team': u'T90HNMX8B', u'text': u'how are you ?', u'ts': u'1520229941.000201', u'user': u'U90L44L0L', u'team': u'T90HNMX8B', u'type': u'message ', u'channel': u'D9KB0A5UP'} | |
""" | |
for event in slack_events: | |
if event["type"] == "message" and not "subtype" in event: | |
user_id, message = parse_direct_mention(event["text"]) | |
if user_id == starterbot_id: | |
return message, event["channel"] | |
elif event['type'] == 'message' and 'file' in event: | |
uri = event['file']['url_private_download'] | |
return uri, event['channel'] | |
else: | |
pass | |
return None, None | |
def predict(image, model): | |
try: | |
fname = image.split("/")[-1] | |
u = "wget -d --header=\"Authorization: Bearer <your token>\" " + str(image) | |
os.system(u) | |
image = os.path.abspath(fname) | |
#im = Image.open(image) | |
#rgb_im = im.convert('RGB') | |
#rgb_im.save(image.split("/")[-1].split(".")[0]+'.jpg', "JPEG") | |
except:pass | |
image = preprocess(image) | |
y_pred = model.predict(image) | |
label = decode_predictions(y_pred) | |
return label | |
def handle_command(command, channel, model): | |
""" | |
Executes bot command if the command is known | |
""" | |
urls = re.findall('http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', command) | |
if len(urls): | |
# url found | |
# download and send to the model | |
response = predict(urls[0], model) | |
# [[(u'n02099601', u'golden_retriever', 0.43866554), (u'n02099712', u'Labrador_retriever', 0.37232184), (u'n02090379', u'redbone', 0.026343988), (u'n02087394', u'Rhodesian_ridgeback', 0.024411326), (u'n02099849', u'Chesapeake_Bay_retriever', 0.022567078)]] | |
candidate_fillers = ["It looks like a <FILL> to me.", \ | |
"It is a <FILL>", \ | |
"I think it's a <FILL>", \ | |
"Well, that's a <FILL>"] | |
try: | |
for i in response[0]: | |
response = random.choice(candidate_fillers).replace("<FILL>", i[1].replace("_"," ").title()) | |
break | |
except: response = 'Not sure what you mean' | |
else: | |
# Default response is help text for the user | |
response = "Not sure what you mean." | |
# Sends the response back to the channel | |
slack_client.api_call( | |
"chat.postMessage", | |
channel=channel, | |
text=response | |
) | |
if __name__ == '__main__': | |
model = load_model() | |
if slack_client.rtm_connect(with_team_state=False): | |
print("Starter Bot connected and running!") | |
# Read bot's user ID by calling Web API method `auth.test` | |
starterbot_id = slack_client.api_call("auth.test")["user_id"] | |
while True: | |
command, channel = parse_bot_commands(slack_client.rtm_read()) | |
if command: | |
handle_command(command, channel, model) | |
time.sleep(RTM_READ_DELAY) | |
else: | |
print("Connection failed. Exception traceback printed above.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment