Skip to content

Instantly share code, notes, and snippets.

@kalloc
Created December 31, 2022 09:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kalloc/ae9a334346a2439b8c99143d392a8897 to your computer and use it in GitHub Desktop.
Save kalloc/ae9a334346a2439b8c99143d392a8897 to your computer and use it in GitHub Desktop.
from telethon import TelegramClient, events, sync
import numpy as np
import asyncio
import random
from telethon.tl.functions.account import UpdateProfileRequest
from telethon.tl.functions.messages import SendReactionRequest
from telethon.tl.types import IpPort, ReactionEmoji
import torch
import os
import sys
import logging
import time
from transformers import BertTokenizer, BertForSequenceClassification
def load_toxic_model():
model_name = 'Skoltech/russian-sensitive-topics'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name)
return tokenizer, model
def load_sensistive_model():
model_name = 'apanc/russian-sensitive-topics'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name);
return tokenizer, model
def load_inappropriate_model():
model_name = 'apanc/russian-inappropriate-messages'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name);
return tokenizer, model
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
level=logging.INFO)
logger = logging.getLogger(__name__)
def get_env(name, message):
if name in os.environ:
return os.environ[name]
return input(message)
sens_tokenizer, sens_model = load_sensistive_model()
inapp_tokenizer, inapp_model = load_inappropriate_model()
toxic_tokenizer, toxic_model = load_toxic_model()
import json
with open("id2topic.json") as f:
target_vaiables_id2topic_dict = json.load(f)
def adjust_multilabel(y, is_pred = False):
y_adjusted = []
for y_c in y:
index = str(int(np.argmax(y_c)))
if index == '0':
continue
y_adjusted.append(target_vaiables_id2topic_dict[index])
return y_adjusted
# These example values won't work. You must get your own api_id and
# api_hash from https://my.telegram.org, under API Development.
api_id = REDACTED
api_hash = REDACTED
# Telethon client
client = TelegramClient('wuzabot', api_id, api_hash, device_model="Linux")
client.start()
me = client.get_me()
def predict_toxic(text):
batch = toxic_tokenizer.encode(text[:511], return_tensors='pt')
output = toxic_model(batch)
y_pred = np.argmax(output.logits.detach().numpy(), axis=1)
return bool(y_pred[0])
def predict_inapp(text):
tokenized = inapp_tokenizer.batch_encode_plus(
[text[:511]],
max_length = 512,
truncation=True,
return_token_type_ids=False
)
tokens_ids, mask = torch.tensor(tokenized['input_ids']), torch.tensor(tokenized['attention_mask'])
model_output = inapp_model(tokens_ids, mask)
return bool(torch.argmax(model_output['logits'], dim = 1)[0])
def predict_sens_topics(text):
tokenized = sens_tokenizer.batch_encode_plus(
[text[:511]],
max_length = 512, truncation=True, return_token_type_ids=False)
tokens_ids, mask = torch.tensor(tokenized['input_ids']),torch.tensor(tokenized['attention_mask'])
with torch.no_grad():
model_output = sens_model(tokens_ids, mask)
preds = adjust_multilabel(model_output['logits'], is_pred = True)
return preds
async def process_message(message):
if not message.text:
return
is_toxic = predict_toxic(message.text)
labels = predict_sens_topics(message.text)
is_inapp = predict_inapp(message.text)
if message.sender.id == 87677941 and (is_toxic or is_inapp):
print("Ivan", "shit him", message.id, message.text)
reaction = is_toxic and ReactionEmoji("💩") or ReactionEmoji("🤮")
await client(SendReactionRequest(
peer=message.peer_id,
msg_id=message.id,
reaction=[ReactionEmoji("💩")],
))
await asyncio.sleep(random.randrange(1,4))
if is_toxic:
labels.append('toxic')
if is_inapp:
labels.append('inappropriate')
print("Message",
message.sender.username,
message.sender.id,
message.text,
labels)
# else:
# print("Non-toxic message",
# message.sender.username,
# message.sender.id,
# message.text,
# output.logits.detach().numpy()
# )
open("/tmp/shit_last", "w").write(str(message.id))
try:
offset_id = int(open("/tmp/shit_last", "r").read())
except:
offset_id = None
async def main():
if offset_id:
async for message in client.iter_messages(-1001085244538, reverse=True, offset_id=offset_id):
await process_message(message=message)
@client.on(events.MessageDeleted(chats=[-1001085244538]))
async def on_delete(event):
print(event.to_json())
@client.on(events.NewMessage(chats=[-1001085244538]))
async def handler(event):
await process_message(message=event.message)
client.loop.run_until_complete(main())
# client.run_until_disconnected()
client.disconnect()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment