axum6 with typesafe state (#674)
This commit is contained in:
parent
ec24437735
commit
0acdb99eb3
5 changed files with 82 additions and 47 deletions
42
Cargo.lock
generated
42
Cargo.lock
generated
|
@ -158,7 +158,7 @@ dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"atuin-common",
|
"atuin-common",
|
||||||
"axum",
|
"axum",
|
||||||
"base64 0.20.0",
|
"base64 0.21.0",
|
||||||
"chrono",
|
"chrono",
|
||||||
"chronoutil",
|
"chronoutil",
|
||||||
"config",
|
"config",
|
||||||
|
@ -186,9 +186,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "axum"
|
name = "axum"
|
||||||
version = "0.5.16"
|
version = "0.6.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "c9e3356844c4d6a6d6467b8da2cffb4a2820be256f50a3a386c9d152bab31043"
|
checksum = "e5694b64066a2459918d8074c2ce0d5a88f409431994c2356617c8ae0c4721fc"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"axum-core",
|
"axum-core",
|
||||||
|
@ -204,8 +204,10 @@ dependencies = [
|
||||||
"mime",
|
"mime",
|
||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
|
"rustversion",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
"serde_path_to_error",
|
||||||
"serde_urlencoded",
|
"serde_urlencoded",
|
||||||
"sync_wrapper",
|
"sync_wrapper",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
@ -217,9 +219,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "axum-core"
|
name = "axum-core"
|
||||||
version = "0.2.8"
|
version = "0.3.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d9f0c0a60006f2a293d82d571f635042a72edf927539b7685bd62d361963839b"
|
checksum = "1cae3e661676ffbacb30f1a824089a8c9150e71017f7e1e38f2aa32009188d34"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"bytes",
|
"bytes",
|
||||||
|
@ -227,6 +229,7 @@ dependencies = [
|
||||||
"http",
|
"http",
|
||||||
"http-body",
|
"http-body",
|
||||||
"mime",
|
"mime",
|
||||||
|
"rustversion",
|
||||||
"tower-layer",
|
"tower-layer",
|
||||||
"tower-service",
|
"tower-service",
|
||||||
]
|
]
|
||||||
|
@ -243,6 +246,12 @@ version = "0.20.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "0ea22880d78093b0cbe17c89f64a7d457941e65759157ec6cb31a31d652b05e5"
|
checksum = "0ea22880d78093b0cbe17c89f64a7d457941e65759157ec6cb31a31d652b05e5"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "base64"
|
||||||
|
version = "0.21.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a4a4ddaa51a5bc52a6948f74c06d20aaaddb71924eab79b8c97a8c556e942d6a"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "beef"
|
name = "beef"
|
||||||
version = "0.5.2"
|
version = "0.5.2"
|
||||||
|
@ -1123,9 +1132,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "matchit"
|
name = "matchit"
|
||||||
version = "0.5.0"
|
version = "0.7.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "73cbba799671b762df5a175adf59ce145165747bb891505c43d09aefbbf38beb"
|
checksum = "b87248edafb776e59e6ee64a79086f65890d3510f2c656c000bf2a7e8a0aea40"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "md-5"
|
name = "md-5"
|
||||||
|
@ -1720,6 +1729,12 @@ dependencies = [
|
||||||
"base64 0.13.1",
|
"base64 0.13.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rustversion"
|
||||||
|
version = "1.0.11"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5583e89e108996506031660fe09baa5011b9dd0341b89029313006d1fb508d70"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ryu"
|
name = "ryu"
|
||||||
version = "1.0.11"
|
version = "1.0.11"
|
||||||
|
@ -1821,6 +1836,15 @@ dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "serde_path_to_error"
|
||||||
|
version = "0.1.9"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "26b04f22b563c91331a10074bda3dd5492e3cc39d56bd557e91c0af42b6c7341"
|
||||||
|
dependencies = [
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_urlencoded"
|
name = "serde_urlencoded"
|
||||||
version = "0.7.1"
|
version = "0.7.1"
|
||||||
|
@ -2328,9 +2352,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tower-layer"
|
name = "tower-layer"
|
||||||
version = "0.3.1"
|
version = "0.3.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "343bc9466d3fe6b0f960ef45960509f84480bf4fd96f92901afe7ff3df9d3a62"
|
checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tower-service"
|
name = "tower-service"
|
||||||
|
|
|
@ -20,7 +20,7 @@ config = { version = "0.13", default-features = false, features = ["toml"] }
|
||||||
serde = { version = "1.0.145", features = ["derive"] }
|
serde = { version = "1.0.145", features = ["derive"] }
|
||||||
serde_json = "1.0.86"
|
serde_json = "1.0.86"
|
||||||
sodiumoxide = "0.2.6"
|
sodiumoxide = "0.2.6"
|
||||||
base64 = "0.20.0"
|
base64 = "0.21.0"
|
||||||
rand = "0.8.4"
|
rand = "0.8.4"
|
||||||
tokio = { version = "1", features = ["full"] }
|
tokio = { version = "1", features = ["full"] }
|
||||||
sqlx = { version = "0.6", features = [
|
sqlx = { version = "0.6", features = [
|
||||||
|
@ -29,7 +29,7 @@ sqlx = { version = "0.6", features = [
|
||||||
"postgres",
|
"postgres",
|
||||||
] }
|
] }
|
||||||
async-trait = "0.1.58"
|
async-trait = "0.1.58"
|
||||||
axum = "0.5"
|
axum = "0.6.4"
|
||||||
http = "0.2"
|
http = "0.2"
|
||||||
fs-err = "2.9"
|
fs-err = "2.9"
|
||||||
chronoutil = "0.2.3"
|
chronoutil = "0.2.3"
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{Path, Query},
|
extract::{Path, Query, State},
|
||||||
Extension, Json,
|
Json,
|
||||||
};
|
};
|
||||||
use http::StatusCode;
|
use http::StatusCode;
|
||||||
use tracing::{debug, error, instrument};
|
use tracing::{debug, error, instrument};
|
||||||
|
@ -10,8 +10,9 @@ use tracing::{debug, error, instrument};
|
||||||
use super::{ErrorResponse, ErrorResponseStatus, RespExt};
|
use super::{ErrorResponse, ErrorResponseStatus, RespExt};
|
||||||
use crate::{
|
use crate::{
|
||||||
calendar::{TimePeriod, TimePeriodInfo},
|
calendar::{TimePeriod, TimePeriodInfo},
|
||||||
database::{Database, Postgres},
|
database::Database,
|
||||||
models::{NewHistory, User},
|
models::{NewHistory, User},
|
||||||
|
router::AppState,
|
||||||
};
|
};
|
||||||
|
|
||||||
use atuin_common::api::*;
|
use atuin_common::api::*;
|
||||||
|
@ -19,8 +20,9 @@ use atuin_common::api::*;
|
||||||
#[instrument(skip_all, fields(user.id = user.id))]
|
#[instrument(skip_all, fields(user.id = user.id))]
|
||||||
pub async fn count(
|
pub async fn count(
|
||||||
user: User,
|
user: User,
|
||||||
db: Extension<Postgres>,
|
state: State<AppState>,
|
||||||
) -> Result<Json<CountResponse>, ErrorResponseStatus<'static>> {
|
) -> Result<Json<CountResponse>, ErrorResponseStatus<'static>> {
|
||||||
|
let db = &state.0.postgres;
|
||||||
match db.count_history_cached(&user).await {
|
match db.count_history_cached(&user).await {
|
||||||
// By default read out the cached value
|
// By default read out the cached value
|
||||||
Ok(count) => Ok(Json(CountResponse { count })),
|
Ok(count) => Ok(Json(CountResponse { count })),
|
||||||
|
@ -39,8 +41,9 @@ pub async fn count(
|
||||||
pub async fn list(
|
pub async fn list(
|
||||||
req: Query<SyncHistoryRequest>,
|
req: Query<SyncHistoryRequest>,
|
||||||
user: User,
|
user: User,
|
||||||
db: Extension<Postgres>,
|
state: State<AppState>,
|
||||||
) -> Result<Json<SyncHistoryResponse>, ErrorResponseStatus<'static>> {
|
) -> Result<Json<SyncHistoryResponse>, ErrorResponseStatus<'static>> {
|
||||||
|
let db = &state.0.postgres;
|
||||||
let history = db
|
let history = db
|
||||||
.list_history(
|
.list_history(
|
||||||
&user,
|
&user,
|
||||||
|
@ -73,9 +76,9 @@ pub async fn list(
|
||||||
|
|
||||||
#[instrument(skip_all, fields(user.id = user.id))]
|
#[instrument(skip_all, fields(user.id = user.id))]
|
||||||
pub async fn add(
|
pub async fn add(
|
||||||
Json(req): Json<Vec<AddHistoryRequest>>,
|
|
||||||
user: User,
|
user: User,
|
||||||
db: Extension<Postgres>,
|
state: State<AppState>,
|
||||||
|
Json(req): Json<Vec<AddHistoryRequest>>,
|
||||||
) -> Result<(), ErrorResponseStatus<'static>> {
|
) -> Result<(), ErrorResponseStatus<'static>> {
|
||||||
debug!("request to add {} history items", req.len());
|
debug!("request to add {} history items", req.len());
|
||||||
|
|
||||||
|
@ -90,6 +93,7 @@ pub async fn add(
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
|
let db = &state.0.postgres;
|
||||||
if let Err(e) = db.add_history(&history).await {
|
if let Err(e) = db.add_history(&history).await {
|
||||||
error!("failed to add history: {}", e);
|
error!("failed to add history: {}", e);
|
||||||
|
|
||||||
|
@ -105,13 +109,14 @@ pub async fn calendar(
|
||||||
Path(focus): Path<String>,
|
Path(focus): Path<String>,
|
||||||
Query(params): Query<HashMap<String, u64>>,
|
Query(params): Query<HashMap<String, u64>>,
|
||||||
user: User,
|
user: User,
|
||||||
db: Extension<Postgres>,
|
state: State<AppState>,
|
||||||
) -> Result<Json<HashMap<u64, TimePeriodInfo>>, ErrorResponseStatus<'static>> {
|
) -> Result<Json<HashMap<u64, TimePeriodInfo>>, ErrorResponseStatus<'static>> {
|
||||||
let focus = focus.as_str();
|
let focus = focus.as_str();
|
||||||
|
|
||||||
let year = params.get("year").unwrap_or(&0);
|
let year = params.get("year").unwrap_or(&0);
|
||||||
let month = params.get("month").unwrap_or(&1);
|
let month = params.get("month").unwrap_or(&1);
|
||||||
|
|
||||||
|
let db = &state.0.postgres;
|
||||||
let focus = match focus {
|
let focus = match focus {
|
||||||
"year" => db
|
"year" => db
|
||||||
.calendar(&user, TimePeriod::YEAR, *year, *month)
|
.calendar(&user, TimePeriod::YEAR, *year, *month)
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
use std::borrow::Borrow;
|
use std::borrow::Borrow;
|
||||||
|
|
||||||
use axum::{extract::Path, Extension, Json};
|
use axum::{
|
||||||
|
extract::{Path, State},
|
||||||
|
Extension, Json,
|
||||||
|
};
|
||||||
use http::StatusCode;
|
use http::StatusCode;
|
||||||
use sodiumoxide::crypto::pwhash::argon2id13;
|
use sodiumoxide::crypto::pwhash::argon2id13;
|
||||||
use tracing::{debug, error, instrument};
|
use tracing::{debug, error, instrument};
|
||||||
|
@ -8,8 +11,9 @@ use uuid::Uuid;
|
||||||
|
|
||||||
use super::{ErrorResponse, ErrorResponseStatus, RespExt};
|
use super::{ErrorResponse, ErrorResponseStatus, RespExt};
|
||||||
use crate::{
|
use crate::{
|
||||||
database::{Database, Postgres},
|
database::Database,
|
||||||
models::{NewSession, NewUser},
|
models::{NewSession, NewUser},
|
||||||
|
router::AppState,
|
||||||
settings::Settings,
|
settings::Settings,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -32,8 +36,9 @@ pub fn verify_str(secret: &str, verify: &str) -> bool {
|
||||||
#[instrument(skip_all, fields(user.username = username.as_str()))]
|
#[instrument(skip_all, fields(user.username = username.as_str()))]
|
||||||
pub async fn get(
|
pub async fn get(
|
||||||
Path(username): Path<String>,
|
Path(username): Path<String>,
|
||||||
db: Extension<Postgres>,
|
state: State<AppState>,
|
||||||
) -> Result<Json<UserResponse>, ErrorResponseStatus<'static>> {
|
) -> Result<Json<UserResponse>, ErrorResponseStatus<'static>> {
|
||||||
|
let db = &state.0.postgres;
|
||||||
let user = match db.get_user(username.as_ref()).await {
|
let user = match db.get_user(username.as_ref()).await {
|
||||||
Ok(user) => user,
|
Ok(user) => user,
|
||||||
Err(sqlx::Error::RowNotFound) => {
|
Err(sqlx::Error::RowNotFound) => {
|
||||||
|
@ -54,9 +59,9 @@ pub async fn get(
|
||||||
|
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
pub async fn register(
|
pub async fn register(
|
||||||
Json(register): Json<RegisterRequest>,
|
|
||||||
settings: Extension<Settings>,
|
settings: Extension<Settings>,
|
||||||
db: Extension<Postgres>,
|
state: State<AppState>,
|
||||||
|
Json(register): Json<RegisterRequest>,
|
||||||
) -> Result<Json<RegisterResponse>, ErrorResponseStatus<'static>> {
|
) -> Result<Json<RegisterResponse>, ErrorResponseStatus<'static>> {
|
||||||
if !settings.open_registration {
|
if !settings.open_registration {
|
||||||
return Err(
|
return Err(
|
||||||
|
@ -73,6 +78,7 @@ pub async fn register(
|
||||||
password: hashed,
|
password: hashed,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let db = &state.0.postgres;
|
||||||
let user_id = match db.add_user(&new_user).await {
|
let user_id = match db.add_user(&new_user).await {
|
||||||
Ok(id) => id,
|
Ok(id) => id,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
@ -102,9 +108,10 @@ pub async fn register(
|
||||||
|
|
||||||
#[instrument(skip_all, fields(user.username = login.username.as_str()))]
|
#[instrument(skip_all, fields(user.username = login.username.as_str()))]
|
||||||
pub async fn login(
|
pub async fn login(
|
||||||
|
state: State<AppState>,
|
||||||
login: Json<LoginRequest>,
|
login: Json<LoginRequest>,
|
||||||
db: Extension<Postgres>,
|
|
||||||
) -> Result<Json<LoginResponse>, ErrorResponseStatus<'static>> {
|
) -> Result<Json<LoginResponse>, ErrorResponseStatus<'static>> {
|
||||||
|
let db = &state.0.postgres;
|
||||||
let user = match db.get_user(login.username.borrow()).await {
|
let user = match db.get_user(login.username.borrow()).await {
|
||||||
Ok(u) => u,
|
Ok(u) => u,
|
||||||
Err(sqlx::Error::RowNotFound) => {
|
Err(sqlx::Error::RowNotFound) => {
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{FromRequest, RequestParts},
|
extract::FromRequestParts,
|
||||||
handler::Handler,
|
|
||||||
response::IntoResponse,
|
response::IntoResponse,
|
||||||
routing::{get, post},
|
routing::{get, post},
|
||||||
Extension, Router,
|
Router,
|
||||||
};
|
};
|
||||||
use eyre::Result;
|
use eyre::Result;
|
||||||
|
use http::request::Parts;
|
||||||
use tower::ServiceBuilder;
|
use tower::ServiceBuilder;
|
||||||
use tower_http::trace::TraceLayer;
|
use tower_http::trace::TraceLayer;
|
||||||
|
|
||||||
|
@ -17,20 +17,15 @@ use super::{
|
||||||
use crate::{models::User, settings::Settings};
|
use crate::{models::User, settings::Settings};
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl<B> FromRequest<B> for User
|
impl FromRequestParts<AppState> for User {
|
||||||
where
|
|
||||||
B: Send,
|
|
||||||
{
|
|
||||||
type Rejection = http::StatusCode;
|
type Rejection = http::StatusCode;
|
||||||
|
|
||||||
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
async fn from_request_parts(
|
||||||
let postgres = req
|
req: &mut Parts,
|
||||||
.extensions()
|
state: &AppState,
|
||||||
.get::<Postgres>()
|
) -> Result<Self, Self::Rejection> {
|
||||||
.ok_or(http::StatusCode::INTERNAL_SERVER_ERROR)?;
|
|
||||||
|
|
||||||
let auth_header = req
|
let auth_header = req
|
||||||
.headers()
|
.headers
|
||||||
.get(http::header::AUTHORIZATION)
|
.get(http::header::AUTHORIZATION)
|
||||||
.ok_or(http::StatusCode::FORBIDDEN)?;
|
.ok_or(http::StatusCode::FORBIDDEN)?;
|
||||||
let auth_header = auth_header
|
let auth_header = auth_header
|
||||||
|
@ -44,7 +39,8 @@ where
|
||||||
return Err(http::StatusCode::FORBIDDEN);
|
return Err(http::StatusCode::FORBIDDEN);
|
||||||
}
|
}
|
||||||
|
|
||||||
let user = postgres
|
let user = state
|
||||||
|
.postgres
|
||||||
.get_session_user(token)
|
.get_session_user(token)
|
||||||
.await
|
.await
|
||||||
.map_err(|_| http::StatusCode::FORBIDDEN)?;
|
.map_err(|_| http::StatusCode::FORBIDDEN)?;
|
||||||
|
@ -56,6 +52,13 @@ where
|
||||||
async fn teapot() -> impl IntoResponse {
|
async fn teapot() -> impl IntoResponse {
|
||||||
(http::StatusCode::IM_A_TEAPOT, "☕")
|
(http::StatusCode::IM_A_TEAPOT, "☕")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct AppState {
|
||||||
|
pub postgres: Postgres,
|
||||||
|
pub settings: Settings,
|
||||||
|
}
|
||||||
|
|
||||||
pub fn router(postgres: Postgres, settings: Settings) -> Router {
|
pub fn router(postgres: Postgres, settings: Settings) -> Router {
|
||||||
let routes = Router::new()
|
let routes = Router::new()
|
||||||
.route("/", get(handlers::index))
|
.route("/", get(handlers::index))
|
||||||
|
@ -73,11 +76,7 @@ pub fn router(postgres: Postgres, settings: Settings) -> Router {
|
||||||
} else {
|
} else {
|
||||||
Router::new().nest(path, routes)
|
Router::new().nest(path, routes)
|
||||||
}
|
}
|
||||||
.fallback(teapot.into_service())
|
.fallback(teapot)
|
||||||
.layer(
|
.with_state(AppState { postgres, settings })
|
||||||
ServiceBuilder::new()
|
.layer(ServiceBuilder::new().layer(TraceLayer::new_for_http()))
|
||||||
.layer(TraceLayer::new_for_http())
|
|
||||||
.layer(Extension(postgres))
|
|
||||||
.layer(Extension(settings)),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue