Skip to content

Instantly share code, notes, and snippets.

@EAimTY
Last active December 16, 2023 08:42
Show Gist options
  • Save EAimTY/451f7d48ade325777303b7d062039eb9 to your computer and use it in GitHub Desktop.
Save EAimTY/451f7d48ade325777303b7d062039eb9 to your computer and use it in GitHub Desktop.
ocr_telegram_bot
#![feature(try_blocks)]
use anyhow::Result;
use bytes::BufMut;
use futures_util::{future::BoxFuture, StreamExt};
use leptess::LepTess;
use std::{
collections::HashMap,
fmt::{Display, Formatter, Result as FmtResult},
sync::Arc,
time::Duration,
};
use tgbot::types::Update;
use tgbot::{
longpoll::LongPoll,
methods::{AnswerCallbackQuery, EditMessageText, GetFile, SendMessage},
types::{
CallbackQuery, Command, File, InlineKeyboardButton, InlineKeyboardButtonKind,
InlineKeyboardMarkup, Message, MessageData, PhotoSize, UpdateKind,
},
};
use tgbot::{Api, UpdateHandler};
use tokio::{
sync::Mutex,
time::{self, Instant},
};
#[tokio::main]
async fn main() {
let token = "TG_BOT_API_TOKEN";
let api = Api::new(token).unwrap();
LongPoll::new(api.clone(), Handler::new(api)).run().await;
}
#[derive(Clone)]
struct Handler {
api: Arc<Api>,
pool: Arc<Mutex<SessionPool>>,
}
impl Handler {
fn new(api: Api) -> Self {
Self {
api: Arc::new(api),
pool: SessionPool::new(),
}
}
}
impl UpdateHandler for Handler {
type Future = BoxFuture<'static, ()>;
fn handle(&self, update: Update) -> Self::Future {
let cx = self.clone();
Box::pin(async move {
let res = match update.kind {
UpdateKind::CallbackQuery(cb_query) => {
handle_ocr_callback_query(&cx, &cb_query).await
}
UpdateKind::Message(msg) => try {
if !handle_ocr_message(&cx, &msg).await? {
if let Ok(cmd) = Command::try_from(msg) {
handle_ocr_command(&cx, &cmd).await?
}
}
}
_ => Ok(()),
};
if let Err(err) = res {
eprintln!("{err}");
}
})
}
}
struct Session {
user: i64,
lang: Option<Language>,
relay: Option<[i64; 2]>,
c_time: Instant,
}
impl Session {
fn new(user_id: i64) -> Self {
Self {
user: user_id,
lang: None,
relay: None,
c_time: Instant::now(),
}
}
}
struct SessionPool {
sessions: HashMap<[i64; 2], Session>,
relay: HashMap<[i64; 2], i64>,
}
impl SessionPool {
fn new() -> Arc<Mutex<Self>> {
let pool = Arc::new(Mutex::new(Self {
sessions: HashMap::new(),
relay: HashMap::new(),
}));
let lifetime = Duration::from_secs(3600);
let gc_period = Duration::from_secs(3);
let mut interval = time::interval(gc_period);
let pool_for_gc = pool.clone();
tokio::spawn(async move {
loop {
interval.tick().await;
let mut pool = pool_for_gc.lock().await;
pool.collect_garbage(lifetime);
}
});
pool
}
fn collect_garbage(&mut self, lifetime: Duration) {
self.sessions.retain(|_, Session { c_time, relay, .. }| {
if c_time.elapsed() < lifetime {
true
} else {
relay.map(|relay| self.relay.remove(&relay));
false
}
});
}
}
#[derive(Clone, Copy)]
enum Language {
English,
Japanese,
SimplifiedChinese,
TraditionalChinese,
}
impl Language {
const ENG: &'static str = "eng";
const JPN: &'static str = "jpn";
const CHI_SIM: &'static str = "chi_sim";
const CHI_TRA: &'static str = "chi_tra";
fn from_tesseract_data_str(s: &str) -> Option<Self> {
match s {
Self::ENG => Some(Self::English),
Self::JPN => Some(Self::Japanese),
Self::CHI_SIM => Some(Self::SimplifiedChinese),
Self::CHI_TRA => Some(Self::TraditionalChinese),
_ => None,
}
}
fn as_tesseract_data_str(&self) -> &'static str {
match self {
Self::English => Self::ENG,
Self::Japanese => Self::JPN,
Self::SimplifiedChinese => Self::CHI_SIM,
Self::TraditionalChinese => Self::CHI_TRA,
}
}
fn iter() -> impl Iterator<Item = Self> {
[
Self::English,
Self::Japanese,
Self::SimplifiedChinese,
Self::TraditionalChinese,
]
.into_iter()
}
}
impl Display for Language {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
let lang_name = match self {
Self::English => "English",
Self::Japanese => "日本語",
Self::SimplifiedChinese => "简体中文",
Self::TraditionalChinese => "繁體中文",
};
write!(f, "{lang_name}")
}
}
async fn handle_ocr_command(cx: &Handler, cmd: &Command) -> Result<()> {
if cmd.get_name() == "/ocr" {
let msg = cmd.get_message();
if let Some(user_id) = msg.get_user_id() {
let chat_id = msg.get_chat_id();
let msg_id = msg.id;
let mut pool = cx.pool.lock().await;
let session = Session::new(user_id);
pool.sessions.insert([chat_id, msg_id], session);
let send_message = SendMessage::new(chat_id, "请选择 OCR 目标语言")
.reply_markup(get_lang_select_keyboard())
.reply_to_message_id(msg_id);
drop(pool);
cx.api.execute(send_message).await?;
}
}
Ok(())
}
async fn handle_ocr_callback_query(cx: &Handler, cb_query: &CallbackQuery) -> Result<()> {
if let CallbackQuery {
id,
from: user,
message: Some(msg),
data: Some(cb_data),
..
} = cb_query
{
if let (Some(data), Some(cmd_msg)) = (parse_callback_data(cb_data), &msg.reply_to) {
let cmd_msg_id = cmd_msg.id;
let msg_id = msg.id;
let chat_id = msg.get_chat_id();
let user_id = user.id;
let mut pool = cx.pool.lock().await;
if let Some(session) = pool.sessions.get_mut(&[chat_id, cmd_msg_id]) {
if session.user == user_id {
let edit_message = if let CallbackData::Select(lang) = data {
session.lang = Some(lang);
session.relay = Some([chat_id, msg_id]);
pool.relay.insert([chat_id, msg_id], cmd_msg_id);
EditMessageText::new(
chat_id,
msg_id,
format!("目标语言:{lang},请以需要识别的图片回复此条消息(以图片方式发送)"),
)
.reply_markup(get_lang_unselect_keyboard())
} else {
session.lang = None;
EditMessageText::new(chat_id, msg_id, "请选择 OCR 目标语言")
.reply_markup(get_lang_select_keyboard())
};
let answer_callback_query = AnswerCallbackQuery::new(id);
drop(pool);
tokio::try_join!(
cx.api.execute(edit_message),
cx.api.execute(answer_callback_query)
)?;
} else {
drop(pool);
let answer_callback_query = AnswerCallbackQuery::new(id)
.text("不是命令触发者")
.show_alert(true);
cx.api.execute(answer_callback_query).await?;
}
} else {
drop(pool);
let answer_callback_query = AnswerCallbackQuery::new(id)
.text("找不到会话")
.show_alert(true);
cx.api.execute(answer_callback_query).await?;
}
}
}
Ok(())
}
async fn handle_ocr_message(cx: &Handler, msg: &Message) -> Result<bool> {
if let (MessageData::Photo { data, .. }, Some(user_id), Some(relay_msg)) =
(&msg.data, msg.get_user_id(), msg.reply_to.as_ref())
{
let msg_id = msg.id;
let chat_id = msg.get_chat_id();
let relay_msg_id = relay_msg.id;
let mut pool = cx.pool.lock().await;
if let Some(cmd_msg_id) = pool.relay.get(&[chat_id, relay_msg_id]).copied() {
if let Some(Session {
user,
lang: Some(lang),
..
}) = pool.sessions.get(&[chat_id, cmd_msg_id])
{
if user_id == *user {
let lang = *lang;
pool.sessions.remove(&[chat_id, cmd_msg_id]);
pool.relay.remove(&[chat_id, relay_msg_id]);
drop(pool);
let PhotoSize { file_id, .. } = unsafe {
data.iter()
.max_by(|a, b| (a.width, a.height).cmp(&(b.width, b.height)))
.unwrap_unchecked()
};
let get_file = GetFile::new(file_id);
if let File {
file_path: Some(path),
..
} = cx.api.execute(get_file).await?
{
let mut stream = cx.api.download_file(path).await?;
let mut pic = Vec::new();
while let Some(chunk) = stream.next().await {
pic.put_slice(&chunk?);
}
let mut leptess = LepTess::new(None, lang.as_tesseract_data_str())?;
leptess.set_image_from_mem(&pic)?;
let res = leptess.get_utf8_text()?;
let send_message =
SendMessage::new(chat_id, res).reply_to_message_id(msg_id);
cx.api.execute(send_message).await?;
} else {
let send_message =
SendMessage::new(chat_id, "图片获取失败").reply_to_message_id(msg_id);
cx.api.execute(send_message).await?;
}
return Ok(true);
}
}
}
}
Ok(false)
}
fn get_lang_select_keyboard() -> InlineKeyboardMarkup {
let vec = Language::iter()
.map(|lang| {
vec![InlineKeyboardButton::new(
lang.to_string(),
InlineKeyboardButtonKind::CallbackData(format!(
"ocr-{}",
lang.as_tesseract_data_str()
)),
)]
})
.collect();
InlineKeyboardMarkup::from_vec(vec)
}
fn get_lang_unselect_keyboard() -> InlineKeyboardMarkup {
let vec = vec![vec![InlineKeyboardButton::new(
"重新选择",
InlineKeyboardButtonKind::CallbackData(String::from("ocr-unselect")),
)]];
InlineKeyboardMarkup::from_vec(vec)
}
enum CallbackData {
Select(Language),
Unselect,
}
fn parse_callback_data(data: &str) -> Option<CallbackData> {
let mut data = data.split('-');
if let (Some("ocr"), Some(target), None) = (data.next(), data.next(), data.next()) {
if target == "unselect" {
return Some(CallbackData::Unselect);
} else if let Some(lang) = Language::from_tesseract_data_str(target) {
return Some(CallbackData::Select(lang));
}
}
None
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment