Merge pull request #6 from house-of-vanity/sql

Simplify SQL command. Add limit.
This commit is contained in:
House of Vanity
2021-01-05 03:46:03 +03:00
committed by GitHub
3 changed files with 135 additions and 184 deletions

View File

@ -1,6 +1,6 @@
use crate::db;
use crate::errors::Error;
use crate::errors::Error::SQLInvalidCommand;
use crate::errors::Error::{SQLInvalidCommand, SQLITE3Error};
use async_trait::async_trait;
use html_escape::encode_text;
use markov::Chain;
@ -39,18 +39,18 @@ pub struct Sql {
#[async_trait]
pub trait Execute {
async fn run(&self, api: Api, message: Message) -> Result<(), Error>;
async fn run(&self, api: &Api, message: &Message) -> Result<(), Error>;
async fn run_mystem(
&self,
api: Api,
message: Message,
api: &Api,
message: &Message,
mystem: &mut MyStem,
) -> Result<(), Error>;
}
#[async_trait]
impl Execute for Sql {
async fn run(&self, api: Api, message: Message) -> Result<(), Error> {
async fn run(&self, api: &Api, message: &Message) -> Result<(), Error> {
let mut sql = self.data.to_uppercase();
let is_head = if sql.starts_with('-') {
sql = sql.replacen("-", "", 1);
@ -59,173 +59,111 @@ impl Execute for Sql {
true
};
let dialect = GenericDialect {};
let ast: Result<Vec<Statement>, Error> = match Parser::parse_sql(&dialect, &sql) {
Ok(ast) => Ok(ast),
let ast: Vec<Statement> = match Parser::parse_sql(&dialect, &sql) {
Ok(ast) => ast,
Err(_) => {
warn!("Invalid SQL - {}", sql);
Err(SQLInvalidCommand)
}
};
let ast = match ast {
Err(_) => {
let _ = api
.send(
message
.text_reply(format!("❌ Invalid SQL. Syntax error ❌"))
.parse_mode(ParseMode::Html),
)
.await;
return Err(SQLInvalidCommand);
}
Ok(ast) => ast,
};
let msg: Result<String, Error> = match ast.len() {
match ast.len() {
l if l > 1 => {
//Max 1 request per message allowed only.
Err(Error::SQLBannedCommand)
return Err(Error::SQLBannedCommand(
"🚫 One statement per message allowed 🚫".into(),
))
}
_ => match ast[0] {
sqlparser::ast::Statement::Query { .. } => {
let conn = db::open()?;
let x = match conn.prepare_cached(&sql) {
Ok(mut stmt) => {
let query = match stmt.query(rusqlite::NO_PARAMS) {
Err(_) => Err(SQLInvalidCommand),
Ok(mut rows) => {
let mut res: Vec<Vec<String>> = match rows.column_names() {
Some(n) => vec![n
.into_iter()
.map(|s| {
let t = String::from(s);
if t.len() > 10 {
"EMSGSIZE".to_string()
} else {
t
}
})
.collect()],
None => return Err(SQLInvalidCommand),
};
let index_count = match rows.column_count() {
Some(c) => c,
None => return Err(SQLInvalidCommand),
};
while let Some(row) = rows.next().unwrap() {
let mut tmp: Vec<String> = Vec::new();
for i in 0..index_count {
match row.get(i).unwrap_or(None) {
Some(rusqlite::types::Value::Text(t)) => {
tmp.push(t)
}
Some(rusqlite::types::Value::Integer(t)) => {
tmp.push(t.to_string())
}
Some(rusqlite::types::Value::Blob(_)) => {
tmp.push("Binary".to_string())
}
Some(rusqlite::types::Value::Real(t)) => {
tmp.push(t.to_string())
}
Some(rusqlite::types::Value::Null) => {
tmp.push("Null".to_string())
}
None => tmp.push("Null".to_string()),
};
}
res.push(tmp);
}
// add Header
let mut msg = if is_head {
let mut x = String::from("<b>");
for head in res[0].iter() {
x = format!("{} {}", x, head);
}
format!("{}{}", x, "</b>\n")
} else {
String::new()
};
// remove header
res.remove(0);
msg = format!("{}{}", msg, "<pre>");
for line in res.iter() {
for field in line.iter() {
msg = format!("{}{}", msg, format!("{} ", field));
}
msg = format!("{}{}", msg, "\n");
}
msg = format!("{}{}", msg, "</pre>");
msg = if msg.len() > 4096 {
"🚫 Result is too big. Use LIMIT 🚫".into()
} else {
msg
};
Ok(msg)
}
};
query
}
Err(e) => Err(Error::SQLITE3Error(e)),
};
x
}
_ => {
warn!("SELECT requests allowed only.");
Err(Error::SQLBannedCommand)
}
},
};
match msg {
Ok(msg) => {
match api
.send(message.text_reply(msg).parse_mode(ParseMode::Html))
.await
{
Ok(_) => debug!("/sql command sent to {}", message.chat.id()),
Err(_) => warn!("/sql command sent failed to {}", message.chat.id()),
}
}
Err(e) => match e {
Error::SQLITE3Error(e) => {
let _ = api
.send(
message
.text_reply(format!("❌ An error occurred {}", e))
.parse_mode(ParseMode::Html),
)
.await;
}
Error::SQLBannedCommand => {
let _ = api
.send(
message
.text_reply(format!("🚫 SELECT requests allowed only 🚫"))
.parse_mode(ParseMode::Html),
)
.await;
}
Error::SQLInvalidCommand => {
let _ = api
.send(
message
.text_reply(format!("🚫 Invalid SQL. Check DB scheme. 🚫"))
.parse_mode(ParseMode::Html),
)
.await;
}
_ => {}
},
_ => (),
}
match ast[0] {
sqlparser::ast::Statement::Query { .. } => {}
_ => {
return Err(Error::SQLBannedCommand(
"🚫 SELECT requests allowed only 🚫".into(),
))
}
}
let conn = db::open()?;
let mut stmt = conn.prepare_cached(&sql)?;
let mut rows = match stmt.query(rusqlite::NO_PARAMS) {
Err(e) => return Err(SQLITE3Error(e)),
Ok(mut rows) => rows,
};
let mut res: Vec<Vec<String>> = match rows.column_names() {
Some(n) => vec![n
.into_iter()
.map(|s| {
let t = String::from(s);
if t.len() > 10 {
"EMSGSIZE".to_string()
} else {
t
}
})
.collect()],
None => return Err(SQLInvalidCommand),
};
let index_count = match rows.column_count() {
Some(c) => c,
None => return Err(SQLInvalidCommand),
};
while let Some(row) = rows.next().unwrap() {
let mut tmp: Vec<String> = Vec::new();
for i in 0..index_count {
match row.get(i).unwrap_or(None) {
Some(rusqlite::types::Value::Text(t)) => tmp.push(t),
Some(rusqlite::types::Value::Integer(t)) => tmp.push(t.to_string()),
Some(rusqlite::types::Value::Blob(_)) => tmp.push("Binary".to_string()),
Some(rusqlite::types::Value::Real(t)) => tmp.push(t.to_string()),
Some(rusqlite::types::Value::Null) => tmp.push("Null".to_string()),
None => tmp.push("Null".to_string()),
};
}
res.push(tmp);
}
if res.len() > 100 {
return Err(Error::SQLResultTooLong(
"SQL result too long. Lines limit is 100. Use LIMIT".to_string(),
));
}
// add Header
let mut msg = if is_head {
let mut x = String::from("<b>");
for head in res[0].iter() {
x = format!("{} {}", x, head);
}
format!("{}{}", x, "</b>\n")
} else {
String::new()
};
// remove header
res.remove(0);
msg = format!("{}{}", msg, "<pre>");
for line in res.iter() {
for field in line.iter() {
msg = format!("{}{}", msg, format!("{} ", field));
}
msg = format!("{}{}", msg, "\n");
}
msg = format!("{}{}", msg, "</pre>");
msg = if msg.len() > 4096 {
"🚫 Result is too big. Use LIMIT 🚫".into()
} else {
msg
};
Ok(())
}
#[allow(unused_variables)]
async fn run_mystem(
&self,
api: Api,
message: Message,
api: &Api,
message: &Message,
mystem: &mut MyStem,
) -> Result<(), Error> {
unimplemented!()
@ -234,7 +172,7 @@ impl Execute for Sql {
#[async_trait]
impl Execute for Here {
async fn run(&self, api: Api, message: Message) -> Result<(), Error> {
async fn run(&self, api: &Api, message: &Message) -> Result<(), Error> {
let members: Vec<telegram_bot::User> = db::get_members(message.chat.id()).unwrap();
for u in &members {
debug!("Found user {:?} in chat {}", u, message.chat.id());
@ -265,8 +203,8 @@ impl Execute for Here {
#[allow(unused_variables)]
async fn run_mystem(
&self,
api: Api,
message: Message,
api: &Api,
message: &Message,
mystem: &mut MyStem,
) -> Result<(), Error> {
unimplemented!()
@ -275,7 +213,7 @@ impl Execute for Here {
#[async_trait]
impl Execute for Top {
async fn run(&self, api: Api, message: Message) -> Result<(), Error> {
async fn run(&self, api: &Api, message: &Message) -> Result<(), Error> {
let top = db::get_top(&message).await?;
let mut msg = "<b>Your top using words:</b>\n<pre>".to_string();
let mut counter = 1;
@ -300,8 +238,8 @@ impl Execute for Top {
#[allow(unused_variables)]
async fn run_mystem(
&self,
api: Api,
message: Message,
api: &Api,
message: &Message,
mystem: &mut MyStem,
) -> Result<(), Error> {
unimplemented!()
@ -310,7 +248,7 @@ impl Execute for Top {
#[async_trait]
impl Execute for MarkovAll {
async fn run(&self, api: Api, message: Message) -> Result<(), Error> {
async fn run(&self, api: &Api, message: &Message) -> Result<(), Error> {
let messages = db::get_messages_random_all().await?;
let mut chain = Chain::new();
chain.feed(messages);
@ -334,8 +272,8 @@ impl Execute for MarkovAll {
#[allow(unused_variables)]
async fn run_mystem(
&self,
api: Api,
message: Message,
api: &Api,
message: &Message,
mystem: &mut MyStem,
) -> Result<(), Error> {
unimplemented!()
@ -344,7 +282,7 @@ impl Execute for MarkovAll {
#[async_trait]
impl Execute for Markov {
async fn run(&self, api: Api, message: Message) -> Result<(), Error> {
async fn run(&self, api: &Api, message: &Message) -> Result<(), Error> {
let messages = db::get_messages_random_group(&message).await?;
let mut chain = Chain::new();
chain.feed(messages);
@ -368,8 +306,8 @@ impl Execute for Markov {
#[allow(unused_variables)]
async fn run_mystem(
&self,
api: Api,
message: Message,
api: &Api,
message: &Message,
mystem: &mut MyStem,
) -> Result<(), Error> {
unimplemented!()
@ -379,15 +317,15 @@ impl Execute for Markov {
#[async_trait]
impl Execute for Omedeto {
#[allow(unused_variables)]
async fn run(&self, api: Api, message: Message) -> Result<(), Error> {
async fn run(&self, api: &Api, message: &Message) -> Result<(), Error> {
unimplemented!()
}
#[warn(unused_must_use)]
async fn run_mystem(
&self,
api: Api,
message: Message,
api: &Api,
message: &Message,
mystem: &mut MyStem,
) -> Result<(), Error> {
let all_msg = db::get_messages_user_all(&message).await?;

View File

@ -20,9 +20,11 @@ pub enum Error {
JsonParseError(serde_error),
PopenError(popen_error),
MystemError(mystem_error),
SQLBannedCommand,
SQLBannedCommand(String),
SQLInvalidCommand,
SQLResultTooLong(String),
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "An error occurred.")

View File

@ -31,49 +31,60 @@ pub async fn handler(
Here {
data: "".to_string(),
}
.run(api, message)
.run(&api, &message)
.await?
}
s if s.to_string().starts_with("/sql") => {
s if s.to_string().starts_with("/sql") => match {
Sql {
data: s.replace("/sql ", ""),
}
.run(api, message)
.await?
}
.run(&api, &message)
.await
} {
Ok(_) => debug!("/sql command sent to {}", message.chat.id()),
Err(e) => {
api.send(
message
.text_reply(format!("Error: {:#?}", e))
.parse_mode(ParseMode::Html),
)
.await?;
()
}
},
"/top" => {
Top {
data: "".to_string(),
}
.run(api, message)
.run(&api, &message)
.await?
}
"/stat" => {
Top {
data: "".to_string(),
}
.run(api, message)
.run(&api, &message)
.await?
}
"/markov_all" => {
MarkovAll {
data: "".to_string(),
}
.run(api, message)
.run(&api, &message)
.await?
}
"/markov" => {
Markov {
data: "".to_string(),
}
.run(api, message)
.run(&api, &message)
.await?
}
"/omedeto" => {
Omedeto {
data: "".to_string(),
}
.run_mystem(api, message, mystem)
.run_mystem(&api, &message, mystem)
.await?
}
_ => (),