diff --git a/atuin-server/src/handlers/history.rs b/atuin-server/src/handlers/history.rs index 9ee13e1..7cf1832 100644 --- a/atuin-server/src/handlers/history.rs +++ b/atuin-server/src/handlers/history.rs @@ -18,11 +18,11 @@ use crate::{ use atuin_common::api::*; #[instrument(skip_all, fields(user.id = user.id))] -pub async fn count( +pub async fn count( user: User, - state: State, + state: State>, ) -> Result, ErrorResponseStatus<'static>> { - let db = &state.0.postgres; + let db = &state.0.database; match db.count_history_cached(&user).await { // By default read out the cached value Ok(count) => Ok(Json(CountResponse { count })), @@ -38,12 +38,12 @@ pub async fn count( } #[instrument(skip_all, fields(user.id = user.id))] -pub async fn list( +pub async fn list( req: Query, user: User, - state: State, + state: State>, ) -> Result, ErrorResponseStatus<'static>> { - let db = &state.0.postgres; + let db = &state.0.database; let history = db .list_history( &user, @@ -75,9 +75,9 @@ pub async fn list( } #[instrument(skip_all, fields(user.id = user.id))] -pub async fn add( +pub async fn add( user: User, - state: State, + state: State>, Json(req): Json>, ) -> Result<(), ErrorResponseStatus<'static>> { debug!("request to add {} history items", req.len()); @@ -93,7 +93,7 @@ pub async fn add( }) .collect(); - let db = &state.0.postgres; + let db = &state.0.database; if let Err(e) = db.add_history(&history).await { error!("failed to add history: {}", e); @@ -105,18 +105,18 @@ pub async fn add( } #[instrument(skip_all, fields(user.id = user.id))] -pub async fn calendar( +pub async fn calendar( Path(focus): Path, Query(params): Query>, user: User, - state: State, + state: State>, ) -> Result>, ErrorResponseStatus<'static>> { let focus = focus.as_str(); let year = params.get("year").unwrap_or(&0); let month = params.get("month").unwrap_or(&1); - let db = &state.0.postgres; + let db = &state.0.database; let focus = match focus { "year" => db .calendar(&user, TimePeriod::YEAR, *year, *month) diff --git a/atuin-server/src/handlers/user.rs b/atuin-server/src/handlers/user.rs index 761724c..677e7c6 100644 --- a/atuin-server/src/handlers/user.rs +++ b/atuin-server/src/handlers/user.rs @@ -34,11 +34,11 @@ pub fn verify_str(secret: &str, verify: &str) -> bool { } #[instrument(skip_all, fields(user.username = username.as_str()))] -pub async fn get( +pub async fn get( Path(username): Path, - state: State, + state: State>, ) -> Result, ErrorResponseStatus<'static>> { - let db = &state.0.postgres; + let db = &state.0.database; let user = match db.get_user(username.as_ref()).await { Ok(user) => user, Err(sqlx::Error::RowNotFound) => { @@ -58,9 +58,9 @@ pub async fn get( } #[instrument(skip_all)] -pub async fn register( +pub async fn register( settings: Extension, - state: State, + state: State>, Json(register): Json, ) -> Result, ErrorResponseStatus<'static>> { if !settings.open_registration { @@ -78,7 +78,7 @@ pub async fn register( password: hashed, }; - let db = &state.0.postgres; + let db = &state.0.database; let user_id = match db.add_user(&new_user).await { Ok(id) => id, Err(e) => { @@ -107,11 +107,11 @@ pub async fn register( } #[instrument(skip_all, fields(user.username = login.username.as_str()))] -pub async fn login( - state: State, +pub async fn login( + state: State>, login: Json, ) -> Result, ErrorResponseStatus<'static>> { - let db = &state.0.postgres; + let db = &state.0.database; let user = match db.get_user(login.username.borrow()).await { Ok(u) => u, Err(sqlx::Error::RowNotFound) => { diff --git a/atuin-server/src/router.rs b/atuin-server/src/router.rs index c4c15f1..c4f7d30 100644 --- a/atuin-server/src/router.rs +++ b/atuin-server/src/router.rs @@ -10,19 +10,19 @@ use http::request::Parts; use tower::ServiceBuilder; use tower_http::trace::TraceLayer; -use super::{ - database::{Database, Postgres}, - handlers, -}; +use super::{database::Database, handlers}; use crate::{models::User, settings::Settings}; #[async_trait] -impl FromRequestParts for User { +impl FromRequestParts> for User +where + DB: Database, +{ type Rejection = http::StatusCode; async fn from_request_parts( req: &mut Parts, - state: &AppState, + state: &AppState, ) -> Result { let auth_header = req .headers @@ -40,7 +40,7 @@ impl FromRequestParts for User { } let user = state - .postgres + .database .get_session_user(token) .await .map_err(|_| http::StatusCode::FORBIDDEN)?; @@ -54,12 +54,15 @@ async fn teapot() -> impl IntoResponse { } #[derive(Clone)] -pub struct AppState { - pub postgres: Postgres, +pub struct AppState { + pub database: DB, pub settings: Settings, } -pub fn router(postgres: Postgres, settings: Settings) -> Router { +pub fn router( + database: DB, + settings: Settings, +) -> Router { let routes = Router::new() .route("/", get(handlers::index)) .route("/sync/count", get(handlers::history::count)) @@ -77,6 +80,6 @@ pub fn router(postgres: Postgres, settings: Settings) -> Router { Router::new().nest(path, routes) } .fallback(teapot) - .with_state(AppState { postgres, settings }) + .with_state(AppState { database, settings }) .layer(ServiceBuilder::new().layer(TraceLayer::new_for_http())) }