Skip to content

Instantly share code, notes, and snippets.

@bulldra
Created April 12, 2025 05:55
Show Gist options
  • Select an option

  • Save bulldra/a1fe728b2f88349db0df9cea462e1bfa to your computer and use it in GitHub Desktop.

Select an option

Save bulldra/a1fe728b2f88349db0df9cea462e1bfa to your computer and use it in GitHub Desktop.
tweet_classfication.py
import argparse
import asyncio
import json
import os
import re
import unicodedata
from datetime import datetime
import aiohttp
import chardet
import ftfy
import torch
from bs4 import BeautifulSoup
from scipy.special import expit
from tqdm import tqdm
from transformers import AutoModelForSequenceClassification, AutoTokenizer
MODEL = "cardiffnlp/tweet-topic-large-multilingual"
tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForSequenceClassification.from_pretrained(MODEL)
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")
model.to(device)
class_mapping = model.config.id2label
CATEGORY_TO_HASHTAG = {
"arts_&_culture": "#アート",
"business_&_entrepreneurs": "#ビジネス",
"celebrity_&_pop_culture": "#ポップカルチャー",
"diaries_&_daily_life": "#日常",
"family": "#家族",
"fashion_&_style": "#ファッション",
"film_tv_&_video": "#ビデオ",
"fitness_&_health": "#健康",
"food_&_dining": "#グルメ",
"gaming": "#ゲーム",
"learning_&_educational": "#学習",
"music": "#音楽",
"news_&_social_concern": "#ニュース",
"other_hobbies": "#その他",
"relationships": "#人間関係",
"science_&_technology": "#テクノロジー",
"sports": "#スポーツ",
"travel_&_adventure": "#旅行",
"youth_&_student_life": "#若者",
}
def parse_args():
parser = argparse.ArgumentParser(description="ツイートをトピックごとに分類します")
parser.add_argument(
"--limit",
type=int,
default=0,
help="処理するツイートの最大数(デフォルト: 無制限、0で無制限)",
)
return parser.parse_args()
def sanitize_unicode_text(text):
if not text:
return text
fixed_text = ftfy.fix_text(text)
fixed_text = unicodedata.normalize("NFC", fixed_text)
# 全角数字を半角数字に変換
for i in range(10):
fixed_text = fixed_text.replace(chr(0xFF10 + i), str(i))
# 全角英字を半角英字に変換
for i in range(26):
# 大文字 A-Z
fixed_text = fixed_text.replace(chr(0xFF21 + i), chr(0x0041 + i))
# 小文字 a-z
fixed_text = fixed_text.replace(chr(0xFF41 + i), chr(0x0061 + i))
return fixed_text
def load_tweets():
with open("./input/data/tweets.js", "r", encoding="utf-8") as f:
content = f.read()
json_str = re.sub(r"^window\.YTD\.tweets\.part0 = ", "", content)
tweets_data = json.loads(json_str)
edited_tweet_ids = set()
for tweet in tweets_data:
tweet_info = tweet["tweet"]
if "edit_info" in tweet_info and "edit" in tweet_info["edit_info"]:
edit_control = tweet_info["edit_info"]["edit"].get(
"editControlInitial", {}
)
edit_tweet_ids = edit_control.get("editTweetIds", [])
if edit_tweet_ids:
edited_tweet_ids.update(edit_tweet_ids[:-1])
tweets_data = [
tweet
for tweet in tweets_data
if tweet["tweet"].get("id_str") not in edited_tweet_ids
]
for tweet in tweets_data:
tweet_info = tweet["tweet"]
tweet_info["datetime"] = datetime.strptime(
tweet_info["created_at"], "%a %b %d %H:%M:%S +0000 %Y"
)
tweets_data.sort(key=lambda x: x["tweet"]["datetime"], reverse=True)
return tweets_data
def load_url_cache():
cache_file = "./output/url_cache.json"
if os.path.exists(cache_file):
with open(cache_file, "r", encoding="utf-8") as f:
return json.load(f)
return {}
def save_url_cache(cache):
os.makedirs("./output", exist_ok=True)
with open("./output/url_cache.json", "w", encoding="utf-8") as f:
json.dump(cache, f, ensure_ascii=False, indent=2)
def get_meta_info(soup):
meta_info = {
"keywords": [],
"description": "",
}
keywords_meta = soup.find("meta", {"name": ["keywords", "Keywords"]}) or soup.find(
"meta", {"property": "keywords"}
)
if keywords_meta:
keywords = [k.strip() for k in keywords_meta.get("content", "").split(",")]
meta_info["keywords"] = [k for k in keywords if k]
desc_meta = soup.find(
"meta", {"name": ["description", "Description"]}
) or soup.find("meta", {"property": "og:description"})
if desc_meta:
meta_info["description"] = desc_meta.get("content", "").strip()
return meta_info
async def get_url_info(session, original_url, expanded_url):
try:
headers = {
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
"AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124"
}
async with session.get(expanded_url, headers=headers, timeout=5) as response:
content = await response.read()
detected = chardet.detect(content)
content = content.decode(detected["encoding"] or "utf-8", errors="ignore")
content = sanitize_unicode_text(content)
expanded_url = str(response.url)
soup = BeautifulSoup(content, "html.parser")
title = (
soup.title.string.strip()
if soup.title and soup.title.string
else str(response.url)
)
title = sanitize_unicode_text(title)
meta_info = get_meta_info(soup)
if meta_info["description"]:
meta_info["description"] = sanitize_unicode_text(
meta_info["description"]
)
meta_info["keywords"] = [
sanitize_unicode_text(keyword) for keyword in meta_info["keywords"]
]
return {
"original_url": original_url,
"title": title,
"url": expanded_url,
"meta": meta_info,
}
except (aiohttp.ClientError, asyncio.TimeoutError) as ex:
print(f"URLの取得に失敗しました: {expanded_url} - {str(ex)}")
return {
"original_url": original_url,
"title": expanded_url,
"url": expanded_url,
"meta": {},
}
async def process_urls_batch(urls_batch):
async with aiohttp.ClientSession() as session:
tasks = []
for url in urls_batch:
original_url = str(url["url"])
expanded_url = str(url["expanded_url"])
tasks.append(get_url_info(session, original_url, expanded_url))
if not tasks:
return []
results = await asyncio.gather(*tasks)
return [r for r in results if r is not None]
def convert_urls_to_markdown(url_cache, tweet_urls, content):
for tweet_url in tweet_urls:
original_url = str(tweet_url["url"])
expanded_url = str(tweet_url["expanded_url"])
if original_url in url_cache:
cache_data = url_cache[original_url]
title = cache_data["title"]
url = cache_data["url"]
meta_info = cache_data["meta"]
content = content.replace(original_url, f"[{title}]({url})")
keywords = meta_info.get("keywords", [])
if keywords:
content = content + " " + " ".join(f"#{k}" for k in keywords)
else:
content = content.replace(original_url, f"[{expanded_url}]({expanded_url})")
return content
async def process_tweets_batch(tweets_batch, url_cache):
urls_to_process = []
for tweet in tweets_batch:
tweet_urls = tweet["tweet"]["entities"]["urls"]
for tweet_url in tweet_urls:
if tweet_url["url"] not in url_cache and not re.match(
r"https?://(x\.com|twitter\.com|twimg\.com|twitpic\.com)/",
tweet_url["expanded_url"],
):
urls_to_process.append(tweet_url)
if urls_to_process:
results = await process_urls_batch(urls_to_process)
for result in results:
url_cache[result["original_url"]] = {
"title": result.get("title", ""),
"url": result.get("url", ""),
"meta": result.get("meta", ""),
}
for tweet in tweets_batch:
content = tweet["tweet"]["full_text"]
tweet_urls = tweet["tweet"]["entities"]["urls"]
content = convert_urls_to_markdown(url_cache, tweet_urls, content)
tweet["tweet"]["content"] = content
def classify_tweet(text):
tokens = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
tokens = {k: v.to(device) for k, v in tokens.items()}
with torch.no_grad():
output = model(**tokens)
scores = output[0][0].cpu().numpy()
probabilities = expit(scores)
predictions = probabilities > 0.7
hashtags = []
for i, prediction in enumerate(predictions):
if prediction:
hashtags.append(
CATEGORY_TO_HASHTAG.get(class_mapping[i], "#" + class_mapping[i])
)
return hashtags
def main():
args = parse_args()
tweets_data = load_tweets()
print(f"{len(tweets_data)}件のツイートを読み込みました。")
limit = args.limit if args.limit > 0 else len(tweets_data)
tweets_to_process = tweets_data[:limit]
print(f"最新の{len(tweets_to_process)}件のツイートを処理します。")
url_cache = load_url_cache()
total = len(tweets_to_process)
BATCH_SIZE = 5
for i in tqdm(
range(0, len(tweets_to_process), BATCH_SIZE),
desc="URL展開処理中",
total=(total + BATCH_SIZE - 1) // BATCH_SIZE,
):
if i % 1000 == 0:
save_url_cache(url_cache)
batch = tweets_to_process[i : i + BATCH_SIZE]
asyncio.run(process_tweets_batch(batch, url_cache))
save_url_cache(url_cache)
for i, tweet in enumerate(tqdm(tweets_to_process, desc="Tweet判定処理中")):
content = tweet["tweet"]["content"]
content = sanitize_unicode_text(content)
tweet["tweet"]["content"] = content
hashtags = classify_tweet(content)
for content_hashtag in re.findall(r"#\w+", content):
content = content.replace(content_hashtag, "", 1).strip()
if content_hashtag not in hashtags:
hashtags.append(content_hashtag)
if content.startswith("RT @") and "#RT" not in hashtags:
hashtags.append("#RT")
if content.startswith("@") and "#Reply" not in hashtags:
hashtags.append("#Reply")
content = content.replace("\r", "\\r")
content = content.replace("\n", "\\n")
tweet["tweet"]["hashtags"] = hashtags
tweet["tweet"]["clean_content"] = sanitize_unicode_text(content)
MAX_FILE_SIZE = 10 * 1024 * 1024
output_data = []
file_counter = 1
total_size = 0
os.makedirs("./output", exist_ok=True)
with tqdm(total=len(tweets_to_process), desc="ファイル出力中") as pbar:
for tweet in tweets_to_process:
timestamp = tweet["tweet"]["created_at"]
content = tweet["tweet"]["clean_content"]
dt = datetime.strptime(timestamp, "%a %b %d %H:%M:%S +0000 %Y")
formatted_timestamp = dt.strftime("%Y-%m-%d %H:%M:%S")
hashtags = tweet["tweet"].get("hashtags", [])
tweet_data = {
"timestamp": formatted_timestamp,
"content": content,
"hashtags": hashtags,
}
output_data.append(tweet_data)
tweet_json = json.dumps(tweet_data, ensure_ascii=False)
tweet_size = len(tweet_json.encode("utf-8"))
total_size += tweet_size
pbar.update(1)
if total_size >= MAX_FILE_SIZE or tweet == tweets_to_process[-1]:
file_name = f"./output/tweets_part_{file_counter}.json"
with open(file_name, "w", encoding="utf-8") as f:
json.dump(output_data, f, ensure_ascii=False, indent=2)
print(
f"ツイートを出力しました: {file_name} "
f"({len(output_data)}件, {total_size / 1024 / 1024:.2f}MB)"
)
output_data = []
total_size = 0
file_counter += 1
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment