Add /sql command.

This commit is contained in:
AB
2021-01-03 22:37:37 +03:00
parent 3fd5b124f3
commit 3f00505659
6 changed files with 215 additions and 70 deletions

View File

@@ -34,4 +34,5 @@ serde_json = "1.0"
markov = "1.1.0" markov = "1.1.0"
rand = "0.7.3" rand = "0.7.3"
mystem = "0.2.1" mystem = "0.2.1"
async-trait = "0.1.42" async-trait = "0.1.42"
sqlparser = "0.7.0"

View File

@@ -1,16 +1,22 @@
use crate::db; use crate::db;
use crate::errors::Error; use crate::errors::Error;
use crate::errors::Error::SQLInvalidCommand;
use async_trait::async_trait; use async_trait::async_trait;
use html_escape::encode_text; use html_escape::encode_text;
use markov::Chain; use markov::Chain;
use mystem::Case::Nominative; use mystem::Case::Nominative;
use mystem::Gender::Feminine; use mystem::Gender::Feminine;
use mystem::Tense::{Inpresent, Past};
use mystem::Person::First;
use mystem::MyStem; use mystem::MyStem;
use mystem::Person::First;
use mystem::Tense::{Inpresent, Past};
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
use rand::Rng; use rand::Rng;
use regex::Regex; use regex::Regex;
use rusqlite::types::FromSql;
use rusqlite::{CachedStatement, Rows, ToSql};
use sqlparser::ast::Statement;
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::Parser;
use telegram_bot::prelude::*; use telegram_bot::prelude::*;
use telegram_bot::{Api, Message, ParseMode}; use telegram_bot::{Api, Message, ParseMode};
@@ -29,14 +35,9 @@ pub struct Markov {
pub struct Omedeto { pub struct Omedeto {
pub data: String, pub data: String,
} }
pub struct Sql {
// pub enum Command { pub data: String,
// Here(Here), }
// Top { data: String },
// MarkovAll { data: String },
// Markov { data: String },
// Omedeto { data: String },
// }
#[async_trait] #[async_trait]
pub trait Execute { pub trait Execute {
@@ -49,6 +50,189 @@ pub trait Execute {
) -> Result<(), Error>; ) -> Result<(), Error>;
} }
#[async_trait]
impl Execute for Sql {
async fn run(&self, api: Api, message: Message) -> Result<(), Error> {
let mut sql = self.data.to_uppercase();
let is_head = if sql.starts_with('-') {
sql = sql.replacen("-", "", 1);
false
} else {
true
};
let dialect = GenericDialect {};
let ast: Result<Vec<Statement>, Error> = match Parser::parse_sql(&dialect, &sql) {
Ok(ast) => Ok(ast),
Err(_) => {
warn!("Invalid SQL - {}", sql);
Err(SQLInvalidCommand)
}
};
let ast = match ast {
Err(e) => {
let _ = api
.send(
message
.text_reply(format!("❌ Invalid SQL. Syntax error ❌"))
.parse_mode(ParseMode::Html),
)
.await;
return Err(SQLInvalidCommand);
}
Ok(ast) => ast,
};
let msg: Result<String, Error> = match ast.len() {
l if l > 1 => {
//Max 1 request per message allowed only.
Err(Error::SQLBannedCommand)
}
_ => match ast[0] {
sqlparser::ast::Statement::Query { .. } => {
let conn = db::open()?;
let mut x = match conn.prepare_cached(&sql) {
Ok(mut stmt) => {
let query = match stmt.query(rusqlite::NO_PARAMS) {
Err(e) => Err(SQLInvalidCommand),
Ok(mut rows) => {
let mut res: Vec<Vec<String>> = match rows.column_names() {
Some(n) => vec![n
.into_iter()
.map(|s| {
let t = String::from(s);
if t.len() > 10 {
"EMSGSIZE".to_string()
} else {
t
}
})
.collect()],
None => return Err(SQLInvalidCommand),
};
let index_count = match rows.column_count() {
Some(c) => c,
None => return Err(SQLInvalidCommand),
};
while let Some(row) = rows.next().unwrap() {
let mut tmp: Vec<String> = Vec::new();
for i in 0..index_count {
match row.get(i).unwrap_or(None) {
Some(rusqlite::types::Value::Text(t)) => {
tmp.push(t)
}
Some(rusqlite::types::Value::Integer(t)) => {
tmp.push(t.to_string())
}
Some(rusqlite::types::Value::Blob(t)) => {
tmp.push("Binary".to_string())
}
Some(rusqlite::types::Value::Real(t)) => {
tmp.push(t.to_string())
}
Some(rusqlite::types::Value::Null) => {
tmp.push("Null".to_string())
}
None => tmp.push("Null".to_string()),
};
}
res.push(tmp);
}
// add Header
let mut msg = if is_head {
let mut x = String::from("<b>");
for head in res[0].iter() {
x = format!("{} {}", x, head);
}
format!("{}{}", x, "</b>\n")
} else {
String::new()
};
// remove header
res.remove(0);
msg = format!("{}{}", msg, "<pre>");
for line in res.iter() {
for field in line.iter() {
msg = format!("{}{}", msg, format!("{} ", field));
}
msg = format!("{}{}", msg, "\n");
}
msg = format!("{}{}", msg, "</pre>");
msg = if msg.len() > 4096 {
"🚫 Result is too big. Use LIMIT 🚫".into()
} else {
msg
};
Ok(msg)
}
};
query
}
Err(e) => Err(Error::SQLITE3Error(e)),
};
x
}
_ => {
warn!("SELECT requests allowed only.");
Err(Error::SQLBannedCommand)
}
},
};
match msg {
Ok(msg) => {
match api
.send(message.text_reply(msg).parse_mode(ParseMode::Html))
.await
{
Ok(_) => debug!("/sql command sent to {}", message.chat.id()),
Err(_) => warn!("/sql command sent failed to {}", message.chat.id()),
}
}
Err(e) => match e {
Error::SQLITE3Error(e) => {
let _ = api
.send(
message
.text_reply(format!("❌ An error ocurred {}", e))
.parse_mode(ParseMode::Html),
)
.await;
}
Error::SQLBannedCommand => {
let _ = api
.send(
message
.text_reply(format!("🚫 SELECT requests allowed only 🚫"))
.parse_mode(ParseMode::Html),
)
.await;
}
Error::SQLInvalidCommand => {
let _ = api
.send(
message
.text_reply(format!("🚫 Invalid SQL. Check DB scheme. 🚫"))
.parse_mode(ParseMode::Html),
)
.await;
}
_ => {}
},
}
Ok(())
}
async fn run_mystem(
&self,
api: Api,
message: Message,
mystem: &mut MyStem,
) -> Result<(), Error> {
unimplemented!()
}
}
#[async_trait] #[async_trait]
impl Execute for Here { impl Execute for Here {
async fn run(&self, api: Api, message: Message) -> Result<(), Error> { async fn run(&self, api: Api, message: Message) -> Result<(), Error> {
@@ -76,8 +260,6 @@ impl Execute for Here {
Ok(_) => debug!("/here command sent to {}", message.chat.id()), Ok(_) => debug!("/here command sent to {}", message.chat.id()),
Err(_) => warn!("/here command sent failed to {}", message.chat.id()), Err(_) => warn!("/here command sent failed to {}", message.chat.id()),
} }
//api.send(message.chat.text("Text to message chat")).await?;
//api.send(message.from.text("Private text")).await?;
Ok(()) Ok(())
} }
@@ -113,8 +295,6 @@ impl Execute for Top {
Ok(_) => debug!("/top command sent to {}", message.chat.id()), Ok(_) => debug!("/top command sent to {}", message.chat.id()),
Err(_) => warn!("/top command sent failed to {}", message.chat.id()), Err(_) => warn!("/top command sent failed to {}", message.chat.id()),
} }
//api.send(message.chat.text("Text to message chat")).await?;
//api.send(message.from.text("Private text")).await?;
Ok(()) Ok(())
} }
@@ -387,7 +567,6 @@ impl Execute for Omedeto {
verbs_i.pop().unwrap_or(placeholders.choose(&mut rand::thread_rng()).unwrap().to_string()), verbs_i.pop().unwrap_or(placeholders.choose(&mut rand::thread_rng()).unwrap().to_string()),
verbs_i.pop().unwrap_or(placeholders.choose(&mut rand::thread_rng()).unwrap().to_string()), verbs_i.pop().unwrap_or(placeholders.choose(&mut rand::thread_rng()).unwrap().to_string()),
); );
//debug!("{:?}", result);
match api match api
.send( .send(
message message

View File

@@ -81,7 +81,7 @@ pub(crate) fn get_conf(id: telegram_bot::ChatId) -> Result<Conf, errors::Error>
} }
} }
/* #[allow(dead_code)]
pub(crate) fn get_confs() -> Result<Vec<Conf>> { pub(crate) fn get_confs() -> Result<Vec<Conf>> {
let conn = open()?; let conn = open()?;
let mut stmt = conn.prepare("SELECT id, title, date FROM conf")?; let mut stmt = conn.prepare("SELECT id, title, date FROM conf")?;
@@ -100,7 +100,7 @@ pub(crate) fn get_confs() -> Result<Vec<Conf>> {
Ok(confs) Ok(confs)
} }
*/
pub(crate) async fn get_messages_random_all() -> Result<Vec<String>, Error> { pub(crate) async fn get_messages_random_all() -> Result<Vec<String>, Error> {
let conn = open()?; let conn = open()?;
let mut stmt = conn.prepare_cached("SELECT text FROM messages ORDER BY RANDOM() LIMIT 50")?; let mut stmt = conn.prepare_cached("SELECT text FROM messages ORDER BY RANDOM() LIMIT 50")?;
@@ -345,7 +345,6 @@ pub(crate) async fn get_file(file_id: String) -> Result<i64, errors::Error> {
Ok(id) => Ok(id), Ok(id) => Ok(id),
Err(_) => Err(errors::Error::FileNotFound), Err(_) => Err(errors::Error::FileNotFound),
}; };
file_rowid file_rowid
} }

View File

@@ -20,6 +20,8 @@ pub enum Error {
JsonParseError(serde_error), JsonParseError(serde_error),
PopenError(popen_error), PopenError(popen_error),
MystemError(mystem_error), MystemError(mystem_error),
SQLBannedCommand,
SQLInvalidCommand,
} }
impl fmt::Display for Error { impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {

View File

@@ -1,57 +1,11 @@
//use crate::commands::Command; //use crate::commands::Command;
use crate::commands::{Execute, Here, Markov, MarkovAll, Omedeto, Top}; use crate::commands::{Execute, Here, Markov, MarkovAll, Omedeto, Sql, Top};
use crate::db; use crate::db;
use crate::errors; use crate::errors;
use crate::utils; use crate::utils;
use mystem::MyStem; use mystem::MyStem;
use telegram_bot::*; use telegram_bot::*;
// struct Command {
// command: Commands,
// explicit: bool,
// rest: String,
// }
// async fn detector(msg: String, me: &User) -> Result<Command, ()> {
// let cleaned_message = msg.replace(&format!("@{}", me.clone().username.unwrap()), "");
// match cleaned_message.as_str() {
// "/here" => Ok(Command::Here {
// data: "".to_string(),
// }),
// s if s.contains("/here") => Ok(Command::Here {
// data: s.to_string(),
// }),
// "/top" => Ok(Command::Top {
// data: "".to_string(),
// }),
// "/stat" => Ok(Command::Top {
// data: "".to_string(),
// }),
// s if s.contains(|z| z == "/top" || z == "/stat") => Ok(Command::Top {
// data: s.to_string(),
// }),
// "/markov_all" => Ok(Command::MarkovAll {
// data: "".to_string(),
// }),
// s if s.contains("/markov_all") => Ok(Command::MarkovAll {
// data: s.to_string(),
// }),
// "/markov" => Ok(Command::Markov {
// data: "".to_string(),
// }),
// s if s.contains("/markov") => Ok(Command::Markov {
// data: s.to_string(),
// }),
// "/omedeto" => Ok(Command::Omedeto {
// data: "".to_string(),
// }),
// s if s.contains("/Omedeto") => Ok(Command::Omedeto {
// data: s.to_string(),
// }),
// _ => Err(()),
// }
// }
pub async fn handler( pub async fn handler(
api: Api, api: Api,
message: Message, message: Message,
@@ -71,9 +25,7 @@ pub async fn handler(
data data
); );
db::add_sentence(&message, mystem).await?; db::add_sentence(&message, mystem).await?;
let cleaned_message = data let cleaned_message = data.replace(&format!("@{}", me.clone().username.unwrap()), "");
.replace(&format!("@{}", me.clone().username.unwrap()), "");
debug!("Cleaned - {}", cleaned_message);
match cleaned_message.as_str() { match cleaned_message.as_str() {
s if s.contains("/here") => { s if s.contains("/here") => {
Here { Here {
@@ -82,6 +34,13 @@ pub async fn handler(
.run(api, message) .run(api, message)
.await? .await?
} }
s if s.to_string().starts_with("/sql") => {
Sql {
data: s.replace("/sql ", ""),
}
.run(api, message)
.await?
}
"/top" => { "/top" => {
Top { Top {
data: "".to_string(), data: "".to_string(),

View File

@@ -49,7 +49,12 @@ async fn main() -> Result<(), errors::Error> {
if let UpdateKind::Message(message) = update.kind { if let UpdateKind::Message(message) = update.kind {
db::add_conf(message.clone()).await?; db::add_conf(message.clone()).await?;
db::add_user(message.clone()).await?; db::add_user(message.clone()).await?;
handlers::handler(api.clone(), message, token.clone(), &mut mystem, me.clone()).await?; match handlers::handler(api.clone(), message, token.clone(), &mut mystem, me.clone())
.await
{
Ok(_) => {}
Err(e) => warn!("An error occurred handling command. {:?}", e),
}
} }
} }
Ok(()) Ok(())