diff --git a/atuin-server/src/database.rs b/atuin-server/src/database.rs index 9043c2d..e163d3a 100644 --- a/atuin-server/src/database.rs +++ b/atuin-server/src/database.rs @@ -1,8 +1,7 @@ use async_trait::async_trait; use std::collections::HashMap; -use eyre::{eyre, Result}; -use sqlx::postgres::PgPoolOptions; +use sqlx::{postgres::PgPoolOptions, Result}; use crate::settings::HISTORY_PAGE_SIZE; @@ -25,6 +24,7 @@ pub trait Database { async fn add_user(&self, user: &NewUser) -> Result; async fn count_history(&self, user: &User) -> Result; + async fn count_history_cached(&self, user: &User) -> Result; async fn count_history_range( &self, @@ -63,7 +63,7 @@ pub struct Postgres { } impl Postgres { - pub async fn new(uri: &str) -> Result { + pub async fn new(uri: &str) -> Result { let pool = PgPoolOptions::new() .max_connections(100) .connect(uri) @@ -78,52 +78,36 @@ impl Postgres { #[async_trait] impl Database for Postgres { async fn get_session(&self, token: &str) -> Result { - let res: Option = - sqlx::query_as::<_, Session>("select * from sessions where token = $1") - .bind(token) - .fetch_optional(&self.pool) - .await?; - - if let Some(s) = res { - Ok(s) - } else { - Err(eyre!("could not find session")) - } + sqlx::query_as::<_, Session>("select * from sessions where token = $1") + .bind(token) + .fetch_one(&self.pool) + .await } async fn get_user(&self, username: &str) -> Result { - let res: Option = - sqlx::query_as::<_, User>("select * from users where username = $1") - .bind(username) - .fetch_optional(&self.pool) - .await?; - - if let Some(u) = res { - Ok(u) - } else { - Err(eyre!("could not find user")) - } + sqlx::query_as::<_, User>("select * from users where username = $1") + .bind(username) + .fetch_one(&self.pool) + .await } async fn get_session_user(&self, token: &str) -> Result { - let res: Option = sqlx::query_as::<_, User>( + sqlx::query_as::<_, User>( "select * from users inner join sessions on users.id = sessions.user_id and sessions.token = $1", ) .bind(token) - .fetch_optional(&self.pool) - .await?; - - if let Some(u) = res { - Ok(u) - } else { - Err(eyre!("could not find user")) - } + .fetch_one(&self.pool) + .await } async fn count_history(&self, user: &User) -> Result { + // The cache is new, and the user might not yet have a cache value. + // They will have one as soon as they post up some new history, but handle that + // edge case. + let res: (i64,) = sqlx::query_as( "select count(1) from history where user_id = $1", @@ -135,6 +119,18 @@ impl Database for Postgres { Ok(res.0) } + async fn count_history_cached(&self, user: &User) -> Result { + let res: (i64,) = sqlx::query_as( + "select total from total_history_count_user + where user_id = $1", + ) + .bind(user.id) + .fetch_one(&self.pool) + .await?; + + Ok(res.0) + } + async fn count_history_range( &self, user: &User, @@ -300,17 +296,10 @@ impl Database for Postgres { } async fn get_user_session(&self, u: &User) -> Result { - let res: Option = - sqlx::query_as::<_, Session>("select * from sessions where user_id = $1") - .bind(u.id) - .fetch_optional(&self.pool) - .await?; - - if let Some(s) = res { - Ok(s) - } else { - Err(eyre!("could not find session")) - } + sqlx::query_as::<_, Session>("select * from sessions where user_id = $1") + .bind(u.id) + .fetch_one(&self.pool) + .await } async fn oldest_history(&self, user: &User) -> Result { diff --git a/atuin-server/src/handlers/history.rs b/atuin-server/src/handlers/history.rs index fde7cf2..4fa2a96 100644 --- a/atuin-server/src/handlers/history.rs +++ b/atuin-server/src/handlers/history.rs @@ -13,10 +13,17 @@ pub async fn count( user: User, db: Extension, ) -> Result, ErrorResponseStatus<'static>> { - match db.count_history(&user).await { + match db.count_history_cached(&user).await { + // By default read out the cached value Ok(count) => Ok(Json(CountResponse { count })), - Err(_) => Err(ErrorResponse::reply("failed to query history count") - .with_status(StatusCode::INTERNAL_SERVER_ERROR)), + + // If that fails, fallback on a full COUNT. Cache is built on a POST + // only + Err(_) => match db.count_history(&user).await { + Ok(count) => Ok(Json(CountResponse { count })), + Err(_) => Err(ErrorResponse::reply("failed to query history count") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)), + }, } } diff --git a/atuin-server/src/handlers/user.rs b/atuin-server/src/handlers/user.rs index 1bcfce2..42e4aa3 100644 --- a/atuin-server/src/handlers/user.rs +++ b/atuin-server/src/handlers/user.rs @@ -32,10 +32,15 @@ pub async fn get( ) -> Result, ErrorResponseStatus<'static>> { let user = match db.get_user(username.as_ref()).await { Ok(user) => user, - Err(e) => { - debug!("user not found: {}", e); + Err(sqlx::Error::RowNotFound) => { + debug!("user not found: {}", username); return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND)); } + Err(err) => { + error!("database error: {}", err); + return Err(ErrorResponse::reply("database error") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + } }; Ok(Json(UserResponse { @@ -96,20 +101,28 @@ pub async fn login( ) -> Result, ErrorResponseStatus<'static>> { let user = match db.get_user(login.username.borrow()).await { Ok(u) => u, + Err(sqlx::Error::RowNotFound) => { + return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND)); + } Err(e) => { error!("failed to get user {}: {}", login.username.clone(), e); - return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND)); + return Err(ErrorResponse::reply("database error") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); } }; let session = match db.get_user_session(&user).await { Ok(u) => u, - Err(e) => { - error!("failed to get session for {}: {}", login.username, e); - + Err(sqlx::Error::RowNotFound) => { + debug!("user session not found for user id={}", user.id); return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND)); } + Err(err) => { + error!("database error for user {}: {}", login.username, err); + return Err(ErrorResponse::reply("database error") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + } }; let verified = verify_str(user.password.as_str(), login.password.borrow());