Add support for generic database in AppState (#711)

This commit is contained in:
Erwin Kroon 2023-02-15 09:54:09 +01:00 committed by GitHub
parent 7e7dd63966
commit dcfad9a90d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 35 additions and 32 deletions

View file

@ -18,11 +18,11 @@ use crate::{
use atuin_common::api::*; use atuin_common::api::*;
#[instrument(skip_all, fields(user.id = user.id))] #[instrument(skip_all, fields(user.id = user.id))]
pub async fn count( pub async fn count<DB: Database>(
user: User, user: User,
state: State<AppState>, state: State<AppState<DB>>,
) -> Result<Json<CountResponse>, ErrorResponseStatus<'static>> { ) -> Result<Json<CountResponse>, ErrorResponseStatus<'static>> {
let db = &state.0.postgres; let db = &state.0.database;
match db.count_history_cached(&user).await { match db.count_history_cached(&user).await {
// By default read out the cached value // By default read out the cached value
Ok(count) => Ok(Json(CountResponse { count })), Ok(count) => Ok(Json(CountResponse { count })),
@ -38,12 +38,12 @@ pub async fn count(
} }
#[instrument(skip_all, fields(user.id = user.id))] #[instrument(skip_all, fields(user.id = user.id))]
pub async fn list( pub async fn list<DB: Database>(
req: Query<SyncHistoryRequest>, req: Query<SyncHistoryRequest>,
user: User, user: User,
state: State<AppState>, state: State<AppState<DB>>,
) -> Result<Json<SyncHistoryResponse>, ErrorResponseStatus<'static>> { ) -> Result<Json<SyncHistoryResponse>, ErrorResponseStatus<'static>> {
let db = &state.0.postgres; let db = &state.0.database;
let history = db let history = db
.list_history( .list_history(
&user, &user,
@ -75,9 +75,9 @@ pub async fn list(
} }
#[instrument(skip_all, fields(user.id = user.id))] #[instrument(skip_all, fields(user.id = user.id))]
pub async fn add( pub async fn add<DB: Database>(
user: User, user: User,
state: State<AppState>, state: State<AppState<DB>>,
Json(req): Json<Vec<AddHistoryRequest>>, Json(req): Json<Vec<AddHistoryRequest>>,
) -> Result<(), ErrorResponseStatus<'static>> { ) -> Result<(), ErrorResponseStatus<'static>> {
debug!("request to add {} history items", req.len()); debug!("request to add {} history items", req.len());
@ -93,7 +93,7 @@ pub async fn add(
}) })
.collect(); .collect();
let db = &state.0.postgres; let db = &state.0.database;
if let Err(e) = db.add_history(&history).await { if let Err(e) = db.add_history(&history).await {
error!("failed to add history: {}", e); error!("failed to add history: {}", e);
@ -105,18 +105,18 @@ pub async fn add(
} }
#[instrument(skip_all, fields(user.id = user.id))] #[instrument(skip_all, fields(user.id = user.id))]
pub async fn calendar( pub async fn calendar<DB: Database>(
Path(focus): Path<String>, Path(focus): Path<String>,
Query(params): Query<HashMap<String, u64>>, Query(params): Query<HashMap<String, u64>>,
user: User, user: User,
state: State<AppState>, state: State<AppState<DB>>,
) -> Result<Json<HashMap<u64, TimePeriodInfo>>, ErrorResponseStatus<'static>> { ) -> Result<Json<HashMap<u64, TimePeriodInfo>>, ErrorResponseStatus<'static>> {
let focus = focus.as_str(); let focus = focus.as_str();
let year = params.get("year").unwrap_or(&0); let year = params.get("year").unwrap_or(&0);
let month = params.get("month").unwrap_or(&1); let month = params.get("month").unwrap_or(&1);
let db = &state.0.postgres; let db = &state.0.database;
let focus = match focus { let focus = match focus {
"year" => db "year" => db
.calendar(&user, TimePeriod::YEAR, *year, *month) .calendar(&user, TimePeriod::YEAR, *year, *month)

View file

@ -34,11 +34,11 @@ pub fn verify_str(secret: &str, verify: &str) -> bool {
} }
#[instrument(skip_all, fields(user.username = username.as_str()))] #[instrument(skip_all, fields(user.username = username.as_str()))]
pub async fn get( pub async fn get<DB: Database>(
Path(username): Path<String>, Path(username): Path<String>,
state: State<AppState>, state: State<AppState<DB>>,
) -> Result<Json<UserResponse>, ErrorResponseStatus<'static>> { ) -> Result<Json<UserResponse>, ErrorResponseStatus<'static>> {
let db = &state.0.postgres; let db = &state.0.database;
let user = match db.get_user(username.as_ref()).await { let user = match db.get_user(username.as_ref()).await {
Ok(user) => user, Ok(user) => user,
Err(sqlx::Error::RowNotFound) => { Err(sqlx::Error::RowNotFound) => {
@ -58,9 +58,9 @@ pub async fn get(
} }
#[instrument(skip_all)] #[instrument(skip_all)]
pub async fn register( pub async fn register<DB: Database>(
settings: Extension<Settings>, settings: Extension<Settings>,
state: State<AppState>, state: State<AppState<DB>>,
Json(register): Json<RegisterRequest>, Json(register): Json<RegisterRequest>,
) -> Result<Json<RegisterResponse>, ErrorResponseStatus<'static>> { ) -> Result<Json<RegisterResponse>, ErrorResponseStatus<'static>> {
if !settings.open_registration { if !settings.open_registration {
@ -78,7 +78,7 @@ pub async fn register(
password: hashed, password: hashed,
}; };
let db = &state.0.postgres; let db = &state.0.database;
let user_id = match db.add_user(&new_user).await { let user_id = match db.add_user(&new_user).await {
Ok(id) => id, Ok(id) => id,
Err(e) => { Err(e) => {
@ -107,11 +107,11 @@ pub async fn register(
} }
#[instrument(skip_all, fields(user.username = login.username.as_str()))] #[instrument(skip_all, fields(user.username = login.username.as_str()))]
pub async fn login( pub async fn login<DB: Database>(
state: State<AppState>, state: State<AppState<DB>>,
login: Json<LoginRequest>, login: Json<LoginRequest>,
) -> Result<Json<LoginResponse>, ErrorResponseStatus<'static>> { ) -> Result<Json<LoginResponse>, ErrorResponseStatus<'static>> {
let db = &state.0.postgres; let db = &state.0.database;
let user = match db.get_user(login.username.borrow()).await { let user = match db.get_user(login.username.borrow()).await {
Ok(u) => u, Ok(u) => u,
Err(sqlx::Error::RowNotFound) => { Err(sqlx::Error::RowNotFound) => {

View file

@ -10,19 +10,19 @@ use http::request::Parts;
use tower::ServiceBuilder; use tower::ServiceBuilder;
use tower_http::trace::TraceLayer; use tower_http::trace::TraceLayer;
use super::{ use super::{database::Database, handlers};
database::{Database, Postgres},
handlers,
};
use crate::{models::User, settings::Settings}; use crate::{models::User, settings::Settings};
#[async_trait] #[async_trait]
impl FromRequestParts<AppState> for User { impl<DB: Send + Sync> FromRequestParts<AppState<DB>> for User
where
DB: Database,
{
type Rejection = http::StatusCode; type Rejection = http::StatusCode;
async fn from_request_parts( async fn from_request_parts(
req: &mut Parts, req: &mut Parts,
state: &AppState, state: &AppState<DB>,
) -> Result<Self, Self::Rejection> { ) -> Result<Self, Self::Rejection> {
let auth_header = req let auth_header = req
.headers .headers
@ -40,7 +40,7 @@ impl FromRequestParts<AppState> for User {
} }
let user = state let user = state
.postgres .database
.get_session_user(token) .get_session_user(token)
.await .await
.map_err(|_| http::StatusCode::FORBIDDEN)?; .map_err(|_| http::StatusCode::FORBIDDEN)?;
@ -54,12 +54,15 @@ async fn teapot() -> impl IntoResponse {
} }
#[derive(Clone)] #[derive(Clone)]
pub struct AppState { pub struct AppState<DB> {
pub postgres: Postgres, pub database: DB,
pub settings: Settings, pub settings: Settings,
} }
pub fn router(postgres: Postgres, settings: Settings) -> Router { pub fn router<DB: Database + Clone + Send + Sync + 'static>(
database: DB,
settings: Settings,
) -> Router {
let routes = Router::new() let routes = Router::new()
.route("/", get(handlers::index)) .route("/", get(handlers::index))
.route("/sync/count", get(handlers::history::count)) .route("/sync/count", get(handlers::history::count))
@ -77,6 +80,6 @@ pub fn router(postgres: Postgres, settings: Settings) -> Router {
Router::new().nest(path, routes) Router::new().nest(path, routes)
} }
.fallback(teapot) .fallback(teapot)
.with_state(AppState { postgres, settings }) .with_state(AppState { database, settings })
.layer(ServiceBuilder::new().layer(TraceLayer::new_for_http())) .layer(ServiceBuilder::new().layer(TraceLayer::new_for_http()))
} }