diff --git a/atuin-client/src/api_client.rs b/atuin-client/src/api_client.rs index b6625a3..fbeea9a 100644 --- a/atuin-client/src/api_client.rs +++ b/atuin-client/src/api_client.rs @@ -181,9 +181,19 @@ impl<'a> Client<'a> { let resp = self.client.get(url).send().await?; - let history = resp.json::().await?; - - Ok(history) + let status = resp.status(); + if status.is_success() { + let history = resp.json::().await?; + Ok(history) + } else if status.is_client_error() { + let error = resp.json::().await?.reason; + bail!("Could not fetch history: {error}.") + } else if status.is_server_error() { + let error = resp.json::().await?.reason; + bail!("There was an error with the atuin sync service: {error}.\nIf the problem persists, contact the host") + } else { + bail!("There was an error with the atuin sync service: Status {status:?}.\nIf the problem persists, contact the host") + } } pub async fn post_history(&self, history: &[AddHistoryRequest]) -> Result<()> { diff --git a/atuin-server/src/handlers/history.rs b/atuin-server/src/handlers/history.rs index 263d6cb..7d6b273 100644 --- a/atuin-server/src/handlers/history.rs +++ b/atuin-server/src/handlers/history.rs @@ -63,10 +63,6 @@ pub async fn list( 100 }; - let history = db - .list_history(&user, req.sync_ts, req.history_ts, &req.host, page_size) - .await; - if req.sync_ts.unix_timestamp_nanos() < 0 || req.history_ts.unix_timestamp_nanos() < 0 { error!("client asked for history from < epoch 0"); return Err( @@ -75,6 +71,10 @@ pub async fn list( ); } + let history = db + .list_history(&user, req.sync_ts, req.history_ts, &req.host, page_size) + .await; + if let Err(e) = history { error!("failed to load history: {}", e); return Err(ErrorResponse::reply("failed to load history") diff --git a/atuin-server/src/router.rs b/atuin-server/src/router.rs index e5b756b..7cfcdad 100644 --- a/atuin-server/src/router.rs +++ b/atuin-server/src/router.rs @@ -1,4 +1,5 @@ use async_trait::async_trait; +use atuin_common::api::ErrorResponse; use axum::{ extract::FromRequestParts, response::IntoResponse, @@ -11,8 +12,11 @@ use tower::ServiceBuilder; use tower_http::trace::TraceLayer; use super::handlers; -use crate::settings::Settings; -use atuin_server_database::{models::User, Database}; +use crate::{ + handlers::{ErrorResponseStatus, RespExt}, + settings::Settings, +}; +use atuin_server_database::{models::User, Database, DbError}; pub struct UserAuth(pub User); @@ -21,7 +25,7 @@ impl FromRequestParts> for UserAuth where DB: Database, { - type Rejection = http::StatusCode; + type Rejection = ErrorResponseStatus<'static>; async fn from_request_parts( req: &mut Parts, @@ -30,23 +34,39 @@ where let auth_header = req .headers .get(http::header::AUTHORIZATION) - .ok_or(http::StatusCode::FORBIDDEN)?; - let auth_header = auth_header - .to_str() - .map_err(|_| http::StatusCode::FORBIDDEN)?; - let (typ, token) = auth_header - .split_once(' ') - .ok_or(http::StatusCode::FORBIDDEN)?; + .ok_or_else(|| { + ErrorResponse::reply("missing authorization header") + .with_status(http::StatusCode::BAD_REQUEST) + })?; + let auth_header = auth_header.to_str().map_err(|_| { + ErrorResponse::reply("invalid authorization header encoding") + .with_status(http::StatusCode::BAD_REQUEST) + })?; + let (typ, token) = auth_header.split_once(' ').ok_or_else(|| { + ErrorResponse::reply("invalid authorization header encoding") + .with_status(http::StatusCode::BAD_REQUEST) + })?; if typ != "Token" { - return Err(http::StatusCode::FORBIDDEN); + return Err( + ErrorResponse::reply("invalid authorization header encoding") + .with_status(http::StatusCode::BAD_REQUEST), + ); } let user = state .database .get_session_user(token) .await - .map_err(|_| http::StatusCode::FORBIDDEN)?; + .map_err(|e| match e { + DbError::NotFound => ErrorResponse::reply("session not found") + .with_status(http::StatusCode::FORBIDDEN), + DbError::Other(e) => { + tracing::error!(error = ?e, "could not query user session"); + ErrorResponse::reply("could not query user session") + .with_status(http::StatusCode::INTERNAL_SERVER_ERROR) + } + })?; Ok(UserAuth(user)) }