diff --git a/Cargo.lock b/Cargo.lock index 7ba0403..de22cc3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -118,6 +118,7 @@ dependencies = [ "itertools", "lazy_static", "log", + "memchr", "minspan", "parse_duration", "regex", @@ -1189,9 +1190,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.4.1" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "308cc39be01b73d0d18f82a0e7b2a3df85245f84af96fdddc5d202d27e47b86a" +checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" [[package]] name = "mime" diff --git a/atuin-client/Cargo.toml b/atuin-client/Cargo.toml index 4a0062f..6e52b93 100644 --- a/atuin-client/Cargo.toml +++ b/atuin-client/Cargo.toml @@ -49,6 +49,7 @@ regex = "1.5.4" fs-err = "2.7" sql-builder = "3" lazy_static = "1" +memchr = "2.5" # sync urlencoding = { version = "2.1.0", optional = true } diff --git a/atuin-client/src/database.rs b/atuin-client/src/database.rs index 5f37e8b..7b3ab3b 100644 --- a/atuin-client/src/database.rs +++ b/atuin-client/src/database.rs @@ -41,7 +41,7 @@ pub fn current_context() -> Context { } #[async_trait] -pub trait Database { +pub trait Database: Send + Sync { async fn save(&mut self, h: &History) -> Result<()>; async fn save_bulk(&mut self, h: &[History]) -> Result<()>; diff --git a/atuin-client/src/import/bash.rs b/atuin-client/src/import/bash.rs index 1a17162..10e8de1 100644 --- a/atuin-client/src/import/bash.rs +++ b/atuin-client/src/import/bash.rs @@ -1,134 +1,106 @@ -use std::{ - fs::File, - io::{BufRead, BufReader, Read, Seek}, - path::{Path, PathBuf}, -}; +use std::{fs::File, io::Read, path::PathBuf}; +use async_trait::async_trait; use directories::UserDirs; use eyre::{eyre, Result}; -use super::{count_lines, Importer}; +use super::{get_histpath, unix_byte_lines, Importer, Loader}; use crate::history::History; #[derive(Debug)] -pub struct Bash { - file: BufReader, - strbuf: String, - loc: usize, - counter: i64, +pub struct Bash { + bytes: Vec, } -impl Bash { - fn new(r: R) -> Result { - let mut buf = BufReader::new(r); - let loc = count_lines(&mut buf)?; +fn default_histpath() -> Result { + let user_dirs = UserDirs::new().ok_or_else(|| eyre!("could not find user directories"))?; + let home_dir = user_dirs.home_dir(); - Ok(Self { - file: buf, - strbuf: String::new(), - loc, - counter: 0, - }) - } + Ok(home_dir.join(".bash_history")) } -impl Importer for Bash { +#[async_trait] +impl Importer for Bash { const NAME: &'static str = "bash"; - fn histpath() -> Result { - let user_dirs = UserDirs::new().unwrap(); - let home_dir = user_dirs.home_dir(); - - Ok(home_dir.join(".bash_history")) + async fn new() -> Result { + let mut bytes = Vec::new(); + let path = get_histpath(default_histpath)?; + let mut f = File::open(path)?; + f.read_to_end(&mut bytes)?; + Ok(Self { bytes }) } - fn parse(path: impl AsRef) -> Result { - Self::new(File::open(path)?) + async fn entries(&mut self) -> Result { + Ok(super::count_lines(&self.bytes)) } -} -impl Iterator for Bash { - type Item = Result; + async fn load(self, h: &mut impl Loader) -> Result<()> { + let now = chrono::Utc::now(); + let mut line = String::new(); - fn next(&mut self) -> Option { - self.strbuf.clear(); - match self.file.read_line(&mut self.strbuf) { - Ok(0) => return None, - Ok(_) => (), - Err(e) => return Some(Err(eyre!("failed to read line: {}", e))), // we can skip past things like invalid utf8 - } - - self.loc -= 1; - - while self.strbuf.ends_with("\\\n") { - if self.file.read_line(&mut self.strbuf).is_err() { - // There's a chance that the last line of a command has invalid - // characters, the only safe thing to do is break :/ - // usually just invalid utf8 or smth - // however, we really need to avoid missing history, so it's - // better to have some items that should have been part of - // something else, than to miss things. So break. - break; + for (i, b) in unix_byte_lines(&self.bytes).enumerate() { + let s = match std::str::from_utf8(b) { + Ok(s) => s, + Err(_) => continue, // we can skip past things like invalid utf8 }; - self.loc -= 1; + if let Some(s) = s.strip_suffix('\\') { + line.push_str(s); + line.push_str("\\\n"); + } else { + line.push_str(s); + let command = std::mem::take(&mut line); + + let offset = chrono::Duration::seconds(i as i64); + h.push(History::new( + now - offset, // preserve ordering + command, + String::from("unknown"), + -1, + -1, + None, + None, + )) + .await?; + } } - let time = chrono::Utc::now(); - let offset = chrono::Duration::seconds(self.counter); - let time = time - offset; - - self.counter += 1; - - Some(Ok(History::new( - time, - self.strbuf.trim_end().to_string(), - String::from("unknown"), - -1, - -1, - None, - None, - ))) - } - - fn size_hint(&self) -> (usize, Option) { - (0, Some(self.loc)) + Ok(()) } } #[cfg(test)] mod tests { - use std::io::Cursor; + use itertools::assert_equal; + + use crate::import::{tests::TestLoader, Importer}; use super::Bash; - #[test] - fn test_parse_file() { - let input = r"cargo install atuin + #[tokio::test] + async fn test_parse_file() { + let bytes = r"cargo install atuin cargo install atuin; \ cargo update cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷ -"; +" + .as_bytes() + .to_owned(); - let cursor = Cursor::new(input); - let mut bash = Bash::new(cursor).unwrap(); - assert_eq!(bash.loc, 4); - assert_eq!(bash.size_hint(), (0, Some(4))); + let mut bash = Bash { bytes }; + assert_eq!(bash.entries().await.unwrap(), 4); - assert_eq!( - &bash.next().unwrap().unwrap().command, - "cargo install atuin" - ); - assert_eq!( - &bash.next().unwrap().unwrap().command, - "cargo install atuin; \\\ncargo update" - ); - assert_eq!( - &bash.next().unwrap().unwrap().command, - "cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷" - ); - assert!(bash.next().is_none()); + let mut loader = TestLoader::default(); + bash.load(&mut loader).await.unwrap(); - assert_eq!(bash.size_hint(), (0, Some(0))); + assert_equal( + loader.buf.iter().map(|h| h.command.as_str()), + [ + "cargo install atuin", + "cargo install atuin; \\\ncargo update", + "cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷", + ], + ); } } diff --git a/atuin-client/src/import/fish.rs b/atuin-client/src/import/fish.rs index 7c05d18..af932d7 100644 --- a/atuin-client/src/import/fish.rs +++ b/atuin-client/src/import/fish.rs @@ -1,99 +1,90 @@ // import old shell history! // automatically hoover up all that we can find -use std::{ - fs::File, - io::{self, BufRead, BufReader, Read, Seek}, - path::{Path, PathBuf}, -}; +use std::{fs::File, io::Read, path::PathBuf}; +use async_trait::async_trait; use chrono::{prelude::*, Utc}; use directories::BaseDirs; use eyre::{eyre, Result}; -use super::{count_lines, Importer}; +use super::{get_histpath, unix_byte_lines, Importer, Loader}; use crate::history::History; #[derive(Debug)] -pub struct Fish { - file: BufReader, - strbuf: String, - loc: usize, +pub struct Fish { + bytes: Vec, } -impl Fish { - fn new(r: R) -> Result { - let mut buf = BufReader::new(r); - let loc = count_lines(&mut buf)?; +/// see https://fishshell.com/docs/current/interactive.html#searchable-command-history +fn default_histpath() -> Result { + let base = BaseDirs::new().ok_or_else(|| eyre!("could not determine data directory"))?; + let data = base.data_local_dir(); - Ok(Self { - file: buf, - strbuf: String::new(), - loc, - }) + // fish supports multiple history sessions + // If `fish_history` var is missing, or set to `default`, use `fish` as the session + let session = std::env::var("fish_history").unwrap_or_else(|_| String::from("fish")); + let session = if session == "default" { + String::from("fish") + } else { + session + }; + + let mut histpath = data.join("fish"); + histpath.push(format!("{}_history", session)); + + if histpath.exists() { + Ok(histpath) + } else { + Err(eyre!("Could not find history file. Try setting $HISTFILE")) } } -impl Fish { - fn new_entry(&mut self) -> io::Result { - let inner = self.file.fill_buf()?; - Ok(inner.starts_with(b"- ")) - } -} - -impl Importer for Fish { +#[async_trait] +impl Importer for Fish { const NAME: &'static str = "fish"; - /// see https://fishshell.com/docs/current/interactive.html#searchable-command-history - fn histpath() -> Result { - let base = BaseDirs::new().ok_or_else(|| eyre!("could not determine data directory"))?; - let data = base.data_local_dir(); - - // fish supports multiple history sessions - // If `fish_history` var is missing, or set to `default`, use `fish` as the session - let session = std::env::var("fish_history").unwrap_or_else(|_| String::from("fish")); - let session = if session == "default" { - String::from("fish") - } else { - session - }; - - let mut histpath = data.join("fish"); - histpath.push(format!("{}_history", session)); - - if histpath.exists() { - Ok(histpath) - } else { - Err(eyre!("Could not find history file. Try setting $HISTFILE")) - } + async fn new() -> Result { + let mut bytes = Vec::new(); + let path = get_histpath(default_histpath)?; + let mut f = File::open(path)?; + f.read_to_end(&mut bytes)?; + Ok(Self { bytes }) } - fn parse(path: impl AsRef) -> Result { - Self::new(File::open(path)?) + async fn entries(&mut self) -> Result { + Ok(super::count_lines(&self.bytes)) } -} -impl Iterator for Fish { - type Item = Result; - - fn next(&mut self) -> Option { + async fn load(self, loader: &mut impl Loader) -> Result<()> { + let now = Utc::now(); let mut time: Option> = None; let mut cmd: Option = None; - loop { - self.strbuf.clear(); - match self.file.read_line(&mut self.strbuf) { - // no more content to read - Ok(0) => break, - // bail on IO error - Err(e) => return Some(Err(e.into())), - _ => (), - } + for b in unix_byte_lines(&self.bytes) { + let s = match std::str::from_utf8(b) { + Ok(s) => s, + Err(_) => continue, // we can skip past things like invalid utf8 + }; - // `read_line` adds the line delimeter to the string. No thanks - self.strbuf.pop(); + if let Some(c) = s.strip_prefix("- cmd: ") { + // first, we must deal with the prev cmd + if let Some(cmd) = cmd.take() { + let time = time.unwrap_or(now); + + loader + .push(History::new( + time, + cmd, + "unknown".into(), + -1, + -1, + None, + None, + )) + .await?; + } - if let Some(c) = self.strbuf.strip_prefix("- cmd: ") { // using raw strings to avoid needing escaping. // replaces double backslashes with single backslashes let c = c.replace(r"\\", r"\"); @@ -102,7 +93,7 @@ impl Iterator for Fish { // TODO: any other escape characters? cmd = Some(c); - } else if let Some(t) = self.strbuf.strip_prefix(" when: ") { + } else if let Some(t) = s.strip_prefix(" when: ") { // if t is not an int, just ignore this line if let Ok(t) = t.parse::() { time = Some(Utc.timestamp(t, 0)); @@ -110,47 +101,40 @@ impl Iterator for Fish { } else { // ... ignore paths lines } - - match self.new_entry() { - // next line is a new entry, so let's stop here - // only if we have found a command though - Ok(true) if cmd.is_some() => break, - // bail on IO error - Err(e) => return Some(Err(e.into())), - _ => (), - } } - let cmd = cmd?; - let time = time.unwrap_or_else(Utc::now); + // we might have a trailing cmd + if let Some(cmd) = cmd.take() { + let time = time.unwrap_or(now); - Some(Ok(History::new( - time, - cmd, - "unknown".into(), - -1, - -1, - None, - None, - ))) - } + loader + .push(History::new( + time, + cmd, + "unknown".into(), + -1, + -1, + None, + None, + )) + .await?; + } - fn size_hint(&self) -> (usize, Option) { - // worst case, entry per line - (0, Some(self.loc)) + Ok(()) } } #[cfg(test)] mod test { - use std::io::Cursor; + + use crate::import::{tests::TestLoader, Importer}; use super::Fish; - #[test] - fn parse_complex() { + #[tokio::test] + async fn parse_complex() { // complicated input with varying contents and escaped strings. - let input = r#"- cmd: history --help + let bytes = r#"- cmd: history --help when: 1639162832 - cmd: cat ~/.bash_history when: 1639162851 @@ -181,14 +165,20 @@ ERROR when: 1639163066 paths: - ~/.local/share/fish/fish_history -"#; - let cursor = Cursor::new(input); - let mut fish = Fish::new(cursor).unwrap(); +"# + .as_bytes() + .to_owned(); + + let fish = Fish { bytes }; + + let mut loader = TestLoader::default(); + fish.load(&mut loader).await.unwrap(); + let mut history = loader.buf.into_iter(); // simple wrapper for fish history entry macro_rules! fishtory { ($timestamp:expr, $command:expr) => { - let h = fish.next().expect("missing entry in history").unwrap(); + let h = history.next().expect("missing entry in history"); assert_eq!(h.command.as_str(), $command); assert_eq!(h.timestamp.timestamp(), $timestamp); }; diff --git a/atuin-client/src/import/mod.rs b/atuin-client/src/import/mod.rs index 8d4aa17..07178d1 100644 --- a/atuin-client/src/import/mod.rs +++ b/atuin-client/src/import/mod.rs @@ -1,9 +1,8 @@ -use std::{ - io::{BufRead, BufReader, Read, Seek, SeekFrom}, - path::{Path, PathBuf}, -}; +use std::path::PathBuf; -use eyre::Result; +use async_trait::async_trait; +use eyre::{bail, Result}; +use memchr::Memchr; use crate::history::History; @@ -12,16 +11,88 @@ pub mod fish; pub mod resh; pub mod zsh; -// this could probably be sped up -fn count_lines(buf: &mut BufReader) -> Result { - let lines = buf.lines().count(); - buf.seek(SeekFrom::Start(0))?; - - Ok(lines) -} - -pub trait Importer: IntoIterator> + Sized { +#[async_trait] +pub trait Importer: Sized { const NAME: &'static str; - fn histpath() -> Result; - fn parse(path: impl AsRef) -> Result; + async fn new() -> Result; + async fn entries(&mut self) -> Result; + async fn load(self, loader: &mut impl Loader) -> Result<()>; +} + +#[async_trait] +pub trait Loader: Sync + Send { + async fn push(&mut self, hist: History) -> eyre::Result<()>; +} + +fn unix_byte_lines(input: &[u8]) -> impl Iterator { + UnixByteLines { + iter: memchr::memchr_iter(b'\n', input), + bytes: input, + i: 0, + } +} + +struct UnixByteLines<'a> { + iter: Memchr<'a>, + bytes: &'a [u8], + i: usize, +} + +impl<'a> Iterator for UnixByteLines<'a> { + type Item = &'a [u8]; + + fn next(&mut self) -> Option { + let j = self.iter.next()?; + let out = &self.bytes[self.i..j]; + self.i = j + 1; + Some(out) + } + + fn count(self) -> usize + where + Self: Sized, + { + self.iter.count() + } +} + +fn count_lines(input: &[u8]) -> usize { + unix_byte_lines(input).count() +} + +fn get_histpath(def: D) -> Result +where + D: FnOnce() -> Result, +{ + if let Ok(p) = std::env::var("HISTFILE") { + is_file(PathBuf::from(p)) + } else { + is_file(def()?) + } +} + +fn is_file(p: PathBuf) -> Result { + if p.is_file() { + Ok(p) + } else { + bail!("Could not find history file {:?}. Try setting $HISTFILE", p) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(Default)] + pub struct TestLoader { + pub buf: Vec, + } + + #[async_trait] + impl Loader for TestLoader { + async fn push(&mut self, hist: History) -> Result<()> { + self.buf.push(hist); + Ok(()) + } + } } diff --git a/atuin-client/src/import/resh.rs b/atuin-client/src/import/resh.rs index 3eea84d..75487fe 100644 --- a/atuin-client/src/import/resh.rs +++ b/atuin-client/src/import/resh.rs @@ -1,9 +1,6 @@ -use std::{ - fs::File, - io::{BufRead, BufReader}, - path::{Path, PathBuf}, -}; +use std::{fs::File, io::Read, path::PathBuf}; +use async_trait::async_trait; use chrono::{TimeZone, Utc}; use directories::UserDirs; use eyre::{eyre, Result}; @@ -11,7 +8,7 @@ use serde::Deserialize; use atuin_common::utils::uuid_v4; -use super::{count_lines, Importer}; +use super::{get_histpath, unix_byte_lines, Importer, Loader}; use crate::history::History; #[derive(Deserialize, Debug)] @@ -72,88 +69,72 @@ pub struct ReshEntry { #[derive(Debug)] pub struct Resh { - file: BufReader, - strbuf: String, - loc: usize, + bytes: Vec, } +fn default_histpath() -> Result { + let user_dirs = UserDirs::new().ok_or_else(|| eyre!("could not find user directories"))?; + let home_dir = user_dirs.home_dir(); + + Ok(home_dir.join(".resh_history.json")) +} + +#[async_trait] impl Importer for Resh { const NAME: &'static str = "resh"; - fn histpath() -> Result { - let user_dirs = UserDirs::new().unwrap(); - let home_dir = user_dirs.home_dir(); - - Ok(home_dir.join(".resh_history.json")) + async fn new() -> Result { + let mut bytes = Vec::new(); + let path = get_histpath(default_histpath)?; + let mut f = File::open(path)?; + f.read_to_end(&mut bytes)?; + Ok(Self { bytes }) } - fn parse(path: impl AsRef) -> Result { - let file = File::open(path)?; - let mut buf = BufReader::new(file); - let loc = count_lines(&mut buf)?; - - Ok(Self { - file: buf, - strbuf: String::new(), - loc, - }) + async fn entries(&mut self) -> Result { + Ok(super::count_lines(&self.bytes)) } -} -impl Iterator for Resh { - type Item = Result; + async fn load(self, h: &mut impl Loader) -> Result<()> { + for b in unix_byte_lines(&self.bytes) { + let s = match std::str::from_utf8(b) { + Ok(s) => s, + Err(_) => continue, // we can skip past things like invalid utf8 + }; + let entry = match serde_json::from_str::(s) { + Ok(e) => e, + Err(_) => continue, // skip invalid json :shrug: + }; - fn next(&mut self) -> Option { - self.strbuf.clear(); - match self.file.read_line(&mut self.strbuf) { - Ok(0) => return None, - Ok(_) => (), - Err(e) => return Some(Err(eyre!("failed to read line: {}", e))), // we can skip past things like invalid utf8 + #[allow(clippy::cast_possible_truncation)] + #[allow(clippy::cast_sign_loss)] + let timestamp = { + let secs = entry.realtime_before.floor() as i64; + let nanosecs = (entry.realtime_before.fract() * 1_000_000_000_f64).round() as u32; + Utc.timestamp(secs, nanosecs) + }; + #[allow(clippy::cast_possible_truncation)] + #[allow(clippy::cast_sign_loss)] + let duration = { + let secs = entry.realtime_after.floor() as i64; + let nanosecs = (entry.realtime_after.fract() * 1_000_000_000_f64).round() as u32; + let difference = Utc.timestamp(secs, nanosecs) - timestamp; + difference.num_nanoseconds().unwrap_or(0) + }; + + h.push(History { + id: uuid_v4(), + timestamp, + duration, + exit: entry.exit_code, + command: entry.cmd_line, + cwd: entry.pwd, + session: uuid_v4(), + hostname: entry.host, + }) + .await?; } - // .resh_history.json lies about being a json. It is actually a file containing valid json - // on every line. This means that the last line will throw an error, as it is just an EOF. - // Without the special case here, that will crash the importer. - let entry = match serde_json::from_str::(&self.strbuf) { - Ok(e) => e, - Err(e) if e.is_eof() => return None, - Err(e) => { - return Some(Err(eyre!( - "Invalid entry found in resh_history file: {}", - e - ))) - } - }; - - #[allow(clippy::cast_possible_truncation)] - #[allow(clippy::cast_sign_loss)] - let timestamp = { - let secs = entry.realtime_before.floor() as i64; - let nanosecs = (entry.realtime_before.fract() * 1_000_000_000_f64).round() as u32; - Utc.timestamp(secs, nanosecs) - }; - #[allow(clippy::cast_possible_truncation)] - #[allow(clippy::cast_sign_loss)] - let duration = { - let secs = entry.realtime_after.floor() as i64; - let nanosecs = (entry.realtime_after.fract() * 1_000_000_000_f64).round() as u32; - let difference = Utc.timestamp(secs, nanosecs) - timestamp; - difference.num_nanoseconds().unwrap_or(0) - }; - - Some(Ok(History { - id: uuid_v4(), - timestamp, - duration, - exit: entry.exit_code, - command: entry.cmd_line, - cwd: entry.pwd, - session: uuid_v4(), - hostname: entry.host, - })) - } - - fn size_hint(&self) -> (usize, Option) { - (self.loc, Some(self.loc)) + Ok(()) } } diff --git a/atuin-client/src/import/zsh.rs b/atuin-client/src/import/zsh.rs index 915b311..62e814d 100644 --- a/atuin-client/src/import/zsh.rs +++ b/atuin-client/src/import/zsh.rs @@ -1,138 +1,104 @@ // import old shell history! // automatically hoover up all that we can find -use std::{ - fs::File, - io::{BufRead, BufReader, Read, Seek}, - path::{Path, PathBuf}, -}; +use std::{fs::File, io::Read, path::PathBuf}; +use async_trait::async_trait; use chrono::{prelude::*, Utc}; use directories::UserDirs; use eyre::{eyre, Result}; -use itertools::Itertools; -use super::{count_lines, Importer}; +use super::{get_histpath, unix_byte_lines, Importer, Loader}; use crate::history::History; #[derive(Debug)] -pub struct Zsh { - file: BufReader, - strbuf: String, - loc: usize, - counter: i64, +pub struct Zsh { + bytes: Vec, } -impl Zsh { - fn new(r: R) -> Result { - let mut buf = BufReader::new(r); - let loc = count_lines(&mut buf)?; +fn default_histpath() -> Result { + // oh-my-zsh sets HISTFILE=~/.zhistory + // zsh has no default value for this var, but uses ~/.zhistory. + // we could maybe be smarter about this in the future :) + let user_dirs = UserDirs::new().ok_or_else(|| eyre!("could not find user directories"))?; + let home_dir = user_dirs.home_dir(); - Ok(Self { - file: buf, - strbuf: String::new(), - loc, - counter: 0, - }) - } -} - -impl Importer for Zsh { - const NAME: &'static str = "zsh"; - - fn histpath() -> Result { - // oh-my-zsh sets HISTFILE=~/.zhistory - // zsh has no default value for this var, but uses ~/.zhistory. - // we could maybe be smarter about this in the future :) - let user_dirs = UserDirs::new().unwrap(); - let home_dir = user_dirs.home_dir(); - - let mut candidates = [".zhistory", ".zsh_history"].iter(); - loop { - match candidates.next() { - Some(candidate) => { - let histpath = home_dir.join(candidate); - if histpath.exists() { - break Ok(histpath); - } + let mut candidates = [".zhistory", ".zsh_history"].iter(); + loop { + match candidates.next() { + Some(candidate) => { + let histpath = home_dir.join(candidate); + if histpath.exists() { + break Ok(histpath); } - None => break Err(eyre!("Could not find history file. Try setting $HISTFILE")), } + None => break Err(eyre!("Could not find history file. Try setting $HISTFILE")), } } - - fn parse(path: impl AsRef) -> Result { - Self::new(File::open(path)?) - } } -impl Iterator for Zsh { - type Item = Result; +#[async_trait] +impl Importer for Zsh { + const NAME: &'static str = "bash"; - fn next(&mut self) -> Option { - // ZSH extended history records the timestamp + command duration - // These lines begin with : - // So, if the line begins with :, parse it. Otherwise it's just - // the command - self.strbuf.clear(); - match self.file.read_line(&mut self.strbuf) { - Ok(0) => return None, - Ok(_) => (), - Err(e) => return Some(Err(eyre!("failed to read line: {}", e))), // we can skip past things like invalid utf8 - } + async fn new() -> Result { + let mut bytes = Vec::new(); + let path = get_histpath(default_histpath)?; + let mut f = File::open(path)?; + f.read_to_end(&mut bytes)?; + Ok(Self { bytes }) + } - self.loc -= 1; + async fn entries(&mut self) -> Result { + Ok(super::count_lines(&self.bytes)) + } - while self.strbuf.ends_with("\\\n") { - if self.file.read_line(&mut self.strbuf).is_err() { - // There's a chance that the last line of a command has invalid - // characters, the only safe thing to do is break :/ - // usually just invalid utf8 or smth - // however, we really need to avoid missing history, so it's - // better to have some items that should have been part of - // something else, than to miss things. So break. - break; + async fn load(self, h: &mut impl Loader) -> Result<()> { + let now = chrono::Utc::now(); + let mut line = String::new(); + + let mut counter = 0; + for b in unix_byte_lines(&self.bytes) { + let s = match std::str::from_utf8(b) { + Ok(s) => s, + Err(_) => continue, // we can skip past things like invalid utf8 }; - self.loc -= 1; + if let Some(s) = s.strip_suffix('\\') { + line.push_str(s); + line.push_str("\\\n"); + } else { + line.push_str(s); + let command = std::mem::take(&mut line); + + if let Some(command) = command.strip_prefix(": ") { + counter += 1; + h.push(parse_extended(command, counter)).await?; + } else { + let offset = chrono::Duration::seconds(counter); + counter += 1; + + h.push(History::new( + now - offset, // preserve ordering + command.trim_end().to_string(), + String::from("unknown"), + -1, + -1, + None, + None, + )) + .await?; + } + } } - // We have to handle the case where a line has escaped newlines. - // Keep reading until we have a non-escaped newline - - let extended = self.strbuf.starts_with(':'); - - if extended { - self.counter += 1; - Some(Ok(parse_extended(&self.strbuf, self.counter))) - } else { - let time = chrono::Utc::now(); - let offset = chrono::Duration::seconds(self.counter); - let time = time - offset; - - self.counter += 1; - - Some(Ok(History::new( - time, - self.strbuf.trim_end().to_string(), - String::from("unknown"), - -1, - -1, - None, - None, - ))) - } - } - - fn size_hint(&self) -> (usize, Option) { - (0, Some(self.loc)) + Ok(()) } } fn parse_extended(line: &str, counter: i64) -> History { - let line = line.replacen(": ", "", 2); - let (time, duration) = line.splitn(2, ':').collect_tuple().unwrap(); - let (duration, command) = duration.splitn(2, ';').collect_tuple().unwrap(); + let (time, duration) = line.split_once(':').unwrap(); + let (duration, command) = duration.split_once(';').unwrap(); let time = time .parse::() @@ -158,64 +124,64 @@ fn parse_extended(line: &str, counter: i64) -> History { #[cfg(test)] mod test { - use std::io::Cursor; - use chrono::prelude::*; use chrono::Utc; + use itertools::assert_equal; + + use crate::import::tests::TestLoader; use super::*; #[test] fn test_parse_extended_simple() { - let parsed = parse_extended(": 1613322469:0;cargo install atuin", 0); + let parsed = parse_extended("1613322469:0;cargo install atuin", 0); assert_eq!(parsed.command, "cargo install atuin"); assert_eq!(parsed.duration, 0); assert_eq!(parsed.timestamp, Utc.timestamp(1_613_322_469, 0)); - let parsed = parse_extended(": 1613322469:10;cargo install atuin;cargo update", 0); + let parsed = parse_extended("1613322469:10;cargo install atuin;cargo update", 0); assert_eq!(parsed.command, "cargo install atuin;cargo update"); assert_eq!(parsed.duration, 10_000_000_000); assert_eq!(parsed.timestamp, Utc.timestamp(1_613_322_469, 0)); - let parsed = parse_extended(": 1613322469:10;cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷", 0); + let parsed = parse_extended("1613322469:10;cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷", 0); assert_eq!(parsed.command, "cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷"); assert_eq!(parsed.duration, 10_000_000_000); assert_eq!(parsed.timestamp, Utc.timestamp(1_613_322_469, 0)); - let parsed = parse_extended(": 1613322469:10;cargo install \\n atuin\n", 0); + let parsed = parse_extended("1613322469:10;cargo install \\n atuin\n", 0); assert_eq!(parsed.command, "cargo install \\n atuin"); assert_eq!(parsed.duration, 10_000_000_000); assert_eq!(parsed.timestamp, Utc.timestamp(1_613_322_469, 0)); } - #[test] - fn test_parse_file() { - let input = r": 1613322469:0;cargo install atuin + #[tokio::test] + async fn test_parse_file() { + let bytes = r": 1613322469:0;cargo install atuin : 1613322469:10;cargo install atuin; \ cargo update : 1613322469:10;cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷ -"; +" + .as_bytes() + .to_owned(); - let cursor = Cursor::new(input); - let mut zsh = Zsh::new(cursor).unwrap(); - assert_eq!(zsh.loc, 4); - assert_eq!(zsh.size_hint(), (0, Some(4))); + let mut zsh = Zsh { bytes }; + assert_eq!(zsh.entries().await.unwrap(), 4); - assert_eq!(&zsh.next().unwrap().unwrap().command, "cargo install atuin"); - assert_eq!( - &zsh.next().unwrap().unwrap().command, - "cargo install atuin; \\\ncargo update" + let mut loader = TestLoader::default(); + zsh.load(&mut loader).await.unwrap(); + + assert_equal( + loader.buf.iter().map(|h| h.command.as_str()), + [ + "cargo install atuin", + "cargo install atuin; \\\ncargo update", + "cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷", + ], ); - assert_eq!( - &zsh.next().unwrap().unwrap().command, - "cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷" - ); - assert!(zsh.next().is_none()); - - assert_eq!(zsh.size_hint(), (0, Some(0))); } } diff --git a/src/command/client.rs b/src/command/client.rs index c75872a..b9d43b3 100644 --- a/src/command/client.rs +++ b/src/command/client.rs @@ -58,6 +58,7 @@ pub enum Cmd { } impl Cmd { + #[tokio::main(flavor = "current_thread")] pub async fn run(self) -> Result<()> { pretty_env_logger::init(); diff --git a/src/command/client/history.rs b/src/command/client/history.rs index d001658..805fe4c 100644 --- a/src/command/client/history.rs +++ b/src/command/client/history.rs @@ -128,11 +128,7 @@ pub fn print_cmd_only(w: &mut StdoutLock, h: &[History]) { } impl Cmd { - pub async fn run( - &self, - settings: &Settings, - db: &mut (impl Database + Send + Sync), - ) -> Result<()> { + pub async fn run(&self, settings: &Settings, db: &mut impl Database) -> Result<()> { let context = current_context(); match self { diff --git a/src/command/client/import.rs b/src/command/client/import.rs index c70446d..580e4b0 100644 --- a/src/command/client/import.rs +++ b/src/command/client/import.rs @@ -1,13 +1,14 @@ -use std::{env, path::PathBuf}; +use std::env; +use async_trait::async_trait; use clap::Parser; -use eyre::{eyre, Result}; +use eyre::Result; use indicatif::ProgressBar; use atuin_client::{ database::Database, history::History, - import::{bash::Bash, fish::Fish, resh::Resh, zsh::Zsh, Importer}, + import::{bash::Bash, fish::Fish, resh::Resh, zsh::Zsh, Importer, Loader}, }; #[derive(Parser)] @@ -18,13 +19,10 @@ pub enum Cmd { /// Import history from the zsh history file Zsh, - /// Import history from the bash history file Bash, - /// Import history from the resh history file Resh, - /// Import history from the fish history file Fish, } @@ -32,7 +30,7 @@ pub enum Cmd { const BATCH_SIZE: usize = 100; impl Cmd { - pub async fn run(&self, db: &mut (impl Database + Send + Sync)) -> Result<()> { + pub async fn run(&self, db: &mut DB) -> Result<()> { println!(" Atuin "); println!("======================"); println!(" \u{1f30d} "); @@ -47,124 +45,73 @@ impl Cmd { if shell.ends_with("/zsh") { println!("Detected ZSH"); - import::, _>(db, BATCH_SIZE).await + import::(db).await } else if shell.ends_with("/fish") { println!("Detected Fish"); - import::, _>(db, BATCH_SIZE).await + import::(db).await } else if shell.ends_with("/bash") { println!("Detected Bash"); - import::, _>(db, BATCH_SIZE).await + import::(db).await } else { println!("cannot import {} history", shell); Ok(()) } } - Self::Zsh => import::, _>(db, BATCH_SIZE).await, - Self::Bash => import::, _>(db, BATCH_SIZE).await, - Self::Resh => import::(db, BATCH_SIZE).await, - Self::Fish => import::, _>(db, BATCH_SIZE).await, + Self::Zsh => import::(db).await, + Self::Bash => import::(db).await, + Self::Resh => import::(db).await, + Self::Fish => import::(db).await, } } } -async fn import( - db: &mut DB, - buf_size: usize, -) -> Result<()> -where - I::IntoIter: Send, -{ +pub struct HistoryImporter<'db, DB: Database> { + pb: ProgressBar, + buf: Vec, + db: &'db mut DB, +} + +impl<'db, DB: Database> HistoryImporter<'db, DB> { + fn new(db: &'db mut DB, len: usize) -> Self { + Self { + pb: ProgressBar::new(len as u64), + buf: Vec::with_capacity(BATCH_SIZE), + db, + } + } + + async fn flush(self) -> Result<()> { + if !self.buf.is_empty() { + self.db.save_bulk(&self.buf).await?; + } + self.pb.finish(); + Ok(()) + } +} + +#[async_trait] +impl<'db, DB: Database> Loader for HistoryImporter<'db, DB> { + async fn push(&mut self, hist: History) -> Result<()> { + self.pb.inc(1); + self.buf.push(hist); + if self.buf.len() == self.buf.capacity() { + self.db.save_bulk(&self.buf).await?; + self.buf.clear(); + } + Ok(()) + } +} + +async fn import(db: &mut DB) -> Result<()> { println!("Importing history from {}", I::NAME); - let histpath = get_histpath::()?; - let contents = I::parse(histpath)?; - - let iter = contents.into_iter(); - let progress = if let (_, Some(upper_bound)) = iter.size_hint() { - ProgressBar::new(upper_bound as u64) - } else { - ProgressBar::new_spinner() - }; - - let mut buf = Vec::::with_capacity(buf_size); - let mut iter = progress.wrap_iter(iter); - loop { - // fill until either no more entries - // or until the buffer is full - let done = fill_buf(&mut buf, &mut iter); - - // flush - db.save_bulk(&buf).await?; - - if done { - break; - } - } + let mut importer = I::new().await?; + let len = importer.entries().await.unwrap(); + let mut loader = HistoryImporter::new(db, len); + importer.load(&mut loader).await?; + loader.flush().await?; println!("Import complete!"); - Ok(()) } - -fn get_histpath() -> Result { - if let Ok(p) = env::var("HISTFILE") { - is_file(PathBuf::from(p)) - } else { - is_file(I::histpath()?) - } -} - -fn is_file(p: PathBuf) -> Result { - if p.is_file() { - Ok(p) - } else { - Err(eyre!( - "Could not find history file {:?}. Try setting $HISTFILE", - p - )) - } -} - -fn fill_buf(buf: &mut Vec, iter: &mut impl Iterator>) -> bool { - buf.clear(); - loop { - match iter.next() { - Some(Ok(t)) => buf.push(t), - Some(Err(_)) => (), - None => break true, - } - - if buf.len() == buf.capacity() { - break false; - } - } -} - -#[cfg(test)] -mod tests { - use super::fill_buf; - - #[test] - fn test_fill_buf() { - let mut buf = Vec::with_capacity(4); - let mut iter = vec![ - Ok(1), - Err(2), - Ok(3), - Ok(4), - Err(5), - Ok(6), - Ok(7), - Err(8), - Ok(9), - ] - .into_iter(); - - assert!(!fill_buf(&mut buf, &mut iter)); - assert_eq!(buf, vec![1, 3, 4, 6]); - - assert!(fill_buf(&mut buf, &mut iter)); - assert_eq!(buf, vec![7, 9]); - } -} diff --git a/src/command/client/search.rs b/src/command/client/search.rs index 8c60bd3..de6e796 100644 --- a/src/command/client/search.rs +++ b/src/command/client/search.rs @@ -75,11 +75,7 @@ pub struct Cmd { } impl Cmd { - pub async fn run( - self, - db: &mut (impl Database + Send + Sync), - settings: &Settings, - ) -> Result<()> { + pub async fn run(self, db: &mut impl Database, settings: &Settings) -> Result<()> { if self.interactive { let item = select_history( &self.query, @@ -257,7 +253,7 @@ impl State { async fn query_results( app: &mut State, search_mode: SearchMode, - db: &mut (impl Database + Send + Sync), + db: &mut impl Database, ) -> Result<()> { let results = match app.input.as_str() { "" => { @@ -284,7 +280,7 @@ async fn query_results( async fn key_handler( input: Key, search_mode: SearchMode, - db: &mut (impl Database + Send + Sync), + db: &mut impl Database, app: &mut State, ) -> Option { match input { @@ -537,7 +533,7 @@ async fn select_history( search_mode: SearchMode, filter_mode: FilterMode, style: atuin_client::settings::Style, - db: &mut (impl Database + Send + Sync), + db: &mut impl Database, ) -> Result { let stdout = stdout().into_raw_mode()?; let stdout = MouseTerminal::from(stdout); @@ -596,7 +592,7 @@ async fn run_non_interactive( after: Option, limit: Option, query: &[String], - db: &mut (impl Database + Send + Sync), + db: &mut impl Database, ) -> Result<()> { let dir = if cwd.as_deref() == Some(".") { let current = std::env::current_dir()?; diff --git a/src/command/client/stats.rs b/src/command/client/stats.rs index 8045098..157496d 100644 --- a/src/command/client/stats.rs +++ b/src/command/client/stats.rs @@ -62,11 +62,7 @@ fn compute_stats(history: &[History]) -> Result<()> { } impl Cmd { - pub async fn run( - &self, - db: &mut (impl Database + Send + Sync), - settings: &Settings, - ) -> Result<()> { + pub async fn run(&self, db: &mut impl Database, settings: &Settings) -> Result<()> { let context = current_context(); let words = if self.period.is_empty() { String::from("all") diff --git a/src/command/client/sync.rs b/src/command/client/sync.rs index 6fbf8df..af809f3 100644 --- a/src/command/client/sync.rs +++ b/src/command/client/sync.rs @@ -31,11 +31,7 @@ pub enum Cmd { } impl Cmd { - pub async fn run( - self, - settings: Settings, - db: &mut (impl Database + Send + Sync), - ) -> Result<()> { + pub async fn run(self, settings: Settings, db: &mut impl Database) -> Result<()> { match self { Self::Sync { force } => run(&settings, force, db).await, Self::Login(l) => l.run(&settings).await, @@ -52,11 +48,7 @@ impl Cmd { } } -async fn run( - settings: &Settings, - force: bool, - db: &mut (impl Database + Send + Sync), -) -> Result<()> { +async fn run(settings: &Settings, force: bool, db: &mut impl Database) -> Result<()> { atuin_client::sync::sync(settings, force, db).await?; println!( "Sync complete! {} items in database, force: {}", diff --git a/src/command/mod.rs b/src/command/mod.rs index 953b76b..c86e76f 100644 --- a/src/command/mod.rs +++ b/src/command/mod.rs @@ -19,11 +19,11 @@ pub enum AtuinCmd { } impl AtuinCmd { - pub async fn run(self) -> Result<()> { + pub fn run(self) -> Result<()> { match self { - Self::Client(client) => client.run().await, + Self::Client(client) => client.run(), #[cfg(feature = "server")] - Self::Server(server) => server.run().await, + Self::Server(server) => server.run(), } } } diff --git a/src/command/server.rs b/src/command/server.rs index 1d514bb..495f85d 100644 --- a/src/command/server.rs +++ b/src/command/server.rs @@ -21,6 +21,7 @@ pub enum Cmd { } impl Cmd { + #[tokio::main] pub async fn run(self) -> Result<()> { tracing_subscriber::registry() .with(fmt::layer()) diff --git a/src/main.rs b/src/main.rs index 00028cd..5d43cc7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -25,12 +25,11 @@ struct Atuin { } impl Atuin { - async fn run(self) -> Result<()> { - self.atuin.run().await + fn run(self) -> Result<()> { + self.atuin.run() } } -#[tokio::main] -async fn main() -> Result<()> { - Atuin::parse().run().await +fn main() -> Result<()> { + Atuin::parse().run() }