Add support for generic database in AppState (#711)
This commit is contained in:
parent
7e7dd63966
commit
dcfad9a90d
3 changed files with 35 additions and 32 deletions
|
@ -18,11 +18,11 @@ use crate::{
|
||||||
use atuin_common::api::*;
|
use atuin_common::api::*;
|
||||||
|
|
||||||
#[instrument(skip_all, fields(user.id = user.id))]
|
#[instrument(skip_all, fields(user.id = user.id))]
|
||||||
pub async fn count(
|
pub async fn count<DB: Database>(
|
||||||
user: User,
|
user: User,
|
||||||
state: State<AppState>,
|
state: State<AppState<DB>>,
|
||||||
) -> Result<Json<CountResponse>, ErrorResponseStatus<'static>> {
|
) -> Result<Json<CountResponse>, ErrorResponseStatus<'static>> {
|
||||||
let db = &state.0.postgres;
|
let db = &state.0.database;
|
||||||
match db.count_history_cached(&user).await {
|
match db.count_history_cached(&user).await {
|
||||||
// By default read out the cached value
|
// By default read out the cached value
|
||||||
Ok(count) => Ok(Json(CountResponse { count })),
|
Ok(count) => Ok(Json(CountResponse { count })),
|
||||||
|
@ -38,12 +38,12 @@ pub async fn count(
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip_all, fields(user.id = user.id))]
|
#[instrument(skip_all, fields(user.id = user.id))]
|
||||||
pub async fn list(
|
pub async fn list<DB: Database>(
|
||||||
req: Query<SyncHistoryRequest>,
|
req: Query<SyncHistoryRequest>,
|
||||||
user: User,
|
user: User,
|
||||||
state: State<AppState>,
|
state: State<AppState<DB>>,
|
||||||
) -> Result<Json<SyncHistoryResponse>, ErrorResponseStatus<'static>> {
|
) -> Result<Json<SyncHistoryResponse>, ErrorResponseStatus<'static>> {
|
||||||
let db = &state.0.postgres;
|
let db = &state.0.database;
|
||||||
let history = db
|
let history = db
|
||||||
.list_history(
|
.list_history(
|
||||||
&user,
|
&user,
|
||||||
|
@ -75,9 +75,9 @@ pub async fn list(
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip_all, fields(user.id = user.id))]
|
#[instrument(skip_all, fields(user.id = user.id))]
|
||||||
pub async fn add(
|
pub async fn add<DB: Database>(
|
||||||
user: User,
|
user: User,
|
||||||
state: State<AppState>,
|
state: State<AppState<DB>>,
|
||||||
Json(req): Json<Vec<AddHistoryRequest>>,
|
Json(req): Json<Vec<AddHistoryRequest>>,
|
||||||
) -> Result<(), ErrorResponseStatus<'static>> {
|
) -> Result<(), ErrorResponseStatus<'static>> {
|
||||||
debug!("request to add {} history items", req.len());
|
debug!("request to add {} history items", req.len());
|
||||||
|
@ -93,7 +93,7 @@ pub async fn add(
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let db = &state.0.postgres;
|
let db = &state.0.database;
|
||||||
if let Err(e) = db.add_history(&history).await {
|
if let Err(e) = db.add_history(&history).await {
|
||||||
error!("failed to add history: {}", e);
|
error!("failed to add history: {}", e);
|
||||||
|
|
||||||
|
@ -105,18 +105,18 @@ pub async fn add(
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip_all, fields(user.id = user.id))]
|
#[instrument(skip_all, fields(user.id = user.id))]
|
||||||
pub async fn calendar(
|
pub async fn calendar<DB: Database>(
|
||||||
Path(focus): Path<String>,
|
Path(focus): Path<String>,
|
||||||
Query(params): Query<HashMap<String, u64>>,
|
Query(params): Query<HashMap<String, u64>>,
|
||||||
user: User,
|
user: User,
|
||||||
state: State<AppState>,
|
state: State<AppState<DB>>,
|
||||||
) -> Result<Json<HashMap<u64, TimePeriodInfo>>, ErrorResponseStatus<'static>> {
|
) -> Result<Json<HashMap<u64, TimePeriodInfo>>, ErrorResponseStatus<'static>> {
|
||||||
let focus = focus.as_str();
|
let focus = focus.as_str();
|
||||||
|
|
||||||
let year = params.get("year").unwrap_or(&0);
|
let year = params.get("year").unwrap_or(&0);
|
||||||
let month = params.get("month").unwrap_or(&1);
|
let month = params.get("month").unwrap_or(&1);
|
||||||
|
|
||||||
let db = &state.0.postgres;
|
let db = &state.0.database;
|
||||||
let focus = match focus {
|
let focus = match focus {
|
||||||
"year" => db
|
"year" => db
|
||||||
.calendar(&user, TimePeriod::YEAR, *year, *month)
|
.calendar(&user, TimePeriod::YEAR, *year, *month)
|
||||||
|
|
|
@ -34,11 +34,11 @@ pub fn verify_str(secret: &str, verify: &str) -> bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip_all, fields(user.username = username.as_str()))]
|
#[instrument(skip_all, fields(user.username = username.as_str()))]
|
||||||
pub async fn get(
|
pub async fn get<DB: Database>(
|
||||||
Path(username): Path<String>,
|
Path(username): Path<String>,
|
||||||
state: State<AppState>,
|
state: State<AppState<DB>>,
|
||||||
) -> Result<Json<UserResponse>, ErrorResponseStatus<'static>> {
|
) -> Result<Json<UserResponse>, ErrorResponseStatus<'static>> {
|
||||||
let db = &state.0.postgres;
|
let db = &state.0.database;
|
||||||
let user = match db.get_user(username.as_ref()).await {
|
let user = match db.get_user(username.as_ref()).await {
|
||||||
Ok(user) => user,
|
Ok(user) => user,
|
||||||
Err(sqlx::Error::RowNotFound) => {
|
Err(sqlx::Error::RowNotFound) => {
|
||||||
|
@ -58,9 +58,9 @@ pub async fn get(
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
pub async fn register(
|
pub async fn register<DB: Database>(
|
||||||
settings: Extension<Settings>,
|
settings: Extension<Settings>,
|
||||||
state: State<AppState>,
|
state: State<AppState<DB>>,
|
||||||
Json(register): Json<RegisterRequest>,
|
Json(register): Json<RegisterRequest>,
|
||||||
) -> Result<Json<RegisterResponse>, ErrorResponseStatus<'static>> {
|
) -> Result<Json<RegisterResponse>, ErrorResponseStatus<'static>> {
|
||||||
if !settings.open_registration {
|
if !settings.open_registration {
|
||||||
|
@ -78,7 +78,7 @@ pub async fn register(
|
||||||
password: hashed,
|
password: hashed,
|
||||||
};
|
};
|
||||||
|
|
||||||
let db = &state.0.postgres;
|
let db = &state.0.database;
|
||||||
let user_id = match db.add_user(&new_user).await {
|
let user_id = match db.add_user(&new_user).await {
|
||||||
Ok(id) => id,
|
Ok(id) => id,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
@ -107,11 +107,11 @@ pub async fn register(
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip_all, fields(user.username = login.username.as_str()))]
|
#[instrument(skip_all, fields(user.username = login.username.as_str()))]
|
||||||
pub async fn login(
|
pub async fn login<DB: Database>(
|
||||||
state: State<AppState>,
|
state: State<AppState<DB>>,
|
||||||
login: Json<LoginRequest>,
|
login: Json<LoginRequest>,
|
||||||
) -> Result<Json<LoginResponse>, ErrorResponseStatus<'static>> {
|
) -> Result<Json<LoginResponse>, ErrorResponseStatus<'static>> {
|
||||||
let db = &state.0.postgres;
|
let db = &state.0.database;
|
||||||
let user = match db.get_user(login.username.borrow()).await {
|
let user = match db.get_user(login.username.borrow()).await {
|
||||||
Ok(u) => u,
|
Ok(u) => u,
|
||||||
Err(sqlx::Error::RowNotFound) => {
|
Err(sqlx::Error::RowNotFound) => {
|
||||||
|
|
|
@ -10,19 +10,19 @@ use http::request::Parts;
|
||||||
use tower::ServiceBuilder;
|
use tower::ServiceBuilder;
|
||||||
use tower_http::trace::TraceLayer;
|
use tower_http::trace::TraceLayer;
|
||||||
|
|
||||||
use super::{
|
use super::{database::Database, handlers};
|
||||||
database::{Database, Postgres},
|
|
||||||
handlers,
|
|
||||||
};
|
|
||||||
use crate::{models::User, settings::Settings};
|
use crate::{models::User, settings::Settings};
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl FromRequestParts<AppState> for User {
|
impl<DB: Send + Sync> FromRequestParts<AppState<DB>> for User
|
||||||
|
where
|
||||||
|
DB: Database,
|
||||||
|
{
|
||||||
type Rejection = http::StatusCode;
|
type Rejection = http::StatusCode;
|
||||||
|
|
||||||
async fn from_request_parts(
|
async fn from_request_parts(
|
||||||
req: &mut Parts,
|
req: &mut Parts,
|
||||||
state: &AppState,
|
state: &AppState<DB>,
|
||||||
) -> Result<Self, Self::Rejection> {
|
) -> Result<Self, Self::Rejection> {
|
||||||
let auth_header = req
|
let auth_header = req
|
||||||
.headers
|
.headers
|
||||||
|
@ -40,7 +40,7 @@ impl FromRequestParts<AppState> for User {
|
||||||
}
|
}
|
||||||
|
|
||||||
let user = state
|
let user = state
|
||||||
.postgres
|
.database
|
||||||
.get_session_user(token)
|
.get_session_user(token)
|
||||||
.await
|
.await
|
||||||
.map_err(|_| http::StatusCode::FORBIDDEN)?;
|
.map_err(|_| http::StatusCode::FORBIDDEN)?;
|
||||||
|
@ -54,12 +54,15 @@ async fn teapot() -> impl IntoResponse {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct AppState {
|
pub struct AppState<DB> {
|
||||||
pub postgres: Postgres,
|
pub database: DB,
|
||||||
pub settings: Settings,
|
pub settings: Settings,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn router(postgres: Postgres, settings: Settings) -> Router {
|
pub fn router<DB: Database + Clone + Send + Sync + 'static>(
|
||||||
|
database: DB,
|
||||||
|
settings: Settings,
|
||||||
|
) -> Router {
|
||||||
let routes = Router::new()
|
let routes = Router::new()
|
||||||
.route("/", get(handlers::index))
|
.route("/", get(handlers::index))
|
||||||
.route("/sync/count", get(handlers::history::count))
|
.route("/sync/count", get(handlers::history::count))
|
||||||
|
@ -77,6 +80,6 @@ pub fn router(postgres: Postgres, settings: Settings) -> Router {
|
||||||
Router::new().nest(path, routes)
|
Router::new().nest(path, routes)
|
||||||
}
|
}
|
||||||
.fallback(teapot)
|
.fallback(teapot)
|
||||||
.with_state(AppState { postgres, settings })
|
.with_state(AppState { database, settings })
|
||||||
.layer(ServiceBuilder::new().layer(TraceLayer::new_for_http()))
|
.layer(ServiceBuilder::new().layer(TraceLayer::new_for_http()))
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue