axum6 with typesafe state (#674)

This commit is contained in:
Conrad Ludgate 2023-02-10 09:45:20 +00:00 committed by GitHub
parent ec24437735
commit 0acdb99eb3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 82 additions and 47 deletions

42
Cargo.lock generated
View file

@ -158,7 +158,7 @@ dependencies = [
"async-trait", "async-trait",
"atuin-common", "atuin-common",
"axum", "axum",
"base64 0.20.0", "base64 0.21.0",
"chrono", "chrono",
"chronoutil", "chronoutil",
"config", "config",
@ -186,9 +186,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]] [[package]]
name = "axum" name = "axum"
version = "0.5.16" version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c9e3356844c4d6a6d6467b8da2cffb4a2820be256f50a3a386c9d152bab31043" checksum = "e5694b64066a2459918d8074c2ce0d5a88f409431994c2356617c8ae0c4721fc"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"axum-core", "axum-core",
@ -204,8 +204,10 @@ dependencies = [
"mime", "mime",
"percent-encoding", "percent-encoding",
"pin-project-lite", "pin-project-lite",
"rustversion",
"serde", "serde",
"serde_json", "serde_json",
"serde_path_to_error",
"serde_urlencoded", "serde_urlencoded",
"sync_wrapper", "sync_wrapper",
"tokio", "tokio",
@ -217,9 +219,9 @@ dependencies = [
[[package]] [[package]]
name = "axum-core" name = "axum-core"
version = "0.2.8" version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9f0c0a60006f2a293d82d571f635042a72edf927539b7685bd62d361963839b" checksum = "1cae3e661676ffbacb30f1a824089a8c9150e71017f7e1e38f2aa32009188d34"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"bytes", "bytes",
@ -227,6 +229,7 @@ dependencies = [
"http", "http",
"http-body", "http-body",
"mime", "mime",
"rustversion",
"tower-layer", "tower-layer",
"tower-service", "tower-service",
] ]
@ -243,6 +246,12 @@ version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ea22880d78093b0cbe17c89f64a7d457941e65759157ec6cb31a31d652b05e5" checksum = "0ea22880d78093b0cbe17c89f64a7d457941e65759157ec6cb31a31d652b05e5"
[[package]]
name = "base64"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4a4ddaa51a5bc52a6948f74c06d20aaaddb71924eab79b8c97a8c556e942d6a"
[[package]] [[package]]
name = "beef" name = "beef"
version = "0.5.2" version = "0.5.2"
@ -1123,9 +1132,9 @@ dependencies = [
[[package]] [[package]]
name = "matchit" name = "matchit"
version = "0.5.0" version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73cbba799671b762df5a175adf59ce145165747bb891505c43d09aefbbf38beb" checksum = "b87248edafb776e59e6ee64a79086f65890d3510f2c656c000bf2a7e8a0aea40"
[[package]] [[package]]
name = "md-5" name = "md-5"
@ -1720,6 +1729,12 @@ dependencies = [
"base64 0.13.1", "base64 0.13.1",
] ]
[[package]]
name = "rustversion"
version = "1.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5583e89e108996506031660fe09baa5011b9dd0341b89029313006d1fb508d70"
[[package]] [[package]]
name = "ryu" name = "ryu"
version = "1.0.11" version = "1.0.11"
@ -1821,6 +1836,15 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "serde_path_to_error"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26b04f22b563c91331a10074bda3dd5492e3cc39d56bd557e91c0af42b6c7341"
dependencies = [
"serde",
]
[[package]] [[package]]
name = "serde_urlencoded" name = "serde_urlencoded"
version = "0.7.1" version = "0.7.1"
@ -2328,9 +2352,9 @@ dependencies = [
[[package]] [[package]]
name = "tower-layer" name = "tower-layer"
version = "0.3.1" version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "343bc9466d3fe6b0f960ef45960509f84480bf4fd96f92901afe7ff3df9d3a62" checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0"
[[package]] [[package]]
name = "tower-service" name = "tower-service"

View file

@ -20,7 +20,7 @@ config = { version = "0.13", default-features = false, features = ["toml"] }
serde = { version = "1.0.145", features = ["derive"] } serde = { version = "1.0.145", features = ["derive"] }
serde_json = "1.0.86" serde_json = "1.0.86"
sodiumoxide = "0.2.6" sodiumoxide = "0.2.6"
base64 = "0.20.0" base64 = "0.21.0"
rand = "0.8.4" rand = "0.8.4"
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }
sqlx = { version = "0.6", features = [ sqlx = { version = "0.6", features = [
@ -29,7 +29,7 @@ sqlx = { version = "0.6", features = [
"postgres", "postgres",
] } ] }
async-trait = "0.1.58" async-trait = "0.1.58"
axum = "0.5" axum = "0.6.4"
http = "0.2" http = "0.2"
fs-err = "2.9" fs-err = "2.9"
chronoutil = "0.2.3" chronoutil = "0.2.3"

View file

@ -1,8 +1,8 @@
use std::collections::HashMap; use std::collections::HashMap;
use axum::{ use axum::{
extract::{Path, Query}, extract::{Path, Query, State},
Extension, Json, Json,
}; };
use http::StatusCode; use http::StatusCode;
use tracing::{debug, error, instrument}; use tracing::{debug, error, instrument};
@ -10,8 +10,9 @@ use tracing::{debug, error, instrument};
use super::{ErrorResponse, ErrorResponseStatus, RespExt}; use super::{ErrorResponse, ErrorResponseStatus, RespExt};
use crate::{ use crate::{
calendar::{TimePeriod, TimePeriodInfo}, calendar::{TimePeriod, TimePeriodInfo},
database::{Database, Postgres}, database::Database,
models::{NewHistory, User}, models::{NewHistory, User},
router::AppState,
}; };
use atuin_common::api::*; use atuin_common::api::*;
@ -19,8 +20,9 @@ use atuin_common::api::*;
#[instrument(skip_all, fields(user.id = user.id))] #[instrument(skip_all, fields(user.id = user.id))]
pub async fn count( pub async fn count(
user: User, user: User,
db: Extension<Postgres>, state: State<AppState>,
) -> Result<Json<CountResponse>, ErrorResponseStatus<'static>> { ) -> Result<Json<CountResponse>, ErrorResponseStatus<'static>> {
let db = &state.0.postgres;
match db.count_history_cached(&user).await { match db.count_history_cached(&user).await {
// By default read out the cached value // By default read out the cached value
Ok(count) => Ok(Json(CountResponse { count })), Ok(count) => Ok(Json(CountResponse { count })),
@ -39,8 +41,9 @@ pub async fn count(
pub async fn list( pub async fn list(
req: Query<SyncHistoryRequest>, req: Query<SyncHistoryRequest>,
user: User, user: User,
db: Extension<Postgres>, state: State<AppState>,
) -> Result<Json<SyncHistoryResponse>, ErrorResponseStatus<'static>> { ) -> Result<Json<SyncHistoryResponse>, ErrorResponseStatus<'static>> {
let db = &state.0.postgres;
let history = db let history = db
.list_history( .list_history(
&user, &user,
@ -73,9 +76,9 @@ pub async fn list(
#[instrument(skip_all, fields(user.id = user.id))] #[instrument(skip_all, fields(user.id = user.id))]
pub async fn add( pub async fn add(
Json(req): Json<Vec<AddHistoryRequest>>,
user: User, user: User,
db: Extension<Postgres>, state: State<AppState>,
Json(req): Json<Vec<AddHistoryRequest>>,
) -> Result<(), ErrorResponseStatus<'static>> { ) -> Result<(), ErrorResponseStatus<'static>> {
debug!("request to add {} history items", req.len()); debug!("request to add {} history items", req.len());
@ -90,6 +93,7 @@ pub async fn add(
}) })
.collect(); .collect();
let db = &state.0.postgres;
if let Err(e) = db.add_history(&history).await { if let Err(e) = db.add_history(&history).await {
error!("failed to add history: {}", e); error!("failed to add history: {}", e);
@ -105,13 +109,14 @@ pub async fn calendar(
Path(focus): Path<String>, Path(focus): Path<String>,
Query(params): Query<HashMap<String, u64>>, Query(params): Query<HashMap<String, u64>>,
user: User, user: User,
db: Extension<Postgres>, state: State<AppState>,
) -> Result<Json<HashMap<u64, TimePeriodInfo>>, ErrorResponseStatus<'static>> { ) -> Result<Json<HashMap<u64, TimePeriodInfo>>, ErrorResponseStatus<'static>> {
let focus = focus.as_str(); let focus = focus.as_str();
let year = params.get("year").unwrap_or(&0); let year = params.get("year").unwrap_or(&0);
let month = params.get("month").unwrap_or(&1); let month = params.get("month").unwrap_or(&1);
let db = &state.0.postgres;
let focus = match focus { let focus = match focus {
"year" => db "year" => db
.calendar(&user, TimePeriod::YEAR, *year, *month) .calendar(&user, TimePeriod::YEAR, *year, *month)

View file

@ -1,6 +1,9 @@
use std::borrow::Borrow; use std::borrow::Borrow;
use axum::{extract::Path, Extension, Json}; use axum::{
extract::{Path, State},
Extension, Json,
};
use http::StatusCode; use http::StatusCode;
use sodiumoxide::crypto::pwhash::argon2id13; use sodiumoxide::crypto::pwhash::argon2id13;
use tracing::{debug, error, instrument}; use tracing::{debug, error, instrument};
@ -8,8 +11,9 @@ use uuid::Uuid;
use super::{ErrorResponse, ErrorResponseStatus, RespExt}; use super::{ErrorResponse, ErrorResponseStatus, RespExt};
use crate::{ use crate::{
database::{Database, Postgres}, database::Database,
models::{NewSession, NewUser}, models::{NewSession, NewUser},
router::AppState,
settings::Settings, settings::Settings,
}; };
@ -32,8 +36,9 @@ pub fn verify_str(secret: &str, verify: &str) -> bool {
#[instrument(skip_all, fields(user.username = username.as_str()))] #[instrument(skip_all, fields(user.username = username.as_str()))]
pub async fn get( pub async fn get(
Path(username): Path<String>, Path(username): Path<String>,
db: Extension<Postgres>, state: State<AppState>,
) -> Result<Json<UserResponse>, ErrorResponseStatus<'static>> { ) -> Result<Json<UserResponse>, ErrorResponseStatus<'static>> {
let db = &state.0.postgres;
let user = match db.get_user(username.as_ref()).await { let user = match db.get_user(username.as_ref()).await {
Ok(user) => user, Ok(user) => user,
Err(sqlx::Error::RowNotFound) => { Err(sqlx::Error::RowNotFound) => {
@ -54,9 +59,9 @@ pub async fn get(
#[instrument(skip_all)] #[instrument(skip_all)]
pub async fn register( pub async fn register(
Json(register): Json<RegisterRequest>,
settings: Extension<Settings>, settings: Extension<Settings>,
db: Extension<Postgres>, state: State<AppState>,
Json(register): Json<RegisterRequest>,
) -> Result<Json<RegisterResponse>, ErrorResponseStatus<'static>> { ) -> Result<Json<RegisterResponse>, ErrorResponseStatus<'static>> {
if !settings.open_registration { if !settings.open_registration {
return Err( return Err(
@ -73,6 +78,7 @@ pub async fn register(
password: hashed, password: hashed,
}; };
let db = &state.0.postgres;
let user_id = match db.add_user(&new_user).await { let user_id = match db.add_user(&new_user).await {
Ok(id) => id, Ok(id) => id,
Err(e) => { Err(e) => {
@ -102,9 +108,10 @@ pub async fn register(
#[instrument(skip_all, fields(user.username = login.username.as_str()))] #[instrument(skip_all, fields(user.username = login.username.as_str()))]
pub async fn login( pub async fn login(
state: State<AppState>,
login: Json<LoginRequest>, login: Json<LoginRequest>,
db: Extension<Postgres>,
) -> Result<Json<LoginResponse>, ErrorResponseStatus<'static>> { ) -> Result<Json<LoginResponse>, ErrorResponseStatus<'static>> {
let db = &state.0.postgres;
let user = match db.get_user(login.username.borrow()).await { let user = match db.get_user(login.username.borrow()).await {
Ok(u) => u, Ok(u) => u,
Err(sqlx::Error::RowNotFound) => { Err(sqlx::Error::RowNotFound) => {

View file

@ -1,12 +1,12 @@
use async_trait::async_trait; use async_trait::async_trait;
use axum::{ use axum::{
extract::{FromRequest, RequestParts}, extract::FromRequestParts,
handler::Handler,
response::IntoResponse, response::IntoResponse,
routing::{get, post}, routing::{get, post},
Extension, Router, Router,
}; };
use eyre::Result; use eyre::Result;
use http::request::Parts;
use tower::ServiceBuilder; use tower::ServiceBuilder;
use tower_http::trace::TraceLayer; use tower_http::trace::TraceLayer;
@ -17,20 +17,15 @@ use super::{
use crate::{models::User, settings::Settings}; use crate::{models::User, settings::Settings};
#[async_trait] #[async_trait]
impl<B> FromRequest<B> for User impl FromRequestParts<AppState> for User {
where
B: Send,
{
type Rejection = http::StatusCode; type Rejection = http::StatusCode;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request_parts(
let postgres = req req: &mut Parts,
.extensions() state: &AppState,
.get::<Postgres>() ) -> Result<Self, Self::Rejection> {
.ok_or(http::StatusCode::INTERNAL_SERVER_ERROR)?;
let auth_header = req let auth_header = req
.headers() .headers
.get(http::header::AUTHORIZATION) .get(http::header::AUTHORIZATION)
.ok_or(http::StatusCode::FORBIDDEN)?; .ok_or(http::StatusCode::FORBIDDEN)?;
let auth_header = auth_header let auth_header = auth_header
@ -44,7 +39,8 @@ where
return Err(http::StatusCode::FORBIDDEN); return Err(http::StatusCode::FORBIDDEN);
} }
let user = postgres let user = state
.postgres
.get_session_user(token) .get_session_user(token)
.await .await
.map_err(|_| http::StatusCode::FORBIDDEN)?; .map_err(|_| http::StatusCode::FORBIDDEN)?;
@ -56,6 +52,13 @@ where
async fn teapot() -> impl IntoResponse { async fn teapot() -> impl IntoResponse {
(http::StatusCode::IM_A_TEAPOT, "") (http::StatusCode::IM_A_TEAPOT, "")
} }
#[derive(Clone)]
pub struct AppState {
pub postgres: Postgres,
pub settings: Settings,
}
pub fn router(postgres: Postgres, settings: Settings) -> Router { pub fn router(postgres: Postgres, settings: Settings) -> Router {
let routes = Router::new() let routes = Router::new()
.route("/", get(handlers::index)) .route("/", get(handlers::index))
@ -73,11 +76,7 @@ pub fn router(postgres: Postgres, settings: Settings) -> Router {
} else { } else {
Router::new().nest(path, routes) Router::new().nest(path, routes)
} }
.fallback(teapot.into_service()) .fallback(teapot)
.layer( .with_state(AppState { postgres, settings })
ServiceBuilder::new() .layer(ServiceBuilder::new().layer(TraceLayer::new_for_http()))
.layer(TraceLayer::new_for_http())
.layer(Extension(postgres))
.layer(Extension(settings)),
)
} }