remove dyn Reply (#48)

* cleanup reply types

* cleanup error api

* small update

* improve api some more

* fmt
This commit is contained in:
Conrad Ludgate 2021-05-07 21:06:56 +01:00 committed by GitHub
parent e2edcbf994
commit 1c59f85ea8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 114 additions and 77 deletions

View file

@ -1,4 +1,8 @@
use std::convert::Infallible;
use chrono::Utc; use chrono::Utc;
use serde::Serialize;
use warp::{reply::Response, Reply};
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct UserResponse { pub struct UserResponse {
@ -58,13 +62,62 @@ pub struct ErrorResponse {
pub reason: String, pub reason: String,
} }
impl ErrorResponse { impl Reply for ErrorResponse {
pub fn reply(reason: &str, status: warp::http::StatusCode) -> impl warp::Reply { fn into_response(self) -> Response {
warp::reply::with_status( warp::reply::json(&self).into_response()
warp::reply::json(&ErrorResponse {
reason: String::from(reason),
}),
status,
)
} }
} }
pub struct ErrorResponseStatus {
pub error: ErrorResponse,
pub status: warp::http::StatusCode,
}
impl Reply for ErrorResponseStatus {
fn into_response(self) -> Response {
warp::reply::with_status(self.error, self.status).into_response()
}
}
impl ErrorResponse {
pub fn with_status(self, status: warp::http::StatusCode) -> ErrorResponseStatus {
ErrorResponseStatus {
error: self,
status,
}
}
pub fn reply(reason: &str) -> ErrorResponse {
Self {
reason: reason.to_string(),
}
}
}
pub enum ReplyEither<T, E> {
Ok(T),
Err(E),
}
impl<T: Reply, E: Reply> Reply for ReplyEither<T, E> {
fn into_response(self) -> Response {
match self {
ReplyEither::Ok(t) => t.into_response(),
ReplyEither::Err(e) => e.into_response(),
}
}
}
pub type ReplyResult<T, E> = Result<ReplyEither<T, E>, Infallible>;
pub fn reply_error<T, E>(e: E) -> ReplyResult<T, E> {
Ok(ReplyEither::Err(e))
}
pub type JSONResult<E> = Result<ReplyEither<warp::reply::Json, E>, Infallible>;
pub fn reply_json<E>(t: impl Serialize) -> JSONResult<E> {
reply(warp::reply::json(&t))
}
pub fn reply<T, E>(t: T) -> ReplyResult<T, E> {
Ok(ReplyEither::Ok(t))
}

View file

@ -1,23 +1,18 @@
use std::convert::Infallible; use warp::{http::StatusCode, Reply};
use warp::{http::StatusCode, reply::json};
use crate::database::Database; use crate::database::Database;
use crate::models::{NewHistory, User}; use crate::models::{NewHistory, User};
use atuin_common::api::{ use atuin_common::api::*;
AddHistoryRequest, CountResponse, ErrorResponse, SyncHistoryRequest, SyncHistoryResponse,
};
pub async fn count( pub async fn count(
user: User, user: User,
db: impl Database + Clone + Send + Sync, db: impl Database + Clone + Send + Sync,
) -> Result<Box<dyn warp::Reply>, Infallible> { ) -> JSONResult<ErrorResponseStatus> {
db.count_history(&user).await.map_or( db.count_history(&user).await.map_or(
Ok(Box::new(ErrorResponse::reply( reply_error(
"failed to query history count", ErrorResponse::reply("failed to query history count")
StatusCode::INTERNAL_SERVER_ERROR, .with_status(StatusCode::INTERNAL_SERVER_ERROR),
))), ),
|count| Ok(Box::new(json(&CountResponse { count }))), |count| reply_json(CountResponse { count }),
) )
} }
@ -25,7 +20,7 @@ pub async fn list(
req: SyncHistoryRequest, req: SyncHistoryRequest,
user: User, user: User,
db: impl Database + Clone + Send + Sync, db: impl Database + Clone + Send + Sync,
) -> Result<Box<dyn warp::Reply>, Infallible> { ) -> JSONResult<ErrorResponseStatus> {
let history = db let history = db
.list_history( .list_history(
&user, &user,
@ -37,10 +32,10 @@ pub async fn list(
if let Err(e) = history { if let Err(e) = history {
error!("failed to load history: {}", e); error!("failed to load history: {}", e);
let resp = return reply_error(
ErrorResponse::reply("failed to load history", StatusCode::INTERNAL_SERVER_ERROR); ErrorResponse::reply("failed to load history")
let resp = Box::new(resp); .with_status(StatusCode::INTERNAL_SERVER_ERROR),
return Ok(resp); );
} }
let history: Vec<String> = history let history: Vec<String> = history
@ -55,14 +50,14 @@ pub async fn list(
user.id user.id
); );
Ok(Box::new(json(&SyncHistoryResponse { history }))) reply_json(SyncHistoryResponse { history })
} }
pub async fn add( pub async fn add(
req: Vec<AddHistoryRequest>, req: Vec<AddHistoryRequest>,
user: User, user: User,
db: impl Database + Clone + Send + Sync, db: impl Database + Clone + Send + Sync,
) -> Result<Box<dyn warp::Reply>, Infallible> { ) -> ReplyResult<impl Reply, ErrorResponseStatus> {
debug!("request to add {} history items", req.len()); debug!("request to add {} history items", req.len());
let history: Vec<NewHistory> = req let history: Vec<NewHistory> = req
@ -79,11 +74,11 @@ pub async fn add(
if let Err(e) = db.add_history(&history).await { if let Err(e) = db.add_history(&history).await {
error!("failed to add history: {}", e); error!("failed to add history: {}", e);
return Ok(Box::new(ErrorResponse::reply( return reply_error(
"failed to add history", ErrorResponse::reply("failed to add history")
StatusCode::INTERNAL_SERVER_ERROR, .with_status(StatusCode::INTERNAL_SERVER_ERROR),
))); );
}; };
Ok(Box::new(warp::reply())) reply(warp::reply())
} }

View file

@ -1,14 +1,8 @@
use std::convert::Infallible; use atuin_common::api::*;
use atuin_common::utils::hash_secret;
use sodiumoxide::crypto::pwhash::argon2id13; use sodiumoxide::crypto::pwhash::argon2id13;
use uuid::Uuid; use uuid::Uuid;
use warp::http::StatusCode; use warp::http::StatusCode;
use warp::reply::json;
use atuin_common::api::{
ErrorResponse, LoginRequest, LoginResponse, RegisterRequest, RegisterResponse, UserResponse,
};
use atuin_common::utils::hash_secret;
use crate::database::Database; use crate::database::Database;
use crate::models::{NewSession, NewUser}; use crate::models::{NewSession, NewUser};
@ -31,33 +25,32 @@ pub fn verify_str(secret: &str, verify: &str) -> bool {
pub async fn get( pub async fn get(
username: String, username: String,
db: impl Database + Clone + Send + Sync, db: impl Database + Clone + Send + Sync,
) -> Result<Box<dyn warp::Reply>, Infallible> { ) -> JSONResult<ErrorResponseStatus> {
let user = match db.get_user(username).await { let user = match db.get_user(username).await {
Ok(user) => user, Ok(user) => user,
Err(e) => { Err(e) => {
debug!("user not found: {}", e); debug!("user not found: {}", e);
return Ok(Box::new(ErrorResponse::reply( return reply_error(
"user not found", ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND),
StatusCode::NOT_FOUND, );
)));
} }
}; };
Ok(Box::new(warp::reply::json(&UserResponse { reply_json(UserResponse {
username: user.username, username: user.username,
}))) })
} }
pub async fn register( pub async fn register(
register: RegisterRequest, register: RegisterRequest,
settings: Settings, settings: Settings,
db: impl Database + Clone + Send + Sync, db: impl Database + Clone + Send + Sync,
) -> Result<Box<dyn warp::Reply>, Infallible> { ) -> JSONResult<ErrorResponseStatus> {
if !settings.open_registration { if !settings.open_registration {
return Ok(Box::new(ErrorResponse::reply( return reply_error(
"this server is not open for registrations", ErrorResponse::reply("this server is not open for registrations")
StatusCode::BAD_REQUEST, .with_status(StatusCode::BAD_REQUEST),
))); );
} }
let hashed = hash_secret(register.password.as_str()); let hashed = hash_secret(register.password.as_str());
@ -72,10 +65,9 @@ pub async fn register(
Ok(id) => id, Ok(id) => id,
Err(e) => { Err(e) => {
error!("failed to add user: {}", e); error!("failed to add user: {}", e);
return Ok(Box::new(ErrorResponse::reply( return reply_error(
"failed to add user", ErrorResponse::reply("failed to add user").with_status(StatusCode::BAD_REQUEST),
StatusCode::BAD_REQUEST, );
)));
} }
}; };
@ -87,13 +79,13 @@ pub async fn register(
}; };
match db.add_session(&new_session).await { match db.add_session(&new_session).await {
Ok(_) => Ok(Box::new(json(&RegisterResponse { session: token }))), Ok(_) => reply_json(RegisterResponse { session: token }),
Err(e) => { Err(e) => {
error!("failed to add session: {}", e); error!("failed to add session: {}", e);
Ok(Box::new(ErrorResponse::reply( reply_error(
"failed to register user", ErrorResponse::reply("failed to register user")
StatusCode::BAD_REQUEST, .with_status(StatusCode::BAD_REQUEST),
))) )
} }
} }
} }
@ -101,16 +93,15 @@ pub async fn register(
pub async fn login( pub async fn login(
login: LoginRequest, login: LoginRequest,
db: impl Database + Clone + Send + Sync, db: impl Database + Clone + Send + Sync,
) -> Result<Box<dyn warp::Reply>, Infallible> { ) -> JSONResult<ErrorResponseStatus> {
let user = match db.get_user(login.username.clone()).await { let user = match db.get_user(login.username.clone()).await {
Ok(u) => u, Ok(u) => u,
Err(e) => { Err(e) => {
error!("failed to get user {}: {}", login.username.clone(), e); error!("failed to get user {}: {}", login.username.clone(), e);
return Ok(Box::new(ErrorResponse::reply( return reply_error(
"user not found", ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND),
StatusCode::NOT_FOUND, );
)));
} }
}; };
@ -119,23 +110,21 @@ pub async fn login(
Err(e) => { Err(e) => {
error!("failed to get session for {}: {}", login.username, e); error!("failed to get session for {}: {}", login.username, e);
return Ok(Box::new(ErrorResponse::reply( return reply_error(
"user not found", ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND),
StatusCode::NOT_FOUND, );
)));
} }
}; };
let verified = verify_str(user.password.as_str(), login.password.as_str()); let verified = verify_str(user.password.as_str(), login.password.as_str());
if !verified { if !verified {
return Ok(Box::new(ErrorResponse::reply( return reply_error(
"user not found", ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND),
StatusCode::NOT_FOUND, );
)));
} }
Ok(Box::new(warp::reply::json(&LoginResponse { reply_json(LoginResponse {
session: session.token, session: session.token,
}))) })
} }