refactor server to allow pluggable db and tracing (#1036)
* refactor server to allow pluggable db and tracing * clean up * fix descriptions * remove dependencies
This commit is contained in:
parent
dccdb2c33f
commit
8655c93853
33 changed files with 760 additions and 888 deletions
104
Cargo.lock
generated
104
Cargo.lock
generated
|
@ -97,6 +97,7 @@ dependencies = [
|
||||||
"atuin-client",
|
"atuin-client",
|
||||||
"atuin-common",
|
"atuin-common",
|
||||||
"atuin-server",
|
"atuin-server",
|
||||||
|
"atuin-server-postgres",
|
||||||
"base64 0.21.0",
|
"base64 0.21.0",
|
||||||
"bitflags",
|
"bitflags",
|
||||||
"cassowary",
|
"cassowary",
|
||||||
|
@ -104,7 +105,6 @@ dependencies = [
|
||||||
"clap",
|
"clap",
|
||||||
"clap_complete",
|
"clap_complete",
|
||||||
"colored",
|
"colored",
|
||||||
"crossbeam-channel",
|
|
||||||
"crossterm",
|
"crossterm",
|
||||||
"directories",
|
"directories",
|
||||||
"env_logger",
|
"env_logger",
|
||||||
|
@ -160,7 +160,6 @@ dependencies = [
|
||||||
"serde_regex",
|
"serde_regex",
|
||||||
"sha2",
|
"sha2",
|
||||||
"shellexpand",
|
"shellexpand",
|
||||||
"sodiumoxide",
|
|
||||||
"sql-builder",
|
"sql-builder",
|
||||||
"sqlx",
|
"sqlx",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
@ -187,6 +186,7 @@ dependencies = [
|
||||||
"argon2",
|
"argon2",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"atuin-common",
|
"atuin-common",
|
||||||
|
"atuin-server-database",
|
||||||
"axum",
|
"axum",
|
||||||
"base64 0.21.0",
|
"base64 0.21.0",
|
||||||
"chrono",
|
"chrono",
|
||||||
|
@ -200,14 +200,39 @@ dependencies = [
|
||||||
"semver",
|
"semver",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"sodiumoxide",
|
|
||||||
"sqlx",
|
|
||||||
"tokio",
|
"tokio",
|
||||||
"tower",
|
"tower",
|
||||||
"tower-http",
|
"tower-http",
|
||||||
"tracing",
|
"tracing",
|
||||||
"uuid",
|
"uuid",
|
||||||
"whoami",
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "atuin-server-database"
|
||||||
|
version = "15.0.0"
|
||||||
|
dependencies = [
|
||||||
|
"async-trait",
|
||||||
|
"atuin-common",
|
||||||
|
"chrono",
|
||||||
|
"chronoutil",
|
||||||
|
"eyre",
|
||||||
|
"serde",
|
||||||
|
"tracing",
|
||||||
|
"uuid",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "atuin-server-postgres"
|
||||||
|
version = "15.0.0"
|
||||||
|
dependencies = [
|
||||||
|
"async-trait",
|
||||||
|
"atuin-common",
|
||||||
|
"atuin-server-database",
|
||||||
|
"chrono",
|
||||||
|
"futures-util",
|
||||||
|
"serde",
|
||||||
|
"sqlx",
|
||||||
|
"tracing",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -515,16 +540,6 @@ version = "2.1.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "2d0165d2900ae6778e36e80bbc4da3b5eefccee9ba939761f9c2882a5d9af3ff"
|
checksum = "2d0165d2900ae6778e36e80bbc4da3b5eefccee9ba939761f9c2882a5d9af3ff"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "crossbeam-channel"
|
|
||||||
version = "0.5.8"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "a33c2bf77f2df06183c3aa30d1e96c0695a313d4f9c453cc3762a6db39f99200"
|
|
||||||
dependencies = [
|
|
||||||
"cfg-if",
|
|
||||||
"crossbeam-utils",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "crossbeam-queue"
|
name = "crossbeam-queue"
|
||||||
version = "0.3.6"
|
version = "0.3.6"
|
||||||
|
@ -631,15 +646,6 @@ dependencies = [
|
||||||
"dirs",
|
"dirs",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "ed25519"
|
|
||||||
version = "1.5.2"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "1e9c280362032ea4203659fc489832d0204ef09f247a0506f170dafcac08c369"
|
|
||||||
dependencies = [
|
|
||||||
"signature",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "either"
|
name = "either"
|
||||||
version = "1.8.0"
|
version = "1.8.0"
|
||||||
|
@ -1175,18 +1181,6 @@ version = "0.2.141"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "3304a64d199bb964be99741b7a14d26972741915b3649639149b2479bb46f4b5"
|
checksum = "3304a64d199bb964be99741b7a14d26972741915b3649639149b2479bb46f4b5"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "libsodium-sys"
|
|
||||||
version = "0.2.7"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "6b779387cd56adfbc02ea4a668e704f729be8d6a6abd2c27ca5ee537849a92fd"
|
|
||||||
dependencies = [
|
|
||||||
"cc",
|
|
||||||
"libc",
|
|
||||||
"pkg-config",
|
|
||||||
"walkdir",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "libsqlite3-sys"
|
name = "libsqlite3-sys"
|
||||||
version = "0.24.2"
|
version = "0.24.2"
|
||||||
|
@ -1875,15 +1869,6 @@ dependencies = [
|
||||||
"cipher",
|
"cipher",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "same-file"
|
|
||||||
version = "1.0.6"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
|
|
||||||
dependencies = [
|
|
||||||
"winapi-util",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "schannel"
|
name = "schannel"
|
||||||
version = "0.1.20"
|
version = "0.1.20"
|
||||||
|
@ -2071,12 +2056,6 @@ dependencies = [
|
||||||
"libc",
|
"libc",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "signature"
|
|
||||||
version = "1.6.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "e90531723b08e4d6d71b791108faf51f03e1b4a7784f96b2b87f852ebc247228"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "slab"
|
name = "slab"
|
||||||
version = "0.4.7"
|
version = "0.4.7"
|
||||||
|
@ -2102,18 +2081,6 @@ dependencies = [
|
||||||
"winapi",
|
"winapi",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "sodiumoxide"
|
|
||||||
version = "0.2.7"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "e26be3acb6c2d9a7aac28482586a7856436af4cfe7100031d219de2d2ecb0028"
|
|
||||||
dependencies = [
|
|
||||||
"ed25519",
|
|
||||||
"libc",
|
|
||||||
"libsodium-sys",
|
|
||||||
"serde",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "spin"
|
name = "spin"
|
||||||
version = "0.5.2"
|
version = "0.5.2"
|
||||||
|
@ -2659,17 +2626,6 @@ version = "0.9.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
|
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "walkdir"
|
|
||||||
version = "2.3.2"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "808cf2735cd4b6866113f648b791c6adc5714537bc222d9347bb203386ffda56"
|
|
||||||
dependencies = [
|
|
||||||
"same-file",
|
|
||||||
"winapi",
|
|
||||||
"winapi-util",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "want"
|
name = "want"
|
||||||
version = "0.3.0"
|
version = "0.3.0"
|
||||||
|
|
10
Cargo.toml
10
Cargo.toml
|
@ -1,5 +1,12 @@
|
||||||
[workspace]
|
[workspace]
|
||||||
members = ["atuin", "atuin-client", "atuin-server", "atuin-common"]
|
members = [
|
||||||
|
"atuin",
|
||||||
|
"atuin-client",
|
||||||
|
"atuin-server",
|
||||||
|
"atuin-server-postgres",
|
||||||
|
"atuin-server-database",
|
||||||
|
"atuin-common",
|
||||||
|
]
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
name = "atuin"
|
name = "atuin"
|
||||||
|
@ -27,7 +34,6 @@ rand = { version = "0.8.5", features = ["std"] }
|
||||||
semver = "1.0.14"
|
semver = "1.0.14"
|
||||||
serde = { version = "1.0.145", features = ["derive"] }
|
serde = { version = "1.0.145", features = ["derive"] }
|
||||||
serde_json = "1.0.86"
|
serde_json = "1.0.86"
|
||||||
sodiumoxide = "0.2.6"
|
|
||||||
tokio = { version = "1", features = ["full"] }
|
tokio = { version = "1", features = ["full"] }
|
||||||
uuid = { version = "1.2", features = ["v4"] }
|
uuid = { version = "1.2", features = ["v4"] }
|
||||||
whoami = "1.1.2"
|
whoami = "1.1.2"
|
||||||
|
|
|
@ -53,7 +53,6 @@ memchr = "2.5"
|
||||||
|
|
||||||
# sync
|
# sync
|
||||||
urlencoding = { version = "2.1.0", optional = true }
|
urlencoding = { version = "2.1.0", optional = true }
|
||||||
sodiumoxide = { workspace = true, optional = true }
|
|
||||||
reqwest = { workspace = true, optional = true }
|
reqwest = { workspace = true, optional = true }
|
||||||
hex = { version = "0.4", optional = true }
|
hex = { version = "0.4", optional = true }
|
||||||
sha2 = { version = "0.10", optional = true }
|
sha2 = { version = "0.10", optional = true }
|
||||||
|
|
21
atuin-server-database/Cargo.toml
Normal file
21
atuin-server-database/Cargo.toml
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
[package]
|
||||||
|
name = "atuin-server-database"
|
||||||
|
edition = "2021"
|
||||||
|
description = "server database library for atuin"
|
||||||
|
|
||||||
|
version = { workspace = true }
|
||||||
|
authors = { workspace = true }
|
||||||
|
license = { workspace = true }
|
||||||
|
homepage = { workspace = true }
|
||||||
|
repository = { workspace = true }
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
atuin-common = { path = "../atuin-common", version = "15.0.0" }
|
||||||
|
|
||||||
|
tracing = "0.1"
|
||||||
|
chrono = { workspace = true }
|
||||||
|
eyre = { workspace = true }
|
||||||
|
uuid = { workspace = true }
|
||||||
|
serde = { workspace = true }
|
||||||
|
async-trait = { workspace = true }
|
||||||
|
chronoutil = "0.2.3"
|
220
atuin-server-database/src/lib.rs
Normal file
220
atuin-server-database/src/lib.rs
Normal file
|
@ -0,0 +1,220 @@
|
||||||
|
#![forbid(unsafe_code)]
|
||||||
|
|
||||||
|
pub mod calendar;
|
||||||
|
pub mod models;
|
||||||
|
|
||||||
|
use std::{
|
||||||
|
collections::HashMap,
|
||||||
|
fmt::{Debug, Display},
|
||||||
|
};
|
||||||
|
|
||||||
|
use self::{
|
||||||
|
calendar::{TimePeriod, TimePeriodInfo},
|
||||||
|
models::{History, NewHistory, NewSession, NewUser, Session, User},
|
||||||
|
};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use atuin_common::utils::get_days_from_month;
|
||||||
|
use chrono::{Datelike, TimeZone};
|
||||||
|
use chronoutil::RelativeDuration;
|
||||||
|
use serde::{de::DeserializeOwned, Serialize};
|
||||||
|
use tracing::instrument;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum DbError {
|
||||||
|
NotFound,
|
||||||
|
Other(eyre::Report),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for DbError {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "{self:?}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for DbError {}
|
||||||
|
|
||||||
|
pub type DbResult<T> = Result<T, DbError>;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait Database: Sized + Clone + Send + Sync + 'static {
|
||||||
|
type Settings: Debug + Clone + DeserializeOwned + Serialize + Send + Sync + 'static;
|
||||||
|
async fn new(settings: &Self::Settings) -> DbResult<Self>;
|
||||||
|
|
||||||
|
async fn get_session(&self, token: &str) -> DbResult<Session>;
|
||||||
|
async fn get_session_user(&self, token: &str) -> DbResult<User>;
|
||||||
|
async fn add_session(&self, session: &NewSession) -> DbResult<()>;
|
||||||
|
|
||||||
|
async fn get_user(&self, username: &str) -> DbResult<User>;
|
||||||
|
async fn get_user_session(&self, u: &User) -> DbResult<Session>;
|
||||||
|
async fn add_user(&self, user: &NewUser) -> DbResult<i64>;
|
||||||
|
async fn delete_user(&self, u: &User) -> DbResult<()>;
|
||||||
|
|
||||||
|
async fn count_history(&self, user: &User) -> DbResult<i64>;
|
||||||
|
async fn count_history_cached(&self, user: &User) -> DbResult<i64>;
|
||||||
|
|
||||||
|
async fn delete_history(&self, user: &User, id: String) -> DbResult<()>;
|
||||||
|
async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>>;
|
||||||
|
|
||||||
|
async fn count_history_range(
|
||||||
|
&self,
|
||||||
|
user: &User,
|
||||||
|
start: chrono::NaiveDateTime,
|
||||||
|
end: chrono::NaiveDateTime,
|
||||||
|
) -> DbResult<i64>;
|
||||||
|
|
||||||
|
async fn list_history(
|
||||||
|
&self,
|
||||||
|
user: &User,
|
||||||
|
created_after: chrono::NaiveDateTime,
|
||||||
|
since: chrono::NaiveDateTime,
|
||||||
|
host: &str,
|
||||||
|
page_size: i64,
|
||||||
|
) -> DbResult<Vec<History>>;
|
||||||
|
|
||||||
|
async fn add_history(&self, history: &[NewHistory]) -> DbResult<()>;
|
||||||
|
|
||||||
|
async fn oldest_history(&self, user: &User) -> DbResult<History>;
|
||||||
|
|
||||||
|
/// Count the history for a given year
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn count_history_year(&self, user: &User, year: i32) -> DbResult<i64> {
|
||||||
|
let start = chrono::Utc.ymd(year, 1, 1).and_hms_nano(0, 0, 0, 0);
|
||||||
|
let end = start + RelativeDuration::years(1);
|
||||||
|
|
||||||
|
let res = self
|
||||||
|
.count_history_range(user, start.naive_utc(), end.naive_utc())
|
||||||
|
.await?;
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Count the history for a given month
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn count_history_month(&self, user: &User, month: chrono::NaiveDate) -> DbResult<i64> {
|
||||||
|
let start = chrono::Utc
|
||||||
|
.ymd(month.year(), month.month(), 1)
|
||||||
|
.and_hms_nano(0, 0, 0, 0);
|
||||||
|
|
||||||
|
// ofc...
|
||||||
|
let end = if month.month() < 12 {
|
||||||
|
chrono::Utc
|
||||||
|
.ymd(month.year(), month.month() + 1, 1)
|
||||||
|
.and_hms_nano(0, 0, 0, 0)
|
||||||
|
} else {
|
||||||
|
chrono::Utc
|
||||||
|
.ymd(month.year() + 1, 1, 1)
|
||||||
|
.and_hms_nano(0, 0, 0, 0)
|
||||||
|
};
|
||||||
|
|
||||||
|
tracing::debug!("start: {}, end: {}", start, end);
|
||||||
|
|
||||||
|
let res = self
|
||||||
|
.count_history_range(user, start.naive_utc(), end.naive_utc())
|
||||||
|
.await?;
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Count the history for a given day
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn count_history_day(&self, user: &User, day: chrono::NaiveDate) -> DbResult<i64> {
|
||||||
|
let start = chrono::Utc
|
||||||
|
.ymd(day.year(), day.month(), day.day())
|
||||||
|
.and_hms_nano(0, 0, 0, 0);
|
||||||
|
let end = chrono::Utc
|
||||||
|
.ymd(day.year(), day.month(), day.day() + 1)
|
||||||
|
.and_hms_nano(0, 0, 0, 0);
|
||||||
|
|
||||||
|
let res = self
|
||||||
|
.count_history_range(user, start.naive_utc(), end.naive_utc())
|
||||||
|
.await?;
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn calendar(
|
||||||
|
&self,
|
||||||
|
user: &User,
|
||||||
|
period: TimePeriod,
|
||||||
|
year: u64,
|
||||||
|
month: u64,
|
||||||
|
) -> DbResult<HashMap<u64, TimePeriodInfo>> {
|
||||||
|
// TODO: Support different timezones. Right now we assume UTC and
|
||||||
|
// everything is stored as such. But it _should_ be possible to
|
||||||
|
// interpret the stored date with a different TZ
|
||||||
|
|
||||||
|
match period {
|
||||||
|
TimePeriod::YEAR => {
|
||||||
|
let mut ret = HashMap::new();
|
||||||
|
// First we need to work out how far back to calculate. Get the
|
||||||
|
// oldest history item
|
||||||
|
let oldest = self.oldest_history(user).await?.timestamp.year();
|
||||||
|
let current_year = chrono::Utc::now().year();
|
||||||
|
|
||||||
|
// All the years we need to get data for
|
||||||
|
// The upper bound is exclusive, so include current +1
|
||||||
|
let years = oldest..current_year + 1;
|
||||||
|
|
||||||
|
for year in years {
|
||||||
|
let count = self.count_history_year(user, year).await?;
|
||||||
|
|
||||||
|
ret.insert(
|
||||||
|
year as u64,
|
||||||
|
TimePeriodInfo {
|
||||||
|
count: count as u64,
|
||||||
|
hash: "".to_string(),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(ret)
|
||||||
|
}
|
||||||
|
|
||||||
|
TimePeriod::MONTH => {
|
||||||
|
let mut ret = HashMap::new();
|
||||||
|
|
||||||
|
for month in 1..13 {
|
||||||
|
let count = self
|
||||||
|
.count_history_month(
|
||||||
|
user,
|
||||||
|
chrono::Utc.ymd(year as i32, month, 1).naive_utc(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
ret.insert(
|
||||||
|
month as u64,
|
||||||
|
TimePeriodInfo {
|
||||||
|
count: count as u64,
|
||||||
|
hash: "".to_string(),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(ret)
|
||||||
|
}
|
||||||
|
|
||||||
|
TimePeriod::DAY => {
|
||||||
|
let mut ret = HashMap::new();
|
||||||
|
|
||||||
|
for day in 1..get_days_from_month(year as i32, month as u32) {
|
||||||
|
let count = self
|
||||||
|
.count_history_day(
|
||||||
|
user,
|
||||||
|
chrono::Utc
|
||||||
|
.ymd(year as i32, month as u32, day as u32)
|
||||||
|
.naive_utc(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
ret.insert(
|
||||||
|
day as u64,
|
||||||
|
TimePeriodInfo {
|
||||||
|
count: count as u64,
|
||||||
|
hash: "".to_string(),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(ret)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,6 +1,5 @@
|
||||||
use chrono::prelude::*;
|
use chrono::prelude::*;
|
||||||
|
|
||||||
#[derive(sqlx::FromRow)]
|
|
||||||
pub struct History {
|
pub struct History {
|
||||||
pub id: i64,
|
pub id: i64,
|
||||||
pub client_id: String, // a client generated ID
|
pub client_id: String, // a client generated ID
|
||||||
|
@ -22,7 +21,6 @@ pub struct NewHistory {
|
||||||
pub data: String,
|
pub data: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(sqlx::FromRow)]
|
|
||||||
pub struct User {
|
pub struct User {
|
||||||
pub id: i64,
|
pub id: i64,
|
||||||
pub username: String,
|
pub username: String,
|
||||||
|
@ -30,7 +28,6 @@ pub struct User {
|
||||||
pub password: String,
|
pub password: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(sqlx::FromRow)]
|
|
||||||
pub struct Session {
|
pub struct Session {
|
||||||
pub id: i64,
|
pub id: i64,
|
||||||
pub user_id: i64,
|
pub user_id: i64,
|
21
atuin-server-postgres/Cargo.toml
Normal file
21
atuin-server-postgres/Cargo.toml
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
[package]
|
||||||
|
name = "atuin-server-postgres"
|
||||||
|
edition = "2018"
|
||||||
|
description = "server postgres database library for atuin"
|
||||||
|
|
||||||
|
version = { workspace = true }
|
||||||
|
authors = { workspace = true }
|
||||||
|
license = { workspace = true }
|
||||||
|
homepage = { workspace = true }
|
||||||
|
repository = { workspace = true }
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
atuin-common = { path = "../atuin-common", version = "15.0.0" }
|
||||||
|
atuin-server-database = { path = "../atuin-server-database", version = "15.0.0" }
|
||||||
|
|
||||||
|
tracing = "0.1"
|
||||||
|
chrono = { workspace = true }
|
||||||
|
serde = { workspace = true }
|
||||||
|
sqlx = { workspace = true }
|
||||||
|
async-trait = { workspace = true }
|
||||||
|
futures-util = "0.3"
|
332
atuin-server-postgres/src/lib.rs
Normal file
332
atuin-server-postgres/src/lib.rs
Normal file
|
@ -0,0 +1,332 @@
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use atuin_server_database::models::{History, NewHistory, NewSession, NewUser, Session, User};
|
||||||
|
use atuin_server_database::{Database, DbError, DbResult};
|
||||||
|
use futures_util::TryStreamExt;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use sqlx::postgres::PgPoolOptions;
|
||||||
|
|
||||||
|
use sqlx::Row;
|
||||||
|
|
||||||
|
use tracing::instrument;
|
||||||
|
use wrappers::{DbHistory, DbSession, DbUser};
|
||||||
|
|
||||||
|
mod wrappers;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct Postgres {
|
||||||
|
pool: sqlx::Pool<sqlx::postgres::Postgres>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct PostgresSettings {
|
||||||
|
pub db_uri: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fix_error(error: sqlx::Error) -> DbError {
|
||||||
|
match error {
|
||||||
|
sqlx::Error::RowNotFound => DbError::NotFound,
|
||||||
|
error => DbError::Other(error.into()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Database for Postgres {
|
||||||
|
type Settings = PostgresSettings;
|
||||||
|
async fn new(settings: &PostgresSettings) -> DbResult<Self> {
|
||||||
|
let pool = PgPoolOptions::new()
|
||||||
|
.max_connections(100)
|
||||||
|
.connect(settings.db_uri.as_str())
|
||||||
|
.await
|
||||||
|
.map_err(fix_error)?;
|
||||||
|
|
||||||
|
sqlx::migrate!("./migrations")
|
||||||
|
.run(&pool)
|
||||||
|
.await
|
||||||
|
.map_err(|error| DbError::Other(error.into()))?;
|
||||||
|
|
||||||
|
Ok(Self { pool })
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn get_session(&self, token: &str) -> DbResult<Session> {
|
||||||
|
sqlx::query_as("select id, user_id, token from sessions where token = $1")
|
||||||
|
.bind(token)
|
||||||
|
.fetch_one(&self.pool)
|
||||||
|
.await
|
||||||
|
.map_err(fix_error)
|
||||||
|
.map(|DbSession(session)| session)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn get_user(&self, username: &str) -> DbResult<User> {
|
||||||
|
sqlx::query_as("select id, username, email, password from users where username = $1")
|
||||||
|
.bind(username)
|
||||||
|
.fetch_one(&self.pool)
|
||||||
|
.await
|
||||||
|
.map_err(fix_error)
|
||||||
|
.map(|DbUser(user)| user)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn get_session_user(&self, token: &str) -> DbResult<User> {
|
||||||
|
sqlx::query_as(
|
||||||
|
"select users.id, users.username, users.email, users.password from users
|
||||||
|
inner join sessions
|
||||||
|
on users.id = sessions.user_id
|
||||||
|
and sessions.token = $1",
|
||||||
|
)
|
||||||
|
.bind(token)
|
||||||
|
.fetch_one(&self.pool)
|
||||||
|
.await
|
||||||
|
.map_err(fix_error)
|
||||||
|
.map(|DbUser(user)| user)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn count_history(&self, user: &User) -> DbResult<i64> {
|
||||||
|
// The cache is new, and the user might not yet have a cache value.
|
||||||
|
// They will have one as soon as they post up some new history, but handle that
|
||||||
|
// edge case.
|
||||||
|
|
||||||
|
let res: (i64,) = sqlx::query_as(
|
||||||
|
"select count(1) from history
|
||||||
|
where user_id = $1",
|
||||||
|
)
|
||||||
|
.bind(user.id)
|
||||||
|
.fetch_one(&self.pool)
|
||||||
|
.await
|
||||||
|
.map_err(fix_error)?;
|
||||||
|
|
||||||
|
Ok(res.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn count_history_cached(&self, user: &User) -> DbResult<i64> {
|
||||||
|
let res: (i32,) = sqlx::query_as(
|
||||||
|
"select total from total_history_count_user
|
||||||
|
where user_id = $1",
|
||||||
|
)
|
||||||
|
.bind(user.id)
|
||||||
|
.fetch_one(&self.pool)
|
||||||
|
.await
|
||||||
|
.map_err(fix_error)?;
|
||||||
|
|
||||||
|
Ok(res.0 as i64)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn delete_history(&self, user: &User, id: String) -> DbResult<()> {
|
||||||
|
sqlx::query(
|
||||||
|
"update history
|
||||||
|
set deleted_at = $3
|
||||||
|
where user_id = $1
|
||||||
|
and client_id = $2
|
||||||
|
and deleted_at is null", // don't just keep setting it
|
||||||
|
)
|
||||||
|
.bind(user.id)
|
||||||
|
.bind(id)
|
||||||
|
.bind(chrono::Utc::now().naive_utc())
|
||||||
|
.fetch_all(&self.pool)
|
||||||
|
.await
|
||||||
|
.map_err(fix_error)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>> {
|
||||||
|
// The cache is new, and the user might not yet have a cache value.
|
||||||
|
// They will have one as soon as they post up some new history, but handle that
|
||||||
|
// edge case.
|
||||||
|
|
||||||
|
let res = sqlx::query(
|
||||||
|
"select client_id from history
|
||||||
|
where user_id = $1
|
||||||
|
and deleted_at is not null",
|
||||||
|
)
|
||||||
|
.bind(user.id)
|
||||||
|
.fetch_all(&self.pool)
|
||||||
|
.await
|
||||||
|
.map_err(fix_error)?;
|
||||||
|
|
||||||
|
let res = res
|
||||||
|
.iter()
|
||||||
|
.map(|row| row.get::<String, _>("client_id"))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn count_history_range(
|
||||||
|
&self,
|
||||||
|
user: &User,
|
||||||
|
start: chrono::NaiveDateTime,
|
||||||
|
end: chrono::NaiveDateTime,
|
||||||
|
) -> DbResult<i64> {
|
||||||
|
let res: (i64,) = sqlx::query_as(
|
||||||
|
"select count(1) from history
|
||||||
|
where user_id = $1
|
||||||
|
and timestamp >= $2::date
|
||||||
|
and timestamp < $3::date",
|
||||||
|
)
|
||||||
|
.bind(user.id)
|
||||||
|
.bind(start)
|
||||||
|
.bind(end)
|
||||||
|
.fetch_one(&self.pool)
|
||||||
|
.await
|
||||||
|
.map_err(fix_error)?;
|
||||||
|
|
||||||
|
Ok(res.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn list_history(
|
||||||
|
&self,
|
||||||
|
user: &User,
|
||||||
|
created_after: chrono::NaiveDateTime,
|
||||||
|
since: chrono::NaiveDateTime,
|
||||||
|
host: &str,
|
||||||
|
page_size: i64,
|
||||||
|
) -> DbResult<Vec<History>> {
|
||||||
|
let res = sqlx::query_as(
|
||||||
|
"select id, client_id, user_id, hostname, timestamp, data, created_at from history
|
||||||
|
where user_id = $1
|
||||||
|
and hostname != $2
|
||||||
|
and created_at >= $3
|
||||||
|
and timestamp >= $4
|
||||||
|
order by timestamp asc
|
||||||
|
limit $5",
|
||||||
|
)
|
||||||
|
.bind(user.id)
|
||||||
|
.bind(host)
|
||||||
|
.bind(created_after)
|
||||||
|
.bind(since)
|
||||||
|
.bind(page_size)
|
||||||
|
.fetch(&self.pool)
|
||||||
|
.map_ok(|DbHistory(h)| h)
|
||||||
|
.try_collect()
|
||||||
|
.await
|
||||||
|
.map_err(fix_error)?;
|
||||||
|
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn add_history(&self, history: &[NewHistory]) -> DbResult<()> {
|
||||||
|
let mut tx = self.pool.begin().await.map_err(fix_error)?;
|
||||||
|
|
||||||
|
for i in history {
|
||||||
|
let client_id: &str = &i.client_id;
|
||||||
|
let hostname: &str = &i.hostname;
|
||||||
|
let data: &str = &i.data;
|
||||||
|
|
||||||
|
sqlx::query(
|
||||||
|
"insert into history
|
||||||
|
(client_id, user_id, hostname, timestamp, data)
|
||||||
|
values ($1, $2, $3, $4, $5)
|
||||||
|
on conflict do nothing
|
||||||
|
",
|
||||||
|
)
|
||||||
|
.bind(client_id)
|
||||||
|
.bind(i.user_id)
|
||||||
|
.bind(hostname)
|
||||||
|
.bind(i.timestamp)
|
||||||
|
.bind(data)
|
||||||
|
.execute(&mut tx)
|
||||||
|
.await
|
||||||
|
.map_err(fix_error)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
tx.commit().await.map_err(fix_error)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn delete_user(&self, u: &User) -> DbResult<()> {
|
||||||
|
sqlx::query("delete from sessions where user_id = $1")
|
||||||
|
.bind(u.id)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await
|
||||||
|
.map_err(fix_error)?;
|
||||||
|
|
||||||
|
sqlx::query("delete from users where id = $1")
|
||||||
|
.bind(u.id)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await
|
||||||
|
.map_err(fix_error)?;
|
||||||
|
|
||||||
|
sqlx::query("delete from history where user_id = $1")
|
||||||
|
.bind(u.id)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await
|
||||||
|
.map_err(fix_error)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn add_user(&self, user: &NewUser) -> DbResult<i64> {
|
||||||
|
let email: &str = &user.email;
|
||||||
|
let username: &str = &user.username;
|
||||||
|
let password: &str = &user.password;
|
||||||
|
|
||||||
|
let res: (i64,) = sqlx::query_as(
|
||||||
|
"insert into users
|
||||||
|
(username, email, password)
|
||||||
|
values($1, $2, $3)
|
||||||
|
returning id",
|
||||||
|
)
|
||||||
|
.bind(username)
|
||||||
|
.bind(email)
|
||||||
|
.bind(password)
|
||||||
|
.fetch_one(&self.pool)
|
||||||
|
.await
|
||||||
|
.map_err(fix_error)?;
|
||||||
|
|
||||||
|
Ok(res.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn add_session(&self, session: &NewSession) -> DbResult<()> {
|
||||||
|
let token: &str = &session.token;
|
||||||
|
|
||||||
|
sqlx::query(
|
||||||
|
"insert into sessions
|
||||||
|
(user_id, token)
|
||||||
|
values($1, $2)",
|
||||||
|
)
|
||||||
|
.bind(session.user_id)
|
||||||
|
.bind(token)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await
|
||||||
|
.map_err(fix_error)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn get_user_session(&self, u: &User) -> DbResult<Session> {
|
||||||
|
sqlx::query_as("select id, user_id, token from sessions where user_id = $1")
|
||||||
|
.bind(u.id)
|
||||||
|
.fetch_one(&self.pool)
|
||||||
|
.await
|
||||||
|
.map_err(fix_error)
|
||||||
|
.map(|DbSession(session)| session)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn oldest_history(&self, user: &User) -> DbResult<History> {
|
||||||
|
sqlx::query_as(
|
||||||
|
"select id, client_id, user_id, hostname, timestamp, data, created_at from history
|
||||||
|
where user_id = $1
|
||||||
|
order by timestamp asc
|
||||||
|
limit 1",
|
||||||
|
)
|
||||||
|
.bind(user.id)
|
||||||
|
.fetch_one(&self.pool)
|
||||||
|
.await
|
||||||
|
.map_err(fix_error)
|
||||||
|
.map(|DbHistory(h)| h)
|
||||||
|
}
|
||||||
|
}
|
42
atuin-server-postgres/src/wrappers.rs
Normal file
42
atuin-server-postgres/src/wrappers.rs
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
use ::sqlx::{FromRow, Result};
|
||||||
|
use atuin_server_database::models::{History, Session, User};
|
||||||
|
use sqlx::{postgres::PgRow, Row};
|
||||||
|
|
||||||
|
pub struct DbUser(pub User);
|
||||||
|
pub struct DbSession(pub Session);
|
||||||
|
pub struct DbHistory(pub History);
|
||||||
|
|
||||||
|
impl<'a> FromRow<'a, PgRow> for DbUser {
|
||||||
|
fn from_row(row: &'a PgRow) -> Result<Self> {
|
||||||
|
Ok(Self(User {
|
||||||
|
id: row.try_get("id")?,
|
||||||
|
username: row.try_get("username")?,
|
||||||
|
email: row.try_get("email")?,
|
||||||
|
password: row.try_get("password")?,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> ::sqlx::FromRow<'a, PgRow> for DbSession {
|
||||||
|
fn from_row(row: &'a PgRow) -> ::sqlx::Result<Self> {
|
||||||
|
Ok(Self(Session {
|
||||||
|
id: row.try_get("id")?,
|
||||||
|
user_id: row.try_get("user_id")?,
|
||||||
|
token: row.try_get("token")?,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> ::sqlx::FromRow<'a, PgRow> for DbHistory {
|
||||||
|
fn from_row(row: &'a PgRow) -> ::sqlx::Result<Self> {
|
||||||
|
Ok(Self(History {
|
||||||
|
id: row.try_get("id")?,
|
||||||
|
client_id: row.try_get("client_id")?,
|
||||||
|
user_id: row.try_get("user_id")?,
|
||||||
|
hostname: row.try_get("hostname")?,
|
||||||
|
timestamp: row.try_get("timestamp")?,
|
||||||
|
data: row.try_get("data")?,
|
||||||
|
created_at: row.try_get("created_at")?,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
|
@ -11,20 +11,18 @@ repository = { workspace = true }
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
atuin-common = { path = "../atuin-common", version = "15.0.0" }
|
atuin-common = { path = "../atuin-common", version = "15.0.0" }
|
||||||
|
atuin-server-database = { path = "../atuin-server-database", version = "15.0.0" }
|
||||||
|
|
||||||
tracing = "0.1"
|
tracing = "0.1"
|
||||||
chrono = { workspace = true }
|
chrono = { workspace = true }
|
||||||
eyre = { workspace = true }
|
eyre = { workspace = true }
|
||||||
uuid = { workspace = true }
|
uuid = { workspace = true }
|
||||||
whoami = { workspace = true }
|
|
||||||
config = { workspace = true }
|
config = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
sodiumoxide = { workspace = true }
|
|
||||||
base64 = { workspace = true }
|
base64 = { workspace = true }
|
||||||
rand = { workspace = true }
|
rand = { workspace = true }
|
||||||
tokio = { workspace = true }
|
tokio = { workspace = true }
|
||||||
sqlx = { workspace = true }
|
|
||||||
async-trait = { workspace = true }
|
async-trait = { workspace = true }
|
||||||
axum = "0.6.4"
|
axum = "0.6.4"
|
||||||
http = "0.2"
|
http = "0.2"
|
||||||
|
|
|
@ -1,222 +0,0 @@
|
||||||
/*
|
|
||||||
use self::diesel::prelude::*;
|
|
||||||
use eyre::Result;
|
|
||||||
use rocket::http::Status;
|
|
||||||
use rocket::request::{self, FromRequest, Outcome, Request};
|
|
||||||
use rocket::State;
|
|
||||||
use rocket_contrib::databases::diesel;
|
|
||||||
use sodiumoxide::crypto::pwhash::argon2id13;
|
|
||||||
|
|
||||||
use rocket_contrib::json::Json;
|
|
||||||
use uuid::Uuid;
|
|
||||||
|
|
||||||
use super::models::{NewSession, NewUser, Session, User};
|
|
||||||
use super::views::ApiResponse;
|
|
||||||
|
|
||||||
use crate::api::{LoginRequest, RegisterRequest};
|
|
||||||
use crate::schema::{sessions, users};
|
|
||||||
use crate::settings::Settings;
|
|
||||||
use crate::utils::hash_secret;
|
|
||||||
|
|
||||||
use super::database::AtuinDbConn;
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub enum KeyError {
|
|
||||||
Missing,
|
|
||||||
Invalid,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn verify_str(secret: &str, verify: &str) -> bool {
|
|
||||||
sodiumoxide::init().unwrap();
|
|
||||||
|
|
||||||
let mut padded = [0_u8; 128];
|
|
||||||
secret.as_bytes().iter().enumerate().for_each(|(i, val)| {
|
|
||||||
padded[i] = *val;
|
|
||||||
});
|
|
||||||
|
|
||||||
match argon2id13::HashedPassword::from_slice(&padded) {
|
|
||||||
Some(hp) => argon2id13::pwhash_verify(&hp, verify.as_bytes()),
|
|
||||||
None => false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a, 'r> FromRequest<'a, 'r> for User {
|
|
||||||
type Error = KeyError;
|
|
||||||
|
|
||||||
fn from_request(request: &'a Request<'r>) -> request::Outcome<User, Self::Error> {
|
|
||||||
let session: Vec<_> = request.headers().get("authorization").collect();
|
|
||||||
|
|
||||||
if session.is_empty() {
|
|
||||||
return Outcome::Failure((Status::BadRequest, KeyError::Missing));
|
|
||||||
} else if session.len() > 1 {
|
|
||||||
return Outcome::Failure((Status::BadRequest, KeyError::Invalid));
|
|
||||||
}
|
|
||||||
|
|
||||||
let session: Vec<_> = session[0].split(' ').collect();
|
|
||||||
|
|
||||||
if session.len() != 2 {
|
|
||||||
return Outcome::Failure((Status::BadRequest, KeyError::Invalid));
|
|
||||||
}
|
|
||||||
|
|
||||||
if session[0] != "Token" {
|
|
||||||
return Outcome::Failure((Status::BadRequest, KeyError::Invalid));
|
|
||||||
}
|
|
||||||
|
|
||||||
let session = session[1];
|
|
||||||
|
|
||||||
let db = request
|
|
||||||
.guard::<AtuinDbConn>()
|
|
||||||
.succeeded()
|
|
||||||
.expect("failed to load database");
|
|
||||||
|
|
||||||
let session = sessions::table
|
|
||||||
.filter(sessions::token.eq(session))
|
|
||||||
.first::<Session>(&*db);
|
|
||||||
|
|
||||||
if session.is_err() {
|
|
||||||
return Outcome::Failure((Status::Unauthorized, KeyError::Invalid));
|
|
||||||
}
|
|
||||||
|
|
||||||
let session = session.unwrap();
|
|
||||||
|
|
||||||
let user = users::table.find(session.user_id).first(&*db);
|
|
||||||
|
|
||||||
match user {
|
|
||||||
Ok(user) => Outcome::Success(user),
|
|
||||||
Err(_) => Outcome::Failure((Status::Unauthorized, KeyError::Invalid)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[get("/user/<user>")]
|
|
||||||
#[allow(clippy::clippy::needless_pass_by_value)]
|
|
||||||
pub fn get_user(user: String, conn: AtuinDbConn) -> ApiResponse {
|
|
||||||
use crate::schema::users::dsl::{username, users};
|
|
||||||
|
|
||||||
let user: Result<String, diesel::result::Error> = users
|
|
||||||
.select(username)
|
|
||||||
.filter(username.eq(user))
|
|
||||||
.first(&*conn);
|
|
||||||
|
|
||||||
if user.is_err() {
|
|
||||||
return ApiResponse {
|
|
||||||
json: json!({
|
|
||||||
"message": "could not find user",
|
|
||||||
}),
|
|
||||||
status: Status::NotFound,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
let user = user.unwrap();
|
|
||||||
|
|
||||||
ApiResponse {
|
|
||||||
json: json!({ "username": user.as_str() }),
|
|
||||||
status: Status::Ok,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[post("/register", data = "<register>")]
|
|
||||||
#[allow(clippy::clippy::needless_pass_by_value)]
|
|
||||||
pub fn register(
|
|
||||||
conn: AtuinDbConn,
|
|
||||||
register: Json<RegisterRequest>,
|
|
||||||
settings: State<Settings>,
|
|
||||||
) -> ApiResponse {
|
|
||||||
if !settings.server.open_registration {
|
|
||||||
return ApiResponse {
|
|
||||||
status: Status::BadRequest,
|
|
||||||
json: json!({
|
|
||||||
"message": "registrations are not open"
|
|
||||||
}),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
let hashed = hash_secret(register.password.as_str());
|
|
||||||
|
|
||||||
let new_user = NewUser {
|
|
||||||
email: register.email.as_str(),
|
|
||||||
username: register.username.as_str(),
|
|
||||||
password: hashed.as_str(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let user = diesel::insert_into(users::table)
|
|
||||||
.values(&new_user)
|
|
||||||
.get_result(&*conn);
|
|
||||||
|
|
||||||
if user.is_err() {
|
|
||||||
return ApiResponse {
|
|
||||||
status: Status::BadRequest,
|
|
||||||
json: json!({
|
|
||||||
"message": "failed to create user - username or email in use?",
|
|
||||||
}),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
let user: User = user.unwrap();
|
|
||||||
let token = Uuid::new_v4().to_simple().to_string();
|
|
||||||
|
|
||||||
let new_session = NewSession {
|
|
||||||
user_id: user.id,
|
|
||||||
token: token.as_str(),
|
|
||||||
};
|
|
||||||
|
|
||||||
match diesel::insert_into(sessions::table)
|
|
||||||
.values(&new_session)
|
|
||||||
.execute(&*conn)
|
|
||||||
{
|
|
||||||
Ok(_) => ApiResponse {
|
|
||||||
status: Status::Ok,
|
|
||||||
json: json!({"message": "user created!", "session": token}),
|
|
||||||
},
|
|
||||||
Err(_) => ApiResponse {
|
|
||||||
status: Status::BadRequest,
|
|
||||||
json: json!({ "message": "failed to create user"}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[post("/login", data = "<login>")]
|
|
||||||
#[allow(clippy::clippy::needless_pass_by_value)]
|
|
||||||
pub fn login(conn: AtuinDbConn, login: Json<LoginRequest>) -> ApiResponse {
|
|
||||||
let user = users::table
|
|
||||||
.filter(users::username.eq(login.username.as_str()))
|
|
||||||
.first(&*conn);
|
|
||||||
|
|
||||||
if user.is_err() {
|
|
||||||
return ApiResponse {
|
|
||||||
status: Status::NotFound,
|
|
||||||
json: json!({"message": "user not found"}),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
let user: User = user.unwrap();
|
|
||||||
|
|
||||||
let session = sessions::table
|
|
||||||
.filter(sessions::user_id.eq(user.id))
|
|
||||||
.first(&*conn);
|
|
||||||
|
|
||||||
// a session should exist...
|
|
||||||
if session.is_err() {
|
|
||||||
return ApiResponse {
|
|
||||||
status: Status::InternalServerError,
|
|
||||||
json: json!({"message": "something went wrong"}),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
let verified = verify_str(user.password.as_str(), login.password.as_str());
|
|
||||||
|
|
||||||
if !verified {
|
|
||||||
return ApiResponse {
|
|
||||||
status: Status::NotFound,
|
|
||||||
json: json!({"message": "user not found"}),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
let session: Session = session.unwrap();
|
|
||||||
|
|
||||||
ApiResponse {
|
|
||||||
status: Status::Ok,
|
|
||||||
json: json!({"session": session.token}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*/
|
|
|
@ -1,510 +0,0 @@
|
||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use chrono::{Datelike, TimeZone};
|
|
||||||
use chronoutil::RelativeDuration;
|
|
||||||
use sqlx::{postgres::PgPoolOptions, Result};
|
|
||||||
|
|
||||||
use sqlx::Row;
|
|
||||||
|
|
||||||
use tracing::{debug, instrument, warn};
|
|
||||||
|
|
||||||
use super::{
|
|
||||||
calendar::{TimePeriod, TimePeriodInfo},
|
|
||||||
models::{History, NewHistory, NewSession, NewUser, Session, User},
|
|
||||||
};
|
|
||||||
use crate::settings::Settings;
|
|
||||||
|
|
||||||
use atuin_common::utils::get_days_from_month;
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
pub trait Database {
|
|
||||||
async fn get_session(&self, token: &str) -> Result<Session>;
|
|
||||||
async fn get_session_user(&self, token: &str) -> Result<User>;
|
|
||||||
async fn add_session(&self, session: &NewSession) -> Result<()>;
|
|
||||||
|
|
||||||
async fn get_user(&self, username: &str) -> Result<User>;
|
|
||||||
async fn get_user_session(&self, u: &User) -> Result<Session>;
|
|
||||||
async fn add_user(&self, user: &NewUser) -> Result<i64>;
|
|
||||||
async fn delete_user(&self, u: &User) -> Result<()>;
|
|
||||||
|
|
||||||
async fn count_history(&self, user: &User) -> Result<i64>;
|
|
||||||
async fn count_history_cached(&self, user: &User) -> Result<i64>;
|
|
||||||
|
|
||||||
async fn delete_history(&self, user: &User, id: String) -> Result<()>;
|
|
||||||
async fn deleted_history(&self, user: &User) -> Result<Vec<String>>;
|
|
||||||
|
|
||||||
async fn count_history_range(
|
|
||||||
&self,
|
|
||||||
user: &User,
|
|
||||||
start: chrono::NaiveDateTime,
|
|
||||||
end: chrono::NaiveDateTime,
|
|
||||||
) -> Result<i64>;
|
|
||||||
async fn count_history_day(&self, user: &User, date: chrono::NaiveDate) -> Result<i64>;
|
|
||||||
async fn count_history_month(&self, user: &User, date: chrono::NaiveDate) -> Result<i64>;
|
|
||||||
async fn count_history_year(&self, user: &User, year: i32) -> Result<i64>;
|
|
||||||
|
|
||||||
async fn list_history(
|
|
||||||
&self,
|
|
||||||
user: &User,
|
|
||||||
created_after: chrono::NaiveDateTime,
|
|
||||||
since: chrono::NaiveDateTime,
|
|
||||||
host: &str,
|
|
||||||
page_size: i64,
|
|
||||||
) -> Result<Vec<History>>;
|
|
||||||
|
|
||||||
async fn add_history(&self, history: &[NewHistory]) -> Result<()>;
|
|
||||||
|
|
||||||
async fn oldest_history(&self, user: &User) -> Result<History>;
|
|
||||||
|
|
||||||
async fn calendar(
|
|
||||||
&self,
|
|
||||||
user: &User,
|
|
||||||
period: TimePeriod,
|
|
||||||
year: u64,
|
|
||||||
month: u64,
|
|
||||||
) -> Result<HashMap<u64, TimePeriodInfo>>;
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct Postgres {
|
|
||||||
pool: sqlx::Pool<sqlx::postgres::Postgres>,
|
|
||||||
settings: Settings,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Postgres {
|
|
||||||
pub async fn new(settings: Settings) -> Result<Self> {
|
|
||||||
let pool = PgPoolOptions::new()
|
|
||||||
.max_connections(100)
|
|
||||||
.connect(settings.db_uri.as_str())
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
sqlx::migrate!("./migrations").run(&pool).await?;
|
|
||||||
|
|
||||||
Ok(Self { pool, settings })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl Database for Postgres {
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
async fn get_session(&self, token: &str) -> Result<Session> {
|
|
||||||
sqlx::query_as::<_, Session>("select id, user_id, token from sessions where token = $1")
|
|
||||||
.bind(token)
|
|
||||||
.fetch_one(&self.pool)
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
async fn get_user(&self, username: &str) -> Result<User> {
|
|
||||||
sqlx::query_as::<_, User>(
|
|
||||||
"select id, username, email, password from users where username = $1",
|
|
||||||
)
|
|
||||||
.bind(username)
|
|
||||||
.fetch_one(&self.pool)
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
async fn get_session_user(&self, token: &str) -> Result<User> {
|
|
||||||
sqlx::query_as::<_, User>(
|
|
||||||
"select users.id, users.username, users.email, users.password from users
|
|
||||||
inner join sessions
|
|
||||||
on users.id = sessions.user_id
|
|
||||||
and sessions.token = $1",
|
|
||||||
)
|
|
||||||
.bind(token)
|
|
||||||
.fetch_one(&self.pool)
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
async fn count_history(&self, user: &User) -> Result<i64> {
|
|
||||||
// The cache is new, and the user might not yet have a cache value.
|
|
||||||
// They will have one as soon as they post up some new history, but handle that
|
|
||||||
// edge case.
|
|
||||||
|
|
||||||
let res: (i64,) = sqlx::query_as(
|
|
||||||
"select count(1) from history
|
|
||||||
where user_id = $1",
|
|
||||||
)
|
|
||||||
.bind(user.id)
|
|
||||||
.fetch_one(&self.pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(res.0)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
async fn count_history_cached(&self, user: &User) -> Result<i64> {
|
|
||||||
let res: (i32,) = sqlx::query_as(
|
|
||||||
"select total from total_history_count_user
|
|
||||||
where user_id = $1",
|
|
||||||
)
|
|
||||||
.bind(user.id)
|
|
||||||
.fetch_one(&self.pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(res.0 as i64)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn delete_history(&self, user: &User, id: String) -> Result<()> {
|
|
||||||
sqlx::query(
|
|
||||||
"update history
|
|
||||||
set deleted_at = $3
|
|
||||||
where user_id = $1
|
|
||||||
and client_id = $2
|
|
||||||
and deleted_at is null", // don't just keep setting it
|
|
||||||
)
|
|
||||||
.bind(user.id)
|
|
||||||
.bind(id)
|
|
||||||
.bind(chrono::Utc::now().naive_utc())
|
|
||||||
.fetch_all(&self.pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
async fn deleted_history(&self, user: &User) -> Result<Vec<String>> {
|
|
||||||
// The cache is new, and the user might not yet have a cache value.
|
|
||||||
// They will have one as soon as they post up some new history, but handle that
|
|
||||||
// edge case.
|
|
||||||
|
|
||||||
let res = sqlx::query(
|
|
||||||
"select client_id from history
|
|
||||||
where user_id = $1
|
|
||||||
and deleted_at is not null",
|
|
||||||
)
|
|
||||||
.bind(user.id)
|
|
||||||
.fetch_all(&self.pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let res = res
|
|
||||||
.iter()
|
|
||||||
.map(|row| row.get::<String, _>("client_id"))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
Ok(res)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
async fn count_history_range(
|
|
||||||
&self,
|
|
||||||
user: &User,
|
|
||||||
start: chrono::NaiveDateTime,
|
|
||||||
end: chrono::NaiveDateTime,
|
|
||||||
) -> Result<i64> {
|
|
||||||
let res: (i64,) = sqlx::query_as(
|
|
||||||
"select count(1) from history
|
|
||||||
where user_id = $1
|
|
||||||
and timestamp >= $2::date
|
|
||||||
and timestamp < $3::date",
|
|
||||||
)
|
|
||||||
.bind(user.id)
|
|
||||||
.bind(start)
|
|
||||||
.bind(end)
|
|
||||||
.fetch_one(&self.pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(res.0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Count the history for a given year
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
async fn count_history_year(&self, user: &User, year: i32) -> Result<i64> {
|
|
||||||
let start = chrono::Utc.ymd(year, 1, 1).and_hms_nano(0, 0, 0, 0);
|
|
||||||
let end = start + RelativeDuration::years(1);
|
|
||||||
|
|
||||||
let res = self
|
|
||||||
.count_history_range(user, start.naive_utc(), end.naive_utc())
|
|
||||||
.await?;
|
|
||||||
Ok(res)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Count the history for a given month
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
async fn count_history_month(&self, user: &User, month: chrono::NaiveDate) -> Result<i64> {
|
|
||||||
let start = chrono::Utc
|
|
||||||
.ymd(month.year(), month.month(), 1)
|
|
||||||
.and_hms_nano(0, 0, 0, 0);
|
|
||||||
|
|
||||||
// ofc...
|
|
||||||
let end = if month.month() < 12 {
|
|
||||||
chrono::Utc
|
|
||||||
.ymd(month.year(), month.month() + 1, 1)
|
|
||||||
.and_hms_nano(0, 0, 0, 0)
|
|
||||||
} else {
|
|
||||||
chrono::Utc
|
|
||||||
.ymd(month.year() + 1, 1, 1)
|
|
||||||
.and_hms_nano(0, 0, 0, 0)
|
|
||||||
};
|
|
||||||
|
|
||||||
debug!("start: {}, end: {}", start, end);
|
|
||||||
|
|
||||||
let res = self
|
|
||||||
.count_history_range(user, start.naive_utc(), end.naive_utc())
|
|
||||||
.await?;
|
|
||||||
Ok(res)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Count the history for a given day
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
async fn count_history_day(&self, user: &User, day: chrono::NaiveDate) -> Result<i64> {
|
|
||||||
let start = chrono::Utc
|
|
||||||
.ymd(day.year(), day.month(), day.day())
|
|
||||||
.and_hms_nano(0, 0, 0, 0);
|
|
||||||
let end = chrono::Utc
|
|
||||||
.ymd(day.year(), day.month(), day.day() + 1)
|
|
||||||
.and_hms_nano(0, 0, 0, 0);
|
|
||||||
|
|
||||||
let res = self
|
|
||||||
.count_history_range(user, start.naive_utc(), end.naive_utc())
|
|
||||||
.await?;
|
|
||||||
Ok(res)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
async fn list_history(
|
|
||||||
&self,
|
|
||||||
user: &User,
|
|
||||||
created_after: chrono::NaiveDateTime,
|
|
||||||
since: chrono::NaiveDateTime,
|
|
||||||
host: &str,
|
|
||||||
page_size: i64,
|
|
||||||
) -> Result<Vec<History>> {
|
|
||||||
let res = sqlx::query_as::<_, History>(
|
|
||||||
"select id, client_id, user_id, hostname, timestamp, data, created_at from history
|
|
||||||
where user_id = $1
|
|
||||||
and hostname != $2
|
|
||||||
and created_at >= $3
|
|
||||||
and timestamp >= $4
|
|
||||||
order by timestamp asc
|
|
||||||
limit $5",
|
|
||||||
)
|
|
||||||
.bind(user.id)
|
|
||||||
.bind(host)
|
|
||||||
.bind(created_after)
|
|
||||||
.bind(since)
|
|
||||||
.bind(page_size)
|
|
||||||
.fetch_all(&self.pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(res)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
async fn add_history(&self, history: &[NewHistory]) -> Result<()> {
|
|
||||||
let mut tx = self.pool.begin().await?;
|
|
||||||
|
|
||||||
for i in history {
|
|
||||||
let client_id: &str = &i.client_id;
|
|
||||||
let hostname: &str = &i.hostname;
|
|
||||||
let data: &str = &i.data;
|
|
||||||
|
|
||||||
if data.len() > self.settings.max_history_length
|
|
||||||
&& self.settings.max_history_length != 0
|
|
||||||
{
|
|
||||||
// Don't return an error here. We want to insert as much of the
|
|
||||||
// history list as we can, so log the error and continue going.
|
|
||||||
|
|
||||||
warn!(
|
|
||||||
"history too long, got length {}, max {}",
|
|
||||||
data.len(),
|
|
||||||
self.settings.max_history_length
|
|
||||||
);
|
|
||||||
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
sqlx::query(
|
|
||||||
"insert into history
|
|
||||||
(client_id, user_id, hostname, timestamp, data)
|
|
||||||
values ($1, $2, $3, $4, $5)
|
|
||||||
on conflict do nothing
|
|
||||||
",
|
|
||||||
)
|
|
||||||
.bind(client_id)
|
|
||||||
.bind(i.user_id)
|
|
||||||
.bind(hostname)
|
|
||||||
.bind(i.timestamp)
|
|
||||||
.bind(data)
|
|
||||||
.execute(&mut tx)
|
|
||||||
.await?;
|
|
||||||
}
|
|
||||||
|
|
||||||
tx.commit().await?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
async fn delete_user(&self, u: &User) -> Result<()> {
|
|
||||||
sqlx::query("delete from sessions where user_id = $1")
|
|
||||||
.bind(u.id)
|
|
||||||
.execute(&self.pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
sqlx::query("delete from users where id = $1")
|
|
||||||
.bind(u.id)
|
|
||||||
.execute(&self.pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
sqlx::query("delete from history where user_id = $1")
|
|
||||||
.bind(u.id)
|
|
||||||
.execute(&self.pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
async fn add_user(&self, user: &NewUser) -> Result<i64> {
|
|
||||||
let email: &str = &user.email;
|
|
||||||
let username: &str = &user.username;
|
|
||||||
let password: &str = &user.password;
|
|
||||||
|
|
||||||
let res: (i64,) = sqlx::query_as(
|
|
||||||
"insert into users
|
|
||||||
(username, email, password)
|
|
||||||
values($1, $2, $3)
|
|
||||||
returning id",
|
|
||||||
)
|
|
||||||
.bind(username)
|
|
||||||
.bind(email)
|
|
||||||
.bind(password)
|
|
||||||
.fetch_one(&self.pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(res.0)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
async fn add_session(&self, session: &NewSession) -> Result<()> {
|
|
||||||
let token: &str = &session.token;
|
|
||||||
|
|
||||||
sqlx::query(
|
|
||||||
"insert into sessions
|
|
||||||
(user_id, token)
|
|
||||||
values($1, $2)",
|
|
||||||
)
|
|
||||||
.bind(session.user_id)
|
|
||||||
.bind(token)
|
|
||||||
.execute(&self.pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
async fn get_user_session(&self, u: &User) -> Result<Session> {
|
|
||||||
sqlx::query_as::<_, Session>("select id, user_id, token from sessions where user_id = $1")
|
|
||||||
.bind(u.id)
|
|
||||||
.fetch_one(&self.pool)
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
async fn oldest_history(&self, user: &User) -> Result<History> {
|
|
||||||
let res = sqlx::query_as::<_, History>(
|
|
||||||
"select id, client_id, user_id, hostname, timestamp, data, created_at from history
|
|
||||||
where user_id = $1
|
|
||||||
order by timestamp asc
|
|
||||||
limit 1",
|
|
||||||
)
|
|
||||||
.bind(user.id)
|
|
||||||
.fetch_one(&self.pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(res)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
async fn calendar(
|
|
||||||
&self,
|
|
||||||
user: &User,
|
|
||||||
period: TimePeriod,
|
|
||||||
year: u64,
|
|
||||||
month: u64,
|
|
||||||
) -> Result<HashMap<u64, TimePeriodInfo>> {
|
|
||||||
// TODO: Support different timezones. Right now we assume UTC and
|
|
||||||
// everything is stored as such. But it _should_ be possible to
|
|
||||||
// interpret the stored date with a different TZ
|
|
||||||
|
|
||||||
match period {
|
|
||||||
TimePeriod::YEAR => {
|
|
||||||
let mut ret = HashMap::new();
|
|
||||||
// First we need to work out how far back to calculate. Get the
|
|
||||||
// oldest history item
|
|
||||||
let oldest = self.oldest_history(user).await?.timestamp.year();
|
|
||||||
let current_year = chrono::Utc::now().year();
|
|
||||||
|
|
||||||
// All the years we need to get data for
|
|
||||||
// The upper bound is exclusive, so include current +1
|
|
||||||
let years = oldest..current_year + 1;
|
|
||||||
|
|
||||||
for year in years {
|
|
||||||
let count = self.count_history_year(user, year).await?;
|
|
||||||
|
|
||||||
ret.insert(
|
|
||||||
year as u64,
|
|
||||||
TimePeriodInfo {
|
|
||||||
count: count as u64,
|
|
||||||
hash: "".to_string(),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(ret)
|
|
||||||
}
|
|
||||||
|
|
||||||
TimePeriod::MONTH => {
|
|
||||||
let mut ret = HashMap::new();
|
|
||||||
|
|
||||||
for month in 1..13 {
|
|
||||||
let count = self
|
|
||||||
.count_history_month(
|
|
||||||
user,
|
|
||||||
chrono::Utc.ymd(year as i32, month, 1).naive_utc(),
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
ret.insert(
|
|
||||||
month as u64,
|
|
||||||
TimePeriodInfo {
|
|
||||||
count: count as u64,
|
|
||||||
hash: "".to_string(),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(ret)
|
|
||||||
}
|
|
||||||
|
|
||||||
TimePeriod::DAY => {
|
|
||||||
let mut ret = HashMap::new();
|
|
||||||
|
|
||||||
for day in 1..get_days_from_month(year as i32, month as u32) {
|
|
||||||
let count = self
|
|
||||||
.count_history_day(
|
|
||||||
user,
|
|
||||||
chrono::Utc
|
|
||||||
.ymd(year as i32, month as u32, day as u32)
|
|
||||||
.naive_utc(),
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
ret.insert(
|
|
||||||
day as u64,
|
|
||||||
TimePeriodInfo {
|
|
||||||
count: count as u64,
|
|
||||||
hash: "".to_string(),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(ret)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -10,18 +10,20 @@ use tracing::{debug, error, instrument};
|
||||||
|
|
||||||
use super::{ErrorResponse, ErrorResponseStatus, RespExt};
|
use super::{ErrorResponse, ErrorResponseStatus, RespExt};
|
||||||
use crate::{
|
use crate::{
|
||||||
calendar::{TimePeriod, TimePeriodInfo},
|
router::{AppState, UserAuth},
|
||||||
database::Database,
|
|
||||||
models::{NewHistory, User},
|
|
||||||
router::AppState,
|
|
||||||
utils::client_version_min,
|
utils::client_version_min,
|
||||||
};
|
};
|
||||||
|
use atuin_server_database::{
|
||||||
|
calendar::{TimePeriod, TimePeriodInfo},
|
||||||
|
models::NewHistory,
|
||||||
|
Database,
|
||||||
|
};
|
||||||
|
|
||||||
use atuin_common::api::*;
|
use atuin_common::api::*;
|
||||||
|
|
||||||
#[instrument(skip_all, fields(user.id = user.id))]
|
#[instrument(skip_all, fields(user.id = user.id))]
|
||||||
pub async fn count<DB: Database>(
|
pub async fn count<DB: Database>(
|
||||||
user: User,
|
UserAuth(user): UserAuth,
|
||||||
state: State<AppState<DB>>,
|
state: State<AppState<DB>>,
|
||||||
) -> Result<Json<CountResponse>, ErrorResponseStatus<'static>> {
|
) -> Result<Json<CountResponse>, ErrorResponseStatus<'static>> {
|
||||||
let db = &state.0.database;
|
let db = &state.0.database;
|
||||||
|
@ -42,7 +44,7 @@ pub async fn count<DB: Database>(
|
||||||
#[instrument(skip_all, fields(user.id = user.id))]
|
#[instrument(skip_all, fields(user.id = user.id))]
|
||||||
pub async fn list<DB: Database>(
|
pub async fn list<DB: Database>(
|
||||||
req: Query<SyncHistoryRequest>,
|
req: Query<SyncHistoryRequest>,
|
||||||
user: User,
|
UserAuth(user): UserAuth,
|
||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
state: State<AppState<DB>>,
|
state: State<AppState<DB>>,
|
||||||
) -> Result<Json<SyncHistoryResponse>, ErrorResponseStatus<'static>> {
|
) -> Result<Json<SyncHistoryResponse>, ErrorResponseStatus<'static>> {
|
||||||
|
@ -101,7 +103,7 @@ pub async fn list<DB: Database>(
|
||||||
|
|
||||||
#[instrument(skip_all, fields(user.id = user.id))]
|
#[instrument(skip_all, fields(user.id = user.id))]
|
||||||
pub async fn delete<DB: Database>(
|
pub async fn delete<DB: Database>(
|
||||||
user: User,
|
UserAuth(user): UserAuth,
|
||||||
state: State<AppState<DB>>,
|
state: State<AppState<DB>>,
|
||||||
Json(req): Json<DeleteHistoryRequest>,
|
Json(req): Json<DeleteHistoryRequest>,
|
||||||
) -> Result<Json<MessageResponse>, ErrorResponseStatus<'static>> {
|
) -> Result<Json<MessageResponse>, ErrorResponseStatus<'static>> {
|
||||||
|
@ -123,13 +125,15 @@ pub async fn delete<DB: Database>(
|
||||||
|
|
||||||
#[instrument(skip_all, fields(user.id = user.id))]
|
#[instrument(skip_all, fields(user.id = user.id))]
|
||||||
pub async fn add<DB: Database>(
|
pub async fn add<DB: Database>(
|
||||||
user: User,
|
UserAuth(user): UserAuth,
|
||||||
state: State<AppState<DB>>,
|
state: State<AppState<DB>>,
|
||||||
Json(req): Json<Vec<AddHistoryRequest>>,
|
Json(req): Json<Vec<AddHistoryRequest>>,
|
||||||
) -> Result<(), ErrorResponseStatus<'static>> {
|
) -> Result<(), ErrorResponseStatus<'static>> {
|
||||||
|
let State(AppState { database, settings }) = state;
|
||||||
|
|
||||||
debug!("request to add {} history items", req.len());
|
debug!("request to add {} history items", req.len());
|
||||||
|
|
||||||
let history: Vec<NewHistory> = req
|
let mut history: Vec<NewHistory> = req
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|h| NewHistory {
|
.map(|h| NewHistory {
|
||||||
client_id: h.id,
|
client_id: h.id,
|
||||||
|
@ -140,8 +144,24 @@ pub async fn add<DB: Database>(
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let db = &state.0.database;
|
history.retain(|h| {
|
||||||
if let Err(e) = db.add_history(&history).await {
|
// keep if within limit, or limit is 0 (unlimited)
|
||||||
|
let keep = h.data.len() <= settings.max_history_length || settings.max_history_length == 0;
|
||||||
|
|
||||||
|
// Don't return an error here. We want to insert as much of the
|
||||||
|
// history list as we can, so log the error and continue going.
|
||||||
|
if !keep {
|
||||||
|
tracing::warn!(
|
||||||
|
"history too long, got length {}, max {}",
|
||||||
|
h.data.len(),
|
||||||
|
settings.max_history_length
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
keep
|
||||||
|
});
|
||||||
|
|
||||||
|
if let Err(e) = database.add_history(&history).await {
|
||||||
error!("failed to add history: {}", e);
|
error!("failed to add history: {}", e);
|
||||||
|
|
||||||
return Err(ErrorResponse::reply("failed to add history")
|
return Err(ErrorResponse::reply("failed to add history")
|
||||||
|
@ -155,7 +175,7 @@ pub async fn add<DB: Database>(
|
||||||
pub async fn calendar<DB: Database>(
|
pub async fn calendar<DB: Database>(
|
||||||
Path(focus): Path<String>,
|
Path(focus): Path<String>,
|
||||||
Query(params): Query<HashMap<String, u64>>,
|
Query(params): Query<HashMap<String, u64>>,
|
||||||
user: User,
|
UserAuth(user): UserAuth,
|
||||||
state: State<AppState<DB>>,
|
state: State<AppState<DB>>,
|
||||||
) -> Result<Json<HashMap<u64, TimePeriodInfo>>, ErrorResponseStatus<'static>> {
|
) -> Result<Json<HashMap<u64, TimePeriodInfo>>, ErrorResponseStatus<'static>> {
|
||||||
let focus = focus.as_str();
|
let focus = focus.as_str();
|
||||||
|
|
|
@ -3,7 +3,8 @@ use http::StatusCode;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
|
||||||
use super::{ErrorResponse, ErrorResponseStatus, RespExt};
|
use super::{ErrorResponse, ErrorResponseStatus, RespExt};
|
||||||
use crate::{database::Database, models::User, router::AppState};
|
use crate::router::{AppState, UserAuth};
|
||||||
|
use atuin_server_database::Database;
|
||||||
|
|
||||||
use atuin_common::api::*;
|
use atuin_common::api::*;
|
||||||
|
|
||||||
|
@ -11,7 +12,7 @@ const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||||
|
|
||||||
#[instrument(skip_all, fields(user.id = user.id))]
|
#[instrument(skip_all, fields(user.id = user.id))]
|
||||||
pub async fn status<DB: Database>(
|
pub async fn status<DB: Database>(
|
||||||
user: User,
|
UserAuth(user): UserAuth,
|
||||||
state: State<AppState<DB>>,
|
state: State<AppState<DB>>,
|
||||||
) -> Result<Json<StatusResponse>, ErrorResponseStatus<'static>> {
|
) -> Result<Json<StatusResponse>, ErrorResponseStatus<'static>> {
|
||||||
let db = &state.0.database;
|
let db = &state.0.database;
|
||||||
|
|
|
@ -16,10 +16,10 @@ use tracing::{debug, error, info, instrument};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use super::{ErrorResponse, ErrorResponseStatus, RespExt};
|
use super::{ErrorResponse, ErrorResponseStatus, RespExt};
|
||||||
use crate::{
|
use crate::router::{AppState, UserAuth};
|
||||||
database::Database,
|
use atuin_server_database::{
|
||||||
models::{NewSession, NewUser, User},
|
models::{NewSession, NewUser},
|
||||||
router::AppState,
|
Database, DbError,
|
||||||
};
|
};
|
||||||
|
|
||||||
use reqwest::header::CONTENT_TYPE;
|
use reqwest::header::CONTENT_TYPE;
|
||||||
|
@ -64,11 +64,11 @@ pub async fn get<DB: Database>(
|
||||||
let db = &state.0.database;
|
let db = &state.0.database;
|
||||||
let user = match db.get_user(username.as_ref()).await {
|
let user = match db.get_user(username.as_ref()).await {
|
||||||
Ok(user) => user,
|
Ok(user) => user,
|
||||||
Err(sqlx::Error::RowNotFound) => {
|
Err(DbError::NotFound) => {
|
||||||
debug!("user not found: {}", username);
|
debug!("user not found: {}", username);
|
||||||
return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND));
|
return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND));
|
||||||
}
|
}
|
||||||
Err(err) => {
|
Err(DbError::Other(err)) => {
|
||||||
error!("database error: {}", err);
|
error!("database error: {}", err);
|
||||||
return Err(ErrorResponse::reply("database error")
|
return Err(ErrorResponse::reply("database error")
|
||||||
.with_status(StatusCode::INTERNAL_SERVER_ERROR));
|
.with_status(StatusCode::INTERNAL_SERVER_ERROR));
|
||||||
|
@ -152,7 +152,7 @@ pub async fn register<DB: Database>(
|
||||||
|
|
||||||
#[instrument(skip_all, fields(user.id = user.id))]
|
#[instrument(skip_all, fields(user.id = user.id))]
|
||||||
pub async fn delete<DB: Database>(
|
pub async fn delete<DB: Database>(
|
||||||
user: User,
|
UserAuth(user): UserAuth,
|
||||||
state: State<AppState<DB>>,
|
state: State<AppState<DB>>,
|
||||||
) -> Result<Json<DeleteUserResponse>, ErrorResponseStatus<'static>> {
|
) -> Result<Json<DeleteUserResponse>, ErrorResponseStatus<'static>> {
|
||||||
debug!("request to delete user {}", user.id);
|
debug!("request to delete user {}", user.id);
|
||||||
|
@ -175,10 +175,10 @@ pub async fn login<DB: Database>(
|
||||||
let db = &state.0.database;
|
let db = &state.0.database;
|
||||||
let user = match db.get_user(login.username.borrow()).await {
|
let user = match db.get_user(login.username.borrow()).await {
|
||||||
Ok(u) => u,
|
Ok(u) => u,
|
||||||
Err(sqlx::Error::RowNotFound) => {
|
Err(DbError::NotFound) => {
|
||||||
return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND));
|
return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND));
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(DbError::Other(e)) => {
|
||||||
error!("failed to get user {}: {}", login.username.clone(), e);
|
error!("failed to get user {}: {}", login.username.clone(), e);
|
||||||
|
|
||||||
return Err(ErrorResponse::reply("database error")
|
return Err(ErrorResponse::reply("database error")
|
||||||
|
@ -188,11 +188,11 @@ pub async fn login<DB: Database>(
|
||||||
|
|
||||||
let session = match db.get_user_session(&user).await {
|
let session = match db.get_user_session(&user).await {
|
||||||
Ok(u) => u,
|
Ok(u) => u,
|
||||||
Err(sqlx::Error::RowNotFound) => {
|
Err(DbError::NotFound) => {
|
||||||
debug!("user session not found for user id={}", user.id);
|
debug!("user session not found for user id={}", user.id);
|
||||||
return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND));
|
return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND));
|
||||||
}
|
}
|
||||||
Err(err) => {
|
Err(DbError::Other(err)) => {
|
||||||
error!("database error for user {}: {}", login.username, err);
|
error!("database error for user {}: {}", login.username, err);
|
||||||
return Err(ErrorResponse::reply("database error")
|
return Err(ErrorResponse::reply("database error")
|
||||||
.with_status(StatusCode::INTERNAL_SERVER_ERROR));
|
.with_status(StatusCode::INTERNAL_SERVER_ERROR));
|
||||||
|
|
|
@ -2,45 +2,38 @@
|
||||||
|
|
||||||
use std::net::{IpAddr, SocketAddr};
|
use std::net::{IpAddr, SocketAddr};
|
||||||
|
|
||||||
|
use atuin_server_database::Database;
|
||||||
use axum::Server;
|
use axum::Server;
|
||||||
use database::Postgres;
|
|
||||||
use eyre::{Context, Result};
|
use eyre::{Context, Result};
|
||||||
|
|
||||||
use crate::settings::Settings;
|
mod handlers;
|
||||||
|
mod router;
|
||||||
|
mod settings;
|
||||||
|
mod utils;
|
||||||
|
|
||||||
|
pub use settings::Settings;
|
||||||
use tokio::signal;
|
use tokio::signal;
|
||||||
|
|
||||||
pub mod auth;
|
|
||||||
pub mod calendar;
|
|
||||||
pub mod database;
|
|
||||||
pub mod handlers;
|
|
||||||
pub mod models;
|
|
||||||
pub mod router;
|
|
||||||
pub mod settings;
|
|
||||||
pub mod utils;
|
|
||||||
|
|
||||||
async fn shutdown_signal() {
|
async fn shutdown_signal() {
|
||||||
let terminate = async {
|
signal::unix::signal(signal::unix::SignalKind::terminate())
|
||||||
signal::unix::signal(signal::unix::SignalKind::terminate())
|
.expect("failed to register signal handler")
|
||||||
.expect("failed to register signal handler")
|
.recv()
|
||||||
.recv()
|
.await;
|
||||||
.await;
|
|
||||||
};
|
|
||||||
|
|
||||||
tokio::select! {
|
|
||||||
_ = terminate => (),
|
|
||||||
}
|
|
||||||
eprintln!("Shutting down gracefully...");
|
eprintln!("Shutting down gracefully...");
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn launch(settings: Settings, host: String, port: u16) -> Result<()> {
|
pub async fn launch<Db: Database>(
|
||||||
|
settings: Settings<Db::Settings>,
|
||||||
|
host: String,
|
||||||
|
port: u16,
|
||||||
|
) -> Result<()> {
|
||||||
let host = host.parse::<IpAddr>()?;
|
let host = host.parse::<IpAddr>()?;
|
||||||
|
|
||||||
let postgres = Postgres::new(settings.clone())
|
let db = Db::new(&settings.db_settings)
|
||||||
.await
|
.await
|
||||||
.wrap_err_with(|| format!("failed to connect to db: {}", settings.db_uri))?;
|
.wrap_err_with(|| format!("failed to connect to db: {:?}", settings.db_settings))?;
|
||||||
|
|
||||||
let r = router::router(postgres, settings);
|
let r = router::router(db, settings);
|
||||||
|
|
||||||
Server::bind(&SocketAddr::new(host, port))
|
Server::bind(&SocketAddr::new(host, port))
|
||||||
.serve(r.into_make_service())
|
.serve(r.into_make_service())
|
||||||
|
|
|
@ -10,11 +10,14 @@ use http::request::Parts;
|
||||||
use tower::ServiceBuilder;
|
use tower::ServiceBuilder;
|
||||||
use tower_http::trace::TraceLayer;
|
use tower_http::trace::TraceLayer;
|
||||||
|
|
||||||
use super::{database::Database, handlers};
|
use super::handlers;
|
||||||
use crate::{models::User, settings::Settings};
|
use crate::settings::Settings;
|
||||||
|
use atuin_server_database::{models::User, Database};
|
||||||
|
|
||||||
|
pub struct UserAuth(pub User);
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl<DB: Send + Sync> FromRequestParts<AppState<DB>> for User
|
impl<DB: Send + Sync> FromRequestParts<AppState<DB>> for UserAuth
|
||||||
where
|
where
|
||||||
DB: Database,
|
DB: Database,
|
||||||
{
|
{
|
||||||
|
@ -45,7 +48,7 @@ where
|
||||||
.await
|
.await
|
||||||
.map_err(|_| http::StatusCode::FORBIDDEN)?;
|
.map_err(|_| http::StatusCode::FORBIDDEN)?;
|
||||||
|
|
||||||
Ok(user)
|
Ok(UserAuth(user))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -54,15 +57,12 @@ async fn teapot() -> impl IntoResponse {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct AppState<DB> {
|
pub struct AppState<DB: Database> {
|
||||||
pub database: DB,
|
pub database: DB,
|
||||||
pub settings: Settings,
|
pub settings: Settings<DB::Settings>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn router<DB: Database + Clone + Send + Sync + 'static>(
|
pub fn router<DB: Database>(database: DB, settings: Settings<DB::Settings>) -> Router {
|
||||||
database: DB,
|
|
||||||
settings: Settings,
|
|
||||||
) -> Router {
|
|
||||||
let routes = Router::new()
|
let routes = Router::new()
|
||||||
.route("/", get(handlers::index))
|
.route("/", get(handlers::index))
|
||||||
.route("/sync/count", get(handlers::history::count))
|
.route("/sync/count", get(handlers::history::count))
|
||||||
|
|
|
@ -3,24 +3,24 @@ use std::{io::prelude::*, path::PathBuf};
|
||||||
use config::{Config, Environment, File as ConfigFile, FileFormat};
|
use config::{Config, Environment, File as ConfigFile, FileFormat};
|
||||||
use eyre::{eyre, Result};
|
use eyre::{eyre, Result};
|
||||||
use fs_err::{create_dir_all, File};
|
use fs_err::{create_dir_all, File};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||||
|
|
||||||
pub const HISTORY_PAGE_SIZE: i64 = 100;
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
pub struct Settings {
|
pub struct Settings<DbSettings> {
|
||||||
pub host: String,
|
pub host: String,
|
||||||
pub port: u16,
|
pub port: u16,
|
||||||
pub path: String,
|
pub path: String,
|
||||||
pub db_uri: String,
|
|
||||||
pub open_registration: bool,
|
pub open_registration: bool,
|
||||||
pub max_history_length: usize,
|
pub max_history_length: usize,
|
||||||
pub page_size: i64,
|
pub page_size: i64,
|
||||||
pub register_webhook_url: Option<String>,
|
pub register_webhook_url: Option<String>,
|
||||||
pub register_webhook_username: String,
|
pub register_webhook_username: String,
|
||||||
|
|
||||||
|
#[serde(flatten)]
|
||||||
|
pub db_settings: DbSettings,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Settings {
|
impl<DbSettings: DeserializeOwned> Settings<DbSettings> {
|
||||||
pub fn new() -> Result<Self> {
|
pub fn new() -> Result<Self> {
|
||||||
let mut config_file = if let Ok(p) = std::env::var("ATUIN_CONFIG_DIR") {
|
let mut config_file = if let Ok(p) = std::env::var("ATUIN_CONFIG_DIR") {
|
||||||
PathBuf::from(p)
|
PathBuf::from(p)
|
||||||
|
|
|
@ -33,15 +33,13 @@ buildflags = ["--release"]
|
||||||
atuin = { path = "/usr/bin/atuin" }
|
atuin = { path = "/usr/bin/atuin" }
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
# TODO(conradludgate)
|
|
||||||
# Currently, this keeps the same default built behaviour for v0.8
|
|
||||||
# We should rethink this by the time we hit a new breaking change
|
|
||||||
default = ["client", "sync", "server"]
|
default = ["client", "sync", "server"]
|
||||||
client = ["atuin-client"]
|
client = ["atuin-client"]
|
||||||
sync = ["atuin-client/sync"]
|
sync = ["atuin-client/sync"]
|
||||||
server = ["atuin-server", "tracing-subscriber"]
|
server = ["atuin-server", "atuin-server-postgres", "tracing-subscriber"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
atuin-server-postgres = { path = "../atuin-server-postgres", version = "15.0.0", optional = true }
|
||||||
atuin-server = { path = "../atuin-server", version = "15.0.0", optional = true }
|
atuin-server = { path = "../atuin-server", version = "15.0.0", optional = true }
|
||||||
atuin-client = { path = "../atuin-client", version = "15.0.0", optional = true, default-features = false }
|
atuin-client = { path = "../atuin-client", version = "15.0.0", optional = true, default-features = false }
|
||||||
atuin-common = { path = "../atuin-common", version = "15.0.0" }
|
atuin-common = { path = "../atuin-common", version = "15.0.0" }
|
||||||
|
@ -61,7 +59,6 @@ tokio = { workspace = true }
|
||||||
async-trait = { workspace = true }
|
async-trait = { workspace = true }
|
||||||
interim = { workspace = true }
|
interim = { workspace = true }
|
||||||
base64 = { workspace = true }
|
base64 = { workspace = true }
|
||||||
crossbeam-channel = "0.5.1"
|
|
||||||
clap = { workspace = true }
|
clap = { workspace = true }
|
||||||
clap_complete = "4.0.3"
|
clap_complete = "4.0.3"
|
||||||
fs-err = { workspace = true }
|
fs-err = { workspace = true }
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
|
use atuin_server_postgres::Postgres;
|
||||||
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
|
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
|
||||||
|
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use eyre::{Context, Result};
|
use eyre::{Context, Result};
|
||||||
|
|
||||||
use atuin_server::{launch, settings::Settings};
|
use atuin_server::{launch, Settings};
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
#[clap(infer_subcommands = true)]
|
#[clap(infer_subcommands = true)]
|
||||||
|
@ -37,7 +38,7 @@ impl Cmd {
|
||||||
.map_or(settings.host.clone(), std::string::ToString::to_string);
|
.map_or(settings.host.clone(), std::string::ToString::to_string);
|
||||||
let port = port.map_or(settings.port, |p| p);
|
let port = port.map_or(settings.port, |p| p);
|
||||||
|
|
||||||
launch(settings, host, port).await
|
launch::<Postgres>(settings, host, port).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue