axum6 with typesafe state (#674)
This commit is contained in:
parent
ec24437735
commit
0acdb99eb3
5 changed files with 82 additions and 47 deletions
42
Cargo.lock
generated
42
Cargo.lock
generated
|
@ -158,7 +158,7 @@ dependencies = [
|
|||
"async-trait",
|
||||
"atuin-common",
|
||||
"axum",
|
||||
"base64 0.20.0",
|
||||
"base64 0.21.0",
|
||||
"chrono",
|
||||
"chronoutil",
|
||||
"config",
|
||||
|
@ -186,9 +186,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
|
|||
|
||||
[[package]]
|
||||
name = "axum"
|
||||
version = "0.5.16"
|
||||
version = "0.6.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c9e3356844c4d6a6d6467b8da2cffb4a2820be256f50a3a386c9d152bab31043"
|
||||
checksum = "e5694b64066a2459918d8074c2ce0d5a88f409431994c2356617c8ae0c4721fc"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"axum-core",
|
||||
|
@ -204,8 +204,10 @@ dependencies = [
|
|||
"mime",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"rustversion",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_path_to_error",
|
||||
"serde_urlencoded",
|
||||
"sync_wrapper",
|
||||
"tokio",
|
||||
|
@ -217,9 +219,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "axum-core"
|
||||
version = "0.2.8"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d9f0c0a60006f2a293d82d571f635042a72edf927539b7685bd62d361963839b"
|
||||
checksum = "1cae3e661676ffbacb30f1a824089a8c9150e71017f7e1e38f2aa32009188d34"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"bytes",
|
||||
|
@ -227,6 +229,7 @@ dependencies = [
|
|||
"http",
|
||||
"http-body",
|
||||
"mime",
|
||||
"rustversion",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
]
|
||||
|
@ -243,6 +246,12 @@ version = "0.20.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0ea22880d78093b0cbe17c89f64a7d457941e65759157ec6cb31a31d652b05e5"
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.21.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a4a4ddaa51a5bc52a6948f74c06d20aaaddb71924eab79b8c97a8c556e942d6a"
|
||||
|
||||
[[package]]
|
||||
name = "beef"
|
||||
version = "0.5.2"
|
||||
|
@ -1123,9 +1132,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "matchit"
|
||||
version = "0.5.0"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "73cbba799671b762df5a175adf59ce145165747bb891505c43d09aefbbf38beb"
|
||||
checksum = "b87248edafb776e59e6ee64a79086f65890d3510f2c656c000bf2a7e8a0aea40"
|
||||
|
||||
[[package]]
|
||||
name = "md-5"
|
||||
|
@ -1720,6 +1729,12 @@ dependencies = [
|
|||
"base64 0.13.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustversion"
|
||||
version = "1.0.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5583e89e108996506031660fe09baa5011b9dd0341b89029313006d1fb508d70"
|
||||
|
||||
[[package]]
|
||||
name = "ryu"
|
||||
version = "1.0.11"
|
||||
|
@ -1821,6 +1836,15 @@ dependencies = [
|
|||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_path_to_error"
|
||||
version = "0.1.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "26b04f22b563c91331a10074bda3dd5492e3cc39d56bd557e91c0af42b6c7341"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_urlencoded"
|
||||
version = "0.7.1"
|
||||
|
@ -2328,9 +2352,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "tower-layer"
|
||||
version = "0.3.1"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "343bc9466d3fe6b0f960ef45960509f84480bf4fd96f92901afe7ff3df9d3a62"
|
||||
checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0"
|
||||
|
||||
[[package]]
|
||||
name = "tower-service"
|
||||
|
|
|
@ -20,7 +20,7 @@ config = { version = "0.13", default-features = false, features = ["toml"] }
|
|||
serde = { version = "1.0.145", features = ["derive"] }
|
||||
serde_json = "1.0.86"
|
||||
sodiumoxide = "0.2.6"
|
||||
base64 = "0.20.0"
|
||||
base64 = "0.21.0"
|
||||
rand = "0.8.4"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
sqlx = { version = "0.6", features = [
|
||||
|
@ -29,7 +29,7 @@ sqlx = { version = "0.6", features = [
|
|||
"postgres",
|
||||
] }
|
||||
async-trait = "0.1.58"
|
||||
axum = "0.5"
|
||||
axum = "0.6.4"
|
||||
http = "0.2"
|
||||
fs-err = "2.9"
|
||||
chronoutil = "0.2.3"
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use axum::{
|
||||
extract::{Path, Query},
|
||||
Extension, Json,
|
||||
extract::{Path, Query, State},
|
||||
Json,
|
||||
};
|
||||
use http::StatusCode;
|
||||
use tracing::{debug, error, instrument};
|
||||
|
@ -10,8 +10,9 @@ use tracing::{debug, error, instrument};
|
|||
use super::{ErrorResponse, ErrorResponseStatus, RespExt};
|
||||
use crate::{
|
||||
calendar::{TimePeriod, TimePeriodInfo},
|
||||
database::{Database, Postgres},
|
||||
database::Database,
|
||||
models::{NewHistory, User},
|
||||
router::AppState,
|
||||
};
|
||||
|
||||
use atuin_common::api::*;
|
||||
|
@ -19,8 +20,9 @@ use atuin_common::api::*;
|
|||
#[instrument(skip_all, fields(user.id = user.id))]
|
||||
pub async fn count(
|
||||
user: User,
|
||||
db: Extension<Postgres>,
|
||||
state: State<AppState>,
|
||||
) -> Result<Json<CountResponse>, ErrorResponseStatus<'static>> {
|
||||
let db = &state.0.postgres;
|
||||
match db.count_history_cached(&user).await {
|
||||
// By default read out the cached value
|
||||
Ok(count) => Ok(Json(CountResponse { count })),
|
||||
|
@ -39,8 +41,9 @@ pub async fn count(
|
|||
pub async fn list(
|
||||
req: Query<SyncHistoryRequest>,
|
||||
user: User,
|
||||
db: Extension<Postgres>,
|
||||
state: State<AppState>,
|
||||
) -> Result<Json<SyncHistoryResponse>, ErrorResponseStatus<'static>> {
|
||||
let db = &state.0.postgres;
|
||||
let history = db
|
||||
.list_history(
|
||||
&user,
|
||||
|
@ -73,9 +76,9 @@ pub async fn list(
|
|||
|
||||
#[instrument(skip_all, fields(user.id = user.id))]
|
||||
pub async fn add(
|
||||
Json(req): Json<Vec<AddHistoryRequest>>,
|
||||
user: User,
|
||||
db: Extension<Postgres>,
|
||||
state: State<AppState>,
|
||||
Json(req): Json<Vec<AddHistoryRequest>>,
|
||||
) -> Result<(), ErrorResponseStatus<'static>> {
|
||||
debug!("request to add {} history items", req.len());
|
||||
|
||||
|
@ -90,6 +93,7 @@ pub async fn add(
|
|||
})
|
||||
.collect();
|
||||
|
||||
let db = &state.0.postgres;
|
||||
if let Err(e) = db.add_history(&history).await {
|
||||
error!("failed to add history: {}", e);
|
||||
|
||||
|
@ -105,13 +109,14 @@ pub async fn calendar(
|
|||
Path(focus): Path<String>,
|
||||
Query(params): Query<HashMap<String, u64>>,
|
||||
user: User,
|
||||
db: Extension<Postgres>,
|
||||
state: State<AppState>,
|
||||
) -> Result<Json<HashMap<u64, TimePeriodInfo>>, ErrorResponseStatus<'static>> {
|
||||
let focus = focus.as_str();
|
||||
|
||||
let year = params.get("year").unwrap_or(&0);
|
||||
let month = params.get("month").unwrap_or(&1);
|
||||
|
||||
let db = &state.0.postgres;
|
||||
let focus = match focus {
|
||||
"year" => db
|
||||
.calendar(&user, TimePeriod::YEAR, *year, *month)
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
use std::borrow::Borrow;
|
||||
|
||||
use axum::{extract::Path, Extension, Json};
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
Extension, Json,
|
||||
};
|
||||
use http::StatusCode;
|
||||
use sodiumoxide::crypto::pwhash::argon2id13;
|
||||
use tracing::{debug, error, instrument};
|
||||
|
@ -8,8 +11,9 @@ use uuid::Uuid;
|
|||
|
||||
use super::{ErrorResponse, ErrorResponseStatus, RespExt};
|
||||
use crate::{
|
||||
database::{Database, Postgres},
|
||||
database::Database,
|
||||
models::{NewSession, NewUser},
|
||||
router::AppState,
|
||||
settings::Settings,
|
||||
};
|
||||
|
||||
|
@ -32,8 +36,9 @@ pub fn verify_str(secret: &str, verify: &str) -> bool {
|
|||
#[instrument(skip_all, fields(user.username = username.as_str()))]
|
||||
pub async fn get(
|
||||
Path(username): Path<String>,
|
||||
db: Extension<Postgres>,
|
||||
state: State<AppState>,
|
||||
) -> Result<Json<UserResponse>, ErrorResponseStatus<'static>> {
|
||||
let db = &state.0.postgres;
|
||||
let user = match db.get_user(username.as_ref()).await {
|
||||
Ok(user) => user,
|
||||
Err(sqlx::Error::RowNotFound) => {
|
||||
|
@ -54,9 +59,9 @@ pub async fn get(
|
|||
|
||||
#[instrument(skip_all)]
|
||||
pub async fn register(
|
||||
Json(register): Json<RegisterRequest>,
|
||||
settings: Extension<Settings>,
|
||||
db: Extension<Postgres>,
|
||||
state: State<AppState>,
|
||||
Json(register): Json<RegisterRequest>,
|
||||
) -> Result<Json<RegisterResponse>, ErrorResponseStatus<'static>> {
|
||||
if !settings.open_registration {
|
||||
return Err(
|
||||
|
@ -73,6 +78,7 @@ pub async fn register(
|
|||
password: hashed,
|
||||
};
|
||||
|
||||
let db = &state.0.postgres;
|
||||
let user_id = match db.add_user(&new_user).await {
|
||||
Ok(id) => id,
|
||||
Err(e) => {
|
||||
|
@ -102,9 +108,10 @@ pub async fn register(
|
|||
|
||||
#[instrument(skip_all, fields(user.username = login.username.as_str()))]
|
||||
pub async fn login(
|
||||
state: State<AppState>,
|
||||
login: Json<LoginRequest>,
|
||||
db: Extension<Postgres>,
|
||||
) -> Result<Json<LoginResponse>, ErrorResponseStatus<'static>> {
|
||||
let db = &state.0.postgres;
|
||||
let user = match db.get_user(login.username.borrow()).await {
|
||||
Ok(u) => u,
|
||||
Err(sqlx::Error::RowNotFound) => {
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
use async_trait::async_trait;
|
||||
use axum::{
|
||||
extract::{FromRequest, RequestParts},
|
||||
handler::Handler,
|
||||
extract::FromRequestParts,
|
||||
response::IntoResponse,
|
||||
routing::{get, post},
|
||||
Extension, Router,
|
||||
Router,
|
||||
};
|
||||
use eyre::Result;
|
||||
use http::request::Parts;
|
||||
use tower::ServiceBuilder;
|
||||
use tower_http::trace::TraceLayer;
|
||||
|
||||
|
@ -17,20 +17,15 @@ use super::{
|
|||
use crate::{models::User, settings::Settings};
|
||||
|
||||
#[async_trait]
|
||||
impl<B> FromRequest<B> for User
|
||||
where
|
||||
B: Send,
|
||||
{
|
||||
impl FromRequestParts<AppState> for User {
|
||||
type Rejection = http::StatusCode;
|
||||
|
||||
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
||||
let postgres = req
|
||||
.extensions()
|
||||
.get::<Postgres>()
|
||||
.ok_or(http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
async fn from_request_parts(
|
||||
req: &mut Parts,
|
||||
state: &AppState,
|
||||
) -> Result<Self, Self::Rejection> {
|
||||
let auth_header = req
|
||||
.headers()
|
||||
.headers
|
||||
.get(http::header::AUTHORIZATION)
|
||||
.ok_or(http::StatusCode::FORBIDDEN)?;
|
||||
let auth_header = auth_header
|
||||
|
@ -44,7 +39,8 @@ where
|
|||
return Err(http::StatusCode::FORBIDDEN);
|
||||
}
|
||||
|
||||
let user = postgres
|
||||
let user = state
|
||||
.postgres
|
||||
.get_session_user(token)
|
||||
.await
|
||||
.map_err(|_| http::StatusCode::FORBIDDEN)?;
|
||||
|
@ -56,6 +52,13 @@ where
|
|||
async fn teapot() -> impl IntoResponse {
|
||||
(http::StatusCode::IM_A_TEAPOT, "☕")
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub postgres: Postgres,
|
||||
pub settings: Settings,
|
||||
}
|
||||
|
||||
pub fn router(postgres: Postgres, settings: Settings) -> Router {
|
||||
let routes = Router::new()
|
||||
.route("/", get(handlers::index))
|
||||
|
@ -73,11 +76,7 @@ pub fn router(postgres: Postgres, settings: Settings) -> Router {
|
|||
} else {
|
||||
Router::new().nest(path, routes)
|
||||
}
|
||||
.fallback(teapot.into_service())
|
||||
.layer(
|
||||
ServiceBuilder::new()
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.layer(Extension(postgres))
|
||||
.layer(Extension(settings)),
|
||||
)
|
||||
.fallback(teapot)
|
||||
.with_state(AppState { postgres, settings })
|
||||
.layer(ServiceBuilder::new().layer(TraceLayer::new_for_http()))
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue