diff --git a/Cargo.lock b/Cargo.lock index c3604b6..a8ee317 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -177,6 +177,7 @@ dependencies = [ "chrono", "rand", "serde", + "typed-builder", "uuid", ] @@ -2528,6 +2529,17 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642" +[[package]] +name = "typed-builder" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64cba322cb9b7bc6ca048de49e83918223f35e7a86311267013afff257004870" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.99", +] + [[package]] name = "typenum" version = "1.15.0" diff --git a/Cargo.toml b/Cargo.toml index b485676..bde7ed6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,6 +37,7 @@ serde_json = "1.0.86" tokio = { version = "1", features = ["full"] } uuid = { version = "1.3", features = ["v4"] } whoami = "1.1.2" +typed-builder = "0.14.0" [workspace.dependencies.reqwest] version = "0.11" diff --git a/atuin-client/Cargo.toml b/atuin-client/Cargo.toml index 42e3cf6..7b85bf7 100644 --- a/atuin-client/Cargo.toml +++ b/atuin-client/Cargo.toml @@ -18,7 +18,6 @@ sync = [ "reqwest", "sha2", "hex", - "rmp-serde", "base64", "generic-array", "xsalsa20poly1305", @@ -51,13 +50,13 @@ fs-err = { workspace = true } sql-builder = "3" lazy_static = "1" memchr = "2.5" +rmp-serde = { version = "1.1.1" } # sync urlencoding = { version = "2.1.0", optional = true } reqwest = { workspace = true, optional = true } hex = { version = "0.4", optional = true } sha2 = { version = "0.10", optional = true } -rmp-serde = { version = "1.1.1", optional = true } base64 = { workspace = true, optional = true } tokio = { workspace = true } semver = { workspace = true } diff --git a/atuin-client/record-migrations/20230531212437_create-records.sql b/atuin-client/record-migrations/20230531212437_create-records.sql new file mode 100644 index 0000000..4696335 --- /dev/null +++ b/atuin-client/record-migrations/20230531212437_create-records.sql @@ -0,0 +1,15 @@ +-- Add migration script here +create table if not exists records ( + id text primary key, + parent text unique, -- null if this is the first one + host text not null, + + timestamp integer not null, + tag text not null, + version text not null, + data blob not null +); + +create index host_idx on records (host); +create index tag_idx on records (tag); +create index host_tag_idx on records (host, tag); diff --git a/atuin-client/src/database.rs b/atuin-client/src/database.rs index 22bd588..a2d8c53 100644 --- a/atuin-client/src/database.rs +++ b/atuin-client/src/database.rs @@ -17,13 +17,14 @@ use sqlx::{ use super::{ history::History, ordering, - settings::{FilterMode, SearchMode}, + settings::{FilterMode, SearchMode, Settings}, }; pub struct Context { pub session: String, pub cwd: String, pub hostname: String, + pub host_id: String, } #[derive(Default, Clone)] @@ -50,11 +51,13 @@ pub fn current_context() -> Context { env::var("ATUIN_HOST_USER").unwrap_or_else(|_| whoami::username()) ); let cwd = utils::get_current_dir(); + let host_id = Settings::host_id().expect("failed to load host ID"); Context { session, hostname, cwd, + host_id, } } @@ -551,6 +554,7 @@ mod test { hostname: "test:host".to_string(), session: "beepboopiamasession".to_string(), cwd: "/home/ellie".to_string(), + host_id: "test-host".to_string(), }; let results = db @@ -757,6 +761,7 @@ mod test { hostname: "test:host".to_string(), session: "beepboopiamasession".to_string(), cwd: "/home/ellie".to_string(), + host_id: "test-host".to_string(), }; let mut db = Sqlite::new("sqlite::memory:").await.unwrap(); diff --git a/atuin-client/src/kv.rs b/atuin-client/src/kv.rs new file mode 100644 index 0000000..8714927 --- /dev/null +++ b/atuin-client/src/kv.rs @@ -0,0 +1,103 @@ +use eyre::Result; +use serde::{Deserialize, Serialize}; + +use crate::record::store::Store; +use crate::settings::Settings; + +const KV_VERSION: &str = "v0"; +const KV_TAG: &str = "kv"; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct KvRecord { + pub key: String, + pub value: String, +} + +impl KvRecord { + pub fn serialize(&self) -> Result> { + let buf = rmp_serde::to_vec(self)?; + + Ok(buf) + } +} + +pub struct KvStore; + +impl Default for KvStore { + fn default() -> Self { + Self::new() + } +} + +impl KvStore { + // will want to init the actual kv store when that is done + pub fn new() -> KvStore { + KvStore {} + } + + pub async fn set( + &self, + store: &mut (impl Store + Send + Sync), + key: &str, + value: &str, + ) -> Result<()> { + let host_id = Settings::host_id().expect("failed to get host_id"); + + let record = KvRecord { + key: key.to_string(), + value: value.to_string(), + }; + + let bytes = record.serialize()?; + + let parent = store + .last(host_id.as_str(), KV_TAG) + .await? + .map(|entry| entry.id); + + let record = atuin_common::record::Record::builder() + .host(host_id) + .version(KV_VERSION.to_string()) + .tag(KV_TAG.to_string()) + .parent(parent) + .data(bytes) + .build(); + + store.push(&record).await?; + + Ok(()) + } + + // TODO: setup an actual kv store, rebuild func, and do not pass the main store in here as + // well. + pub async fn get(&self, store: &impl Store, key: &str) -> Result> { + // TODO: don't load this from disk so much + let host_id = Settings::host_id().expect("failed to get host_id"); + + // Currently, this is O(n). When we have an actual KV store, it can be better + // Just a poc for now! + + // iterate records to find the value we want + // start at the end, so we get the most recent version + let Some(mut record) = store.last(host_id.as_str(), KV_TAG).await? else { + return Ok(None); + }; + let kv: KvRecord = rmp_serde::from_slice(&record.data)?; + + if kv.key == key { + return Ok(Some(kv)); + } + + while let Some(parent) = record.parent { + record = store.get(parent.as_str()).await?; + let kv: KvRecord = rmp_serde::from_slice(&record.data)?; + + if kv.key == key { + return Ok(Some(kv)); + } + } + + // if we get here, then... we didn't find the record with that key :( + Ok(None) + } +} diff --git a/atuin-client/src/lib.rs b/atuin-client/src/lib.rs index 497c5e7..3f12153 100644 --- a/atuin-client/src/lib.rs +++ b/atuin-client/src/lib.rs @@ -13,5 +13,7 @@ pub mod sync; pub mod database; pub mod history; pub mod import; +pub mod kv; pub mod ordering; +pub mod record; pub mod settings; diff --git a/atuin-client/src/record/mod.rs b/atuin-client/src/record/mod.rs new file mode 100644 index 0000000..72c1f88 --- /dev/null +++ b/atuin-client/src/record/mod.rs @@ -0,0 +1,2 @@ +pub mod sqlite_store; +pub mod store; diff --git a/atuin-client/src/record/sqlite_store.rs b/atuin-client/src/record/sqlite_store.rs new file mode 100644 index 0000000..f116b6e --- /dev/null +++ b/atuin-client/src/record/sqlite_store.rs @@ -0,0 +1,331 @@ +// Here we are using sqlite as a pretty dumb store, and will not be running any complex queries. +// Multiple stores of multiple types are all stored in one chonky table (for now), and we just index +// by tag/host + +use std::path::Path; +use std::str::FromStr; + +use async_trait::async_trait; +use eyre::{eyre, Result}; +use fs_err as fs; +use sqlx::{ + sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow}, + Row, +}; + +use atuin_common::record::Record; + +use super::store::Store; + +pub struct SqliteStore { + pool: SqlitePool, +} + +impl SqliteStore { + pub async fn new(path: impl AsRef) -> Result { + let path = path.as_ref(); + + debug!("opening sqlite database at {:?}", path); + + let create = !path.exists(); + if create { + if let Some(dir) = path.parent() { + fs::create_dir_all(dir)?; + } + } + + let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())? + .journal_mode(SqliteJournalMode::Wal) + .create_if_missing(true); + + let pool = SqlitePoolOptions::new().connect_with(opts).await?; + + Self::setup_db(&pool).await?; + + Ok(Self { pool }) + } + + async fn setup_db(pool: &SqlitePool) -> Result<()> { + debug!("running sqlite database setup"); + + sqlx::migrate!("./record-migrations").run(pool).await?; + + Ok(()) + } + + async fn save_raw(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, r: &Record) -> Result<()> { + // In sqlite, we are "limited" to i64. But that is still fine, until 2262. + sqlx::query( + "insert or ignore into records(id, host, tag, timestamp, parent, version, data) + values(?1, ?2, ?3, ?4, ?5, ?6, ?7)", + ) + .bind(r.id.as_str()) + .bind(r.host.as_str()) + .bind(r.tag.as_str()) + .bind(r.timestamp as i64) + .bind(r.parent.as_ref()) + .bind(r.version.as_str()) + .bind(r.data.as_slice()) + .execute(tx) + .await?; + + Ok(()) + } + + fn query_row(row: SqliteRow) -> Record { + let timestamp: i64 = row.get("timestamp"); + + Record { + id: row.get("id"), + host: row.get("host"), + parent: row.get("parent"), + timestamp: timestamp as u64, + tag: row.get("tag"), + version: row.get("version"), + data: row.get("data"), + } + } +} + +#[async_trait] +impl Store for SqliteStore { + async fn push_batch(&self, records: impl Iterator + Send + Sync) -> Result<()> { + let mut tx = self.pool.begin().await?; + + for record in records { + Self::save_raw(&mut tx, record).await?; + } + + tx.commit().await?; + + Ok(()) + } + + async fn get(&self, id: &str) -> Result { + let res = sqlx::query("select * from records where id = ?1") + .bind(id) + .map(Self::query_row) + .fetch_one(&self.pool) + .await?; + + Ok(res) + } + + async fn len(&self, host: &str, tag: &str) -> Result { + let res: (i64,) = + sqlx::query_as("select count(1) from records where host = ?1 and tag = ?2") + .bind(host) + .bind(tag) + .fetch_one(&self.pool) + .await?; + + Ok(res.0 as u64) + } + + async fn next(&self, record: &Record) -> Result> { + let res = sqlx::query("select * from records where parent = ?1") + .bind(record.id.clone()) + .map(Self::query_row) + .fetch_one(&self.pool) + .await; + + match res { + Err(sqlx::Error::RowNotFound) => Ok(None), + Err(e) => Err(eyre!("an error occured: {}", e)), + Ok(v) => Ok(Some(v)), + } + } + + async fn first(&self, host: &str, tag: &str) -> Result> { + let res = sqlx::query( + "select * from records where host = ?1 and tag = ?2 and parent is null limit 1", + ) + .bind(host) + .bind(tag) + .map(Self::query_row) + .fetch_optional(&self.pool) + .await?; + + Ok(res) + } + + async fn last(&self, host: &str, tag: &str) -> Result> { + let res = sqlx::query( + "select * from records rp where tag=?1 and host=?2 and (select count(1) from records where parent=rp.id) = 0;", + ) + .bind(tag) + .bind(host) + .map(Self::query_row) + .fetch_optional(&self.pool) + .await?; + + Ok(res) + } +} + +#[cfg(test)] +mod tests { + use atuin_common::record::Record; + + use crate::record::store::Store; + + use super::SqliteStore; + + fn test_record() -> Record { + Record::builder() + .host(atuin_common::utils::uuid_v7().simple().to_string()) + .version("v1".into()) + .tag(atuin_common::utils::uuid_v7().simple().to_string()) + .data(vec![0, 1, 2, 3]) + .build() + } + + #[tokio::test] + async fn create_db() { + let db = SqliteStore::new(":memory:").await; + + assert!( + db.is_ok(), + "db could not be created, {:?}", + db.err().unwrap() + ); + } + + #[tokio::test] + async fn push_record() { + let db = SqliteStore::new(":memory:").await.unwrap(); + let record = test_record(); + + db.push(&record).await.expect("failed to insert record"); + } + + #[tokio::test] + async fn get_record() { + let db = SqliteStore::new(":memory:").await.unwrap(); + let record = test_record(); + db.push(&record).await.unwrap(); + + let new_record = db + .get(record.id.as_str()) + .await + .expect("failed to fetch record"); + + assert_eq!(record, new_record, "records are not equal"); + } + + #[tokio::test] + async fn len() { + let db = SqliteStore::new(":memory:").await.unwrap(); + let record = test_record(); + db.push(&record).await.unwrap(); + + let len = db + .len(record.host.as_str(), record.tag.as_str()) + .await + .expect("failed to get store len"); + + assert_eq!(len, 1, "expected length of 1 after insert"); + } + + #[tokio::test] + async fn len_different_tags() { + let db = SqliteStore::new(":memory:").await.unwrap(); + + // these have different tags, so the len should be the same + // we model multiple stores within one database + // new store = new tag = independent length + let first = test_record(); + let second = test_record(); + + db.push(&first).await.unwrap(); + db.push(&second).await.unwrap(); + + let first_len = db + .len(first.host.as_str(), first.tag.as_str()) + .await + .unwrap(); + let second_len = db + .len(second.host.as_str(), second.tag.as_str()) + .await + .unwrap(); + + assert_eq!(first_len, 1, "expected length of 1 after insert"); + assert_eq!(second_len, 1, "expected length of 1 after insert"); + } + + #[tokio::test] + async fn append_a_bunch() { + let db = SqliteStore::new(":memory:").await.unwrap(); + + let mut tail = test_record(); + db.push(&tail).await.expect("failed to push record"); + + for _ in 1..100 { + tail = tail.new_child(vec![1, 2, 3, 4]); + db.push(&tail).await.unwrap(); + } + + assert_eq!( + db.len(tail.host.as_str(), tail.tag.as_str()).await.unwrap(), + 100, + "failed to insert 100 records" + ); + } + + #[tokio::test] + async fn append_a_big_bunch() { + let db = SqliteStore::new(":memory:").await.unwrap(); + + let mut records: Vec = Vec::with_capacity(10000); + + let mut tail = test_record(); + records.push(tail.clone()); + + for _ in 1..10000 { + tail = tail.new_child(vec![1, 2, 3]); + records.push(tail.clone()); + } + + db.push_batch(records.iter()).await.unwrap(); + + assert_eq!( + db.len(tail.host.as_str(), tail.tag.as_str()).await.unwrap(), + 10000, + "failed to insert 10k records" + ); + } + + #[tokio::test] + async fn test_chain() { + let db = SqliteStore::new(":memory:").await.unwrap(); + + let mut records: Vec = Vec::with_capacity(1000); + + let mut tail = test_record(); + records.push(tail.clone()); + + for _ in 1..1000 { + tail = tail.new_child(vec![1, 2, 3]); + records.push(tail.clone()); + } + + db.push_batch(records.iter()).await.unwrap(); + + let mut record = db + .first(tail.host.as_str(), tail.tag.as_str()) + .await + .expect("in memory sqlite should not fail") + .expect("entry exists"); + + let mut count = 1; + + while let Some(next) = db.next(&record).await.unwrap() { + assert_eq!(record.id, next.clone().parent.unwrap()); + record = next; + + count += 1; + } + + assert_eq!(count, 1000); + } +} diff --git a/atuin-client/src/record/store.rs b/atuin-client/src/record/store.rs new file mode 100644 index 0000000..75d79fb --- /dev/null +++ b/atuin-client/src/record/store.rs @@ -0,0 +1,30 @@ +use async_trait::async_trait; +use eyre::Result; + +use atuin_common::record::Record; + +/// A record store stores records +/// In more detail - we tend to need to process this into _another_ format to actually query it. +/// As is, the record store is intended as the source of truth for arbitratry data, which could +/// be shell history, kvs, etc. +#[async_trait] +pub trait Store { + // Push a record + async fn push(&self, record: &Record) -> Result<()> { + self.push_batch(std::iter::once(record)).await + } + + // Push a batch of records, all in one transaction + async fn push_batch(&self, records: impl Iterator + Send + Sync) -> Result<()>; + + async fn get(&self, id: &str) -> Result; + async fn len(&self, host: &str, tag: &str) -> Result; + + /// Get the record that follows this record + async fn next(&self, record: &Record) -> Result>; + + /// Get the first record for a given host and tag + async fn first(&self, host: &str, tag: &str) -> Result>; + /// Get the last record for a given host and tag + async fn last(&self, host: &str, tag: &str) -> Result>; +} diff --git a/atuin-client/src/settings.rs b/atuin-client/src/settings.rs index 524b2fd..dd07245 100644 --- a/atuin-client/src/settings.rs +++ b/atuin-client/src/settings.rs @@ -17,6 +17,7 @@ pub const HISTORY_PAGE_SIZE: i64 = 100; pub const LAST_SYNC_FILENAME: &str = "last_sync_time"; pub const LAST_VERSION_CHECK_FILENAME: &str = "last_version_check_time"; pub const LATEST_VERSION_FILENAME: &str = "latest_version"; +pub const HOST_ID_FILENAME: &str = "host_id"; #[derive(Clone, Debug, Deserialize, Copy, ValueEnum, PartialEq)] pub enum SearchMode { @@ -140,6 +141,7 @@ pub struct Settings { pub sync_address: String, pub sync_frequency: String, pub db_path: String, + pub record_store_path: String, pub key_path: String, pub session_path: String, pub search_mode: SearchMode, @@ -226,6 +228,21 @@ impl Settings { Settings::load_time_from_file(LAST_VERSION_CHECK_FILENAME) } + pub fn host_id() -> Option { + let id = Settings::read_from_data_dir(HOST_ID_FILENAME); + + if id.is_some() { + return id; + } + + let uuid = atuin_common::utils::uuid_v7(); + + Settings::save_to_data_dir(HOST_ID_FILENAME, uuid.as_simple().to_string().as_ref()) + .expect("Could not write host ID to data dir"); + + Some(uuid.as_simple().to_string()) + } + pub fn should_sync(&self) -> Result { if !self.auto_sync || !PathBuf::from(self.session_path.as_str()).exists() { return Ok(false); @@ -321,11 +338,14 @@ impl Settings { config_file.push("config.toml"); let db_path = data_dir.join("history.db"); + let record_store_path = data_dir.join("records.db"); + let key_path = data_dir.join("key"); let session_path = data_dir.join("session"); let mut config_builder = Config::builder() .set_default("db_path", db_path.to_str())? + .set_default("record_store_path", record_store_path.to_str())? .set_default("key_path", key_path.to_str())? .set_default("session_path", session_path.to_str())? .set_default("dialect", "us")? diff --git a/atuin-common/Cargo.toml b/atuin-common/Cargo.toml index 94225e6..918b5b5 100644 --- a/atuin-common/Cargo.toml +++ b/atuin-common/Cargo.toml @@ -16,3 +16,4 @@ chrono = { workspace = true } serde = { workspace = true } uuid = { workspace = true } rand = { workspace = true } +typed-builder = { workspace = true } diff --git a/atuin-common/src/lib.rs b/atuin-common/src/lib.rs index e76a7ab..b332e23 100644 --- a/atuin-common/src/lib.rs +++ b/atuin-common/src/lib.rs @@ -1,4 +1,5 @@ #![forbid(unsafe_code)] pub mod api; +pub mod record; pub mod utils; diff --git a/atuin-common/src/record.rs b/atuin-common/src/record.rs new file mode 100644 index 0000000..1fb60e5 --- /dev/null +++ b/atuin-common/src/record.rs @@ -0,0 +1,49 @@ +use serde::{Deserialize, Serialize}; +use typed_builder::TypedBuilder; + +/// A single record stored inside of our local database +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, TypedBuilder)] +pub struct Record { + /// a unique ID + #[builder(default = crate::utils::uuid_v7().as_simple().to_string())] + pub id: String, + + /// The unique ID of the host. + // TODO(ellie): Optimize the storage here. We use a bunch of IDs, and currently store + // as strings. I would rather avoid normalization, so store as UUID binary instead of + // encoding to a string and wasting much more storage. + pub host: String, + + /// The ID of the parent entry + // A store is technically just a double linked list + // We can do some cheating with the timestamps, but should not rely upon them. + // Clocks are tricksy. + #[builder(default)] + pub parent: Option, + + /// The creation time in nanoseconds since unix epoch + #[builder(default = chrono::Utc::now().timestamp_nanos() as u64)] + pub timestamp: u64, + + /// The version the data in the entry conforms to + // However we want to track versions for this tag, eg v2 + pub version: String, + + /// The type of data we are storing here. Eg, "history" + pub tag: String, + + /// Some data. This can be anything you wish to store. Use the tag field to know how to handle it. + pub data: Vec, +} + +impl Record { + pub fn new_child(&self, data: Vec) -> Record { + Record::builder() + .host(self.host.clone()) + .version(self.version.clone()) + .parent(Some(self.id.clone())) + .tag(self.tag.clone()) + .data(data) + .build() + } +} diff --git a/atuin/src/command/client.rs b/atuin/src/command/client.rs index 6a2d868..6d5fe56 100644 --- a/atuin/src/command/client.rs +++ b/atuin/src/command/client.rs @@ -3,7 +3,7 @@ use std::path::PathBuf; use clap::Subcommand; use eyre::{Result, WrapErr}; -use atuin_client::{database::Sqlite, settings::Settings}; +use atuin_client::{database::Sqlite, record::sqlite_store::SqliteStore, settings::Settings}; use env_logger::Builder; #[cfg(feature = "sync")] @@ -14,6 +14,7 @@ mod account; mod history; mod import; +mod kv; mod search; mod stats; @@ -40,6 +41,9 @@ pub enum Cmd { #[cfg(feature = "sync")] Account(account::Cmd), + + #[command(subcommand)] + Kv(kv::Cmd), } impl Cmd { @@ -53,7 +57,10 @@ impl Cmd { let mut settings = Settings::new().wrap_err("could not load client settings")?; let db_path = PathBuf::from(settings.db_path.as_str()); + let record_store_path = PathBuf::from(settings.record_store_path.as_str()); + let mut db = Sqlite::new(db_path).await?; + let mut store = SqliteStore::new(record_store_path).await?; match self { Self::History(history) => history.run(&settings, &mut db).await, @@ -66,6 +73,8 @@ impl Cmd { #[cfg(feature = "sync")] Self::Account(account) => account.run(settings).await, + + Self::Kv(kv) => kv.run(&settings, &mut store).await, } } } diff --git a/atuin/src/command/client/kv.rs b/atuin/src/command/client/kv.rs new file mode 100644 index 0000000..f922b06 --- /dev/null +++ b/atuin/src/command/client/kv.rs @@ -0,0 +1,45 @@ +use clap::Subcommand; +use eyre::Result; + +use atuin_client::{kv::KvStore, record::store::Store, settings::Settings}; + +#[derive(Subcommand)] +#[command(infer_subcommands = true)] +pub enum Cmd { + // atuin kv set foo bar bar + Set { + #[arg(long, short)] + key: String, + + value: String, + }, + + // atuin kv get foo => bar baz + Get { + key: String, + }, +} + +impl Cmd { + pub async fn run( + &self, + _settings: &Settings, + store: &mut (impl Store + Send + Sync), + ) -> Result<()> { + let kv_store = KvStore::new(); + + match self { + Self::Set { key, value } => kv_store.set(store, key, value).await, + + Self::Get { key } => { + let val = kv_store.get(store, key).await?; + + if let Some(kv) = val { + println!("{}", kv.value); + } + + Ok(()) + } + } + } +}