diff --git a/Cargo.lock b/Cargo.lock index dcfb7ed..66b7df3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -31,6 +31,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "anyhow" +version = "1.0.57" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08f9b8508dccb7687a1d6c4ce66b2b0ecef467c94667de27d8d7fe1f8d2a9cdc" + [[package]] name = "async-trait" version = "0.1.53" @@ -110,6 +116,7 @@ dependencies = [ "eyre", "fs-err", "itertools", + "lazy_static", "log", "minspan", "parse_duration", @@ -121,6 +128,7 @@ dependencies = [ "serde_json", "shellexpand", "sodiumoxide", + "sql-builder", "sqlx", "tokio", "urlencoding", @@ -2095,6 +2103,16 @@ dependencies = [ "lock_api", ] +[[package]] +name = "sql-builder" +version = "3.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1008d95d2ec2d062959352527be30e10fec42a1aa5e5a48d990a5ff0fb9bdc0" +dependencies = [ + "anyhow", + "thiserror", +] + [[package]] name = "sqlformat" version = "0.1.8" diff --git a/atuin-client/Cargo.toml b/atuin-client/Cargo.toml index aba5039..9d76b78 100644 --- a/atuin-client/Cargo.toml +++ b/atuin-client/Cargo.toml @@ -46,6 +46,8 @@ sqlx = { version = "0.5", features = [ minspan = "0.1.1" regex = "1.5.4" fs-err = "2.7" +sql-builder = "3" +lazy_static = "1" # sync urlencoding = { version = "2.1.0", optional = true } diff --git a/atuin-client/src/database.rs b/atuin-client/src/database.rs index 3c3167c..14f01cd 100644 --- a/atuin-client/src/database.rs +++ b/atuin-client/src/database.rs @@ -5,11 +5,11 @@ use std::str::FromStr; use async_trait::async_trait; use chrono::prelude::*; use chrono::Utc; - -use itertools::Itertools; -use regex::Regex; - use fs_err as fs; +use itertools::Itertools; +use lazy_static::lazy_static; +use regex::Regex; +use sql_builder::{esc, quote, SqlBuilder, SqlName}; use sqlx::{ sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow}, Result, Row, @@ -219,50 +219,30 @@ impl Database for Sqlite { ) -> Result> { debug!("listing history"); - // gotta get that query builder in soon cuz I kinda hate this - let query = if unique { - "where timestamp = ( - select max(timestamp) from history - where h.command = history.command - )" - } else { - "" + let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted()); + query.field("*").order_desc("timestamp"); + + match filter { + FilterMode::Global => &mut query, + FilterMode::Host => query.and_where_eq("hostname", quote(&context.hostname)), + FilterMode::Session => query.and_where_eq("session", quote(&context.session)), + FilterMode::Directory => query.and_where_eq("cwd", quote(&context.cwd)), + }; + + if unique { + query.and_where_eq( + "timestamp", + "(select max(timestamp) from history where h.command = history.command)", + ); } - .to_string(); - let mut join = if unique { "and" } else { "where" }.to_string(); + if let Some(max) = max { + query.limit(max); + } - let filter_query = match filter { - FilterMode::Global => { - join = "".to_string(); - "".to_string() - } - FilterMode::Host => format!("hostname = '{}'", context.hostname).to_string(), - FilterMode::Session => format!("session = '{}'", context.session).to_string(), - FilterMode::Directory => format!("cwd = '{}'", context.cwd).to_string(), - }; + let query = query.sql().expect("bug in list query. please report"); - let filter = if filter_query.is_empty() { - "".to_string() - } else { - format!("{} {}", join, filter_query) - }; - - let limit = if let Some(max) = max { - format!("limit {}", max) - } else { - "".to_string() - }; - - let query = format!( - "select * from history h - {} {} - order by timestamp desc - {}", - query, filter, limit, - ); - - let res = sqlx::query(query.as_str()) + let res = sqlx::query(&query) .map(Self::query_history) .fetch_all(&self.pool) .await?; @@ -339,108 +319,78 @@ impl Database for Sqlite { context: &Context, query: &str, ) -> Result> { - let orig_query = query; - let query = query.to_string().replace('*', "%"); // allow wildcard char - let limit = limit.map_or("".to_owned(), |l| format!("limit {}", l)); + let mut sql = SqlBuilder::select_from("history"); - let (query_sql, query_params) = match search_mode { - SearchMode::Prefix => ("command like ?1".to_string(), vec![format!("{}%", query)]), - SearchMode::FullText => ("command like ?1".to_string(), vec![format!("%{}%", query)]), + sql.group_by("command") + .having("max(timestamp)") + .order_desc("timestamp"); + + if let Some(limit) = limit { + sql.limit(limit); + } + + match filter { + FilterMode::Global => &mut sql, + FilterMode::Host => sql.and_where_eq("hostname", quote(&context.hostname)), + FilterMode::Session => sql.and_where_eq("session", quote(&context.session)), + FilterMode::Directory => sql.and_where_eq("cwd", quote(&context.cwd)), + }; + + let orig_query = query; + let query = query.replace('*', "%"); // allow wildcard char + + match search_mode { + SearchMode::Prefix => sql.and_where_like_left("command", query), + SearchMode::FullText => sql.and_where_like_any("command", query), SearchMode::Fuzzy => { - let split_regex = Regex::new(r" +").unwrap(); - let terms: Vec<&str> = split_regex.split(query.as_str()).collect(); - let mut query_sql = std::string::String::new(); - let mut query_params = Vec::with_capacity(terms.len()); - let mut was_or = false; - for (i, query_part) in terms.into_iter().enumerate() { + // don't recompile the regex on successive calls! + lazy_static! { + static ref SPLIT_REGEX: Regex = Regex::new(r" +").unwrap(); + } + + let mut is_or = false; + for query_part in SPLIT_REGEX.split(query.as_str()) { // TODO smart case mode could be made configurable like in fzf - let (operator, glob) = if query_part.contains(char::is_uppercase) { - ("glob", '*') + let (is_glob, glob) = if query_part.contains(char::is_uppercase) { + (true, "*") } else { - ("like", '%') + (false, "%") }; + let (is_inverse, query_part) = match query_part.strip_prefix('!') { Some(stripped) => (true, stripped), None => (false, query_part), }; - match query_part { - "|" => { - if !was_or { - query_sql.push_str(" OR "); - was_or = true; - continue; - } else { - query_params.push(format!("{glob}|{glob}")); - } + + let param = if query_part == "|" { + if !is_or { + is_or = true; + continue; + } else { + format!("{glob}|{glob}") } - exact_prefix if query_part.starts_with('^') => query_params.push(format!( - "{term}{glob}", - term = exact_prefix.strip_prefix('^').unwrap() - )), - exact_suffix if query_part.ends_with('$') => query_params.push(format!( - "{glob}{term}", - term = exact_suffix.strip_suffix('$').unwrap() - )), - exact if query_part.starts_with('\'') => query_params.push(format!( - "{glob}{term}{glob}", - term = exact.strip_prefix('\'').unwrap() - )), - exact if is_inverse => { - query_params.push(format!("{glob}{term}{glob}", term = exact)) - } - _ => { - query_params.push(query_part.split("").join(glob.to_string().as_str())) - } - } - if i > 0 && !was_or { - query_sql.push_str(" AND "); - } - if is_inverse { - query_sql.push_str("NOT "); - } - query_sql - .push_str(format!("command {} ?{}", operator, query_params.len()).as_str()); - was_or = false; + } else if let Some(term) = query_part.strip_prefix('^') { + format!("{term}{glob}") + } else if let Some(term) = query_part.strip_suffix('$') { + format!("{glob}{term}") + } else if let Some(term) = query_part.strip_prefix('\'') { + format!("{glob}{term}{glob}") + } else if is_inverse { + format!("{glob}{term}{glob}", term = query_part) + } else { + query_part.split("").join(glob) + }; + + sql.fuzzy_condition("command", param, is_inverse, is_glob, is_or); + is_or = false; } - (query_sql, query_params) + &mut sql } }; - let filter_base = if query_sql.is_empty() { - "".to_string() - } else { - "and".to_string() - }; + let query = sql.sql().expect("bug in search query. please report"); - let filter_query = match filter { - FilterMode::Global => String::from(""), - FilterMode::Session => format!("session = '{}'", context.session), - FilterMode::Directory => format!("cwd = '{}'", context.cwd), - FilterMode::Host => format!("hostname = '{}'", context.hostname), - }; - - let filter_sql = if filter_query.is_empty() { - "".to_string() - } else { - format!("{} {}", filter_base, filter_query) - }; - - let sql = format!( - "select * from history h - where {} {} - group by command - having max(timestamp) - order by timestamp desc {}", - query_sql.as_str(), - filter_sql.as_str(), - limit.clone() - ); - - let res = query_params - .iter() - .fold(sqlx::query(sql.as_str()), |query, query_param| { - query.bind(query_param) - }) + let res = sqlx::query(&query) .map(Self::query_history) .fetch_all(&self.pool) .await?; @@ -687,3 +637,43 @@ mod test { assert!(duration < Duration::from_secs(15)); } } + +trait SqlBuilderExt { + fn fuzzy_condition( + &mut self, + field: S, + mask: T, + inverse: bool, + glob: bool, + is_or: bool, + ) -> &mut Self; +} + +impl SqlBuilderExt for SqlBuilder { + /// adapted from the sql-builder *like functions + fn fuzzy_condition( + &mut self, + field: S, + mask: T, + inverse: bool, + glob: bool, + is_or: bool, + ) -> &mut Self { + let mut cond = field.to_string(); + if inverse { + cond.push_str(" NOT"); + } + if glob { + cond.push_str(" GLOB '"); + } else { + cond.push_str(" LIKE '"); + } + cond.push_str(&esc(&mask.to_string())); + cond.push('\''); + if is_or { + self.or_where(cond) + } else { + self.and_where(cond) + } + } +}