Add support for generic database in AppState (#711)
This commit is contained in:
parent
7e7dd63966
commit
dcfad9a90d
3 changed files with 35 additions and 32 deletions
|
@ -18,11 +18,11 @@ use crate::{
|
|||
use atuin_common::api::*;
|
||||
|
||||
#[instrument(skip_all, fields(user.id = user.id))]
|
||||
pub async fn count(
|
||||
pub async fn count<DB: Database>(
|
||||
user: User,
|
||||
state: State<AppState>,
|
||||
state: State<AppState<DB>>,
|
||||
) -> Result<Json<CountResponse>, ErrorResponseStatus<'static>> {
|
||||
let db = &state.0.postgres;
|
||||
let db = &state.0.database;
|
||||
match db.count_history_cached(&user).await {
|
||||
// By default read out the cached value
|
||||
Ok(count) => Ok(Json(CountResponse { count })),
|
||||
|
@ -38,12 +38,12 @@ pub async fn count(
|
|||
}
|
||||
|
||||
#[instrument(skip_all, fields(user.id = user.id))]
|
||||
pub async fn list(
|
||||
pub async fn list<DB: Database>(
|
||||
req: Query<SyncHistoryRequest>,
|
||||
user: User,
|
||||
state: State<AppState>,
|
||||
state: State<AppState<DB>>,
|
||||
) -> Result<Json<SyncHistoryResponse>, ErrorResponseStatus<'static>> {
|
||||
let db = &state.0.postgres;
|
||||
let db = &state.0.database;
|
||||
let history = db
|
||||
.list_history(
|
||||
&user,
|
||||
|
@ -75,9 +75,9 @@ pub async fn list(
|
|||
}
|
||||
|
||||
#[instrument(skip_all, fields(user.id = user.id))]
|
||||
pub async fn add(
|
||||
pub async fn add<DB: Database>(
|
||||
user: User,
|
||||
state: State<AppState>,
|
||||
state: State<AppState<DB>>,
|
||||
Json(req): Json<Vec<AddHistoryRequest>>,
|
||||
) -> Result<(), ErrorResponseStatus<'static>> {
|
||||
debug!("request to add {} history items", req.len());
|
||||
|
@ -93,7 +93,7 @@ pub async fn add(
|
|||
})
|
||||
.collect();
|
||||
|
||||
let db = &state.0.postgres;
|
||||
let db = &state.0.database;
|
||||
if let Err(e) = db.add_history(&history).await {
|
||||
error!("failed to add history: {}", e);
|
||||
|
||||
|
@ -105,18 +105,18 @@ pub async fn add(
|
|||
}
|
||||
|
||||
#[instrument(skip_all, fields(user.id = user.id))]
|
||||
pub async fn calendar(
|
||||
pub async fn calendar<DB: Database>(
|
||||
Path(focus): Path<String>,
|
||||
Query(params): Query<HashMap<String, u64>>,
|
||||
user: User,
|
||||
state: State<AppState>,
|
||||
state: State<AppState<DB>>,
|
||||
) -> Result<Json<HashMap<u64, TimePeriodInfo>>, ErrorResponseStatus<'static>> {
|
||||
let focus = focus.as_str();
|
||||
|
||||
let year = params.get("year").unwrap_or(&0);
|
||||
let month = params.get("month").unwrap_or(&1);
|
||||
|
||||
let db = &state.0.postgres;
|
||||
let db = &state.0.database;
|
||||
let focus = match focus {
|
||||
"year" => db
|
||||
.calendar(&user, TimePeriod::YEAR, *year, *month)
|
||||
|
|
|
@ -34,11 +34,11 @@ pub fn verify_str(secret: &str, verify: &str) -> bool {
|
|||
}
|
||||
|
||||
#[instrument(skip_all, fields(user.username = username.as_str()))]
|
||||
pub async fn get(
|
||||
pub async fn get<DB: Database>(
|
||||
Path(username): Path<String>,
|
||||
state: State<AppState>,
|
||||
state: State<AppState<DB>>,
|
||||
) -> Result<Json<UserResponse>, ErrorResponseStatus<'static>> {
|
||||
let db = &state.0.postgres;
|
||||
let db = &state.0.database;
|
||||
let user = match db.get_user(username.as_ref()).await {
|
||||
Ok(user) => user,
|
||||
Err(sqlx::Error::RowNotFound) => {
|
||||
|
@ -58,9 +58,9 @@ pub async fn get(
|
|||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub async fn register(
|
||||
pub async fn register<DB: Database>(
|
||||
settings: Extension<Settings>,
|
||||
state: State<AppState>,
|
||||
state: State<AppState<DB>>,
|
||||
Json(register): Json<RegisterRequest>,
|
||||
) -> Result<Json<RegisterResponse>, ErrorResponseStatus<'static>> {
|
||||
if !settings.open_registration {
|
||||
|
@ -78,7 +78,7 @@ pub async fn register(
|
|||
password: hashed,
|
||||
};
|
||||
|
||||
let db = &state.0.postgres;
|
||||
let db = &state.0.database;
|
||||
let user_id = match db.add_user(&new_user).await {
|
||||
Ok(id) => id,
|
||||
Err(e) => {
|
||||
|
@ -107,11 +107,11 @@ pub async fn register(
|
|||
}
|
||||
|
||||
#[instrument(skip_all, fields(user.username = login.username.as_str()))]
|
||||
pub async fn login(
|
||||
state: State<AppState>,
|
||||
pub async fn login<DB: Database>(
|
||||
state: State<AppState<DB>>,
|
||||
login: Json<LoginRequest>,
|
||||
) -> Result<Json<LoginResponse>, ErrorResponseStatus<'static>> {
|
||||
let db = &state.0.postgres;
|
||||
let db = &state.0.database;
|
||||
let user = match db.get_user(login.username.borrow()).await {
|
||||
Ok(u) => u,
|
||||
Err(sqlx::Error::RowNotFound) => {
|
||||
|
|
|
@ -10,19 +10,19 @@ use http::request::Parts;
|
|||
use tower::ServiceBuilder;
|
||||
use tower_http::trace::TraceLayer;
|
||||
|
||||
use super::{
|
||||
database::{Database, Postgres},
|
||||
handlers,
|
||||
};
|
||||
use super::{database::Database, handlers};
|
||||
use crate::{models::User, settings::Settings};
|
||||
|
||||
#[async_trait]
|
||||
impl FromRequestParts<AppState> for User {
|
||||
impl<DB: Send + Sync> FromRequestParts<AppState<DB>> for User
|
||||
where
|
||||
DB: Database,
|
||||
{
|
||||
type Rejection = http::StatusCode;
|
||||
|
||||
async fn from_request_parts(
|
||||
req: &mut Parts,
|
||||
state: &AppState,
|
||||
state: &AppState<DB>,
|
||||
) -> Result<Self, Self::Rejection> {
|
||||
let auth_header = req
|
||||
.headers
|
||||
|
@ -40,7 +40,7 @@ impl FromRequestParts<AppState> for User {
|
|||
}
|
||||
|
||||
let user = state
|
||||
.postgres
|
||||
.database
|
||||
.get_session_user(token)
|
||||
.await
|
||||
.map_err(|_| http::StatusCode::FORBIDDEN)?;
|
||||
|
@ -54,12 +54,15 @@ async fn teapot() -> impl IntoResponse {
|
|||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub postgres: Postgres,
|
||||
pub struct AppState<DB> {
|
||||
pub database: DB,
|
||||
pub settings: Settings,
|
||||
}
|
||||
|
||||
pub fn router(postgres: Postgres, settings: Settings) -> Router {
|
||||
pub fn router<DB: Database + Clone + Send + Sync + 'static>(
|
||||
database: DB,
|
||||
settings: Settings,
|
||||
) -> Router {
|
||||
let routes = Router::new()
|
||||
.route("/", get(handlers::index))
|
||||
.route("/sync/count", get(handlers::history::count))
|
||||
|
@ -77,6 +80,6 @@ pub fn router(postgres: Postgres, settings: Settings) -> Router {
|
|||
Router::new().nest(path, routes)
|
||||
}
|
||||
.fallback(teapot)
|
||||
.with_state(AppState { postgres, settings })
|
||||
.with_state(AppState { database, settings })
|
||||
.layer(ServiceBuilder::new().layer(TraceLayer::new_for_http()))
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue