atuin/atuin-client/src/database.rs

680 lines
20 KiB
Rust
Raw Normal View History

use std::env;
use std::path::Path;
use std::str::FromStr;
use async_trait::async_trait;
2021-04-25 14:27:51 -06:00
use chrono::prelude::*;
use chrono::Utc;
use fs_err as fs;
2021-06-01 01:38:19 -06:00
use itertools::Itertools;
use lazy_static::lazy_static;
use regex::Regex;
use sql_builder::{esc, quote, SqlBuilder, SqlName};
use sqlx::{
sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow},
Result, Row,
2021-04-25 14:27:51 -06:00
};
2021-02-14 10:18:02 -07:00
use super::history::History;
use super::ordering;
use super::settings::{FilterMode, SearchMode};
pub struct Context {
session: String,
cwd: String,
hostname: String,
}
pub fn current_context() -> Context {
let session =
env::var("ATUIN_SESSION").expect("failed to find ATUIN_SESSION - check your shell setup");
let hostname = format!("{}:{}", whoami::hostname(), whoami::username());
let cwd = match env::current_dir() {
Ok(dir) => dir.display().to_string(),
Err(_) => String::from(""),
};
Context {
session,
hostname,
cwd,
}
}
#[async_trait]
pub trait Database {
async fn save(&mut self, h: &History) -> Result<()>;
async fn save_bulk(&mut self, h: &[History]) -> Result<()>;
2021-03-19 18:50:31 -06:00
async fn load(&self, id: &str) -> Result<History>;
async fn list(
&self,
filter: FilterMode,
context: &Context,
max: Option<usize>,
unique: bool,
) -> Result<Vec<History>>;
async fn range(
&self,
from: chrono::DateTime<Utc>,
to: chrono::DateTime<Utc>,
) -> Result<Vec<History>>;
2021-03-19 18:50:31 -06:00
async fn update(&self, h: &History) -> Result<()>;
async fn history_count(&self) -> Result<i64>;
2021-03-19 18:50:31 -06:00
async fn first(&self) -> Result<History>;
async fn last(&self) -> Result<History>;
async fn before(&self, timestamp: chrono::DateTime<Utc>, count: i64) -> Result<Vec<History>>;
async fn search(
&self,
limit: Option<i64>,
search_mode: SearchMode,
filter: FilterMode,
context: &Context,
query: &str,
) -> Result<Vec<History>>;
async fn query_history(&self, query: &str) -> Result<Vec<History>>;
}
// Intended for use on a developer machine and not a sync server.
// TODO: implement IntoIterator
2021-02-14 08:15:26 -07:00
pub struct Sqlite {
pool: SqlitePool,
}
2021-02-14 08:15:26 -07:00
impl Sqlite {
pub async fn new(path: impl AsRef<Path>) -> Result<Self> {
let path = path.as_ref();
debug!("opening sqlite database at {:?}", path);
let create = !path.exists();
if create {
if let Some(dir) = path.parent() {
2022-04-13 11:08:49 -06:00
fs::create_dir_all(dir)?;
}
}
let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())?
.journal_mode(SqliteJournalMode::Wal)
.create_if_missing(true);
let pool = SqlitePoolOptions::new().connect_with(opts).await?;
Self::setup_db(&pool).await?;
Ok(Self { pool })
}
async fn setup_db(pool: &SqlitePool) -> Result<()> {
debug!("running sqlite database setup");
sqlx::migrate!("./migrations").run(pool).await?;
Ok(())
}
2021-02-14 10:18:02 -07:00
async fn save_raw(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, h: &History) -> Result<()> {
sqlx::query(
"insert or ignore into history(id, timestamp, duration, exit, command, cwd, session, hostname)
values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
)
.bind(h.id.as_str())
2021-04-25 14:27:51 -06:00
.bind(h.timestamp.timestamp_nanos())
.bind(h.duration)
.bind(h.exit)
.bind(h.command.as_str())
.bind(h.cwd.as_str())
.bind(h.session.as_str())
.bind(h.hostname.as_str())
.execute(tx)
.await?;
2021-02-14 10:18:02 -07:00
Ok(())
}
2021-04-25 14:27:51 -06:00
fn query_history(row: SqliteRow) -> History {
History {
id: row.get("id"),
timestamp: Utc.timestamp_nanos(row.get("timestamp")),
duration: row.get("duration"),
exit: row.get("exit"),
command: row.get("command"),
cwd: row.get("cwd"),
session: row.get("session"),
hostname: row.get("hostname"),
}
}
}
#[async_trait]
2021-02-14 08:15:26 -07:00
impl Database for Sqlite {
async fn save(&mut self, h: &History) -> Result<()> {
debug!("saving history to sqlite");
let mut tx = self.pool.begin().await?;
Self::save_raw(&mut tx, h).await?;
tx.commit().await?;
2021-02-14 10:18:02 -07:00
Ok(())
2021-02-13 10:02:52 -07:00
}
async fn save_bulk(&mut self, h: &[History]) -> Result<()> {
2021-02-13 12:37:00 -07:00
debug!("saving history to sqlite");
let mut tx = self.pool.begin().await?;
2021-02-13 12:37:00 -07:00
for i in h {
Self::save_raw(&mut tx, i).await?
2021-02-13 12:37:00 -07:00
}
tx.commit().await?;
2021-02-13 12:37:00 -07:00
Ok(())
}
async fn load(&self, id: &str) -> Result<History> {
debug!("loading history item {}", id);
2021-02-13 10:02:52 -07:00
2021-04-25 14:27:51 -06:00
let res = sqlx::query("select * from history where id = ?1")
.bind(id)
2021-04-25 14:27:51 -06:00
.map(Self::query_history)
.fetch_one(&self.pool)
.await?;
2021-02-13 10:02:52 -07:00
Ok(res)
2021-02-13 10:02:52 -07:00
}
async fn update(&self, h: &History) -> Result<()> {
2021-02-13 10:02:52 -07:00
debug!("updating sqlite history");
sqlx::query(
2021-02-13 10:02:52 -07:00
"update history
2021-02-13 13:21:49 -07:00
set timestamp = ?2, duration = ?3, exit = ?4, command = ?5, cwd = ?6, session = ?7, hostname = ?8
2021-02-13 10:02:52 -07:00
where id = ?1",
)
.bind(h.id.as_str())
2021-04-25 14:27:51 -06:00
.bind(h.timestamp.timestamp_nanos())
.bind(h.duration)
.bind(h.exit)
.bind(h.command.as_str())
.bind(h.cwd.as_str())
.bind(h.session.as_str())
.bind(h.hostname.as_str())
.execute(&self.pool)
.await?;
Ok(())
}
// make a unique list, that only shows the *newest* version of things
async fn list(
&self,
filter: FilterMode,
context: &Context,
max: Option<usize>,
unique: bool,
) -> Result<Vec<History>> {
debug!("listing history");
let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted());
query.field("*").order_desc("timestamp");
match filter {
FilterMode::Global => &mut query,
FilterMode::Host => query.and_where_eq("hostname", quote(&context.hostname)),
FilterMode::Session => query.and_where_eq("session", quote(&context.session)),
FilterMode::Directory => query.and_where_eq("cwd", quote(&context.cwd)),
};
if unique {
query.and_where_eq(
"timestamp",
"(select max(timestamp) from history where h.command = history.command)",
);
}
if let Some(max) = max {
query.limit(max);
}
let query = query.sql().expect("bug in list query. please report");
2021-02-13 10:02:52 -07:00
let res = sqlx::query(&query)
2021-04-25 14:27:51 -06:00
.map(Self::query_history)
.fetch_all(&self.pool)
.await?;
Ok(res)
}
async fn range(
&self,
from: chrono::DateTime<Utc>,
to: chrono::DateTime<Utc>,
) -> Result<Vec<History>> {
debug!("listing history from {:?} to {:?}", from, to);
2021-04-25 14:27:51 -06:00
let res = sqlx::query(
"select * from history where timestamp >= ?1 and timestamp <= ?2 order by timestamp asc",
)
2022-04-26 11:03:13 -06:00
.bind(from.timestamp_nanos())
.bind(to.timestamp_nanos())
2021-04-25 14:27:51 -06:00
.map(Self::query_history)
.fetch_all(&self.pool)
.await?;
Ok(res)
}
async fn first(&self) -> Result<History> {
2021-04-25 14:27:51 -06:00
let res =
sqlx::query("select * from history where duration >= 0 order by timestamp asc limit 1")
.map(Self::query_history)
.fetch_one(&self.pool)
.await?;
Ok(res)
}
async fn last(&self) -> Result<History> {
2021-04-25 14:27:51 -06:00
let res = sqlx::query(
"select * from history where duration >= 0 order by timestamp desc limit 1",
)
2021-04-25 14:27:51 -06:00
.map(Self::query_history)
.fetch_one(&self.pool)
.await?;
Ok(res)
}
async fn before(&self, timestamp: chrono::DateTime<Utc>, count: i64) -> Result<Vec<History>> {
2021-04-25 14:27:51 -06:00
let res = sqlx::query(
"select * from history where timestamp < ?1 order by timestamp desc limit ?2",
)
2021-04-25 14:27:51 -06:00
.bind(timestamp.timestamp_nanos())
.bind(count)
2021-04-25 14:27:51 -06:00
.map(Self::query_history)
.fetch_all(&self.pool)
.await?;
Ok(res)
}
async fn history_count(&self) -> Result<i64> {
let res: (i64,) = sqlx::query_as("select count(1) from history")
.fetch_one(&self.pool)
.await?;
Ok(res.0)
}
2021-03-19 18:50:31 -06:00
async fn search(
&self,
limit: Option<i64>,
search_mode: SearchMode,
filter: FilterMode,
context: &Context,
query: &str,
) -> Result<Vec<History>> {
let mut sql = SqlBuilder::select_from("history");
sql.group_by("command")
.having("max(timestamp)")
.order_desc("timestamp");
if let Some(limit) = limit {
sql.limit(limit);
}
match filter {
FilterMode::Global => &mut sql,
FilterMode::Host => sql.and_where_eq("hostname", quote(&context.hostname)),
FilterMode::Session => sql.and_where_eq("session", quote(&context.session)),
FilterMode::Directory => sql.and_where_eq("cwd", quote(&context.cwd)),
};
let orig_query = query;
let query = query.replace('*', "%"); // allow wildcard char
match search_mode {
SearchMode::Prefix => sql.and_where_like_left("command", query),
SearchMode::FullText => sql.and_where_like_any("command", query),
SearchMode::Fuzzy => {
// don't recompile the regex on successive calls!
lazy_static! {
static ref SPLIT_REGEX: Regex = Regex::new(r" +").unwrap();
}
let mut is_or = false;
for query_part in SPLIT_REGEX.split(query.as_str()) {
// TODO smart case mode could be made configurable like in fzf
let (is_glob, glob) = if query_part.contains(char::is_uppercase) {
(true, "*")
} else {
(false, "%")
};
let (is_inverse, query_part) = match query_part.strip_prefix('!') {
Some(stripped) => (true, stripped),
None => (false, query_part),
};
let param = if query_part == "|" {
if !is_or {
is_or = true;
continue;
} else {
format!("{glob}|{glob}")
}
} else if let Some(term) = query_part.strip_prefix('^') {
format!("{term}{glob}")
} else if let Some(term) = query_part.strip_suffix('$') {
format!("{glob}{term}")
} else if let Some(term) = query_part.strip_prefix('\'') {
format!("{glob}{term}{glob}")
} else if is_inverse {
format!("{glob}{term}{glob}", term = query_part)
} else {
query_part.split("").join(glob)
};
sql.fuzzy_condition("command", param, is_inverse, is_glob, is_or);
is_or = false;
}
&mut sql
}
};
let query = sql.sql().expect("bug in search query. please report");
2022-04-22 15:15:50 -06:00
let res = sqlx::query(&query)
.map(Self::query_history)
.fetch_all(&self.pool)
.await?;
2021-03-19 18:50:31 -06:00
Ok(ordering::reorder_fuzzy(search_mode, orig_query, res))
2021-03-19 18:50:31 -06:00
}
async fn query_history(&self, query: &str) -> Result<Vec<History>> {
2021-04-25 14:27:51 -06:00
let res = sqlx::query(query)
.map(Self::query_history)
.fetch_all(&self.pool)
.await?;
2021-02-14 10:18:02 -07:00
Ok(res)
}
2021-02-14 10:18:02 -07:00
}
2021-06-01 01:38:19 -06:00
#[cfg(test)]
mod test {
use super::*;
use std::time::{Duration, Instant};
2021-06-01 01:38:19 -06:00
async fn assert_search_eq<'a>(
db: &impl Database,
mode: SearchMode,
filter_mode: FilterMode,
query: &str,
expected: usize,
) -> Result<Vec<History>> {
let context = Context {
hostname: "test:host".to_string(),
session: "beepboopiamasession".to_string(),
cwd: "/home/ellie".to_string(),
};
let results = db.search(None, mode, filter_mode, &context, query).await?;
assert_eq!(
results.len(),
expected,
"query \"{}\", commands: {:?}",
query,
results.iter().map(|a| &a.command).collect::<Vec<&String>>()
);
Ok(results)
}
async fn assert_search_commands(
db: &impl Database,
mode: SearchMode,
filter_mode: FilterMode,
query: &str,
expected_commands: Vec<&str>,
) {
let results = assert_search_eq(db, mode, filter_mode, query, expected_commands.len())
.await
.unwrap();
let commands: Vec<&str> = results.iter().map(|a| a.command.as_str()).collect();
assert_eq!(commands, expected_commands);
}
2021-06-01 01:38:19 -06:00
async fn new_history_item(db: &mut impl Database, cmd: &str) -> Result<()> {
let history = History::new(
chrono::Utc::now(),
cmd.to_string(),
"/home/ellie".to_string(),
0,
1,
Some("beep boop".to_string()),
Some("booop".to_string()),
);
return db.save(&history).await;
}
#[tokio::test(flavor = "multi_thread")]
async fn test_search_prefix() {
let mut db = Sqlite::new("sqlite::memory:").await.unwrap();
new_history_item(&mut db, "ls /home/ellie").await.unwrap();
assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "ls", 1)
.await
.unwrap();
assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "/home", 0)
.await
.unwrap();
assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "ls ", 0)
.await
.unwrap();
2021-06-01 01:38:19 -06:00
}
#[tokio::test(flavor = "multi_thread")]
async fn test_search_fulltext() {
let mut db = Sqlite::new("sqlite::memory:").await.unwrap();
new_history_item(&mut db, "ls /home/ellie").await.unwrap();
assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "ls", 1)
.await
.unwrap();
assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "/home", 1)
.await
.unwrap();
assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "ls ", 0)
2021-06-01 01:38:19 -06:00
.await
.unwrap();
}
#[tokio::test(flavor = "multi_thread")]
async fn test_search_fuzzy() {
let mut db = Sqlite::new("sqlite::memory:").await.unwrap();
new_history_item(&mut db, "ls /home/ellie").await.unwrap();
new_history_item(&mut db, "ls /home/frank").await.unwrap();
new_history_item(&mut db, "cd /home/Ellie").await.unwrap();
2021-06-01 01:38:19 -06:00
new_history_item(&mut db, "/home/ellie/.bin/rustup")
.await
.unwrap();
assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls /", 3)
.await
.unwrap();
assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls/", 2)
.await
.unwrap();
assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "l/h/", 2)
.await
.unwrap();
assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/h/e", 3)
.await
.unwrap();
assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/hmoe/", 0)
.await
.unwrap();
assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ellie/home", 0)
.await
.unwrap();
assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "lsellie", 1)
.await
.unwrap();
assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, " ", 4)
.await
.unwrap();
2021-06-01 01:38:19 -06:00
// single term operators
assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "^ls", 2)
.await
.unwrap();
assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "'ls", 2)
.await
.unwrap();
assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ellie$", 2)
.await
.unwrap();
assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!^ls", 2)
.await
.unwrap();
assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!ellie", 1)
.await
.unwrap();
assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!ellie$", 2)
2021-06-01 01:38:19 -06:00
.await
.unwrap();
// multiple terms
assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls !ellie", 1)
.await
.unwrap();
assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "^ls !e$", 1)
.await
.unwrap();
assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "home !^ls", 2)
.await
.unwrap();
assert_search_eq(
&db,
SearchMode::Fuzzy,
FilterMode::Global,
"'frank | 'rustup",
2,
)
.await
.unwrap();
assert_search_eq(
&db,
SearchMode::Fuzzy,
FilterMode::Global,
"'frank | 'rustup 'ls",
1,
)
.await
.unwrap();
2021-06-01 01:38:19 -06:00
// case matching
assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "Ellie", 1)
.await
.unwrap();
2021-06-01 01:38:19 -06:00
}
#[tokio::test(flavor = "multi_thread")]
async fn test_search_reordered_fuzzy() {
let mut db = Sqlite::new("sqlite::memory:").await.unwrap();
// test ordering of results: we should choose the first, even though it happened longer ago.
new_history_item(&mut db, "curl").await.unwrap();
new_history_item(&mut db, "corburl").await.unwrap();
// if fuzzy reordering is on, it should come back in a more sensible order
assert_search_commands(
&db,
SearchMode::Fuzzy,
FilterMode::Global,
"curl",
vec!["curl", "corburl"],
)
.await;
assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "xxxx", 0)
.await
.unwrap();
assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "", 2)
.await
.unwrap();
}
#[tokio::test(flavor = "multi_thread")]
async fn test_search_bench_dupes() {
let context = Context {
hostname: "test:host".to_string(),
session: "beepboopiamasession".to_string(),
cwd: "/home/ellie".to_string(),
};
let mut db = Sqlite::new("sqlite::memory:").await.unwrap();
for _i in 1..10000 {
new_history_item(&mut db, "i am a duplicated command")
.await
.unwrap();
}
let start = Instant::now();
let _results = db
.search(None, SearchMode::Fuzzy, FilterMode::Global, &context, "")
.await
.unwrap();
let duration = start.elapsed();
assert!(duration < Duration::from_secs(15));
}
2021-06-01 01:38:19 -06:00
}
trait SqlBuilderExt {
fn fuzzy_condition<S: ToString, T: ToString>(
&mut self,
field: S,
mask: T,
inverse: bool,
glob: bool,
is_or: bool,
) -> &mut Self;
}
impl SqlBuilderExt for SqlBuilder {
/// adapted from the sql-builder *like functions
fn fuzzy_condition<S: ToString, T: ToString>(
&mut self,
field: S,
mask: T,
inverse: bool,
glob: bool,
is_or: bool,
) -> &mut Self {
let mut cond = field.to_string();
if inverse {
cond.push_str(" NOT");
}
if glob {
cond.push_str(" GLOB '");
} else {
cond.push_str(" LIKE '");
}
cond.push_str(&esc(&mask.to_string()));
cond.push('\'');
if is_or {
self.or_where(cond)
} else {
self.and_where(cond)
}
}
}