Skip to content

Instantly share code, notes, and snippets.

@trpfrog
Created February 3, 2023 07:00
Show Gist options
  • Save trpfrog/ec2e810b1558bde7fb3af5c83d1fec78 to your computer and use it in GitHub Desktop.
Save trpfrog/ec2e810b1558bde7fb3af5c83d1fec78 to your computer and use it in GitHub Desktop.
AIつまみロボのPython版ソースコード
import os
import json
import re
import itertools
import requests
import time
import datetime
import random
import math
from typing import Optional
import tweepy
import flask
import functions_framework
HF_TOKEN = os.environ["HF_TOKEN"]
GPT2_MODEL_NAME = os.environ["GPT2_MODEL_NAME"]
TWITTER_API_KEY = json.loads(os.environ["TWITTER_TOKEN_JSON"])
NG_WORDS = json.loads(os.environ["NG_WORDS"])
TRPFROG_WEBHOOK_TOKEN = os.environ["TRPFROG_WEBHOOK_TOKEN"]
def query(text, retries=2, do_sleep=True):
payload = {
"inputs": text,
"parameters": {
"top_k": 50,
"top_p": 0.95,
"return_full_text": False, # prompt を除外
"num_return_sequences": 6,
"repetition_penalty": 1.2,
"max_new_tokens": 20,
"temperature": 1,
},
"options": {
"use_cache": False
}
}
headers = {"Authorization": f"Bearer {HF_TOKEN}"}
API_URL = f"https://api-inference.huggingface.co/models/{GPT2_MODEL_NAME}"
response = requests.post(API_URL, headers=headers, json=payload)
result = response.json()
if response.status_code == 200:
if do_sleep:
time.sleep(10) # 返信が早すぎてもアレかな……
return result
else:
if "estimated_time" in result and retries > 0:
t = float(result["estimated_time"]) * 1.05
print(f"Waiting for API... ({t:.3f} sec)")
time.sleep(t)
return query(text, retries=retries-1, do_sleep=False)
else:
print(response)
print(result)
raise Exception("API is not responding")
def random_heart() -> str:
return random.choice(['❤️', '🧡', '💛', '💚', '💙', '💜', '💗', '💖', '💓', '💕', '💝'])
def mesugakinize(s: str) -> str:
abusive_words = [
"うるせえ", "お前", "馬鹿", "糞",
"バカモン", "ばかもん", "バカモノ", "ばかもの",
"バカ", "アホ", "ドジ", "マヌケ", "カス", "クソ", "クズ",
"ばか", "あほ", "どじ", "まぬけ", "かす", "くそ", "くず"
]
for word in abusive_words:
s = s.replace(word, word + random_heart())
return s
def softmax(arr: list[float], temperature: float = 1.0) -> list[float]:
maximum = max(arr) / temperature
e_x = [math.exp(x / temperature - maximum) for x in arr]
return [x / sum(e_x) for x in e_x]
class Pipeline:
def __init__(self, prompt: str):
self.prompt = prompt
@staticmethod
def cleaning(s: str) -> str:
s = s.replace(" ", "")
s = re.sub(r"https?://[\w/:%#\$&\?\(\)~\.=\+\-…]+", "", s)
for ng_word in itertools.chain(*NG_WORDS.values()):
s = s.replace(ng_word, "💗💗")
s = re.sub(r"[「」\s]", "", s)
s = re.sub(r"[殺死]", "💗", s)
s = re.sub(r"[@@]", "", s)
return mesugakinize(s)
@staticmethod
def _softmax_with_filter(arr: list[float], temperature: float = 1.0) -> list[float]:
for i in range(len(arr)):
if arr[i] <= 0:
arr[i] = -1e9
return softmax(arr, temperature)
@staticmethod
def select_from_sample(generated: list[str], temperature: float = 1.0) -> str:
probabilities = Pipeline._softmax_with_filter([len(s) for s in generated], temperature)
if all(p < 1e-5 for p in probabilities):
return ''
else:
return random.choices(generated, weights=probabilities, k=1)[0]
def generate(self, return_all=False, do_sleep=True) -> list[str]:
data = query(self.prompt, do_sleep=do_sleep)
if data[0]["generated_text"] is not None:
generated = list(map(lambda s: self.cleaning(s["generated_text"]), data))
if return_all:
return generated
else:
return Pipeline.select_from_sample(generated)
else:
print("Something went wrong...")
print(data)
raise Exception
class ReplyPipeline(Pipeline):
def __init__(self, *, other: str, me: Optional[str] = None) -> None:
prompt = ""
if me is not None:
prompt += f"「{me}」"
if other is not None:
prompt += f"「{other}」"
prompt += "わし「"
super().__init__(prompt)
@staticmethod
def parse_reply_reaction(text: str):
bracket_cnt, idx = 1, 0
for c in text:
if bracket_cnt == 0:
break
idx += 1
if c == '「':
bracket_cnt += 1
elif c == '」':
bracket_cnt -= 1
return text[:idx]
def generate(self, return_all=False, do_sleep=True) -> list[str]:
generated = super().generate(return_all=True, do_sleep=do_sleep)
generated = list(map(self.parse_reply_reaction, generated))
if return_all:
return generated
else:
return Pipeline.select_from_sample(generated)
class FrogRobo:
def __init__(self):
self.client = tweepy.Client(**TWITTER_API_KEY)
self.id = 2744579940 # self.client.get_me().data.id
def crawl_timeline(self, target_minutes=15) -> str | None:
# 1 hours ago as ISO format
if target_minutes > 0:
start_time = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(minutes=target_minutes)
start_time = start_time.isoformat().split('.')[0] + 'Z'
tweets = self.client.get_home_timeline(start_time=start_time)
else:
tweets = self.client.get_home_timeline()
if tweets.data is None or len(tweets.data) <= 0:
return None
target_tweets = []
for tweet in tweets.data:
if tweet.possibly_sensitive:
continue
if tweet.in_reply_to_user_id is not None:
continue
if len(tweet.text) < 10:
continue
if tweet.text.startswith("RT"):
continue
tweet.text = re.sub(r"https?://[\w/:%#\$&\?\(\)~\.=\+\-…]+", "", tweet.text)
target_tweets.append(tweet)
if len(target_tweets) <= 0:
return None
ids = [int(tweet.id) for tweet in target_tweets]
min_id, max_id = min(ids), max(ids)
id_range = max_id - min_id + 1
probs = softmax([(x - min_id) / id_range for x in ids])
return random.choices(target_tweets, weights=probs, k=1)[0].text
def general_tweet(self):
crawl_range_minutes = [15, 60, -1]
for m in crawl_range_minutes:
target_tweet = self.crawl_timeline(target_minutes=m)
if target_tweet is not None:
prompt = ReplyPipeline(other=target_tweet)
break
else:
prompt = Pipeline("今から面白いこと言います。")
self.client.create_tweet(text=prompt.generate())
def debug_dm(self, text: str):
DM_TARGET_USER_ID = 92482871
self.client.create_dm(text=text, participant_id=DM_TARGET_USER_ID)
def reply_tweet(self, tweet_dict, ignore_probablity=0.1):
tweet_id = tweet_dict["id"]
if int(tweet_dict["user"]["id"]) == self.id:
return
if tweet_dict["text"].startswith("RT "):
return
in_reply_to = tweet_dict["in_reply_to_status_id_str"]
my_tweet = None
try:
if in_reply_to is not None:
res = self.client.get_tweet(in_reply_to)
if (res.data is not None) and (res.data.text is not None):
my_tweet = res.data.text
except Exception as e:
print(e)
if random.random() > ignore_probablity:
text = tweet_dict["text"].lower().replace("@frogrobo", "").strip()
prompt = ReplyPipeline(me=my_tweet, other=text)
tweet = prompt.generate()
self.client.create_tweet(text=tweet, in_reply_to_tweet_id=tweet_id)
else:
self.client.like(tweet_id)
def webhook_challenge(request: flask.Request):
import hmac
import hashlib
import base64
params = request.args
if 'crc_token' in params:
sha256_hash_digest = hmac.new(
TWITTER_API_KEY['consumer_secret'].encode(),
msg = params.get('crc_token').encode(),
digestmod = hashlib.sha256
).digest()
response_token = 'sha256=' + base64.b64encode(sha256_hash_digest).decode()
response = {'response_token': response_token}
return json.dumps(response), 200, {'Content-Type': 'application/json'}
else:
return json.dumps({"error":"No Content"})
def is_access_allowed(request: flask.Request):
if not 'trpfrog-webhook-token' in request.args:
return False
else:
if request.args['trpfrog-webhook-token'] != TRPFROG_WEBHOOK_TOKEN:
return False
return True
def webhook(request: flask.Request):
data = request.get_json(silent=True)
event_name = 'tweet_create_events'
if not is_access_allowed(request):
token = request.args['trpfrog-webhook-token'] or 'none'
return json.dumps({"error":"Forbidden", "message": f"{token} is invalid token."}), 403, {'Content-Type': 'application/json'}
print(data, type(data))
if data is not None and event_name in data and len(data[event_name]) > 0:
print("start replying")
robot = FrogRobo()
for tweet_dict in data[event_name]:
try:
robot.reply_tweet(tweet_dict)
except Exception:
pass
elif data is None and 'ping' in request.args:
print("start pinging")
prompt = Pipeline("今から面白いこと言います。")
try:
generated = prompt.generate(do_sleep=False, retries=0)
except Exception:
generated = 'uouo'
FrogRobo().debug_dm("pong: " + prompt.generate(do_sleep=False))
elif data is None and 'general' in request.args:
print("start simple tweeting")
FrogRobo().general_tweet()
return 'ok'
@functions_framework.http
def start_bot(request: flask.Request):
"""HTTP Cloud Function.
Args:
request (flask.Request): The request object.
<https://flask.palletsprojects.com/en/1.1.x/api/#incoming-request-data>
Returns:
The response text, or any set of values that can be turned into a
Response object using `make_response`
<https://flask.palletsprojects.com/en/1.1.x/api/#flask.make_response>.
"""
print("called!", request.args)
print(request)
if request.method == 'GET':
return webhook_challenge(request)
elif request.method == 'POST':
return webhook(request)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment