-
-
Save trpfrog/ec2e810b1558bde7fb3af5c83d1fec78 to your computer and use it in GitHub Desktop.
AIつまみロボのPython版ソースコード
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import json | |
import re | |
import itertools | |
import requests | |
import time | |
import datetime | |
import random | |
import math | |
from typing import Optional | |
import tweepy | |
import flask | |
import functions_framework | |
HF_TOKEN = os.environ["HF_TOKEN"] | |
GPT2_MODEL_NAME = os.environ["GPT2_MODEL_NAME"] | |
TWITTER_API_KEY = json.loads(os.environ["TWITTER_TOKEN_JSON"]) | |
NG_WORDS = json.loads(os.environ["NG_WORDS"]) | |
TRPFROG_WEBHOOK_TOKEN = os.environ["TRPFROG_WEBHOOK_TOKEN"] | |
def query(text, retries=2, do_sleep=True): | |
payload = { | |
"inputs": text, | |
"parameters": { | |
"top_k": 50, | |
"top_p": 0.95, | |
"return_full_text": False, # prompt を除外 | |
"num_return_sequences": 6, | |
"repetition_penalty": 1.2, | |
"max_new_tokens": 20, | |
"temperature": 1, | |
}, | |
"options": { | |
"use_cache": False | |
} | |
} | |
headers = {"Authorization": f"Bearer {HF_TOKEN}"} | |
API_URL = f"https://api-inference.huggingface.co/models/{GPT2_MODEL_NAME}" | |
response = requests.post(API_URL, headers=headers, json=payload) | |
result = response.json() | |
if response.status_code == 200: | |
if do_sleep: | |
time.sleep(10) # 返信が早すぎてもアレかな…… | |
return result | |
else: | |
if "estimated_time" in result and retries > 0: | |
t = float(result["estimated_time"]) * 1.05 | |
print(f"Waiting for API... ({t:.3f} sec)") | |
time.sleep(t) | |
return query(text, retries=retries-1, do_sleep=False) | |
else: | |
print(response) | |
print(result) | |
raise Exception("API is not responding") | |
def random_heart() -> str: | |
return random.choice(['❤️', '🧡', '💛', '💚', '💙', '💜', '💗', '💖', '💓', '💕', '💝']) | |
def mesugakinize(s: str) -> str: | |
abusive_words = [ | |
"うるせえ", "お前", "馬鹿", "糞", | |
"バカモン", "ばかもん", "バカモノ", "ばかもの", | |
"バカ", "アホ", "ドジ", "マヌケ", "カス", "クソ", "クズ", | |
"ばか", "あほ", "どじ", "まぬけ", "かす", "くそ", "くず" | |
] | |
for word in abusive_words: | |
s = s.replace(word, word + random_heart()) | |
return s | |
def softmax(arr: list[float], temperature: float = 1.0) -> list[float]: | |
maximum = max(arr) / temperature | |
e_x = [math.exp(x / temperature - maximum) for x in arr] | |
return [x / sum(e_x) for x in e_x] | |
class Pipeline: | |
def __init__(self, prompt: str): | |
self.prompt = prompt | |
@staticmethod | |
def cleaning(s: str) -> str: | |
s = s.replace(" ", "") | |
s = re.sub(r"https?://[\w/:%#\$&\?\(\)~\.=\+\-…]+", "", s) | |
for ng_word in itertools.chain(*NG_WORDS.values()): | |
s = s.replace(ng_word, "💗💗") | |
s = re.sub(r"[「」\s]", "", s) | |
s = re.sub(r"[殺死]", "💗", s) | |
s = re.sub(r"[@@]", "", s) | |
return mesugakinize(s) | |
@staticmethod | |
def _softmax_with_filter(arr: list[float], temperature: float = 1.0) -> list[float]: | |
for i in range(len(arr)): | |
if arr[i] <= 0: | |
arr[i] = -1e9 | |
return softmax(arr, temperature) | |
@staticmethod | |
def select_from_sample(generated: list[str], temperature: float = 1.0) -> str: | |
probabilities = Pipeline._softmax_with_filter([len(s) for s in generated], temperature) | |
if all(p < 1e-5 for p in probabilities): | |
return '' | |
else: | |
return random.choices(generated, weights=probabilities, k=1)[0] | |
def generate(self, return_all=False, do_sleep=True) -> list[str]: | |
data = query(self.prompt, do_sleep=do_sleep) | |
if data[0]["generated_text"] is not None: | |
generated = list(map(lambda s: self.cleaning(s["generated_text"]), data)) | |
if return_all: | |
return generated | |
else: | |
return Pipeline.select_from_sample(generated) | |
else: | |
print("Something went wrong...") | |
print(data) | |
raise Exception | |
class ReplyPipeline(Pipeline): | |
def __init__(self, *, other: str, me: Optional[str] = None) -> None: | |
prompt = "" | |
if me is not None: | |
prompt += f"「{me}」" | |
if other is not None: | |
prompt += f"「{other}」" | |
prompt += "わし「" | |
super().__init__(prompt) | |
@staticmethod | |
def parse_reply_reaction(text: str): | |
bracket_cnt, idx = 1, 0 | |
for c in text: | |
if bracket_cnt == 0: | |
break | |
idx += 1 | |
if c == '「': | |
bracket_cnt += 1 | |
elif c == '」': | |
bracket_cnt -= 1 | |
return text[:idx] | |
def generate(self, return_all=False, do_sleep=True) -> list[str]: | |
generated = super().generate(return_all=True, do_sleep=do_sleep) | |
generated = list(map(self.parse_reply_reaction, generated)) | |
if return_all: | |
return generated | |
else: | |
return Pipeline.select_from_sample(generated) | |
class FrogRobo: | |
def __init__(self): | |
self.client = tweepy.Client(**TWITTER_API_KEY) | |
self.id = 2744579940 # self.client.get_me().data.id | |
def crawl_timeline(self, target_minutes=15) -> str | None: | |
# 1 hours ago as ISO format | |
if target_minutes > 0: | |
start_time = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(minutes=target_minutes) | |
start_time = start_time.isoformat().split('.')[0] + 'Z' | |
tweets = self.client.get_home_timeline(start_time=start_time) | |
else: | |
tweets = self.client.get_home_timeline() | |
if tweets.data is None or len(tweets.data) <= 0: | |
return None | |
target_tweets = [] | |
for tweet in tweets.data: | |
if tweet.possibly_sensitive: | |
continue | |
if tweet.in_reply_to_user_id is not None: | |
continue | |
if len(tweet.text) < 10: | |
continue | |
if tweet.text.startswith("RT"): | |
continue | |
tweet.text = re.sub(r"https?://[\w/:%#\$&\?\(\)~\.=\+\-…]+", "", tweet.text) | |
target_tweets.append(tweet) | |
if len(target_tweets) <= 0: | |
return None | |
ids = [int(tweet.id) for tweet in target_tweets] | |
min_id, max_id = min(ids), max(ids) | |
id_range = max_id - min_id + 1 | |
probs = softmax([(x - min_id) / id_range for x in ids]) | |
return random.choices(target_tweets, weights=probs, k=1)[0].text | |
def general_tweet(self): | |
crawl_range_minutes = [15, 60, -1] | |
for m in crawl_range_minutes: | |
target_tweet = self.crawl_timeline(target_minutes=m) | |
if target_tweet is not None: | |
prompt = ReplyPipeline(other=target_tweet) | |
break | |
else: | |
prompt = Pipeline("今から面白いこと言います。") | |
self.client.create_tweet(text=prompt.generate()) | |
def debug_dm(self, text: str): | |
DM_TARGET_USER_ID = 92482871 | |
self.client.create_dm(text=text, participant_id=DM_TARGET_USER_ID) | |
def reply_tweet(self, tweet_dict, ignore_probablity=0.1): | |
tweet_id = tweet_dict["id"] | |
if int(tweet_dict["user"]["id"]) == self.id: | |
return | |
if tweet_dict["text"].startswith("RT "): | |
return | |
in_reply_to = tweet_dict["in_reply_to_status_id_str"] | |
my_tweet = None | |
try: | |
if in_reply_to is not None: | |
res = self.client.get_tweet(in_reply_to) | |
if (res.data is not None) and (res.data.text is not None): | |
my_tweet = res.data.text | |
except Exception as e: | |
print(e) | |
if random.random() > ignore_probablity: | |
text = tweet_dict["text"].lower().replace("@frogrobo", "").strip() | |
prompt = ReplyPipeline(me=my_tweet, other=text) | |
tweet = prompt.generate() | |
self.client.create_tweet(text=tweet, in_reply_to_tweet_id=tweet_id) | |
else: | |
self.client.like(tweet_id) | |
def webhook_challenge(request: flask.Request): | |
import hmac | |
import hashlib | |
import base64 | |
params = request.args | |
if 'crc_token' in params: | |
sha256_hash_digest = hmac.new( | |
TWITTER_API_KEY['consumer_secret'].encode(), | |
msg = params.get('crc_token').encode(), | |
digestmod = hashlib.sha256 | |
).digest() | |
response_token = 'sha256=' + base64.b64encode(sha256_hash_digest).decode() | |
response = {'response_token': response_token} | |
return json.dumps(response), 200, {'Content-Type': 'application/json'} | |
else: | |
return json.dumps({"error":"No Content"}) | |
def is_access_allowed(request: flask.Request): | |
if not 'trpfrog-webhook-token' in request.args: | |
return False | |
else: | |
if request.args['trpfrog-webhook-token'] != TRPFROG_WEBHOOK_TOKEN: | |
return False | |
return True | |
def webhook(request: flask.Request): | |
data = request.get_json(silent=True) | |
event_name = 'tweet_create_events' | |
if not is_access_allowed(request): | |
token = request.args['trpfrog-webhook-token'] or 'none' | |
return json.dumps({"error":"Forbidden", "message": f"{token} is invalid token."}), 403, {'Content-Type': 'application/json'} | |
print(data, type(data)) | |
if data is not None and event_name in data and len(data[event_name]) > 0: | |
print("start replying") | |
robot = FrogRobo() | |
for tweet_dict in data[event_name]: | |
try: | |
robot.reply_tweet(tweet_dict) | |
except Exception: | |
pass | |
elif data is None and 'ping' in request.args: | |
print("start pinging") | |
prompt = Pipeline("今から面白いこと言います。") | |
try: | |
generated = prompt.generate(do_sleep=False, retries=0) | |
except Exception: | |
generated = 'uouo' | |
FrogRobo().debug_dm("pong: " + prompt.generate(do_sleep=False)) | |
elif data is None and 'general' in request.args: | |
print("start simple tweeting") | |
FrogRobo().general_tweet() | |
return 'ok' | |
@functions_framework.http | |
def start_bot(request: flask.Request): | |
"""HTTP Cloud Function. | |
Args: | |
request (flask.Request): The request object. | |
<https://flask.palletsprojects.com/en/1.1.x/api/#incoming-request-data> | |
Returns: | |
The response text, or any set of values that can be turned into a | |
Response object using `make_response` | |
<https://flask.palletsprojects.com/en/1.1.x/api/#flask.make_response>. | |
""" | |
print("called!", request.args) | |
print(request) | |
if request.method == 'GET': | |
return webhook_challenge(request) | |
elif request.method == 'POST': | |
return webhook(request) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment