diff --git a/Cargo.lock b/Cargo.lock index 16e67c2..a3ac81b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -97,6 +97,7 @@ dependencies = [ "atuin-client", "atuin-common", "atuin-server", + "atuin-server-postgres", "base64 0.21.0", "bitflags", "cassowary", @@ -104,7 +105,6 @@ dependencies = [ "clap", "clap_complete", "colored", - "crossbeam-channel", "crossterm", "directories", "env_logger", @@ -160,7 +160,6 @@ dependencies = [ "serde_regex", "sha2", "shellexpand", - "sodiumoxide", "sql-builder", "sqlx", "tokio", @@ -187,6 +186,7 @@ dependencies = [ "argon2", "async-trait", "atuin-common", + "atuin-server-database", "axum", "base64 0.21.0", "chrono", @@ -200,14 +200,39 @@ dependencies = [ "semver", "serde", "serde_json", - "sodiumoxide", - "sqlx", "tokio", "tower", "tower-http", "tracing", "uuid", - "whoami", +] + +[[package]] +name = "atuin-server-database" +version = "15.0.0" +dependencies = [ + "async-trait", + "atuin-common", + "chrono", + "chronoutil", + "eyre", + "serde", + "tracing", + "uuid", +] + +[[package]] +name = "atuin-server-postgres" +version = "15.0.0" +dependencies = [ + "async-trait", + "atuin-common", + "atuin-server-database", + "chrono", + "futures-util", + "serde", + "sqlx", + "tracing", ] [[package]] @@ -515,16 +540,6 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2d0165d2900ae6778e36e80bbc4da3b5eefccee9ba939761f9c2882a5d9af3ff" -[[package]] -name = "crossbeam-channel" -version = "0.5.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a33c2bf77f2df06183c3aa30d1e96c0695a313d4f9c453cc3762a6db39f99200" -dependencies = [ - "cfg-if", - "crossbeam-utils", -] - [[package]] name = "crossbeam-queue" version = "0.3.6" @@ -631,15 +646,6 @@ dependencies = [ "dirs", ] -[[package]] -name = "ed25519" -version = "1.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e9c280362032ea4203659fc489832d0204ef09f247a0506f170dafcac08c369" -dependencies = [ - "signature", -] - [[package]] name = "either" version = "1.8.0" @@ -1175,18 +1181,6 @@ version = "0.2.141" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3304a64d199bb964be99741b7a14d26972741915b3649639149b2479bb46f4b5" -[[package]] -name = "libsodium-sys" -version = "0.2.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b779387cd56adfbc02ea4a668e704f729be8d6a6abd2c27ca5ee537849a92fd" -dependencies = [ - "cc", - "libc", - "pkg-config", - "walkdir", -] - [[package]] name = "libsqlite3-sys" version = "0.24.2" @@ -1875,15 +1869,6 @@ dependencies = [ "cipher", ] -[[package]] -name = "same-file" -version = "1.0.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" -dependencies = [ - "winapi-util", -] - [[package]] name = "schannel" version = "0.1.20" @@ -2071,12 +2056,6 @@ dependencies = [ "libc", ] -[[package]] -name = "signature" -version = "1.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e90531723b08e4d6d71b791108faf51f03e1b4a7784f96b2b87f852ebc247228" - [[package]] name = "slab" version = "0.4.7" @@ -2102,18 +2081,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "sodiumoxide" -version = "0.2.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e26be3acb6c2d9a7aac28482586a7856436af4cfe7100031d219de2d2ecb0028" -dependencies = [ - "ed25519", - "libc", - "libsodium-sys", - "serde", -] - [[package]] name = "spin" version = "0.5.2" @@ -2659,17 +2626,6 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" -[[package]] -name = "walkdir" -version = "2.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "808cf2735cd4b6866113f648b791c6adc5714537bc222d9347bb203386ffda56" -dependencies = [ - "same-file", - "winapi", - "winapi-util", -] - [[package]] name = "want" version = "0.3.0" diff --git a/Cargo.toml b/Cargo.toml index 652efb8..00b0434 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,12 @@ [workspace] -members = ["atuin", "atuin-client", "atuin-server", "atuin-common"] +members = [ + "atuin", + "atuin-client", + "atuin-server", + "atuin-server-postgres", + "atuin-server-database", + "atuin-common", +] [workspace.package] name = "atuin" @@ -27,7 +34,6 @@ rand = { version = "0.8.5", features = ["std"] } semver = "1.0.14" serde = { version = "1.0.145", features = ["derive"] } serde_json = "1.0.86" -sodiumoxide = "0.2.6" tokio = { version = "1", features = ["full"] } uuid = { version = "1.2", features = ["v4"] } whoami = "1.1.2" diff --git a/atuin-client/Cargo.toml b/atuin-client/Cargo.toml index fee3eb5..770d774 100644 --- a/atuin-client/Cargo.toml +++ b/atuin-client/Cargo.toml @@ -53,7 +53,6 @@ memchr = "2.5" # sync urlencoding = { version = "2.1.0", optional = true } -sodiumoxide = { workspace = true, optional = true } reqwest = { workspace = true, optional = true } hex = { version = "0.4", optional = true } sha2 = { version = "0.10", optional = true } diff --git a/atuin-server-database/Cargo.toml b/atuin-server-database/Cargo.toml new file mode 100644 index 0000000..485b324 --- /dev/null +++ b/atuin-server-database/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "atuin-server-database" +edition = "2021" +description = "server database library for atuin" + +version = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } + +[dependencies] +atuin-common = { path = "../atuin-common", version = "15.0.0" } + +tracing = "0.1" +chrono = { workspace = true } +eyre = { workspace = true } +uuid = { workspace = true } +serde = { workspace = true } +async-trait = { workspace = true } +chronoutil = "0.2.3" diff --git a/atuin-server/src/calendar.rs b/atuin-server-database/src/calendar.rs similarity index 100% rename from atuin-server/src/calendar.rs rename to atuin-server-database/src/calendar.rs diff --git a/atuin-server-database/src/lib.rs b/atuin-server-database/src/lib.rs new file mode 100644 index 0000000..de33ba4 --- /dev/null +++ b/atuin-server-database/src/lib.rs @@ -0,0 +1,220 @@ +#![forbid(unsafe_code)] + +pub mod calendar; +pub mod models; + +use std::{ + collections::HashMap, + fmt::{Debug, Display}, +}; + +use self::{ + calendar::{TimePeriod, TimePeriodInfo}, + models::{History, NewHistory, NewSession, NewUser, Session, User}, +}; +use async_trait::async_trait; +use atuin_common::utils::get_days_from_month; +use chrono::{Datelike, TimeZone}; +use chronoutil::RelativeDuration; +use serde::{de::DeserializeOwned, Serialize}; +use tracing::instrument; + +#[derive(Debug)] +pub enum DbError { + NotFound, + Other(eyre::Report), +} + +impl Display for DbError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{self:?}") + } +} + +impl std::error::Error for DbError {} + +pub type DbResult = Result; + +#[async_trait] +pub trait Database: Sized + Clone + Send + Sync + 'static { + type Settings: Debug + Clone + DeserializeOwned + Serialize + Send + Sync + 'static; + async fn new(settings: &Self::Settings) -> DbResult; + + async fn get_session(&self, token: &str) -> DbResult; + async fn get_session_user(&self, token: &str) -> DbResult; + async fn add_session(&self, session: &NewSession) -> DbResult<()>; + + async fn get_user(&self, username: &str) -> DbResult; + async fn get_user_session(&self, u: &User) -> DbResult; + async fn add_user(&self, user: &NewUser) -> DbResult; + async fn delete_user(&self, u: &User) -> DbResult<()>; + + async fn count_history(&self, user: &User) -> DbResult; + async fn count_history_cached(&self, user: &User) -> DbResult; + + async fn delete_history(&self, user: &User, id: String) -> DbResult<()>; + async fn deleted_history(&self, user: &User) -> DbResult>; + + async fn count_history_range( + &self, + user: &User, + start: chrono::NaiveDateTime, + end: chrono::NaiveDateTime, + ) -> DbResult; + + async fn list_history( + &self, + user: &User, + created_after: chrono::NaiveDateTime, + since: chrono::NaiveDateTime, + host: &str, + page_size: i64, + ) -> DbResult>; + + async fn add_history(&self, history: &[NewHistory]) -> DbResult<()>; + + async fn oldest_history(&self, user: &User) -> DbResult; + + /// Count the history for a given year + #[instrument(skip_all)] + async fn count_history_year(&self, user: &User, year: i32) -> DbResult { + let start = chrono::Utc.ymd(year, 1, 1).and_hms_nano(0, 0, 0, 0); + let end = start + RelativeDuration::years(1); + + let res = self + .count_history_range(user, start.naive_utc(), end.naive_utc()) + .await?; + Ok(res) + } + + /// Count the history for a given month + #[instrument(skip_all)] + async fn count_history_month(&self, user: &User, month: chrono::NaiveDate) -> DbResult { + let start = chrono::Utc + .ymd(month.year(), month.month(), 1) + .and_hms_nano(0, 0, 0, 0); + + // ofc... + let end = if month.month() < 12 { + chrono::Utc + .ymd(month.year(), month.month() + 1, 1) + .and_hms_nano(0, 0, 0, 0) + } else { + chrono::Utc + .ymd(month.year() + 1, 1, 1) + .and_hms_nano(0, 0, 0, 0) + }; + + tracing::debug!("start: {}, end: {}", start, end); + + let res = self + .count_history_range(user, start.naive_utc(), end.naive_utc()) + .await?; + Ok(res) + } + + /// Count the history for a given day + #[instrument(skip_all)] + async fn count_history_day(&self, user: &User, day: chrono::NaiveDate) -> DbResult { + let start = chrono::Utc + .ymd(day.year(), day.month(), day.day()) + .and_hms_nano(0, 0, 0, 0); + let end = chrono::Utc + .ymd(day.year(), day.month(), day.day() + 1) + .and_hms_nano(0, 0, 0, 0); + + let res = self + .count_history_range(user, start.naive_utc(), end.naive_utc()) + .await?; + Ok(res) + } + + #[instrument(skip_all)] + async fn calendar( + &self, + user: &User, + period: TimePeriod, + year: u64, + month: u64, + ) -> DbResult> { + // TODO: Support different timezones. Right now we assume UTC and + // everything is stored as such. But it _should_ be possible to + // interpret the stored date with a different TZ + + match period { + TimePeriod::YEAR => { + let mut ret = HashMap::new(); + // First we need to work out how far back to calculate. Get the + // oldest history item + let oldest = self.oldest_history(user).await?.timestamp.year(); + let current_year = chrono::Utc::now().year(); + + // All the years we need to get data for + // The upper bound is exclusive, so include current +1 + let years = oldest..current_year + 1; + + for year in years { + let count = self.count_history_year(user, year).await?; + + ret.insert( + year as u64, + TimePeriodInfo { + count: count as u64, + hash: "".to_string(), + }, + ); + } + + Ok(ret) + } + + TimePeriod::MONTH => { + let mut ret = HashMap::new(); + + for month in 1..13 { + let count = self + .count_history_month( + user, + chrono::Utc.ymd(year as i32, month, 1).naive_utc(), + ) + .await?; + + ret.insert( + month as u64, + TimePeriodInfo { + count: count as u64, + hash: "".to_string(), + }, + ); + } + + Ok(ret) + } + + TimePeriod::DAY => { + let mut ret = HashMap::new(); + + for day in 1..get_days_from_month(year as i32, month as u32) { + let count = self + .count_history_day( + user, + chrono::Utc + .ymd(year as i32, month as u32, day as u32) + .naive_utc(), + ) + .await?; + + ret.insert( + day as u64, + TimePeriodInfo { + count: count as u64, + hash: "".to_string(), + }, + ); + } + + Ok(ret) + } + } + } +} diff --git a/atuin-server/src/models.rs b/atuin-server-database/src/models.rs similarity index 91% rename from atuin-server/src/models.rs rename to atuin-server-database/src/models.rs index ee84f58..a95ceba 100644 --- a/atuin-server/src/models.rs +++ b/atuin-server-database/src/models.rs @@ -1,6 +1,5 @@ use chrono::prelude::*; -#[derive(sqlx::FromRow)] pub struct History { pub id: i64, pub client_id: String, // a client generated ID @@ -22,7 +21,6 @@ pub struct NewHistory { pub data: String, } -#[derive(sqlx::FromRow)] pub struct User { pub id: i64, pub username: String, @@ -30,7 +28,6 @@ pub struct User { pub password: String, } -#[derive(sqlx::FromRow)] pub struct Session { pub id: i64, pub user_id: i64, diff --git a/atuin-server-postgres/Cargo.toml b/atuin-server-postgres/Cargo.toml new file mode 100644 index 0000000..18864f6 --- /dev/null +++ b/atuin-server-postgres/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "atuin-server-postgres" +edition = "2018" +description = "server postgres database library for atuin" + +version = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } + +[dependencies] +atuin-common = { path = "../atuin-common", version = "15.0.0" } +atuin-server-database = { path = "../atuin-server-database", version = "15.0.0" } + +tracing = "0.1" +chrono = { workspace = true } +serde = { workspace = true } +sqlx = { workspace = true } +async-trait = { workspace = true } +futures-util = "0.3" diff --git a/atuin-server/migrations/20210425153745_create_history.sql b/atuin-server-postgres/migrations/20210425153745_create_history.sql similarity index 100% rename from atuin-server/migrations/20210425153745_create_history.sql rename to atuin-server-postgres/migrations/20210425153745_create_history.sql diff --git a/atuin-server/migrations/20210425153757_create_users.sql b/atuin-server-postgres/migrations/20210425153757_create_users.sql similarity index 100% rename from atuin-server/migrations/20210425153757_create_users.sql rename to atuin-server-postgres/migrations/20210425153757_create_users.sql diff --git a/atuin-server/migrations/20210425153800_create_sessions.sql b/atuin-server-postgres/migrations/20210425153800_create_sessions.sql similarity index 100% rename from atuin-server/migrations/20210425153800_create_sessions.sql rename to atuin-server-postgres/migrations/20210425153800_create_sessions.sql diff --git a/atuin-server/migrations/20220419082412_add_count_trigger.sql b/atuin-server-postgres/migrations/20220419082412_add_count_trigger.sql similarity index 100% rename from atuin-server/migrations/20220419082412_add_count_trigger.sql rename to atuin-server-postgres/migrations/20220419082412_add_count_trigger.sql diff --git a/atuin-server/migrations/20220421073605_fix_count_trigger_delete.sql b/atuin-server-postgres/migrations/20220421073605_fix_count_trigger_delete.sql similarity index 100% rename from atuin-server/migrations/20220421073605_fix_count_trigger_delete.sql rename to atuin-server-postgres/migrations/20220421073605_fix_count_trigger_delete.sql diff --git a/atuin-server/migrations/20220421174016_larger-commands.sql b/atuin-server-postgres/migrations/20220421174016_larger-commands.sql similarity index 100% rename from atuin-server/migrations/20220421174016_larger-commands.sql rename to atuin-server-postgres/migrations/20220421174016_larger-commands.sql diff --git a/atuin-server/migrations/20220426172813_user-created-at.sql b/atuin-server-postgres/migrations/20220426172813_user-created-at.sql similarity index 100% rename from atuin-server/migrations/20220426172813_user-created-at.sql rename to atuin-server-postgres/migrations/20220426172813_user-created-at.sql diff --git a/atuin-server/migrations/20220505082442_create-events.sql b/atuin-server-postgres/migrations/20220505082442_create-events.sql similarity index 100% rename from atuin-server/migrations/20220505082442_create-events.sql rename to atuin-server-postgres/migrations/20220505082442_create-events.sql diff --git a/atuin-server/migrations/20220610074049_history-length.sql b/atuin-server-postgres/migrations/20220610074049_history-length.sql similarity index 100% rename from atuin-server/migrations/20220610074049_history-length.sql rename to atuin-server-postgres/migrations/20220610074049_history-length.sql diff --git a/atuin-server/migrations/20230315220537_drop-events.sql b/atuin-server-postgres/migrations/20230315220537_drop-events.sql similarity index 100% rename from atuin-server/migrations/20230315220537_drop-events.sql rename to atuin-server-postgres/migrations/20230315220537_drop-events.sql diff --git a/atuin-server/migrations/20230315224203_create-deleted.sql b/atuin-server-postgres/migrations/20230315224203_create-deleted.sql similarity index 100% rename from atuin-server/migrations/20230315224203_create-deleted.sql rename to atuin-server-postgres/migrations/20230315224203_create-deleted.sql diff --git a/atuin-server/migrations/20230515221038_trigger-delete-only.sql b/atuin-server-postgres/migrations/20230515221038_trigger-delete-only.sql similarity index 100% rename from atuin-server/migrations/20230515221038_trigger-delete-only.sql rename to atuin-server-postgres/migrations/20230515221038_trigger-delete-only.sql diff --git a/atuin-server-postgres/src/lib.rs b/atuin-server-postgres/src/lib.rs new file mode 100644 index 0000000..0dc51da --- /dev/null +++ b/atuin-server-postgres/src/lib.rs @@ -0,0 +1,332 @@ +use async_trait::async_trait; +use atuin_server_database::models::{History, NewHistory, NewSession, NewUser, Session, User}; +use atuin_server_database::{Database, DbError, DbResult}; +use futures_util::TryStreamExt; +use serde::{Deserialize, Serialize}; +use sqlx::postgres::PgPoolOptions; + +use sqlx::Row; + +use tracing::instrument; +use wrappers::{DbHistory, DbSession, DbUser}; + +mod wrappers; + +#[derive(Clone)] +pub struct Postgres { + pool: sqlx::Pool, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct PostgresSettings { + pub db_uri: String, +} + +fn fix_error(error: sqlx::Error) -> DbError { + match error { + sqlx::Error::RowNotFound => DbError::NotFound, + error => DbError::Other(error.into()), + } +} + +#[async_trait] +impl Database for Postgres { + type Settings = PostgresSettings; + async fn new(settings: &PostgresSettings) -> DbResult { + let pool = PgPoolOptions::new() + .max_connections(100) + .connect(settings.db_uri.as_str()) + .await + .map_err(fix_error)?; + + sqlx::migrate!("./migrations") + .run(&pool) + .await + .map_err(|error| DbError::Other(error.into()))?; + + Ok(Self { pool }) + } + + #[instrument(skip_all)] + async fn get_session(&self, token: &str) -> DbResult { + sqlx::query_as("select id, user_id, token from sessions where token = $1") + .bind(token) + .fetch_one(&self.pool) + .await + .map_err(fix_error) + .map(|DbSession(session)| session) + } + + #[instrument(skip_all)] + async fn get_user(&self, username: &str) -> DbResult { + sqlx::query_as("select id, username, email, password from users where username = $1") + .bind(username) + .fetch_one(&self.pool) + .await + .map_err(fix_error) + .map(|DbUser(user)| user) + } + + #[instrument(skip_all)] + async fn get_session_user(&self, token: &str) -> DbResult { + sqlx::query_as( + "select users.id, users.username, users.email, users.password from users + inner join sessions + on users.id = sessions.user_id + and sessions.token = $1", + ) + .bind(token) + .fetch_one(&self.pool) + .await + .map_err(fix_error) + .map(|DbUser(user)| user) + } + + #[instrument(skip_all)] + async fn count_history(&self, user: &User) -> DbResult { + // The cache is new, and the user might not yet have a cache value. + // They will have one as soon as they post up some new history, but handle that + // edge case. + + let res: (i64,) = sqlx::query_as( + "select count(1) from history + where user_id = $1", + ) + .bind(user.id) + .fetch_one(&self.pool) + .await + .map_err(fix_error)?; + + Ok(res.0) + } + + #[instrument(skip_all)] + async fn count_history_cached(&self, user: &User) -> DbResult { + let res: (i32,) = sqlx::query_as( + "select total from total_history_count_user + where user_id = $1", + ) + .bind(user.id) + .fetch_one(&self.pool) + .await + .map_err(fix_error)?; + + Ok(res.0 as i64) + } + + async fn delete_history(&self, user: &User, id: String) -> DbResult<()> { + sqlx::query( + "update history + set deleted_at = $3 + where user_id = $1 + and client_id = $2 + and deleted_at is null", // don't just keep setting it + ) + .bind(user.id) + .bind(id) + .bind(chrono::Utc::now().naive_utc()) + .fetch_all(&self.pool) + .await + .map_err(fix_error)?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn deleted_history(&self, user: &User) -> DbResult> { + // The cache is new, and the user might not yet have a cache value. + // They will have one as soon as they post up some new history, but handle that + // edge case. + + let res = sqlx::query( + "select client_id from history + where user_id = $1 + and deleted_at is not null", + ) + .bind(user.id) + .fetch_all(&self.pool) + .await + .map_err(fix_error)?; + + let res = res + .iter() + .map(|row| row.get::("client_id")) + .collect(); + + Ok(res) + } + + #[instrument(skip_all)] + async fn count_history_range( + &self, + user: &User, + start: chrono::NaiveDateTime, + end: chrono::NaiveDateTime, + ) -> DbResult { + let res: (i64,) = sqlx::query_as( + "select count(1) from history + where user_id = $1 + and timestamp >= $2::date + and timestamp < $3::date", + ) + .bind(user.id) + .bind(start) + .bind(end) + .fetch_one(&self.pool) + .await + .map_err(fix_error)?; + + Ok(res.0) + } + + #[instrument(skip_all)] + async fn list_history( + &self, + user: &User, + created_after: chrono::NaiveDateTime, + since: chrono::NaiveDateTime, + host: &str, + page_size: i64, + ) -> DbResult> { + let res = sqlx::query_as( + "select id, client_id, user_id, hostname, timestamp, data, created_at from history + where user_id = $1 + and hostname != $2 + and created_at >= $3 + and timestamp >= $4 + order by timestamp asc + limit $5", + ) + .bind(user.id) + .bind(host) + .bind(created_after) + .bind(since) + .bind(page_size) + .fetch(&self.pool) + .map_ok(|DbHistory(h)| h) + .try_collect() + .await + .map_err(fix_error)?; + + Ok(res) + } + + #[instrument(skip_all)] + async fn add_history(&self, history: &[NewHistory]) -> DbResult<()> { + let mut tx = self.pool.begin().await.map_err(fix_error)?; + + for i in history { + let client_id: &str = &i.client_id; + let hostname: &str = &i.hostname; + let data: &str = &i.data; + + sqlx::query( + "insert into history + (client_id, user_id, hostname, timestamp, data) + values ($1, $2, $3, $4, $5) + on conflict do nothing + ", + ) + .bind(client_id) + .bind(i.user_id) + .bind(hostname) + .bind(i.timestamp) + .bind(data) + .execute(&mut tx) + .await + .map_err(fix_error)?; + } + + tx.commit().await.map_err(fix_error)?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn delete_user(&self, u: &User) -> DbResult<()> { + sqlx::query("delete from sessions where user_id = $1") + .bind(u.id) + .execute(&self.pool) + .await + .map_err(fix_error)?; + + sqlx::query("delete from users where id = $1") + .bind(u.id) + .execute(&self.pool) + .await + .map_err(fix_error)?; + + sqlx::query("delete from history where user_id = $1") + .bind(u.id) + .execute(&self.pool) + .await + .map_err(fix_error)?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn add_user(&self, user: &NewUser) -> DbResult { + let email: &str = &user.email; + let username: &str = &user.username; + let password: &str = &user.password; + + let res: (i64,) = sqlx::query_as( + "insert into users + (username, email, password) + values($1, $2, $3) + returning id", + ) + .bind(username) + .bind(email) + .bind(password) + .fetch_one(&self.pool) + .await + .map_err(fix_error)?; + + Ok(res.0) + } + + #[instrument(skip_all)] + async fn add_session(&self, session: &NewSession) -> DbResult<()> { + let token: &str = &session.token; + + sqlx::query( + "insert into sessions + (user_id, token) + values($1, $2)", + ) + .bind(session.user_id) + .bind(token) + .execute(&self.pool) + .await + .map_err(fix_error)?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn get_user_session(&self, u: &User) -> DbResult { + sqlx::query_as("select id, user_id, token from sessions where user_id = $1") + .bind(u.id) + .fetch_one(&self.pool) + .await + .map_err(fix_error) + .map(|DbSession(session)| session) + } + + #[instrument(skip_all)] + async fn oldest_history(&self, user: &User) -> DbResult { + sqlx::query_as( + "select id, client_id, user_id, hostname, timestamp, data, created_at from history + where user_id = $1 + order by timestamp asc + limit 1", + ) + .bind(user.id) + .fetch_one(&self.pool) + .await + .map_err(fix_error) + .map(|DbHistory(h)| h) + } +} diff --git a/atuin-server-postgres/src/wrappers.rs b/atuin-server-postgres/src/wrappers.rs new file mode 100644 index 0000000..cb3d5a9 --- /dev/null +++ b/atuin-server-postgres/src/wrappers.rs @@ -0,0 +1,42 @@ +use ::sqlx::{FromRow, Result}; +use atuin_server_database::models::{History, Session, User}; +use sqlx::{postgres::PgRow, Row}; + +pub struct DbUser(pub User); +pub struct DbSession(pub Session); +pub struct DbHistory(pub History); + +impl<'a> FromRow<'a, PgRow> for DbUser { + fn from_row(row: &'a PgRow) -> Result { + Ok(Self(User { + id: row.try_get("id")?, + username: row.try_get("username")?, + email: row.try_get("email")?, + password: row.try_get("password")?, + })) + } +} + +impl<'a> ::sqlx::FromRow<'a, PgRow> for DbSession { + fn from_row(row: &'a PgRow) -> ::sqlx::Result { + Ok(Self(Session { + id: row.try_get("id")?, + user_id: row.try_get("user_id")?, + token: row.try_get("token")?, + })) + } +} + +impl<'a> ::sqlx::FromRow<'a, PgRow> for DbHistory { + fn from_row(row: &'a PgRow) -> ::sqlx::Result { + Ok(Self(History { + id: row.try_get("id")?, + client_id: row.try_get("client_id")?, + user_id: row.try_get("user_id")?, + hostname: row.try_get("hostname")?, + timestamp: row.try_get("timestamp")?, + data: row.try_get("data")?, + created_at: row.try_get("created_at")?, + })) + } +} diff --git a/atuin-server/Cargo.toml b/atuin-server/Cargo.toml index e4cbf3e..f308fa3 100644 --- a/atuin-server/Cargo.toml +++ b/atuin-server/Cargo.toml @@ -11,20 +11,18 @@ repository = { workspace = true } [dependencies] atuin-common = { path = "../atuin-common", version = "15.0.0" } +atuin-server-database = { path = "../atuin-server-database", version = "15.0.0" } tracing = "0.1" chrono = { workspace = true } eyre = { workspace = true } uuid = { workspace = true } -whoami = { workspace = true } config = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } -sodiumoxide = { workspace = true } base64 = { workspace = true } rand = { workspace = true } tokio = { workspace = true } -sqlx = { workspace = true } async-trait = { workspace = true } axum = "0.6.4" http = "0.2" diff --git a/atuin-server/src/auth.rs b/atuin-server/src/auth.rs deleted file mode 100644 index 52a7310..0000000 --- a/atuin-server/src/auth.rs +++ /dev/null @@ -1,222 +0,0 @@ -/* -use self::diesel::prelude::*; -use eyre::Result; -use rocket::http::Status; -use rocket::request::{self, FromRequest, Outcome, Request}; -use rocket::State; -use rocket_contrib::databases::diesel; -use sodiumoxide::crypto::pwhash::argon2id13; - -use rocket_contrib::json::Json; -use uuid::Uuid; - -use super::models::{NewSession, NewUser, Session, User}; -use super::views::ApiResponse; - -use crate::api::{LoginRequest, RegisterRequest}; -use crate::schema::{sessions, users}; -use crate::settings::Settings; -use crate::utils::hash_secret; - -use super::database::AtuinDbConn; - -#[derive(Debug)] -pub enum KeyError { - Missing, - Invalid, -} - -pub fn verify_str(secret: &str, verify: &str) -> bool { - sodiumoxide::init().unwrap(); - - let mut padded = [0_u8; 128]; - secret.as_bytes().iter().enumerate().for_each(|(i, val)| { - padded[i] = *val; - }); - - match argon2id13::HashedPassword::from_slice(&padded) { - Some(hp) => argon2id13::pwhash_verify(&hp, verify.as_bytes()), - None => false, - } -} - -impl<'a, 'r> FromRequest<'a, 'r> for User { - type Error = KeyError; - - fn from_request(request: &'a Request<'r>) -> request::Outcome { - let session: Vec<_> = request.headers().get("authorization").collect(); - - if session.is_empty() { - return Outcome::Failure((Status::BadRequest, KeyError::Missing)); - } else if session.len() > 1 { - return Outcome::Failure((Status::BadRequest, KeyError::Invalid)); - } - - let session: Vec<_> = session[0].split(' ').collect(); - - if session.len() != 2 { - return Outcome::Failure((Status::BadRequest, KeyError::Invalid)); - } - - if session[0] != "Token" { - return Outcome::Failure((Status::BadRequest, KeyError::Invalid)); - } - - let session = session[1]; - - let db = request - .guard::() - .succeeded() - .expect("failed to load database"); - - let session = sessions::table - .filter(sessions::token.eq(session)) - .first::(&*db); - - if session.is_err() { - return Outcome::Failure((Status::Unauthorized, KeyError::Invalid)); - } - - let session = session.unwrap(); - - let user = users::table.find(session.user_id).first(&*db); - - match user { - Ok(user) => Outcome::Success(user), - Err(_) => Outcome::Failure((Status::Unauthorized, KeyError::Invalid)), - } - } -} - -#[get("/user/")] -#[allow(clippy::clippy::needless_pass_by_value)] -pub fn get_user(user: String, conn: AtuinDbConn) -> ApiResponse { - use crate::schema::users::dsl::{username, users}; - - let user: Result = users - .select(username) - .filter(username.eq(user)) - .first(&*conn); - - if user.is_err() { - return ApiResponse { - json: json!({ - "message": "could not find user", - }), - status: Status::NotFound, - }; - } - - let user = user.unwrap(); - - ApiResponse { - json: json!({ "username": user.as_str() }), - status: Status::Ok, - } -} - -#[post("/register", data = "")] -#[allow(clippy::clippy::needless_pass_by_value)] -pub fn register( - conn: AtuinDbConn, - register: Json, - settings: State, -) -> ApiResponse { - if !settings.server.open_registration { - return ApiResponse { - status: Status::BadRequest, - json: json!({ - "message": "registrations are not open" - }), - }; - } - - let hashed = hash_secret(register.password.as_str()); - - let new_user = NewUser { - email: register.email.as_str(), - username: register.username.as_str(), - password: hashed.as_str(), - }; - - let user = diesel::insert_into(users::table) - .values(&new_user) - .get_result(&*conn); - - if user.is_err() { - return ApiResponse { - status: Status::BadRequest, - json: json!({ - "message": "failed to create user - username or email in use?", - }), - }; - } - - let user: User = user.unwrap(); - let token = Uuid::new_v4().to_simple().to_string(); - - let new_session = NewSession { - user_id: user.id, - token: token.as_str(), - }; - - match diesel::insert_into(sessions::table) - .values(&new_session) - .execute(&*conn) - { - Ok(_) => ApiResponse { - status: Status::Ok, - json: json!({"message": "user created!", "session": token}), - }, - Err(_) => ApiResponse { - status: Status::BadRequest, - json: json!({ "message": "failed to create user"}), - }, - } -} - -#[post("/login", data = "")] -#[allow(clippy::clippy::needless_pass_by_value)] -pub fn login(conn: AtuinDbConn, login: Json) -> ApiResponse { - let user = users::table - .filter(users::username.eq(login.username.as_str())) - .first(&*conn); - - if user.is_err() { - return ApiResponse { - status: Status::NotFound, - json: json!({"message": "user not found"}), - }; - } - - let user: User = user.unwrap(); - - let session = sessions::table - .filter(sessions::user_id.eq(user.id)) - .first(&*conn); - - // a session should exist... - if session.is_err() { - return ApiResponse { - status: Status::InternalServerError, - json: json!({"message": "something went wrong"}), - }; - } - - let verified = verify_str(user.password.as_str(), login.password.as_str()); - - if !verified { - return ApiResponse { - status: Status::NotFound, - json: json!({"message": "user not found"}), - }; - } - - let session: Session = session.unwrap(); - - ApiResponse { - status: Status::Ok, - json: json!({"session": session.token}), - } -} -*/ diff --git a/atuin-server/src/database.rs b/atuin-server/src/database.rs deleted file mode 100644 index 894fab7..0000000 --- a/atuin-server/src/database.rs +++ /dev/null @@ -1,510 +0,0 @@ -use std::collections::HashMap; - -use async_trait::async_trait; -use chrono::{Datelike, TimeZone}; -use chronoutil::RelativeDuration; -use sqlx::{postgres::PgPoolOptions, Result}; - -use sqlx::Row; - -use tracing::{debug, instrument, warn}; - -use super::{ - calendar::{TimePeriod, TimePeriodInfo}, - models::{History, NewHistory, NewSession, NewUser, Session, User}, -}; -use crate::settings::Settings; - -use atuin_common::utils::get_days_from_month; - -#[async_trait] -pub trait Database { - async fn get_session(&self, token: &str) -> Result; - async fn get_session_user(&self, token: &str) -> Result; - async fn add_session(&self, session: &NewSession) -> Result<()>; - - async fn get_user(&self, username: &str) -> Result; - async fn get_user_session(&self, u: &User) -> Result; - async fn add_user(&self, user: &NewUser) -> Result; - async fn delete_user(&self, u: &User) -> Result<()>; - - async fn count_history(&self, user: &User) -> Result; - async fn count_history_cached(&self, user: &User) -> Result; - - async fn delete_history(&self, user: &User, id: String) -> Result<()>; - async fn deleted_history(&self, user: &User) -> Result>; - - async fn count_history_range( - &self, - user: &User, - start: chrono::NaiveDateTime, - end: chrono::NaiveDateTime, - ) -> Result; - async fn count_history_day(&self, user: &User, date: chrono::NaiveDate) -> Result; - async fn count_history_month(&self, user: &User, date: chrono::NaiveDate) -> Result; - async fn count_history_year(&self, user: &User, year: i32) -> Result; - - async fn list_history( - &self, - user: &User, - created_after: chrono::NaiveDateTime, - since: chrono::NaiveDateTime, - host: &str, - page_size: i64, - ) -> Result>; - - async fn add_history(&self, history: &[NewHistory]) -> Result<()>; - - async fn oldest_history(&self, user: &User) -> Result; - - async fn calendar( - &self, - user: &User, - period: TimePeriod, - year: u64, - month: u64, - ) -> Result>; -} - -#[derive(Clone)] -pub struct Postgres { - pool: sqlx::Pool, - settings: Settings, -} - -impl Postgres { - pub async fn new(settings: Settings) -> Result { - let pool = PgPoolOptions::new() - .max_connections(100) - .connect(settings.db_uri.as_str()) - .await?; - - sqlx::migrate!("./migrations").run(&pool).await?; - - Ok(Self { pool, settings }) - } -} - -#[async_trait] -impl Database for Postgres { - #[instrument(skip_all)] - async fn get_session(&self, token: &str) -> Result { - sqlx::query_as::<_, Session>("select id, user_id, token from sessions where token = $1") - .bind(token) - .fetch_one(&self.pool) - .await - } - - #[instrument(skip_all)] - async fn get_user(&self, username: &str) -> Result { - sqlx::query_as::<_, User>( - "select id, username, email, password from users where username = $1", - ) - .bind(username) - .fetch_one(&self.pool) - .await - } - - #[instrument(skip_all)] - async fn get_session_user(&self, token: &str) -> Result { - sqlx::query_as::<_, User>( - "select users.id, users.username, users.email, users.password from users - inner join sessions - on users.id = sessions.user_id - and sessions.token = $1", - ) - .bind(token) - .fetch_one(&self.pool) - .await - } - - #[instrument(skip_all)] - async fn count_history(&self, user: &User) -> Result { - // The cache is new, and the user might not yet have a cache value. - // They will have one as soon as they post up some new history, but handle that - // edge case. - - let res: (i64,) = sqlx::query_as( - "select count(1) from history - where user_id = $1", - ) - .bind(user.id) - .fetch_one(&self.pool) - .await?; - - Ok(res.0) - } - - #[instrument(skip_all)] - async fn count_history_cached(&self, user: &User) -> Result { - let res: (i32,) = sqlx::query_as( - "select total from total_history_count_user - where user_id = $1", - ) - .bind(user.id) - .fetch_one(&self.pool) - .await?; - - Ok(res.0 as i64) - } - - async fn delete_history(&self, user: &User, id: String) -> Result<()> { - sqlx::query( - "update history - set deleted_at = $3 - where user_id = $1 - and client_id = $2 - and deleted_at is null", // don't just keep setting it - ) - .bind(user.id) - .bind(id) - .bind(chrono::Utc::now().naive_utc()) - .fetch_all(&self.pool) - .await?; - - Ok(()) - } - - #[instrument(skip_all)] - async fn deleted_history(&self, user: &User) -> Result> { - // The cache is new, and the user might not yet have a cache value. - // They will have one as soon as they post up some new history, but handle that - // edge case. - - let res = sqlx::query( - "select client_id from history - where user_id = $1 - and deleted_at is not null", - ) - .bind(user.id) - .fetch_all(&self.pool) - .await?; - - let res = res - .iter() - .map(|row| row.get::("client_id")) - .collect(); - - Ok(res) - } - - #[instrument(skip_all)] - async fn count_history_range( - &self, - user: &User, - start: chrono::NaiveDateTime, - end: chrono::NaiveDateTime, - ) -> Result { - let res: (i64,) = sqlx::query_as( - "select count(1) from history - where user_id = $1 - and timestamp >= $2::date - and timestamp < $3::date", - ) - .bind(user.id) - .bind(start) - .bind(end) - .fetch_one(&self.pool) - .await?; - - Ok(res.0) - } - - // Count the history for a given year - #[instrument(skip_all)] - async fn count_history_year(&self, user: &User, year: i32) -> Result { - let start = chrono::Utc.ymd(year, 1, 1).and_hms_nano(0, 0, 0, 0); - let end = start + RelativeDuration::years(1); - - let res = self - .count_history_range(user, start.naive_utc(), end.naive_utc()) - .await?; - Ok(res) - } - - // Count the history for a given month - #[instrument(skip_all)] - async fn count_history_month(&self, user: &User, month: chrono::NaiveDate) -> Result { - let start = chrono::Utc - .ymd(month.year(), month.month(), 1) - .and_hms_nano(0, 0, 0, 0); - - // ofc... - let end = if month.month() < 12 { - chrono::Utc - .ymd(month.year(), month.month() + 1, 1) - .and_hms_nano(0, 0, 0, 0) - } else { - chrono::Utc - .ymd(month.year() + 1, 1, 1) - .and_hms_nano(0, 0, 0, 0) - }; - - debug!("start: {}, end: {}", start, end); - - let res = self - .count_history_range(user, start.naive_utc(), end.naive_utc()) - .await?; - Ok(res) - } - - // Count the history for a given day - #[instrument(skip_all)] - async fn count_history_day(&self, user: &User, day: chrono::NaiveDate) -> Result { - let start = chrono::Utc - .ymd(day.year(), day.month(), day.day()) - .and_hms_nano(0, 0, 0, 0); - let end = chrono::Utc - .ymd(day.year(), day.month(), day.day() + 1) - .and_hms_nano(0, 0, 0, 0); - - let res = self - .count_history_range(user, start.naive_utc(), end.naive_utc()) - .await?; - Ok(res) - } - - #[instrument(skip_all)] - async fn list_history( - &self, - user: &User, - created_after: chrono::NaiveDateTime, - since: chrono::NaiveDateTime, - host: &str, - page_size: i64, - ) -> Result> { - let res = sqlx::query_as::<_, History>( - "select id, client_id, user_id, hostname, timestamp, data, created_at from history - where user_id = $1 - and hostname != $2 - and created_at >= $3 - and timestamp >= $4 - order by timestamp asc - limit $5", - ) - .bind(user.id) - .bind(host) - .bind(created_after) - .bind(since) - .bind(page_size) - .fetch_all(&self.pool) - .await?; - - Ok(res) - } - - #[instrument(skip_all)] - async fn add_history(&self, history: &[NewHistory]) -> Result<()> { - let mut tx = self.pool.begin().await?; - - for i in history { - let client_id: &str = &i.client_id; - let hostname: &str = &i.hostname; - let data: &str = &i.data; - - if data.len() > self.settings.max_history_length - && self.settings.max_history_length != 0 - { - // Don't return an error here. We want to insert as much of the - // history list as we can, so log the error and continue going. - - warn!( - "history too long, got length {}, max {}", - data.len(), - self.settings.max_history_length - ); - - continue; - } - - sqlx::query( - "insert into history - (client_id, user_id, hostname, timestamp, data) - values ($1, $2, $3, $4, $5) - on conflict do nothing - ", - ) - .bind(client_id) - .bind(i.user_id) - .bind(hostname) - .bind(i.timestamp) - .bind(data) - .execute(&mut tx) - .await?; - } - - tx.commit().await?; - - Ok(()) - } - - #[instrument(skip_all)] - async fn delete_user(&self, u: &User) -> Result<()> { - sqlx::query("delete from sessions where user_id = $1") - .bind(u.id) - .execute(&self.pool) - .await?; - - sqlx::query("delete from users where id = $1") - .bind(u.id) - .execute(&self.pool) - .await?; - - sqlx::query("delete from history where user_id = $1") - .bind(u.id) - .execute(&self.pool) - .await?; - - Ok(()) - } - - #[instrument(skip_all)] - async fn add_user(&self, user: &NewUser) -> Result { - let email: &str = &user.email; - let username: &str = &user.username; - let password: &str = &user.password; - - let res: (i64,) = sqlx::query_as( - "insert into users - (username, email, password) - values($1, $2, $3) - returning id", - ) - .bind(username) - .bind(email) - .bind(password) - .fetch_one(&self.pool) - .await?; - - Ok(res.0) - } - - #[instrument(skip_all)] - async fn add_session(&self, session: &NewSession) -> Result<()> { - let token: &str = &session.token; - - sqlx::query( - "insert into sessions - (user_id, token) - values($1, $2)", - ) - .bind(session.user_id) - .bind(token) - .execute(&self.pool) - .await?; - - Ok(()) - } - - #[instrument(skip_all)] - async fn get_user_session(&self, u: &User) -> Result { - sqlx::query_as::<_, Session>("select id, user_id, token from sessions where user_id = $1") - .bind(u.id) - .fetch_one(&self.pool) - .await - } - - #[instrument(skip_all)] - async fn oldest_history(&self, user: &User) -> Result { - let res = sqlx::query_as::<_, History>( - "select id, client_id, user_id, hostname, timestamp, data, created_at from history - where user_id = $1 - order by timestamp asc - limit 1", - ) - .bind(user.id) - .fetch_one(&self.pool) - .await?; - - Ok(res) - } - - #[instrument(skip_all)] - async fn calendar( - &self, - user: &User, - period: TimePeriod, - year: u64, - month: u64, - ) -> Result> { - // TODO: Support different timezones. Right now we assume UTC and - // everything is stored as such. But it _should_ be possible to - // interpret the stored date with a different TZ - - match period { - TimePeriod::YEAR => { - let mut ret = HashMap::new(); - // First we need to work out how far back to calculate. Get the - // oldest history item - let oldest = self.oldest_history(user).await?.timestamp.year(); - let current_year = chrono::Utc::now().year(); - - // All the years we need to get data for - // The upper bound is exclusive, so include current +1 - let years = oldest..current_year + 1; - - for year in years { - let count = self.count_history_year(user, year).await?; - - ret.insert( - year as u64, - TimePeriodInfo { - count: count as u64, - hash: "".to_string(), - }, - ); - } - - Ok(ret) - } - - TimePeriod::MONTH => { - let mut ret = HashMap::new(); - - for month in 1..13 { - let count = self - .count_history_month( - user, - chrono::Utc.ymd(year as i32, month, 1).naive_utc(), - ) - .await?; - - ret.insert( - month as u64, - TimePeriodInfo { - count: count as u64, - hash: "".to_string(), - }, - ); - } - - Ok(ret) - } - - TimePeriod::DAY => { - let mut ret = HashMap::new(); - - for day in 1..get_days_from_month(year as i32, month as u32) { - let count = self - .count_history_day( - user, - chrono::Utc - .ymd(year as i32, month as u32, day as u32) - .naive_utc(), - ) - .await?; - - ret.insert( - day as u64, - TimePeriodInfo { - count: count as u64, - hash: "".to_string(), - }, - ); - } - - Ok(ret) - } - } - } -} diff --git a/atuin-server/src/handlers/history.rs b/atuin-server/src/handlers/history.rs index 1c9dff5..bb0aa32 100644 --- a/atuin-server/src/handlers/history.rs +++ b/atuin-server/src/handlers/history.rs @@ -10,18 +10,20 @@ use tracing::{debug, error, instrument}; use super::{ErrorResponse, ErrorResponseStatus, RespExt}; use crate::{ - calendar::{TimePeriod, TimePeriodInfo}, - database::Database, - models::{NewHistory, User}, - router::AppState, + router::{AppState, UserAuth}, utils::client_version_min, }; +use atuin_server_database::{ + calendar::{TimePeriod, TimePeriodInfo}, + models::NewHistory, + Database, +}; use atuin_common::api::*; #[instrument(skip_all, fields(user.id = user.id))] pub async fn count( - user: User, + UserAuth(user): UserAuth, state: State>, ) -> Result, ErrorResponseStatus<'static>> { let db = &state.0.database; @@ -42,7 +44,7 @@ pub async fn count( #[instrument(skip_all, fields(user.id = user.id))] pub async fn list( req: Query, - user: User, + UserAuth(user): UserAuth, headers: HeaderMap, state: State>, ) -> Result, ErrorResponseStatus<'static>> { @@ -101,7 +103,7 @@ pub async fn list( #[instrument(skip_all, fields(user.id = user.id))] pub async fn delete( - user: User, + UserAuth(user): UserAuth, state: State>, Json(req): Json, ) -> Result, ErrorResponseStatus<'static>> { @@ -123,13 +125,15 @@ pub async fn delete( #[instrument(skip_all, fields(user.id = user.id))] pub async fn add( - user: User, + UserAuth(user): UserAuth, state: State>, Json(req): Json>, ) -> Result<(), ErrorResponseStatus<'static>> { + let State(AppState { database, settings }) = state; + debug!("request to add {} history items", req.len()); - let history: Vec = req + let mut history: Vec = req .into_iter() .map(|h| NewHistory { client_id: h.id, @@ -140,8 +144,24 @@ pub async fn add( }) .collect(); - let db = &state.0.database; - if let Err(e) = db.add_history(&history).await { + history.retain(|h| { + // keep if within limit, or limit is 0 (unlimited) + let keep = h.data.len() <= settings.max_history_length || settings.max_history_length == 0; + + // Don't return an error here. We want to insert as much of the + // history list as we can, so log the error and continue going. + if !keep { + tracing::warn!( + "history too long, got length {}, max {}", + h.data.len(), + settings.max_history_length + ); + } + + keep + }); + + if let Err(e) = database.add_history(&history).await { error!("failed to add history: {}", e); return Err(ErrorResponse::reply("failed to add history") @@ -155,7 +175,7 @@ pub async fn add( pub async fn calendar( Path(focus): Path, Query(params): Query>, - user: User, + UserAuth(user): UserAuth, state: State>, ) -> Result>, ErrorResponseStatus<'static>> { let focus = focus.as_str(); diff --git a/atuin-server/src/handlers/status.rs b/atuin-server/src/handlers/status.rs index 97c0288..d9b6afa 100644 --- a/atuin-server/src/handlers/status.rs +++ b/atuin-server/src/handlers/status.rs @@ -3,7 +3,8 @@ use http::StatusCode; use tracing::instrument; use super::{ErrorResponse, ErrorResponseStatus, RespExt}; -use crate::{database::Database, models::User, router::AppState}; +use crate::router::{AppState, UserAuth}; +use atuin_server_database::Database; use atuin_common::api::*; @@ -11,7 +12,7 @@ const VERSION: &str = env!("CARGO_PKG_VERSION"); #[instrument(skip_all, fields(user.id = user.id))] pub async fn status( - user: User, + UserAuth(user): UserAuth, state: State>, ) -> Result, ErrorResponseStatus<'static>> { let db = &state.0.database; diff --git a/atuin-server/src/handlers/user.rs b/atuin-server/src/handlers/user.rs index e67828e..7508115 100644 --- a/atuin-server/src/handlers/user.rs +++ b/atuin-server/src/handlers/user.rs @@ -16,10 +16,10 @@ use tracing::{debug, error, info, instrument}; use uuid::Uuid; use super::{ErrorResponse, ErrorResponseStatus, RespExt}; -use crate::{ - database::Database, - models::{NewSession, NewUser, User}, - router::AppState, +use crate::router::{AppState, UserAuth}; +use atuin_server_database::{ + models::{NewSession, NewUser}, + Database, DbError, }; use reqwest::header::CONTENT_TYPE; @@ -64,11 +64,11 @@ pub async fn get( let db = &state.0.database; let user = match db.get_user(username.as_ref()).await { Ok(user) => user, - Err(sqlx::Error::RowNotFound) => { + Err(DbError::NotFound) => { debug!("user not found: {}", username); return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND)); } - Err(err) => { + Err(DbError::Other(err)) => { error!("database error: {}", err); return Err(ErrorResponse::reply("database error") .with_status(StatusCode::INTERNAL_SERVER_ERROR)); @@ -152,7 +152,7 @@ pub async fn register( #[instrument(skip_all, fields(user.id = user.id))] pub async fn delete( - user: User, + UserAuth(user): UserAuth, state: State>, ) -> Result, ErrorResponseStatus<'static>> { debug!("request to delete user {}", user.id); @@ -175,10 +175,10 @@ pub async fn login( let db = &state.0.database; let user = match db.get_user(login.username.borrow()).await { Ok(u) => u, - Err(sqlx::Error::RowNotFound) => { + Err(DbError::NotFound) => { return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND)); } - Err(e) => { + Err(DbError::Other(e)) => { error!("failed to get user {}: {}", login.username.clone(), e); return Err(ErrorResponse::reply("database error") @@ -188,11 +188,11 @@ pub async fn login( let session = match db.get_user_session(&user).await { Ok(u) => u, - Err(sqlx::Error::RowNotFound) => { + Err(DbError::NotFound) => { debug!("user session not found for user id={}", user.id); return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND)); } - Err(err) => { + Err(DbError::Other(err)) => { error!("database error for user {}: {}", login.username, err); return Err(ErrorResponse::reply("database error") .with_status(StatusCode::INTERNAL_SERVER_ERROR)); diff --git a/atuin-server/src/lib.rs b/atuin-server/src/lib.rs index 01873af..aa2250d 100644 --- a/atuin-server/src/lib.rs +++ b/atuin-server/src/lib.rs @@ -2,45 +2,38 @@ use std::net::{IpAddr, SocketAddr}; +use atuin_server_database::Database; use axum::Server; -use database::Postgres; use eyre::{Context, Result}; -use crate::settings::Settings; +mod handlers; +mod router; +mod settings; +mod utils; +pub use settings::Settings; use tokio::signal; -pub mod auth; -pub mod calendar; -pub mod database; -pub mod handlers; -pub mod models; -pub mod router; -pub mod settings; -pub mod utils; - async fn shutdown_signal() { - let terminate = async { - signal::unix::signal(signal::unix::SignalKind::terminate()) - .expect("failed to register signal handler") - .recv() - .await; - }; - - tokio::select! { - _ = terminate => (), - } + signal::unix::signal(signal::unix::SignalKind::terminate()) + .expect("failed to register signal handler") + .recv() + .await; eprintln!("Shutting down gracefully..."); } -pub async fn launch(settings: Settings, host: String, port: u16) -> Result<()> { +pub async fn launch( + settings: Settings, + host: String, + port: u16, +) -> Result<()> { let host = host.parse::()?; - let postgres = Postgres::new(settings.clone()) + let db = Db::new(&settings.db_settings) .await - .wrap_err_with(|| format!("failed to connect to db: {}", settings.db_uri))?; + .wrap_err_with(|| format!("failed to connect to db: {:?}", settings.db_settings))?; - let r = router::router(postgres, settings); + let r = router::router(db, settings); Server::bind(&SocketAddr::new(host, port)) .serve(r.into_make_service()) diff --git a/atuin-server/src/router.rs b/atuin-server/src/router.rs index 20b11f4..ec558e7 100644 --- a/atuin-server/src/router.rs +++ b/atuin-server/src/router.rs @@ -10,11 +10,14 @@ use http::request::Parts; use tower::ServiceBuilder; use tower_http::trace::TraceLayer; -use super::{database::Database, handlers}; -use crate::{models::User, settings::Settings}; +use super::handlers; +use crate::settings::Settings; +use atuin_server_database::{models::User, Database}; + +pub struct UserAuth(pub User); #[async_trait] -impl FromRequestParts> for User +impl FromRequestParts> for UserAuth where DB: Database, { @@ -45,7 +48,7 @@ where .await .map_err(|_| http::StatusCode::FORBIDDEN)?; - Ok(user) + Ok(UserAuth(user)) } } @@ -54,15 +57,12 @@ async fn teapot() -> impl IntoResponse { } #[derive(Clone)] -pub struct AppState { +pub struct AppState { pub database: DB, - pub settings: Settings, + pub settings: Settings, } -pub fn router( - database: DB, - settings: Settings, -) -> Router { +pub fn router(database: DB, settings: Settings) -> Router { let routes = Router::new() .route("/", get(handlers::index)) .route("/sync/count", get(handlers::history::count)) diff --git a/atuin-server/src/settings.rs b/atuin-server/src/settings.rs index 981d239..fb5325d 100644 --- a/atuin-server/src/settings.rs +++ b/atuin-server/src/settings.rs @@ -3,24 +3,24 @@ use std::{io::prelude::*, path::PathBuf}; use config::{Config, Environment, File as ConfigFile, FileFormat}; use eyre::{eyre, Result}; use fs_err::{create_dir_all, File}; -use serde::{Deserialize, Serialize}; - -pub const HISTORY_PAGE_SIZE: i64 = 100; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; #[derive(Clone, Debug, Deserialize, Serialize)] -pub struct Settings { +pub struct Settings { pub host: String, pub port: u16, pub path: String, - pub db_uri: String, pub open_registration: bool, pub max_history_length: usize, pub page_size: i64, pub register_webhook_url: Option, pub register_webhook_username: String, + + #[serde(flatten)] + pub db_settings: DbSettings, } -impl Settings { +impl Settings { pub fn new() -> Result { let mut config_file = if let Ok(p) = std::env::var("ATUIN_CONFIG_DIR") { PathBuf::from(p) diff --git a/atuin/Cargo.toml b/atuin/Cargo.toml index 9085623..bf03805 100644 --- a/atuin/Cargo.toml +++ b/atuin/Cargo.toml @@ -33,15 +33,13 @@ buildflags = ["--release"] atuin = { path = "/usr/bin/atuin" } [features] -# TODO(conradludgate) -# Currently, this keeps the same default built behaviour for v0.8 -# We should rethink this by the time we hit a new breaking change default = ["client", "sync", "server"] client = ["atuin-client"] sync = ["atuin-client/sync"] -server = ["atuin-server", "tracing-subscriber"] +server = ["atuin-server", "atuin-server-postgres", "tracing-subscriber"] [dependencies] +atuin-server-postgres = { path = "../atuin-server-postgres", version = "15.0.0", optional = true } atuin-server = { path = "../atuin-server", version = "15.0.0", optional = true } atuin-client = { path = "../atuin-client", version = "15.0.0", optional = true, default-features = false } atuin-common = { path = "../atuin-common", version = "15.0.0" } @@ -61,7 +59,6 @@ tokio = { workspace = true } async-trait = { workspace = true } interim = { workspace = true } base64 = { workspace = true } -crossbeam-channel = "0.5.1" clap = { workspace = true } clap_complete = "4.0.3" fs-err = { workspace = true } diff --git a/atuin/src/command/server.rs b/atuin/src/command/server.rs index 495f85d..c65cb50 100644 --- a/atuin/src/command/server.rs +++ b/atuin/src/command/server.rs @@ -1,9 +1,10 @@ +use atuin_server_postgres::Postgres; use tracing_subscriber::{fmt, prelude::*, EnvFilter}; use clap::Parser; use eyre::{Context, Result}; -use atuin_server::{launch, settings::Settings}; +use atuin_server::{launch, Settings}; #[derive(Parser)] #[clap(infer_subcommands = true)] @@ -37,7 +38,7 @@ impl Cmd { .map_or(settings.host.clone(), std::string::ToString::to_string); let port = port.map_or(settings.port, |p| p); - launch(settings, host, port).await + launch::(settings, host, port).await } } }