Last active
December 16, 2023 08:42
-
-
Save EAimTY/451f7d48ade325777303b7d062039eb9 to your computer and use it in GitHub Desktop.
ocr_telegram_bot
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#![feature(try_blocks)] | |
use anyhow::Result; | |
use bytes::BufMut; | |
use futures_util::{future::BoxFuture, StreamExt}; | |
use leptess::LepTess; | |
use std::{ | |
collections::HashMap, | |
fmt::{Display, Formatter, Result as FmtResult}, | |
sync::Arc, | |
time::Duration, | |
}; | |
use tgbot::types::Update; | |
use tgbot::{ | |
longpoll::LongPoll, | |
methods::{AnswerCallbackQuery, EditMessageText, GetFile, SendMessage}, | |
types::{ | |
CallbackQuery, Command, File, InlineKeyboardButton, InlineKeyboardButtonKind, | |
InlineKeyboardMarkup, Message, MessageData, PhotoSize, UpdateKind, | |
}, | |
}; | |
use tgbot::{Api, UpdateHandler}; | |
use tokio::{ | |
sync::Mutex, | |
time::{self, Instant}, | |
}; | |
#[tokio::main] | |
async fn main() { | |
let token = "TG_BOT_API_TOKEN"; | |
let api = Api::new(token).unwrap(); | |
LongPoll::new(api.clone(), Handler::new(api)).run().await; | |
} | |
#[derive(Clone)] | |
struct Handler { | |
api: Arc<Api>, | |
pool: Arc<Mutex<SessionPool>>, | |
} | |
impl Handler { | |
fn new(api: Api) -> Self { | |
Self { | |
api: Arc::new(api), | |
pool: SessionPool::new(), | |
} | |
} | |
} | |
impl UpdateHandler for Handler { | |
type Future = BoxFuture<'static, ()>; | |
fn handle(&self, update: Update) -> Self::Future { | |
let cx = self.clone(); | |
Box::pin(async move { | |
let res = match update.kind { | |
UpdateKind::CallbackQuery(cb_query) => { | |
handle_ocr_callback_query(&cx, &cb_query).await | |
} | |
UpdateKind::Message(msg) => try { | |
if !handle_ocr_message(&cx, &msg).await? { | |
if let Ok(cmd) = Command::try_from(msg) { | |
handle_ocr_command(&cx, &cmd).await? | |
} | |
} | |
} | |
_ => Ok(()), | |
}; | |
if let Err(err) = res { | |
eprintln!("{err}"); | |
} | |
}) | |
} | |
} | |
struct Session { | |
user: i64, | |
lang: Option<Language>, | |
relay: Option<[i64; 2]>, | |
c_time: Instant, | |
} | |
impl Session { | |
fn new(user_id: i64) -> Self { | |
Self { | |
user: user_id, | |
lang: None, | |
relay: None, | |
c_time: Instant::now(), | |
} | |
} | |
} | |
struct SessionPool { | |
sessions: HashMap<[i64; 2], Session>, | |
relay: HashMap<[i64; 2], i64>, | |
} | |
impl SessionPool { | |
fn new() -> Arc<Mutex<Self>> { | |
let pool = Arc::new(Mutex::new(Self { | |
sessions: HashMap::new(), | |
relay: HashMap::new(), | |
})); | |
let lifetime = Duration::from_secs(3600); | |
let gc_period = Duration::from_secs(3); | |
let mut interval = time::interval(gc_period); | |
let pool_for_gc = pool.clone(); | |
tokio::spawn(async move { | |
loop { | |
interval.tick().await; | |
let mut pool = pool_for_gc.lock().await; | |
pool.collect_garbage(lifetime); | |
} | |
}); | |
pool | |
} | |
fn collect_garbage(&mut self, lifetime: Duration) { | |
self.sessions.retain(|_, Session { c_time, relay, .. }| { | |
if c_time.elapsed() < lifetime { | |
true | |
} else { | |
relay.map(|relay| self.relay.remove(&relay)); | |
false | |
} | |
}); | |
} | |
} | |
#[derive(Clone, Copy)] | |
enum Language { | |
English, | |
Japanese, | |
SimplifiedChinese, | |
TraditionalChinese, | |
} | |
impl Language { | |
const ENG: &'static str = "eng"; | |
const JPN: &'static str = "jpn"; | |
const CHI_SIM: &'static str = "chi_sim"; | |
const CHI_TRA: &'static str = "chi_tra"; | |
fn from_tesseract_data_str(s: &str) -> Option<Self> { | |
match s { | |
Self::ENG => Some(Self::English), | |
Self::JPN => Some(Self::Japanese), | |
Self::CHI_SIM => Some(Self::SimplifiedChinese), | |
Self::CHI_TRA => Some(Self::TraditionalChinese), | |
_ => None, | |
} | |
} | |
fn as_tesseract_data_str(&self) -> &'static str { | |
match self { | |
Self::English => Self::ENG, | |
Self::Japanese => Self::JPN, | |
Self::SimplifiedChinese => Self::CHI_SIM, | |
Self::TraditionalChinese => Self::CHI_TRA, | |
} | |
} | |
fn iter() -> impl Iterator<Item = Self> { | |
[ | |
Self::English, | |
Self::Japanese, | |
Self::SimplifiedChinese, | |
Self::TraditionalChinese, | |
] | |
.into_iter() | |
} | |
} | |
impl Display for Language { | |
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { | |
let lang_name = match self { | |
Self::English => "English", | |
Self::Japanese => "日本語", | |
Self::SimplifiedChinese => "简体中文", | |
Self::TraditionalChinese => "繁體中文", | |
}; | |
write!(f, "{lang_name}") | |
} | |
} | |
async fn handle_ocr_command(cx: &Handler, cmd: &Command) -> Result<()> { | |
if cmd.get_name() == "/ocr" { | |
let msg = cmd.get_message(); | |
if let Some(user_id) = msg.get_user_id() { | |
let chat_id = msg.get_chat_id(); | |
let msg_id = msg.id; | |
let mut pool = cx.pool.lock().await; | |
let session = Session::new(user_id); | |
pool.sessions.insert([chat_id, msg_id], session); | |
let send_message = SendMessage::new(chat_id, "请选择 OCR 目标语言") | |
.reply_markup(get_lang_select_keyboard()) | |
.reply_to_message_id(msg_id); | |
drop(pool); | |
cx.api.execute(send_message).await?; | |
} | |
} | |
Ok(()) | |
} | |
async fn handle_ocr_callback_query(cx: &Handler, cb_query: &CallbackQuery) -> Result<()> { | |
if let CallbackQuery { | |
id, | |
from: user, | |
message: Some(msg), | |
data: Some(cb_data), | |
.. | |
} = cb_query | |
{ | |
if let (Some(data), Some(cmd_msg)) = (parse_callback_data(cb_data), &msg.reply_to) { | |
let cmd_msg_id = cmd_msg.id; | |
let msg_id = msg.id; | |
let chat_id = msg.get_chat_id(); | |
let user_id = user.id; | |
let mut pool = cx.pool.lock().await; | |
if let Some(session) = pool.sessions.get_mut(&[chat_id, cmd_msg_id]) { | |
if session.user == user_id { | |
let edit_message = if let CallbackData::Select(lang) = data { | |
session.lang = Some(lang); | |
session.relay = Some([chat_id, msg_id]); | |
pool.relay.insert([chat_id, msg_id], cmd_msg_id); | |
EditMessageText::new( | |
chat_id, | |
msg_id, | |
format!("目标语言:{lang},请以需要识别的图片回复此条消息(以图片方式发送)"), | |
) | |
.reply_markup(get_lang_unselect_keyboard()) | |
} else { | |
session.lang = None; | |
EditMessageText::new(chat_id, msg_id, "请选择 OCR 目标语言") | |
.reply_markup(get_lang_select_keyboard()) | |
}; | |
let answer_callback_query = AnswerCallbackQuery::new(id); | |
drop(pool); | |
tokio::try_join!( | |
cx.api.execute(edit_message), | |
cx.api.execute(answer_callback_query) | |
)?; | |
} else { | |
drop(pool); | |
let answer_callback_query = AnswerCallbackQuery::new(id) | |
.text("不是命令触发者") | |
.show_alert(true); | |
cx.api.execute(answer_callback_query).await?; | |
} | |
} else { | |
drop(pool); | |
let answer_callback_query = AnswerCallbackQuery::new(id) | |
.text("找不到会话") | |
.show_alert(true); | |
cx.api.execute(answer_callback_query).await?; | |
} | |
} | |
} | |
Ok(()) | |
} | |
async fn handle_ocr_message(cx: &Handler, msg: &Message) -> Result<bool> { | |
if let (MessageData::Photo { data, .. }, Some(user_id), Some(relay_msg)) = | |
(&msg.data, msg.get_user_id(), msg.reply_to.as_ref()) | |
{ | |
let msg_id = msg.id; | |
let chat_id = msg.get_chat_id(); | |
let relay_msg_id = relay_msg.id; | |
let mut pool = cx.pool.lock().await; | |
if let Some(cmd_msg_id) = pool.relay.get(&[chat_id, relay_msg_id]).copied() { | |
if let Some(Session { | |
user, | |
lang: Some(lang), | |
.. | |
}) = pool.sessions.get(&[chat_id, cmd_msg_id]) | |
{ | |
if user_id == *user { | |
let lang = *lang; | |
pool.sessions.remove(&[chat_id, cmd_msg_id]); | |
pool.relay.remove(&[chat_id, relay_msg_id]); | |
drop(pool); | |
let PhotoSize { file_id, .. } = unsafe { | |
data.iter() | |
.max_by(|a, b| (a.width, a.height).cmp(&(b.width, b.height))) | |
.unwrap_unchecked() | |
}; | |
let get_file = GetFile::new(file_id); | |
if let File { | |
file_path: Some(path), | |
.. | |
} = cx.api.execute(get_file).await? | |
{ | |
let mut stream = cx.api.download_file(path).await?; | |
let mut pic = Vec::new(); | |
while let Some(chunk) = stream.next().await { | |
pic.put_slice(&chunk?); | |
} | |
let mut leptess = LepTess::new(None, lang.as_tesseract_data_str())?; | |
leptess.set_image_from_mem(&pic)?; | |
let res = leptess.get_utf8_text()?; | |
let send_message = | |
SendMessage::new(chat_id, res).reply_to_message_id(msg_id); | |
cx.api.execute(send_message).await?; | |
} else { | |
let send_message = | |
SendMessage::new(chat_id, "图片获取失败").reply_to_message_id(msg_id); | |
cx.api.execute(send_message).await?; | |
} | |
return Ok(true); | |
} | |
} | |
} | |
} | |
Ok(false) | |
} | |
fn get_lang_select_keyboard() -> InlineKeyboardMarkup { | |
let vec = Language::iter() | |
.map(|lang| { | |
vec![InlineKeyboardButton::new( | |
lang.to_string(), | |
InlineKeyboardButtonKind::CallbackData(format!( | |
"ocr-{}", | |
lang.as_tesseract_data_str() | |
)), | |
)] | |
}) | |
.collect(); | |
InlineKeyboardMarkup::from_vec(vec) | |
} | |
fn get_lang_unselect_keyboard() -> InlineKeyboardMarkup { | |
let vec = vec![vec![InlineKeyboardButton::new( | |
"重新选择", | |
InlineKeyboardButtonKind::CallbackData(String::from("ocr-unselect")), | |
)]]; | |
InlineKeyboardMarkup::from_vec(vec) | |
} | |
enum CallbackData { | |
Select(Language), | |
Unselect, | |
} | |
fn parse_callback_data(data: &str) -> Option<CallbackData> { | |
let mut data = data.split('-'); | |
if let (Some("ocr"), Some(target), None) = (data.next(), data.next(), data.next()) { | |
if target == "unselect" { | |
return Some(CallbackData::Unselect); | |
} else if let Some(lang) = Language::from_tesseract_data_str(target) { | |
return Some(CallbackData::Select(lang)); | |
} | |
} | |
None | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment