From a95018cc9039851e707973bc19faf907132ae4f3 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Tue, 12 Apr 2022 23:06:19 +0100 Subject: [PATCH] goodbye warp, hello axum (#296) --- Cargo.lock | 360 ++++++++------------------- atuin-client/src/api_client.rs | 6 +- atuin-client/src/sync.rs | 4 +- atuin-common/Cargo.toml | 3 +- atuin-common/src/api.rs | 86 ++----- atuin-server/Cargo.toml | 3 +- atuin-server/src/handlers/history.rs | 53 ++-- atuin-server/src/handlers/mod.rs | 2 +- atuin-server/src/handlers/user.rs | 72 +++--- atuin-server/src/lib.rs | 20 +- atuin-server/src/models.rs | 22 +- atuin-server/src/router.rs | 169 ++++--------- src/command/login.rs | 7 +- src/command/mod.rs | 2 +- src/command/server.rs | 4 +- 15 files changed, 285 insertions(+), 528 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8a8a7d9..7d605e9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -122,14 +122,15 @@ dependencies = [ name = "atuin-common" version = "0.8.1" dependencies = [ + "axum", "chrono", + "http", "rust-crypto", "serde", "serde_derive", "serde_json", "sodiumoxide", "uuid", - "warp", ] [[package]] @@ -138,10 +139,12 @@ version = "0.8.1" dependencies = [ "async-trait", "atuin-common", + "axum", "base64", "chrono", "config", "eyre", + "http", "log", "rand 0.8.5", "rust-crypto", @@ -152,7 +155,6 @@ dependencies = [ "sqlx", "tokio", "uuid", - "warp", "whoami", ] @@ -162,6 +164,51 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +[[package]] +name = "axum" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47594e438a243791dba58124b6669561f5baa14cb12046641d8008bf035e5a25" +dependencies = [ + "async-trait", + "axum-core", + "bitflags", + "bytes", + "futures-util", + "http", + "http-body", + "hyper", + "itoa 1.0.1", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-http", + "tower-layer", + "tower-service", +] + +[[package]] +name = "axum-core" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a671c9ae99531afdd5d3ee8340b8da547779430689947144c140fc74a740244" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http", + "http-body", + "mime", +] + [[package]] name = "base64" version = "0.13.0" @@ -195,15 +242,6 @@ dependencies = [ "generic-array 0.14.5", ] -[[package]] -name = "block-buffer" -version = "0.10.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bf7fe51849ea569fd452f37822f606a5cabb684dc918707a0193fd4664ff324" -dependencies = [ - "generic-array 0.14.5", -] - [[package]] name = "block-padding" version = "0.1.5" @@ -225,16 +263,6 @@ dependencies = [ "serde", ] -[[package]] -name = "buf_redux" -version = "0.8.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b953a6887648bb07a535631f2bc00fbdb2a2216f135552cb3f534ed136b9c07f" -dependencies = [ - "memchr", - "safemem", -] - [[package]] name = "bumpalo" version = "3.9.1" @@ -449,16 +477,6 @@ dependencies = [ "lazy_static", ] -[[package]] -name = "crypto-common" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57952ca27b5e3606ff4dd79b0020231aaf9d6aa76dc05fd30137538c50bd3ce8" -dependencies = [ - "generic-array 0.14.5", - "typenum", -] - [[package]] name = "crypto-mac" version = "0.11.1" @@ -509,16 +527,6 @@ dependencies = [ "generic-array 0.14.5", ] -[[package]] -name = "digest" -version = "0.10.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2fb860ca6fafa5552fb6d0e816a69c8e49f0908bf524e30a90d97c85892d506" -dependencies = [ - "block-buffer 0.10.2", - "crypto-common", -] - [[package]] name = "directories" version = "3.0.2" @@ -640,15 +648,6 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e88a8acf291dafb59c2d96e8f59828f3838bb1a70398823ade51a84de6a6deed" -[[package]] -name = "fastrand" -version = "1.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3fcf0cee53519c866c09b5de1f6c56ff9d647101f81c1964fa632e148896cdf" -dependencies = [ - "instant", -] - [[package]] name = "flume" version = "0.10.12" @@ -806,7 +805,7 @@ dependencies = [ "indexmap", "slab", "tokio", - "tokio-util 0.7.1", + "tokio-util", "tracing", ] @@ -837,31 +836,6 @@ dependencies = [ "hashbrown 0.11.2", ] -[[package]] -name = "headers" -version = "0.3.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4cff78e5788be1e0ab65b04d306b2ed5092c815ec97ec70f4ebd5aee158aa55d" -dependencies = [ - "base64", - "bitflags", - "bytes", - "headers-core", - "http", - "httpdate", - "mime", - "sha-1 0.10.0", -] - -[[package]] -name = "headers-core" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7f66481bfee273957b1f20485a4ff3362987f85b2c236580d81b4eb7a326429" -dependencies = [ - "http", -] - [[package]] name = "heck" version = "0.3.3" @@ -924,6 +898,12 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "http-range-header" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bfe8eed0a9285ef776bb792479ea3834e8b94e13d615c2f66d03dd50a435a29" + [[package]] name = "httparse" version = "1.7.0" @@ -1155,6 +1135,12 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a3e378b66a060d48947b590737b30a1be76706c8dd7b8ba0f2fe3989c68a853f" +[[package]] +name = "matchit" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73cbba799671b762df5a175adf59ce145165747bb891505c43d09aefbbf38beb" + [[package]] name = "md-5" version = "0.9.1" @@ -1178,16 +1164,6 @@ version = "0.3.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2a60c7ce501c71e03a9c9c0d35b861413ae925bd979cc7a4e30d060069aaac8d" -[[package]] -name = "mime_guess" -version = "2.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4192263c238a5f0d0c6bfd21f336a313a4ce1c450542449ca191bb657b4642ef" -dependencies = [ - "mime", - "unicase", -] - [[package]] name = "minimal-lexical" version = "0.2.1" @@ -1223,24 +1199,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "multipart" -version = "0.18.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00dec633863867f29cb39df64a397cdf4a6354708ddd7759f70c7fb51c5f9182" -dependencies = [ - "buf_redux", - "httparse", - "log", - "mime", - "mime_guess", - "quick-error", - "rand 0.8.5", - "safemem", - "tempfile", - "twoway", -] - [[package]] name = "nom" version = "7.1.1" @@ -1747,15 +1705,6 @@ version = "0.6.25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b" -[[package]] -name = "remove_dir_all" -version = "0.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3acd125665422973a33ac9d3dd2df85edad0f4ae9b00dafb1a05e43a9f5ef8e7" -dependencies = [ - "winapi", -] - [[package]] name = "reqwest" version = "0.11.10" @@ -1910,12 +1859,6 @@ version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73b4b750c782965c211b42f022f59af1fbceabdd026623714f104152f1ec149f" -[[package]] -name = "safemem" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef703b7cb59335eae2eb93ceb664c0eb7ea6bf567079d843e09420219668e072" - [[package]] name = "same-file" version = "1.0.6" @@ -1931,12 +1874,6 @@ version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "088c5d71572124929ea7549a8ce98e1a6fd33d0a38367b09027b382e67c033db" -[[package]] -name = "scoped-tls" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea6a9290e3c9cf0f18145ef7ffa62d68ee0bf5fcd651017e586dc7fd5da448c2" - [[package]] name = "scopeguard" version = "1.1.0" @@ -2031,17 +1968,6 @@ dependencies = [ "opaque-debug 0.3.0", ] -[[package]] -name = "sha-1" -version = "0.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "028f48d513f9678cda28f6e4064755b3fbb2af6acd672f2c209b62323f7aea0f" -dependencies = [ - "cfg-if", - "cpufeatures", - "digest 0.10.3", -] - [[package]] name = "sha2" version = "0.9.9" @@ -2267,6 +2193,12 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "sync_wrapper" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20518fe4a4c9acf048008599e464deb21beeae3d3578418951a189c235a7a9a8" + [[package]] name = "tabwriter" version = "1.2.1" @@ -2276,20 +2208,6 @@ dependencies = [ "unicode-width", ] -[[package]] -name = "tempfile" -version = "3.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cdb1ef4eaeeaddc8fbd371e5017057064af0911902ef36b39801f67cc6d79e4" -dependencies = [ - "cfg-if", - "fastrand", - "libc", - "redox_syscall", - "remove_dir_all", - "winapi", -] - [[package]] name = "termcolor" version = "1.1.3" @@ -2436,33 +2354,6 @@ dependencies = [ "tokio", ] -[[package]] -name = "tokio-tungstenite" -version = "0.15.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "511de3f85caf1c98983545490c3d09685fa8eb634e57eec22bb4db271f46cbd8" -dependencies = [ - "futures-util", - "log", - "pin-project", - "tokio", - "tungstenite", -] - -[[package]] -name = "tokio-util" -version = "0.6.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e99e1983e5d376cd8eb4b66604d2e99e79f5bd988c3055891dcd8c9e2604cc0" -dependencies = [ - "bytes", - "futures-core", - "futures-sink", - "log", - "pin-project-lite", - "tokio", -] - [[package]] name = "tokio-util" version = "0.7.1" @@ -2486,6 +2377,48 @@ dependencies = [ "serde", ] +[[package]] +name = "tower" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a89fd63ad6adf737582df5db40d286574513c69a11dac5214dc3b5603d6713e" +dependencies = [ + "futures-core", + "futures-util", + "pin-project", + "pin-project-lite", + "tokio", + "tokio-util", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-http" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aba3f3efabf7fb41fae8534fc20a817013dd1c12cb45441efb6c82e6556b4cd8" +dependencies = [ + "bitflags", + "bytes", + "futures-core", + "futures-util", + "http", + "http-body", + "http-range-header", + "pin-project-lite", + "tower", + "tower-layer", + "tower-service", +] + +[[package]] +name = "tower-layer" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "343bc9466d3fe6b0f960ef45960509f84480bf4fd96f92901afe7ff3df9d3a62" + [[package]] name = "tower-service" version = "0.3.1" @@ -2544,34 +2477,6 @@ dependencies = [ "unicode-width", ] -[[package]] -name = "tungstenite" -version = "0.14.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0b2d8558abd2e276b0a8df5c05a2ec762609344191e5fd23e292c910e9165b5" -dependencies = [ - "base64", - "byteorder", - "bytes", - "http", - "httparse", - "log", - "rand 0.8.5", - "sha-1 0.9.8", - "thiserror", - "url", - "utf-8", -] - -[[package]] -name = "twoway" -version = "0.1.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59b11b2b5241ba34be09c3cc85a36e56e48f9888862e19cedf23336d35316ed1" -dependencies = [ - "memchr", -] - [[package]] name = "typenum" version = "1.15.0" @@ -2584,15 +2489,6 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56dee185309b50d1f11bfedef0fe6d036842e3fb77413abef29f8f8d1c5d4c1c" -[[package]] -name = "unicase" -version = "2.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50f37be617794602aabbeee0be4f259dc1778fabe05e2d67ee8f79326d5cb4f6" -dependencies = [ - "version_check", -] - [[package]] name = "unicode-bidi" version = "0.3.7" @@ -2656,12 +2552,6 @@ version = "1.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5a1f0175e03a0973cf4afd476bef05c26e228520400eb1fd473ad417b1c00ffb" -[[package]] -name = "utf-8" -version = "0.7.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" - [[package]] name = "uuid" version = "0.8.2" @@ -2704,36 +2594,6 @@ dependencies = [ "try-lock", ] -[[package]] -name = "warp" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3cef4e1e9114a4b7f1ac799f16ce71c14de5778500c5450ec6b7b920c55b587e" -dependencies = [ - "bytes", - "futures-channel", - "futures-util", - "headers", - "http", - "hyper", - "log", - "mime", - "mime_guess", - "multipart", - "percent-encoding", - "pin-project", - "scoped-tls", - "serde", - "serde_json", - "serde_urlencoded", - "tokio", - "tokio-stream", - "tokio-tungstenite", - "tokio-util 0.6.9", - "tower-service", - "tracing", -] - [[package]] name = "wasi" version = "0.10.2+wasi-snapshot-preview1" diff --git a/atuin-client/src/api_client.rs b/atuin-client/src/api_client.rs index 3a4c859..87c4b6a 100644 --- a/atuin-client/src/api_client.rs +++ b/atuin-client/src/api_client.rs @@ -31,7 +31,7 @@ pub async fn register( username: &str, email: &str, password: &str, -) -> Result> { +) -> Result { let mut map = HashMap::new(); map.insert("username", username); map.insert("email", email); @@ -61,7 +61,7 @@ pub async fn register( Ok(session) } -pub async fn login(address: &str, req: LoginRequest<'_>) -> Result> { +pub async fn login(address: &str, req: LoginRequest) -> Result { let url = format!("{}/login", address); let client = reqwest::Client::new(); @@ -142,7 +142,7 @@ impl<'a> Client<'a> { Ok(history) } - pub async fn post_history(&self, history: &[AddHistoryRequest<'_, String>]) -> Result<()> { + pub async fn post_history(&self, history: &[AddHistoryRequest]) -> Result<()> { let url = format!("{}/history", self.sync_addr); let url = Url::parse(url.as_str())?; diff --git a/atuin-client/src/sync.rs b/atuin-client/src/sync.rs index c1c02b0..9e74961 100644 --- a/atuin-client/src/sync.rs +++ b/atuin-client/src/sync.rs @@ -110,10 +110,10 @@ async fn sync_upload( let data = serde_json::to_string(&data)?; let add_hist = AddHistoryRequest { - id: i.id.into(), + id: i.id, timestamp: i.timestamp, data, - hostname: hash_str(&i.hostname).into(), + hostname: hash_str(&i.hostname), }; buffer.push(add_hist); diff --git a/atuin-common/Cargo.toml b/atuin-common/Cargo.toml index 85f8012..93814ac 100644 --- a/atuin-common/Cargo.toml +++ b/atuin-common/Cargo.toml @@ -17,5 +17,6 @@ chrono = { version = "0.4", features = ["serde"] } serde_derive = "1.0.125" serde = "1.0.126" serde_json = "1.0.75" -warp = "0.3" uuid = { version = "0.8", features = ["v4"] } +axum = "0.5" +http = "0.2" diff --git a/atuin-common/src/api.rs b/atuin-common/src/api.rs index 862759b..803fbbc 100644 --- a/atuin-common/src/api.rs +++ b/atuin-common/src/api.rs @@ -1,43 +1,43 @@ -use std::{borrow::Cow, convert::Infallible}; +use std::borrow::Cow; +use axum::{response::IntoResponse, Json}; use chrono::Utc; use serde::Serialize; -use warp::{reply::Response, Reply}; #[derive(Debug, Serialize, Deserialize)] -pub struct UserResponse<'a> { - pub username: Cow<'a, str>, +pub struct UserResponse { + pub username: String, } #[derive(Debug, Serialize, Deserialize)] -pub struct RegisterRequest<'a> { - pub email: Cow<'a, str>, - pub username: Cow<'a, str>, - pub password: Cow<'a, str>, +pub struct RegisterRequest { + pub email: String, + pub username: String, + pub password: String, } #[derive(Debug, Serialize, Deserialize)] -pub struct RegisterResponse<'a> { - pub session: Cow<'a, str>, +pub struct RegisterResponse { + pub session: String, } #[derive(Debug, Serialize, Deserialize)] -pub struct LoginRequest<'a> { - pub username: Cow<'a, str>, - pub password: Cow<'a, str>, +pub struct LoginRequest { + pub username: String, + pub password: String, } #[derive(Debug, Serialize, Deserialize)] -pub struct LoginResponse<'a> { - pub session: Cow<'a, str>, +pub struct LoginResponse { + pub session: String, } #[derive(Debug, Serialize, Deserialize)] -pub struct AddHistoryRequest<'a, D> { - pub id: Cow<'a, str>, +pub struct AddHistoryRequest { + pub id: String, pub timestamp: chrono::DateTime, - pub data: D, - pub hostname: Cow<'a, str>, + pub data: String, + pub hostname: String, } #[derive(Debug, Serialize, Deserialize)] @@ -46,10 +46,10 @@ pub struct CountResponse { } #[derive(Debug, Serialize, Deserialize)] -pub struct SyncHistoryRequest<'a> { +pub struct SyncHistoryRequest { pub sync_ts: chrono::DateTime, pub history_ts: chrono::DateTime, - pub host: Cow<'a, str>, + pub host: String, } #[derive(Debug, Serialize, Deserialize)] @@ -62,25 +62,19 @@ pub struct ErrorResponse<'a> { pub reason: Cow<'a, str>, } -impl Reply for ErrorResponse<'_> { - fn into_response(self) -> Response { - warp::reply::json(&self).into_response() +impl<'a> IntoResponse for ErrorResponseStatus<'a> { + fn into_response(self) -> axum::response::Response { + (self.status, Json(self.error)).into_response() } } pub struct ErrorResponseStatus<'a> { pub error: ErrorResponse<'a>, - pub status: warp::http::StatusCode, -} - -impl Reply for ErrorResponseStatus<'_> { - fn into_response(self) -> Response { - warp::reply::with_status(self.error, self.status).into_response() - } + pub status: http::StatusCode, } impl<'a> ErrorResponse<'a> { - pub fn with_status(self, status: warp::http::StatusCode) -> ErrorResponseStatus<'a> { + pub fn with_status(self, status: http::StatusCode) -> ErrorResponseStatus<'a> { ErrorResponseStatus { error: self, status, @@ -93,31 +87,3 @@ impl<'a> ErrorResponse<'a> { } } } - -pub enum ReplyEither { - Ok(T), - Err(E), -} - -impl Reply for ReplyEither { - fn into_response(self) -> Response { - match self { - ReplyEither::Ok(t) => t.into_response(), - ReplyEither::Err(e) => e.into_response(), - } - } -} - -pub type ReplyResult = Result, Infallible>; -pub fn reply_error(e: E) -> ReplyResult { - Ok(ReplyEither::Err(e)) -} - -pub type JSONResult = Result, Infallible>; -pub fn reply_json(t: impl Serialize) -> JSONResult { - reply(warp::reply::json(&t)) -} - -pub fn reply(t: T) -> ReplyResult { - Ok(ReplyEither::Ok(t)) -} diff --git a/atuin-server/Cargo.toml b/atuin-server/Cargo.toml index e1acc97..16a9fa0 100644 --- a/atuin-server/Cargo.toml +++ b/atuin-server/Cargo.toml @@ -25,6 +25,7 @@ base64 = "0.13.0" rand = "0.8.4" rust-crypto = "^0.2" tokio = { version = "1", features = ["full"] } -warp = "0.3" sqlx = { version = "0.5", features = [ "runtime-tokio-rustls", "uuid", "chrono", "postgres" ] } async-trait = "0.1.49" +axum = "0.5" +http = "0.2" diff --git a/atuin-server/src/handlers/history.rs b/atuin-server/src/handlers/history.rs index 0671538..546e5a2 100644 --- a/atuin-server/src/handlers/history.rs +++ b/atuin-server/src/handlers/history.rs @@ -1,26 +1,27 @@ -use warp::{http::StatusCode, Reply}; +use axum::extract::Query; +use axum::{Extension, Json}; +use http::StatusCode; -use crate::database::Database; +use crate::database::{Database, Postgres}; use crate::models::{NewHistory, User}; use atuin_common::api::*; + pub async fn count( user: User, - db: impl Database + Clone + Send + Sync, -) -> JSONResult> { - db.count_history(&user).await.map_or( - reply_error( - ErrorResponse::reply("failed to query history count") - .with_status(StatusCode::INTERNAL_SERVER_ERROR), - ), - |count| reply_json(CountResponse { count }), - ) + db: Extension, +) -> Result, ErrorResponseStatus<'static>> { + match db.count_history(&user).await { + Ok(count) => Ok(Json(CountResponse { count })), + Err(_) => Err(ErrorResponse::reply("failed to query history count") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)), + } } pub async fn list( - req: SyncHistoryRequest<'_>, + req: Query, user: User, - db: impl Database + Clone + Send + Sync, -) -> JSONResult> { + db: Extension, +) -> Result, ErrorResponseStatus<'static>> { let history = db .list_history( &user, @@ -32,10 +33,8 @@ pub async fn list( if let Err(e) = history { error!("failed to load history: {}", e); - return reply_error( - ErrorResponse::reply("failed to load history") - .with_status(StatusCode::INTERNAL_SERVER_ERROR), - ); + return Err(ErrorResponse::reply("failed to load history") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); } let history: Vec = history @@ -50,14 +49,14 @@ pub async fn list( user.id ); - reply_json(SyncHistoryResponse { history }) + Ok(Json(SyncHistoryResponse { history })) } pub async fn add( - req: Vec>, + Json(req): Json>, user: User, - db: impl Database + Clone + Send + Sync, -) -> ReplyResult> { + db: Extension, +) -> Result<(), ErrorResponseStatus<'static>> { debug!("request to add {} history items", req.len()); let history: Vec = req @@ -67,18 +66,16 @@ pub async fn add( user_id: user.id, hostname: h.hostname, timestamp: h.timestamp.naive_utc(), - data: h.data.into(), + data: h.data, }) .collect(); if let Err(e) = db.add_history(&history).await { error!("failed to add history: {}", e); - return reply_error( - ErrorResponse::reply("failed to add history") - .with_status(StatusCode::INTERNAL_SERVER_ERROR), - ); + return Err(ErrorResponse::reply("failed to add history") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); }; - reply(warp::reply()) + Ok(()) } diff --git a/atuin-server/src/handlers/mod.rs b/atuin-server/src/handlers/mod.rs index 3c20538..83c2d0c 100644 --- a/atuin-server/src/handlers/mod.rs +++ b/atuin-server/src/handlers/mod.rs @@ -1,6 +1,6 @@ pub mod history; pub mod user; -pub const fn index() -> &'static str { +pub async fn index() -> &'static str { "\"Through the fathomless deeps of space swims the star turtle Great A\u{2019}Tuin, bearing on its back the four giant elephants who carry on their shoulders the mass of the Discworld.\"\n\t-- Sir Terry Pratchett" } diff --git a/atuin-server/src/handlers/user.rs b/atuin-server/src/handlers/user.rs index 8144ada..1bcfce2 100644 --- a/atuin-server/src/handlers/user.rs +++ b/atuin-server/src/handlers/user.rs @@ -2,11 +2,13 @@ use std::borrow::Borrow; use atuin_common::api::*; use atuin_common::utils::hash_secret; +use axum::extract::Path; +use axum::{Extension, Json}; +use http::StatusCode; use sodiumoxide::crypto::pwhash::argon2id13; use uuid::Uuid; -use warp::http::StatusCode; -use crate::database::Database; +use crate::database::{Database, Postgres}; use crate::models::{NewSession, NewUser}; use crate::settings::Settings; @@ -25,31 +27,29 @@ pub fn verify_str(secret: &str, verify: &str) -> bool { } pub async fn get( - username: impl AsRef, - db: impl Database + Clone + Send + Sync, -) -> JSONResult> { + Path(username): Path, + db: Extension, +) -> Result, ErrorResponseStatus<'static>> { let user = match db.get_user(username.as_ref()).await { Ok(user) => user, Err(e) => { debug!("user not found: {}", e); - return reply_error( - ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND), - ); + return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND)); } }; - reply_json(UserResponse { - username: user.username.into(), - }) + Ok(Json(UserResponse { + username: user.username, + })) } pub async fn register( - register: RegisterRequest<'_>, - settings: Settings, - db: impl Database + Clone + Send + Sync, -) -> JSONResult> { + Json(register): Json, + settings: Extension, + db: Extension, +) -> Result, ErrorResponseStatus<'static>> { if !settings.open_registration { - return reply_error( + return Err( ErrorResponse::reply("this server is not open for registrations") .with_status(StatusCode::BAD_REQUEST), ); @@ -60,15 +60,15 @@ pub async fn register( let new_user = NewUser { email: register.email, username: register.username, - password: hashed.into(), + password: hashed, }; let user_id = match db.add_user(&new_user).await { Ok(id) => id, Err(e) => { error!("failed to add user: {}", e); - return reply_error( - ErrorResponse::reply("failed to add user").with_status(StatusCode::BAD_REQUEST), + return Err( + ErrorResponse::reply("failed to add user").with_status(StatusCode::BAD_REQUEST) ); } }; @@ -81,31 +81,25 @@ pub async fn register( }; match db.add_session(&new_session).await { - Ok(_) => reply_json(RegisterResponse { - session: token.into(), - }), + Ok(_) => Ok(Json(RegisterResponse { session: token })), Err(e) => { error!("failed to add session: {}", e); - reply_error( - ErrorResponse::reply("failed to register user") - .with_status(StatusCode::BAD_REQUEST), - ) + Err(ErrorResponse::reply("failed to register user") + .with_status(StatusCode::BAD_REQUEST)) } } } pub async fn login( - login: LoginRequest<'_>, - db: impl Database + Clone + Send + Sync, -) -> JSONResult> { + login: Json, + db: Extension, +) -> Result, ErrorResponseStatus<'static>> { let user = match db.get_user(login.username.borrow()).await { Ok(u) => u, Err(e) => { error!("failed to get user {}: {}", login.username.clone(), e); - return reply_error( - ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND), - ); + return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND)); } }; @@ -114,21 +108,17 @@ pub async fn login( Err(e) => { error!("failed to get session for {}: {}", login.username, e); - return reply_error( - ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND), - ); + return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND)); } }; let verified = verify_str(user.password.as_str(), login.password.borrow()); if !verified { - return reply_error( - ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND), - ); + return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND)); } - reply_json(LoginResponse { - session: session.token.into(), - }) + Ok(Json(LoginResponse { + session: session.token, + })) } diff --git a/atuin-server/src/lib.rs b/atuin-server/src/lib.rs index e485881..ca0aa11 100644 --- a/atuin-server/src/lib.rs +++ b/atuin-server/src/lib.rs @@ -1,8 +1,10 @@ #![forbid(unsafe_code)] -use std::net::IpAddr; +use std::net::{IpAddr, SocketAddr}; -use eyre::Result; +use axum::Server; +use database::Postgres; +use eyre::{Context, Result}; use crate::settings::Settings; @@ -19,14 +21,18 @@ pub mod models; pub mod router; pub mod settings; -pub async fn launch(settings: &Settings, host: String, port: u16) -> Result<()> { - // routes to run: - // index, register, add_history, login, get_user, sync_count, sync_list +pub async fn launch(settings: Settings, host: String, port: u16) -> Result<()> { let host = host.parse::()?; - let r = router::router(settings).await?; + let postgres = Postgres::new(settings.db_uri.as_str()) + .await + .wrap_err_with(|| format!("failed to connect to db: {}", settings.db_uri))?; - warp::serve(r).run((host, port)).await; + let r = router::router(postgres, settings); + + Server::bind(&SocketAddr::new(host, port)) + .serve(r.into_make_service()) + .await?; Ok(()) } diff --git a/atuin-server/src/models.rs b/atuin-server/src/models.rs index d493153..ee84f58 100644 --- a/atuin-server/src/models.rs +++ b/atuin-server/src/models.rs @@ -1,5 +1,3 @@ -use std::borrow::Cow; - use chrono::prelude::*; #[derive(sqlx::FromRow)] @@ -15,13 +13,13 @@ pub struct History { pub created_at: NaiveDateTime, } -pub struct NewHistory<'a> { - pub client_id: Cow<'a, str>, +pub struct NewHistory { + pub client_id: String, pub user_id: i64, - pub hostname: Cow<'a, str>, + pub hostname: String, pub timestamp: chrono::NaiveDateTime, - pub data: Cow<'a, str>, + pub data: String, } #[derive(sqlx::FromRow)] @@ -39,13 +37,13 @@ pub struct Session { pub token: String, } -pub struct NewUser<'a> { - pub username: Cow<'a, str>, - pub email: Cow<'a, str>, - pub password: Cow<'a, str>, +pub struct NewUser { + pub username: String, + pub email: String, + pub password: String, } -pub struct NewSession<'a> { +pub struct NewSession { pub user_id: i64, - pub token: Cow<'a, str>, + pub token: String, } diff --git a/atuin-server/src/router.rs b/atuin-server/src/router.rs index f7e142a..6ca4722 100644 --- a/atuin-server/src/router.rs +++ b/atuin-server/src/router.rs @@ -1,9 +1,12 @@ -use std::convert::Infallible; - +use async_trait::async_trait; +use axum::{ + extract::{FromRequest, RequestParts}, + handler::Handler, + response::IntoResponse, + routing::{get, post}, + Extension, Router, +}; use eyre::Result; -use warp::{hyper::StatusCode, Filter}; - -use atuin_common::api::SyncHistoryRequest; use super::{ database::{Database, Postgres}, @@ -11,119 +14,57 @@ use super::{ }; use crate::{models::User, settings::Settings}; -fn with_settings( - settings: Settings, -) -> impl Filter + Clone { - warp::any().map(move || settings.clone()) -} +#[async_trait] +impl FromRequest for User +where + B: Send, +{ + type Rejection = http::StatusCode; -fn with_db( - db: impl Database + Clone + Send + Sync, -) -> impl Filter + Clone { - warp::any().map(move || db.clone()) -} + async fn from_request(req: &mut RequestParts) -> Result { + let postgres = req + .extensions() + .get::() + .ok_or(http::StatusCode::INTERNAL_SERVER_ERROR)?; -fn with_user( - postgres: Postgres, -) -> impl Filter + Clone { - warp::header::("authorization").and_then(move |header: String| { - // async closures are still buggy :( - let postgres = postgres.clone(); + let auth_header = req + .headers() + .get(http::header::AUTHORIZATION) + .ok_or(http::StatusCode::FORBIDDEN)?; + let auth_header = auth_header + .to_str() + .map_err(|_| http::StatusCode::FORBIDDEN)?; + let (typ, token) = auth_header + .split_once(' ') + .ok_or(http::StatusCode::FORBIDDEN)?; - async move { - let header: Vec<&str> = header.split(' ').collect(); - - let token = if header.len() == 2 { - if header[0] != "Token" { - return Err(warp::reject()); - } - - header[1] - } else { - return Err(warp::reject()); - }; - - let user = postgres - .get_session_user(token) - .await - .map_err(|_| warp::reject())?; - - Ok(user) + if typ != "Token" { + return Err(http::StatusCode::FORBIDDEN); } - }) + + let user = postgres + .get_session_user(token) + .await + .map_err(|_| http::StatusCode::FORBIDDEN)?; + + Ok(user) + } } -pub async fn router( - settings: &Settings, -) -> Result + Clone> { - let postgres = Postgres::new(settings.db_uri.as_str()).await?; - let index = warp::get().and(warp::path::end()).map(handlers::index); - - let count = warp::get() - .and(warp::path("sync")) - .and(warp::path("count")) - .and(warp::path::end()) - .and(with_user(postgres.clone())) - .and(with_db(postgres.clone())) - .and_then(handlers::history::count) - .boxed(); - - let sync = warp::get() - .and(warp::path("sync")) - .and(warp::path("history")) - .and(warp::query::()) - .and(warp::path::end()) - .and(with_user(postgres.clone())) - .and(with_db(postgres.clone())) - .and_then(handlers::history::list) - .boxed(); - - let add_history = warp::post() - .and(warp::path("history")) - .and(warp::path::end()) - .and(warp::body::json()) - .and(with_user(postgres.clone())) - .and(with_db(postgres.clone())) - .and_then(handlers::history::add) - .boxed(); - - let user = warp::get() - .and(warp::path("user")) - .and(warp::path::param::()) - .and(warp::path::end()) - .and(with_db(postgres.clone())) - .and_then(handlers::user::get) - .boxed(); - - let register = warp::post() - .and(warp::path("register")) - .and(warp::path::end()) - .and(warp::body::json()) - .and(with_settings(settings.clone())) - .and(with_db(postgres.clone())) - .and_then(handlers::user::register) - .boxed(); - - let login = warp::post() - .and(warp::path("login")) - .and(warp::path::end()) - .and(warp::body::json()) - .and(with_db(postgres)) - .and_then(handlers::user::login) - .boxed(); - - let r = warp::any() - .and( - index - .or(count) - .or(sync) - .or(add_history) - .or(user) - .or(register) - .or(login) - .or(warp::any().map(|| warp::reply::with_status("☕", StatusCode::IM_A_TEAPOT))), - ) - .with(warp::filters::log::log("atuin::api")); - - Ok(r) +async fn teapot() -> impl IntoResponse { + (http::StatusCode::IM_A_TEAPOT, "☕") +} + +pub fn router(postgres: Postgres, settings: Settings) -> Router { + Router::new() + .route("/", get(handlers::index)) + .route("/sync/count", get(handlers::history::count)) + .route("/sync/history", get(handlers::history::list)) + .route("/history", post(handlers::history::add)) + .route("/user/:username", get(handlers::user::get)) + .route("/register", post(handlers::user::register)) + .route("/login", post(handlers::user::login)) + .fallback(teapot.into_service()) + .layer(Extension(postgres)) + .layer(Extension(settings)) } diff --git a/src/command/login.rs b/src/command/login.rs index fe442bc..efc9c59 100644 --- a/src/command/login.rs +++ b/src/command/login.rs @@ -1,4 +1,3 @@ -use std::borrow::Cow; use std::io; use atuin_common::api::LoginRequest; @@ -66,10 +65,8 @@ impl Cmd { } } -pub(super) fn or_user_input<'a>(value: &'a Option, name: &'static str) -> Cow<'a, str> { - value - .as_deref() - .map_or_else(|| Cow::Owned(read_user_input(name)), Cow::Borrowed) +pub(super) fn or_user_input(value: &'_ Option, name: &'static str) -> String { + value.clone().unwrap_or_else(|| read_user_input(name)) } fn read_user_input(name: &'static str) -> String { diff --git a/src/command/mod.rs b/src/command/mod.rs index 6873c58..8463421 100644 --- a/src/command/mod.rs +++ b/src/command/mod.rs @@ -133,7 +133,7 @@ impl AtuinCmd { match self { Self::History(history) => history.run(&client_settings, &mut db).await, Self::Import(import) => import.run(&mut db).await, - Self::Server(server) => server.run(&server_settings).await, + Self::Server(server) => server.run(server_settings).await, Self::Stats(stats) => stats.run(&mut db, &client_settings).await, Self::Init(init) => { init.run(); diff --git a/src/command/server.rs b/src/command/server.rs index 6047e5b..9d97e92 100644 --- a/src/command/server.rs +++ b/src/command/server.rs @@ -10,7 +10,7 @@ pub enum Cmd { /// Start the server Start { /// The host address to bind - #[clap(long, short)] + #[clap(long)] host: Option, /// The port to bind @@ -20,7 +20,7 @@ pub enum Cmd { } impl Cmd { - pub async fn run(&self, settings: &Settings) -> Result<()> { + pub async fn run(&self, settings: Settings) -> Result<()> { match self { Self::Start { host, port } => { let host = host