diff --git a/atuin-common/src/api.rs b/atuin-common/src/api.rs index 82ee660..44a73c1 100644 --- a/atuin-common/src/api.rs +++ b/atuin-common/src/api.rs @@ -1,4 +1,8 @@ +use std::convert::Infallible; + use chrono::Utc; +use serde::Serialize; +use warp::{reply::Response, Reply}; #[derive(Debug, Serialize, Deserialize)] pub struct UserResponse { @@ -58,13 +62,62 @@ pub struct ErrorResponse { pub reason: String, } -impl ErrorResponse { - pub fn reply(reason: &str, status: warp::http::StatusCode) -> impl warp::Reply { - warp::reply::with_status( - warp::reply::json(&ErrorResponse { - reason: String::from(reason), - }), - status, - ) +impl Reply for ErrorResponse { + fn into_response(self) -> Response { + warp::reply::json(&self).into_response() } } + +pub struct ErrorResponseStatus { + pub error: ErrorResponse, + pub status: warp::http::StatusCode, +} + +impl Reply for ErrorResponseStatus { + fn into_response(self) -> Response { + warp::reply::with_status(self.error, self.status).into_response() + } +} + +impl ErrorResponse { + pub fn with_status(self, status: warp::http::StatusCode) -> ErrorResponseStatus { + ErrorResponseStatus { + error: self, + status, + } + } + + pub fn reply(reason: &str) -> ErrorResponse { + Self { + reason: reason.to_string(), + } + } +} + +pub enum ReplyEither { + Ok(T), + Err(E), +} + +impl Reply for ReplyEither { + fn into_response(self) -> Response { + match self { + ReplyEither::Ok(t) => t.into_response(), + ReplyEither::Err(e) => e.into_response(), + } + } +} + +pub type ReplyResult = Result, Infallible>; +pub fn reply_error(e: E) -> ReplyResult { + Ok(ReplyEither::Err(e)) +} + +pub type JSONResult = Result, Infallible>; +pub fn reply_json(t: impl Serialize) -> JSONResult { + reply(warp::reply::json(&t)) +} + +pub fn reply(t: T) -> ReplyResult { + Ok(ReplyEither::Ok(t)) +} diff --git a/atuin-server/src/handlers/history.rs b/atuin-server/src/handlers/history.rs index 1aebdde..18852b5 100644 --- a/atuin-server/src/handlers/history.rs +++ b/atuin-server/src/handlers/history.rs @@ -1,23 +1,18 @@ -use std::convert::Infallible; - -use warp::{http::StatusCode, reply::json}; +use warp::{http::StatusCode, Reply}; use crate::database::Database; use crate::models::{NewHistory, User}; -use atuin_common::api::{ - AddHistoryRequest, CountResponse, ErrorResponse, SyncHistoryRequest, SyncHistoryResponse, -}; - +use atuin_common::api::*; pub async fn count( user: User, db: impl Database + Clone + Send + Sync, -) -> Result, Infallible> { +) -> JSONResult { db.count_history(&user).await.map_or( - Ok(Box::new(ErrorResponse::reply( - "failed to query history count", - StatusCode::INTERNAL_SERVER_ERROR, - ))), - |count| Ok(Box::new(json(&CountResponse { count }))), + reply_error( + ErrorResponse::reply("failed to query history count") + .with_status(StatusCode::INTERNAL_SERVER_ERROR), + ), + |count| reply_json(CountResponse { count }), ) } @@ -25,7 +20,7 @@ pub async fn list( req: SyncHistoryRequest, user: User, db: impl Database + Clone + Send + Sync, -) -> Result, Infallible> { +) -> JSONResult { let history = db .list_history( &user, @@ -37,10 +32,10 @@ pub async fn list( if let Err(e) = history { error!("failed to load history: {}", e); - let resp = - ErrorResponse::reply("failed to load history", StatusCode::INTERNAL_SERVER_ERROR); - let resp = Box::new(resp); - return Ok(resp); + return reply_error( + ErrorResponse::reply("failed to load history") + .with_status(StatusCode::INTERNAL_SERVER_ERROR), + ); } let history: Vec = history @@ -55,14 +50,14 @@ pub async fn list( user.id ); - Ok(Box::new(json(&SyncHistoryResponse { history }))) + reply_json(SyncHistoryResponse { history }) } pub async fn add( req: Vec, user: User, db: impl Database + Clone + Send + Sync, -) -> Result, Infallible> { +) -> ReplyResult { debug!("request to add {} history items", req.len()); let history: Vec = req @@ -79,11 +74,11 @@ pub async fn add( if let Err(e) = db.add_history(&history).await { error!("failed to add history: {}", e); - return Ok(Box::new(ErrorResponse::reply( - "failed to add history", - StatusCode::INTERNAL_SERVER_ERROR, - ))); + return reply_error( + ErrorResponse::reply("failed to add history") + .with_status(StatusCode::INTERNAL_SERVER_ERROR), + ); }; - Ok(Box::new(warp::reply())) + reply(warp::reply()) } diff --git a/atuin-server/src/handlers/user.rs b/atuin-server/src/handlers/user.rs index 6b142cd..ed77916 100644 --- a/atuin-server/src/handlers/user.rs +++ b/atuin-server/src/handlers/user.rs @@ -1,14 +1,8 @@ -use std::convert::Infallible; - +use atuin_common::api::*; +use atuin_common::utils::hash_secret; use sodiumoxide::crypto::pwhash::argon2id13; use uuid::Uuid; use warp::http::StatusCode; -use warp::reply::json; - -use atuin_common::api::{ - ErrorResponse, LoginRequest, LoginResponse, RegisterRequest, RegisterResponse, UserResponse, -}; -use atuin_common::utils::hash_secret; use crate::database::Database; use crate::models::{NewSession, NewUser}; @@ -31,33 +25,32 @@ pub fn verify_str(secret: &str, verify: &str) -> bool { pub async fn get( username: String, db: impl Database + Clone + Send + Sync, -) -> Result, Infallible> { +) -> JSONResult { let user = match db.get_user(username).await { Ok(user) => user, Err(e) => { debug!("user not found: {}", e); - return Ok(Box::new(ErrorResponse::reply( - "user not found", - StatusCode::NOT_FOUND, - ))); + return reply_error( + ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND), + ); } }; - Ok(Box::new(warp::reply::json(&UserResponse { + reply_json(UserResponse { username: user.username, - }))) + }) } pub async fn register( register: RegisterRequest, settings: Settings, db: impl Database + Clone + Send + Sync, -) -> Result, Infallible> { +) -> JSONResult { if !settings.open_registration { - return Ok(Box::new(ErrorResponse::reply( - "this server is not open for registrations", - StatusCode::BAD_REQUEST, - ))); + return reply_error( + ErrorResponse::reply("this server is not open for registrations") + .with_status(StatusCode::BAD_REQUEST), + ); } let hashed = hash_secret(register.password.as_str()); @@ -72,10 +65,9 @@ pub async fn register( Ok(id) => id, Err(e) => { error!("failed to add user: {}", e); - return Ok(Box::new(ErrorResponse::reply( - "failed to add user", - StatusCode::BAD_REQUEST, - ))); + return reply_error( + ErrorResponse::reply("failed to add user").with_status(StatusCode::BAD_REQUEST), + ); } }; @@ -87,13 +79,13 @@ pub async fn register( }; match db.add_session(&new_session).await { - Ok(_) => Ok(Box::new(json(&RegisterResponse { session: token }))), + Ok(_) => reply_json(RegisterResponse { session: token }), Err(e) => { error!("failed to add session: {}", e); - Ok(Box::new(ErrorResponse::reply( - "failed to register user", - StatusCode::BAD_REQUEST, - ))) + reply_error( + ErrorResponse::reply("failed to register user") + .with_status(StatusCode::BAD_REQUEST), + ) } } } @@ -101,16 +93,15 @@ pub async fn register( pub async fn login( login: LoginRequest, db: impl Database + Clone + Send + Sync, -) -> Result, Infallible> { +) -> JSONResult { let user = match db.get_user(login.username.clone()).await { Ok(u) => u, Err(e) => { error!("failed to get user {}: {}", login.username.clone(), e); - return Ok(Box::new(ErrorResponse::reply( - "user not found", - StatusCode::NOT_FOUND, - ))); + return reply_error( + ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND), + ); } }; @@ -119,23 +110,21 @@ pub async fn login( Err(e) => { error!("failed to get session for {}: {}", login.username, e); - return Ok(Box::new(ErrorResponse::reply( - "user not found", - StatusCode::NOT_FOUND, - ))); + return reply_error( + ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND), + ); } }; let verified = verify_str(user.password.as_str(), login.password.as_str()); if !verified { - return Ok(Box::new(ErrorResponse::reply( - "user not found", - StatusCode::NOT_FOUND, - ))); + return reply_error( + ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND), + ); } - Ok(Box::new(warp::reply::json(&LoginResponse { + reply_json(LoginResponse { session: session.token, - }))) + }) }