-
-
Save bulldra/a1fe728b2f88349db0df9cea462e1bfa to your computer and use it in GitHub Desktop.
tweet_classfication.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import argparse | |
| import asyncio | |
| import json | |
| import os | |
| import re | |
| import unicodedata | |
| from datetime import datetime | |
| import aiohttp | |
| import chardet | |
| import ftfy | |
| import torch | |
| from bs4 import BeautifulSoup | |
| from scipy.special import expit | |
| from tqdm import tqdm | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| MODEL = "cardiffnlp/tweet-topic-large-multilingual" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL) | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL) | |
| device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| model.to(device) | |
| class_mapping = model.config.id2label | |
| CATEGORY_TO_HASHTAG = { | |
| "arts_&_culture": "#アート", | |
| "business_&_entrepreneurs": "#ビジネス", | |
| "celebrity_&_pop_culture": "#ポップカルチャー", | |
| "diaries_&_daily_life": "#日常", | |
| "family": "#家族", | |
| "fashion_&_style": "#ファッション", | |
| "film_tv_&_video": "#ビデオ", | |
| "fitness_&_health": "#健康", | |
| "food_&_dining": "#グルメ", | |
| "gaming": "#ゲーム", | |
| "learning_&_educational": "#学習", | |
| "music": "#音楽", | |
| "news_&_social_concern": "#ニュース", | |
| "other_hobbies": "#その他", | |
| "relationships": "#人間関係", | |
| "science_&_technology": "#テクノロジー", | |
| "sports": "#スポーツ", | |
| "travel_&_adventure": "#旅行", | |
| "youth_&_student_life": "#若者", | |
| } | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="ツイートをトピックごとに分類します") | |
| parser.add_argument( | |
| "--limit", | |
| type=int, | |
| default=0, | |
| help="処理するツイートの最大数(デフォルト: 無制限、0で無制限)", | |
| ) | |
| return parser.parse_args() | |
| def sanitize_unicode_text(text): | |
| if not text: | |
| return text | |
| fixed_text = ftfy.fix_text(text) | |
| fixed_text = unicodedata.normalize("NFC", fixed_text) | |
| # 全角数字を半角数字に変換 | |
| for i in range(10): | |
| fixed_text = fixed_text.replace(chr(0xFF10 + i), str(i)) | |
| # 全角英字を半角英字に変換 | |
| for i in range(26): | |
| # 大文字 A-Z | |
| fixed_text = fixed_text.replace(chr(0xFF21 + i), chr(0x0041 + i)) | |
| # 小文字 a-z | |
| fixed_text = fixed_text.replace(chr(0xFF41 + i), chr(0x0061 + i)) | |
| return fixed_text | |
| def load_tweets(): | |
| with open("./input/data/tweets.js", "r", encoding="utf-8") as f: | |
| content = f.read() | |
| json_str = re.sub(r"^window\.YTD\.tweets\.part0 = ", "", content) | |
| tweets_data = json.loads(json_str) | |
| edited_tweet_ids = set() | |
| for tweet in tweets_data: | |
| tweet_info = tweet["tweet"] | |
| if "edit_info" in tweet_info and "edit" in tweet_info["edit_info"]: | |
| edit_control = tweet_info["edit_info"]["edit"].get( | |
| "editControlInitial", {} | |
| ) | |
| edit_tweet_ids = edit_control.get("editTweetIds", []) | |
| if edit_tweet_ids: | |
| edited_tweet_ids.update(edit_tweet_ids[:-1]) | |
| tweets_data = [ | |
| tweet | |
| for tweet in tweets_data | |
| if tweet["tweet"].get("id_str") not in edited_tweet_ids | |
| ] | |
| for tweet in tweets_data: | |
| tweet_info = tweet["tweet"] | |
| tweet_info["datetime"] = datetime.strptime( | |
| tweet_info["created_at"], "%a %b %d %H:%M:%S +0000 %Y" | |
| ) | |
| tweets_data.sort(key=lambda x: x["tweet"]["datetime"], reverse=True) | |
| return tweets_data | |
| def load_url_cache(): | |
| cache_file = "./output/url_cache.json" | |
| if os.path.exists(cache_file): | |
| with open(cache_file, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| return {} | |
| def save_url_cache(cache): | |
| os.makedirs("./output", exist_ok=True) | |
| with open("./output/url_cache.json", "w", encoding="utf-8") as f: | |
| json.dump(cache, f, ensure_ascii=False, indent=2) | |
| def get_meta_info(soup): | |
| meta_info = { | |
| "keywords": [], | |
| "description": "", | |
| } | |
| keywords_meta = soup.find("meta", {"name": ["keywords", "Keywords"]}) or soup.find( | |
| "meta", {"property": "keywords"} | |
| ) | |
| if keywords_meta: | |
| keywords = [k.strip() for k in keywords_meta.get("content", "").split(",")] | |
| meta_info["keywords"] = [k for k in keywords if k] | |
| desc_meta = soup.find( | |
| "meta", {"name": ["description", "Description"]} | |
| ) or soup.find("meta", {"property": "og:description"}) | |
| if desc_meta: | |
| meta_info["description"] = desc_meta.get("content", "").strip() | |
| return meta_info | |
| async def get_url_info(session, original_url, expanded_url): | |
| try: | |
| headers = { | |
| "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) " | |
| "AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124" | |
| } | |
| async with session.get(expanded_url, headers=headers, timeout=5) as response: | |
| content = await response.read() | |
| detected = chardet.detect(content) | |
| content = content.decode(detected["encoding"] or "utf-8", errors="ignore") | |
| content = sanitize_unicode_text(content) | |
| expanded_url = str(response.url) | |
| soup = BeautifulSoup(content, "html.parser") | |
| title = ( | |
| soup.title.string.strip() | |
| if soup.title and soup.title.string | |
| else str(response.url) | |
| ) | |
| title = sanitize_unicode_text(title) | |
| meta_info = get_meta_info(soup) | |
| if meta_info["description"]: | |
| meta_info["description"] = sanitize_unicode_text( | |
| meta_info["description"] | |
| ) | |
| meta_info["keywords"] = [ | |
| sanitize_unicode_text(keyword) for keyword in meta_info["keywords"] | |
| ] | |
| return { | |
| "original_url": original_url, | |
| "title": title, | |
| "url": expanded_url, | |
| "meta": meta_info, | |
| } | |
| except (aiohttp.ClientError, asyncio.TimeoutError) as ex: | |
| print(f"URLの取得に失敗しました: {expanded_url} - {str(ex)}") | |
| return { | |
| "original_url": original_url, | |
| "title": expanded_url, | |
| "url": expanded_url, | |
| "meta": {}, | |
| } | |
| async def process_urls_batch(urls_batch): | |
| async with aiohttp.ClientSession() as session: | |
| tasks = [] | |
| for url in urls_batch: | |
| original_url = str(url["url"]) | |
| expanded_url = str(url["expanded_url"]) | |
| tasks.append(get_url_info(session, original_url, expanded_url)) | |
| if not tasks: | |
| return [] | |
| results = await asyncio.gather(*tasks) | |
| return [r for r in results if r is not None] | |
| def convert_urls_to_markdown(url_cache, tweet_urls, content): | |
| for tweet_url in tweet_urls: | |
| original_url = str(tweet_url["url"]) | |
| expanded_url = str(tweet_url["expanded_url"]) | |
| if original_url in url_cache: | |
| cache_data = url_cache[original_url] | |
| title = cache_data["title"] | |
| url = cache_data["url"] | |
| meta_info = cache_data["meta"] | |
| content = content.replace(original_url, f"[{title}]({url})") | |
| keywords = meta_info.get("keywords", []) | |
| if keywords: | |
| content = content + " " + " ".join(f"#{k}" for k in keywords) | |
| else: | |
| content = content.replace(original_url, f"[{expanded_url}]({expanded_url})") | |
| return content | |
| async def process_tweets_batch(tweets_batch, url_cache): | |
| urls_to_process = [] | |
| for tweet in tweets_batch: | |
| tweet_urls = tweet["tweet"]["entities"]["urls"] | |
| for tweet_url in tweet_urls: | |
| if tweet_url["url"] not in url_cache and not re.match( | |
| r"https?://(x\.com|twitter\.com|twimg\.com|twitpic\.com)/", | |
| tweet_url["expanded_url"], | |
| ): | |
| urls_to_process.append(tweet_url) | |
| if urls_to_process: | |
| results = await process_urls_batch(urls_to_process) | |
| for result in results: | |
| url_cache[result["original_url"]] = { | |
| "title": result.get("title", ""), | |
| "url": result.get("url", ""), | |
| "meta": result.get("meta", ""), | |
| } | |
| for tweet in tweets_batch: | |
| content = tweet["tweet"]["full_text"] | |
| tweet_urls = tweet["tweet"]["entities"]["urls"] | |
| content = convert_urls_to_markdown(url_cache, tweet_urls, content) | |
| tweet["tweet"]["content"] = content | |
| def classify_tweet(text): | |
| tokens = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) | |
| tokens = {k: v.to(device) for k, v in tokens.items()} | |
| with torch.no_grad(): | |
| output = model(**tokens) | |
| scores = output[0][0].cpu().numpy() | |
| probabilities = expit(scores) | |
| predictions = probabilities > 0.7 | |
| hashtags = [] | |
| for i, prediction in enumerate(predictions): | |
| if prediction: | |
| hashtags.append( | |
| CATEGORY_TO_HASHTAG.get(class_mapping[i], "#" + class_mapping[i]) | |
| ) | |
| return hashtags | |
| def main(): | |
| args = parse_args() | |
| tweets_data = load_tweets() | |
| print(f"{len(tweets_data)}件のツイートを読み込みました。") | |
| limit = args.limit if args.limit > 0 else len(tweets_data) | |
| tweets_to_process = tweets_data[:limit] | |
| print(f"最新の{len(tweets_to_process)}件のツイートを処理します。") | |
| url_cache = load_url_cache() | |
| total = len(tweets_to_process) | |
| BATCH_SIZE = 5 | |
| for i in tqdm( | |
| range(0, len(tweets_to_process), BATCH_SIZE), | |
| desc="URL展開処理中", | |
| total=(total + BATCH_SIZE - 1) // BATCH_SIZE, | |
| ): | |
| if i % 1000 == 0: | |
| save_url_cache(url_cache) | |
| batch = tweets_to_process[i : i + BATCH_SIZE] | |
| asyncio.run(process_tweets_batch(batch, url_cache)) | |
| save_url_cache(url_cache) | |
| for i, tweet in enumerate(tqdm(tweets_to_process, desc="Tweet判定処理中")): | |
| content = tweet["tweet"]["content"] | |
| content = sanitize_unicode_text(content) | |
| tweet["tweet"]["content"] = content | |
| hashtags = classify_tweet(content) | |
| for content_hashtag in re.findall(r"#\w+", content): | |
| content = content.replace(content_hashtag, "", 1).strip() | |
| if content_hashtag not in hashtags: | |
| hashtags.append(content_hashtag) | |
| if content.startswith("RT @") and "#RT" not in hashtags: | |
| hashtags.append("#RT") | |
| if content.startswith("@") and "#Reply" not in hashtags: | |
| hashtags.append("#Reply") | |
| content = content.replace("\r", "\\r") | |
| content = content.replace("\n", "\\n") | |
| tweet["tweet"]["hashtags"] = hashtags | |
| tweet["tweet"]["clean_content"] = sanitize_unicode_text(content) | |
| MAX_FILE_SIZE = 10 * 1024 * 1024 | |
| output_data = [] | |
| file_counter = 1 | |
| total_size = 0 | |
| os.makedirs("./output", exist_ok=True) | |
| with tqdm(total=len(tweets_to_process), desc="ファイル出力中") as pbar: | |
| for tweet in tweets_to_process: | |
| timestamp = tweet["tweet"]["created_at"] | |
| content = tweet["tweet"]["clean_content"] | |
| dt = datetime.strptime(timestamp, "%a %b %d %H:%M:%S +0000 %Y") | |
| formatted_timestamp = dt.strftime("%Y-%m-%d %H:%M:%S") | |
| hashtags = tweet["tweet"].get("hashtags", []) | |
| tweet_data = { | |
| "timestamp": formatted_timestamp, | |
| "content": content, | |
| "hashtags": hashtags, | |
| } | |
| output_data.append(tweet_data) | |
| tweet_json = json.dumps(tweet_data, ensure_ascii=False) | |
| tweet_size = len(tweet_json.encode("utf-8")) | |
| total_size += tweet_size | |
| pbar.update(1) | |
| if total_size >= MAX_FILE_SIZE or tweet == tweets_to_process[-1]: | |
| file_name = f"./output/tweets_part_{file_counter}.json" | |
| with open(file_name, "w", encoding="utf-8") as f: | |
| json.dump(output_data, f, ensure_ascii=False, indent=2) | |
| print( | |
| f"ツイートを出力しました: {file_name} " | |
| f"({len(output_data)}件, {total_size / 1024 / 1024:.2f}MB)" | |
| ) | |
| output_data = [] | |
| total_size = 0 | |
| file_counter += 1 | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment