goodbye warp, hello axum (#296)

This commit is contained in:
Conrad Ludgate 2022-04-12 23:06:19 +01:00 committed by GitHub
parent 3b7ed7caff
commit a95018cc90
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 285 additions and 528 deletions

360
Cargo.lock generated
View file

@ -122,14 +122,15 @@ dependencies = [
name = "atuin-common"
version = "0.8.1"
dependencies = [
"axum",
"chrono",
"http",
"rust-crypto",
"serde",
"serde_derive",
"serde_json",
"sodiumoxide",
"uuid",
"warp",
]
[[package]]
@ -138,10 +139,12 @@ version = "0.8.1"
dependencies = [
"async-trait",
"atuin-common",
"axum",
"base64",
"chrono",
"config",
"eyre",
"http",
"log",
"rand 0.8.5",
"rust-crypto",
@ -152,7 +155,6 @@ dependencies = [
"sqlx",
"tokio",
"uuid",
"warp",
"whoami",
]
@ -162,6 +164,51 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "axum"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "47594e438a243791dba58124b6669561f5baa14cb12046641d8008bf035e5a25"
dependencies = [
"async-trait",
"axum-core",
"bitflags",
"bytes",
"futures-util",
"http",
"http-body",
"hyper",
"itoa 1.0.1",
"matchit",
"memchr",
"mime",
"percent-encoding",
"pin-project-lite",
"serde",
"serde_json",
"serde_urlencoded",
"sync_wrapper",
"tokio",
"tower",
"tower-http",
"tower-layer",
"tower-service",
]
[[package]]
name = "axum-core"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a671c9ae99531afdd5d3ee8340b8da547779430689947144c140fc74a740244"
dependencies = [
"async-trait",
"bytes",
"futures-util",
"http",
"http-body",
"mime",
]
[[package]]
name = "base64"
version = "0.13.0"
@ -195,15 +242,6 @@ dependencies = [
"generic-array 0.14.5",
]
[[package]]
name = "block-buffer"
version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bf7fe51849ea569fd452f37822f606a5cabb684dc918707a0193fd4664ff324"
dependencies = [
"generic-array 0.14.5",
]
[[package]]
name = "block-padding"
version = "0.1.5"
@ -225,16 +263,6 @@ dependencies = [
"serde",
]
[[package]]
name = "buf_redux"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b953a6887648bb07a535631f2bc00fbdb2a2216f135552cb3f534ed136b9c07f"
dependencies = [
"memchr",
"safemem",
]
[[package]]
name = "bumpalo"
version = "3.9.1"
@ -449,16 +477,6 @@ dependencies = [
"lazy_static",
]
[[package]]
name = "crypto-common"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57952ca27b5e3606ff4dd79b0020231aaf9d6aa76dc05fd30137538c50bd3ce8"
dependencies = [
"generic-array 0.14.5",
"typenum",
]
[[package]]
name = "crypto-mac"
version = "0.11.1"
@ -509,16 +527,6 @@ dependencies = [
"generic-array 0.14.5",
]
[[package]]
name = "digest"
version = "0.10.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2fb860ca6fafa5552fb6d0e816a69c8e49f0908bf524e30a90d97c85892d506"
dependencies = [
"block-buffer 0.10.2",
"crypto-common",
]
[[package]]
name = "directories"
version = "3.0.2"
@ -640,15 +648,6 @@ version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e88a8acf291dafb59c2d96e8f59828f3838bb1a70398823ade51a84de6a6deed"
[[package]]
name = "fastrand"
version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3fcf0cee53519c866c09b5de1f6c56ff9d647101f81c1964fa632e148896cdf"
dependencies = [
"instant",
]
[[package]]
name = "flume"
version = "0.10.12"
@ -806,7 +805,7 @@ dependencies = [
"indexmap",
"slab",
"tokio",
"tokio-util 0.7.1",
"tokio-util",
"tracing",
]
@ -837,31 +836,6 @@ dependencies = [
"hashbrown 0.11.2",
]
[[package]]
name = "headers"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4cff78e5788be1e0ab65b04d306b2ed5092c815ec97ec70f4ebd5aee158aa55d"
dependencies = [
"base64",
"bitflags",
"bytes",
"headers-core",
"http",
"httpdate",
"mime",
"sha-1 0.10.0",
]
[[package]]
name = "headers-core"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7f66481bfee273957b1f20485a4ff3362987f85b2c236580d81b4eb7a326429"
dependencies = [
"http",
]
[[package]]
name = "heck"
version = "0.3.3"
@ -924,6 +898,12 @@ dependencies = [
"pin-project-lite",
]
[[package]]
name = "http-range-header"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bfe8eed0a9285ef776bb792479ea3834e8b94e13d615c2f66d03dd50a435a29"
[[package]]
name = "httparse"
version = "1.7.0"
@ -1155,6 +1135,12 @@ version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a3e378b66a060d48947b590737b30a1be76706c8dd7b8ba0f2fe3989c68a853f"
[[package]]
name = "matchit"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73cbba799671b762df5a175adf59ce145165747bb891505c43d09aefbbf38beb"
[[package]]
name = "md-5"
version = "0.9.1"
@ -1178,16 +1164,6 @@ version = "0.3.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a60c7ce501c71e03a9c9c0d35b861413ae925bd979cc7a4e30d060069aaac8d"
[[package]]
name = "mime_guess"
version = "2.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4192263c238a5f0d0c6bfd21f336a313a4ce1c450542449ca191bb657b4642ef"
dependencies = [
"mime",
"unicase",
]
[[package]]
name = "minimal-lexical"
version = "0.2.1"
@ -1223,24 +1199,6 @@ dependencies = [
"winapi",
]
[[package]]
name = "multipart"
version = "0.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00dec633863867f29cb39df64a397cdf4a6354708ddd7759f70c7fb51c5f9182"
dependencies = [
"buf_redux",
"httparse",
"log",
"mime",
"mime_guess",
"quick-error",
"rand 0.8.5",
"safemem",
"tempfile",
"twoway",
]
[[package]]
name = "nom"
version = "7.1.1"
@ -1747,15 +1705,6 @@ version = "0.6.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b"
[[package]]
name = "remove_dir_all"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3acd125665422973a33ac9d3dd2df85edad0f4ae9b00dafb1a05e43a9f5ef8e7"
dependencies = [
"winapi",
]
[[package]]
name = "reqwest"
version = "0.11.10"
@ -1910,12 +1859,6 @@ version = "1.0.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73b4b750c782965c211b42f022f59af1fbceabdd026623714f104152f1ec149f"
[[package]]
name = "safemem"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef703b7cb59335eae2eb93ceb664c0eb7ea6bf567079d843e09420219668e072"
[[package]]
name = "same-file"
version = "1.0.6"
@ -1931,12 +1874,6 @@ version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "088c5d71572124929ea7549a8ce98e1a6fd33d0a38367b09027b382e67c033db"
[[package]]
name = "scoped-tls"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ea6a9290e3c9cf0f18145ef7ffa62d68ee0bf5fcd651017e586dc7fd5da448c2"
[[package]]
name = "scopeguard"
version = "1.1.0"
@ -2031,17 +1968,6 @@ dependencies = [
"opaque-debug 0.3.0",
]
[[package]]
name = "sha-1"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "028f48d513f9678cda28f6e4064755b3fbb2af6acd672f2c209b62323f7aea0f"
dependencies = [
"cfg-if",
"cpufeatures",
"digest 0.10.3",
]
[[package]]
name = "sha2"
version = "0.9.9"
@ -2267,6 +2193,12 @@ dependencies = [
"unicode-xid",
]
[[package]]
name = "sync_wrapper"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "20518fe4a4c9acf048008599e464deb21beeae3d3578418951a189c235a7a9a8"
[[package]]
name = "tabwriter"
version = "1.2.1"
@ -2276,20 +2208,6 @@ dependencies = [
"unicode-width",
]
[[package]]
name = "tempfile"
version = "3.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5cdb1ef4eaeeaddc8fbd371e5017057064af0911902ef36b39801f67cc6d79e4"
dependencies = [
"cfg-if",
"fastrand",
"libc",
"redox_syscall",
"remove_dir_all",
"winapi",
]
[[package]]
name = "termcolor"
version = "1.1.3"
@ -2436,33 +2354,6 @@ dependencies = [
"tokio",
]
[[package]]
name = "tokio-tungstenite"
version = "0.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "511de3f85caf1c98983545490c3d09685fa8eb634e57eec22bb4db271f46cbd8"
dependencies = [
"futures-util",
"log",
"pin-project",
"tokio",
"tungstenite",
]
[[package]]
name = "tokio-util"
version = "0.6.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e99e1983e5d376cd8eb4b66604d2e99e79f5bd988c3055891dcd8c9e2604cc0"
dependencies = [
"bytes",
"futures-core",
"futures-sink",
"log",
"pin-project-lite",
"tokio",
]
[[package]]
name = "tokio-util"
version = "0.7.1"
@ -2486,6 +2377,48 @@ dependencies = [
"serde",
]
[[package]]
name = "tower"
version = "0.4.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a89fd63ad6adf737582df5db40d286574513c69a11dac5214dc3b5603d6713e"
dependencies = [
"futures-core",
"futures-util",
"pin-project",
"pin-project-lite",
"tokio",
"tokio-util",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "tower-http"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aba3f3efabf7fb41fae8534fc20a817013dd1c12cb45441efb6c82e6556b4cd8"
dependencies = [
"bitflags",
"bytes",
"futures-core",
"futures-util",
"http",
"http-body",
"http-range-header",
"pin-project-lite",
"tower",
"tower-layer",
"tower-service",
]
[[package]]
name = "tower-layer"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "343bc9466d3fe6b0f960ef45960509f84480bf4fd96f92901afe7ff3df9d3a62"
[[package]]
name = "tower-service"
version = "0.3.1"
@ -2544,34 +2477,6 @@ dependencies = [
"unicode-width",
]
[[package]]
name = "tungstenite"
version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a0b2d8558abd2e276b0a8df5c05a2ec762609344191e5fd23e292c910e9165b5"
dependencies = [
"base64",
"byteorder",
"bytes",
"http",
"httparse",
"log",
"rand 0.8.5",
"sha-1 0.9.8",
"thiserror",
"url",
"utf-8",
]
[[package]]
name = "twoway"
version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59b11b2b5241ba34be09c3cc85a36e56e48f9888862e19cedf23336d35316ed1"
dependencies = [
"memchr",
]
[[package]]
name = "typenum"
version = "1.15.0"
@ -2584,15 +2489,6 @@ version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56dee185309b50d1f11bfedef0fe6d036842e3fb77413abef29f8f8d1c5d4c1c"
[[package]]
name = "unicase"
version = "2.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "50f37be617794602aabbeee0be4f259dc1778fabe05e2d67ee8f79326d5cb4f6"
dependencies = [
"version_check",
]
[[package]]
name = "unicode-bidi"
version = "0.3.7"
@ -2656,12 +2552,6 @@ version = "1.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a1f0175e03a0973cf4afd476bef05c26e228520400eb1fd473ad417b1c00ffb"
[[package]]
name = "utf-8"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
[[package]]
name = "uuid"
version = "0.8.2"
@ -2704,36 +2594,6 @@ dependencies = [
"try-lock",
]
[[package]]
name = "warp"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3cef4e1e9114a4b7f1ac799f16ce71c14de5778500c5450ec6b7b920c55b587e"
dependencies = [
"bytes",
"futures-channel",
"futures-util",
"headers",
"http",
"hyper",
"log",
"mime",
"mime_guess",
"multipart",
"percent-encoding",
"pin-project",
"scoped-tls",
"serde",
"serde_json",
"serde_urlencoded",
"tokio",
"tokio-stream",
"tokio-tungstenite",
"tokio-util 0.6.9",
"tower-service",
"tracing",
]
[[package]]
name = "wasi"
version = "0.10.2+wasi-snapshot-preview1"

View file

@ -31,7 +31,7 @@ pub async fn register(
username: &str,
email: &str,
password: &str,
) -> Result<RegisterResponse<'static>> {
) -> Result<RegisterResponse> {
let mut map = HashMap::new();
map.insert("username", username);
map.insert("email", email);
@ -61,7 +61,7 @@ pub async fn register(
Ok(session)
}
pub async fn login(address: &str, req: LoginRequest<'_>) -> Result<LoginResponse<'static>> {
pub async fn login(address: &str, req: LoginRequest) -> Result<LoginResponse> {
let url = format!("{}/login", address);
let client = reqwest::Client::new();
@ -142,7 +142,7 @@ impl<'a> Client<'a> {
Ok(history)
}
pub async fn post_history(&self, history: &[AddHistoryRequest<'_, String>]) -> Result<()> {
pub async fn post_history(&self, history: &[AddHistoryRequest]) -> Result<()> {
let url = format!("{}/history", self.sync_addr);
let url = Url::parse(url.as_str())?;

View file

@ -110,10 +110,10 @@ async fn sync_upload(
let data = serde_json::to_string(&data)?;
let add_hist = AddHistoryRequest {
id: i.id.into(),
id: i.id,
timestamp: i.timestamp,
data,
hostname: hash_str(&i.hostname).into(),
hostname: hash_str(&i.hostname),
};
buffer.push(add_hist);

View file

@ -17,5 +17,6 @@ chrono = { version = "0.4", features = ["serde"] }
serde_derive = "1.0.125"
serde = "1.0.126"
serde_json = "1.0.75"
warp = "0.3"
uuid = { version = "0.8", features = ["v4"] }
axum = "0.5"
http = "0.2"

View file

@ -1,43 +1,43 @@
use std::{borrow::Cow, convert::Infallible};
use std::borrow::Cow;
use axum::{response::IntoResponse, Json};
use chrono::Utc;
use serde::Serialize;
use warp::{reply::Response, Reply};
#[derive(Debug, Serialize, Deserialize)]
pub struct UserResponse<'a> {
pub username: Cow<'a, str>,
pub struct UserResponse {
pub username: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct RegisterRequest<'a> {
pub email: Cow<'a, str>,
pub username: Cow<'a, str>,
pub password: Cow<'a, str>,
pub struct RegisterRequest {
pub email: String,
pub username: String,
pub password: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct RegisterResponse<'a> {
pub session: Cow<'a, str>,
pub struct RegisterResponse {
pub session: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct LoginRequest<'a> {
pub username: Cow<'a, str>,
pub password: Cow<'a, str>,
pub struct LoginRequest {
pub username: String,
pub password: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct LoginResponse<'a> {
pub session: Cow<'a, str>,
pub struct LoginResponse {
pub session: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct AddHistoryRequest<'a, D> {
pub id: Cow<'a, str>,
pub struct AddHistoryRequest {
pub id: String,
pub timestamp: chrono::DateTime<Utc>,
pub data: D,
pub hostname: Cow<'a, str>,
pub data: String,
pub hostname: String,
}
#[derive(Debug, Serialize, Deserialize)]
@ -46,10 +46,10 @@ pub struct CountResponse {
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SyncHistoryRequest<'a> {
pub struct SyncHistoryRequest {
pub sync_ts: chrono::DateTime<chrono::FixedOffset>,
pub history_ts: chrono::DateTime<chrono::FixedOffset>,
pub host: Cow<'a, str>,
pub host: String,
}
#[derive(Debug, Serialize, Deserialize)]
@ -62,25 +62,19 @@ pub struct ErrorResponse<'a> {
pub reason: Cow<'a, str>,
}
impl Reply for ErrorResponse<'_> {
fn into_response(self) -> Response {
warp::reply::json(&self).into_response()
impl<'a> IntoResponse for ErrorResponseStatus<'a> {
fn into_response(self) -> axum::response::Response {
(self.status, Json(self.error)).into_response()
}
}
pub struct ErrorResponseStatus<'a> {
pub error: ErrorResponse<'a>,
pub status: warp::http::StatusCode,
}
impl Reply for ErrorResponseStatus<'_> {
fn into_response(self) -> Response {
warp::reply::with_status(self.error, self.status).into_response()
}
pub status: http::StatusCode,
}
impl<'a> ErrorResponse<'a> {
pub fn with_status(self, status: warp::http::StatusCode) -> ErrorResponseStatus<'a> {
pub fn with_status(self, status: http::StatusCode) -> ErrorResponseStatus<'a> {
ErrorResponseStatus {
error: self,
status,
@ -93,31 +87,3 @@ impl<'a> ErrorResponse<'a> {
}
}
}
pub enum ReplyEither<T, E> {
Ok(T),
Err(E),
}
impl<T: Reply, E: Reply> Reply for ReplyEither<T, E> {
fn into_response(self) -> Response {
match self {
ReplyEither::Ok(t) => t.into_response(),
ReplyEither::Err(e) => e.into_response(),
}
}
}
pub type ReplyResult<T, E> = Result<ReplyEither<T, E>, Infallible>;
pub fn reply_error<T, E>(e: E) -> ReplyResult<T, E> {
Ok(ReplyEither::Err(e))
}
pub type JSONResult<E> = Result<ReplyEither<warp::reply::Json, E>, Infallible>;
pub fn reply_json<E>(t: impl Serialize) -> JSONResult<E> {
reply(warp::reply::json(&t))
}
pub fn reply<T, E>(t: T) -> ReplyResult<T, E> {
Ok(ReplyEither::Ok(t))
}

View file

@ -25,6 +25,7 @@ base64 = "0.13.0"
rand = "0.8.4"
rust-crypto = "^0.2"
tokio = { version = "1", features = ["full"] }
warp = "0.3"
sqlx = { version = "0.5", features = [ "runtime-tokio-rustls", "uuid", "chrono", "postgres" ] }
async-trait = "0.1.49"
axum = "0.5"
http = "0.2"

View file

@ -1,26 +1,27 @@
use warp::{http::StatusCode, Reply};
use axum::extract::Query;
use axum::{Extension, Json};
use http::StatusCode;
use crate::database::Database;
use crate::database::{Database, Postgres};
use crate::models::{NewHistory, User};
use atuin_common::api::*;
pub async fn count(
user: User,
db: impl Database + Clone + Send + Sync,
) -> JSONResult<ErrorResponseStatus<'static>> {
db.count_history(&user).await.map_or(
reply_error(
ErrorResponse::reply("failed to query history count")
.with_status(StatusCode::INTERNAL_SERVER_ERROR),
),
|count| reply_json(CountResponse { count }),
)
db: Extension<Postgres>,
) -> Result<Json<CountResponse>, ErrorResponseStatus<'static>> {
match db.count_history(&user).await {
Ok(count) => Ok(Json(CountResponse { count })),
Err(_) => Err(ErrorResponse::reply("failed to query history count")
.with_status(StatusCode::INTERNAL_SERVER_ERROR)),
}
}
pub async fn list(
req: SyncHistoryRequest<'_>,
req: Query<SyncHistoryRequest>,
user: User,
db: impl Database + Clone + Send + Sync,
) -> JSONResult<ErrorResponseStatus<'static>> {
db: Extension<Postgres>,
) -> Result<Json<SyncHistoryResponse>, ErrorResponseStatus<'static>> {
let history = db
.list_history(
&user,
@ -32,10 +33,8 @@ pub async fn list(
if let Err(e) = history {
error!("failed to load history: {}", e);
return reply_error(
ErrorResponse::reply("failed to load history")
.with_status(StatusCode::INTERNAL_SERVER_ERROR),
);
return Err(ErrorResponse::reply("failed to load history")
.with_status(StatusCode::INTERNAL_SERVER_ERROR));
}
let history: Vec<String> = history
@ -50,14 +49,14 @@ pub async fn list(
user.id
);
reply_json(SyncHistoryResponse { history })
Ok(Json(SyncHistoryResponse { history }))
}
pub async fn add(
req: Vec<AddHistoryRequest<'_, String>>,
Json(req): Json<Vec<AddHistoryRequest>>,
user: User,
db: impl Database + Clone + Send + Sync,
) -> ReplyResult<impl Reply, ErrorResponseStatus<'_>> {
db: Extension<Postgres>,
) -> Result<(), ErrorResponseStatus<'static>> {
debug!("request to add {} history items", req.len());
let history: Vec<NewHistory> = req
@ -67,18 +66,16 @@ pub async fn add(
user_id: user.id,
hostname: h.hostname,
timestamp: h.timestamp.naive_utc(),
data: h.data.into(),
data: h.data,
})
.collect();
if let Err(e) = db.add_history(&history).await {
error!("failed to add history: {}", e);
return reply_error(
ErrorResponse::reply("failed to add history")
.with_status(StatusCode::INTERNAL_SERVER_ERROR),
);
return Err(ErrorResponse::reply("failed to add history")
.with_status(StatusCode::INTERNAL_SERVER_ERROR));
};
reply(warp::reply())
Ok(())
}

View file

@ -1,6 +1,6 @@
pub mod history;
pub mod user;
pub const fn index() -> &'static str {
pub async fn index() -> &'static str {
"\"Through the fathomless deeps of space swims the star turtle Great A\u{2019}Tuin, bearing on its back the four giant elephants who carry on their shoulders the mass of the Discworld.\"\n\t-- Sir Terry Pratchett"
}

View file

@ -2,11 +2,13 @@ use std::borrow::Borrow;
use atuin_common::api::*;
use atuin_common::utils::hash_secret;
use axum::extract::Path;
use axum::{Extension, Json};
use http::StatusCode;
use sodiumoxide::crypto::pwhash::argon2id13;
use uuid::Uuid;
use warp::http::StatusCode;
use crate::database::Database;
use crate::database::{Database, Postgres};
use crate::models::{NewSession, NewUser};
use crate::settings::Settings;
@ -25,31 +27,29 @@ pub fn verify_str(secret: &str, verify: &str) -> bool {
}
pub async fn get(
username: impl AsRef<str>,
db: impl Database + Clone + Send + Sync,
) -> JSONResult<ErrorResponseStatus<'static>> {
Path(username): Path<String>,
db: Extension<Postgres>,
) -> Result<Json<UserResponse>, ErrorResponseStatus<'static>> {
let user = match db.get_user(username.as_ref()).await {
Ok(user) => user,
Err(e) => {
debug!("user not found: {}", e);
return reply_error(
ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND),
);
return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND));
}
};
reply_json(UserResponse {
username: user.username.into(),
})
Ok(Json(UserResponse {
username: user.username,
}))
}
pub async fn register(
register: RegisterRequest<'_>,
settings: Settings,
db: impl Database + Clone + Send + Sync,
) -> JSONResult<ErrorResponseStatus<'static>> {
Json(register): Json<RegisterRequest>,
settings: Extension<Settings>,
db: Extension<Postgres>,
) -> Result<Json<RegisterResponse>, ErrorResponseStatus<'static>> {
if !settings.open_registration {
return reply_error(
return Err(
ErrorResponse::reply("this server is not open for registrations")
.with_status(StatusCode::BAD_REQUEST),
);
@ -60,15 +60,15 @@ pub async fn register(
let new_user = NewUser {
email: register.email,
username: register.username,
password: hashed.into(),
password: hashed,
};
let user_id = match db.add_user(&new_user).await {
Ok(id) => id,
Err(e) => {
error!("failed to add user: {}", e);
return reply_error(
ErrorResponse::reply("failed to add user").with_status(StatusCode::BAD_REQUEST),
return Err(
ErrorResponse::reply("failed to add user").with_status(StatusCode::BAD_REQUEST)
);
}
};
@ -81,31 +81,25 @@ pub async fn register(
};
match db.add_session(&new_session).await {
Ok(_) => reply_json(RegisterResponse {
session: token.into(),
}),
Ok(_) => Ok(Json(RegisterResponse { session: token })),
Err(e) => {
error!("failed to add session: {}", e);
reply_error(
ErrorResponse::reply("failed to register user")
.with_status(StatusCode::BAD_REQUEST),
)
Err(ErrorResponse::reply("failed to register user")
.with_status(StatusCode::BAD_REQUEST))
}
}
}
pub async fn login(
login: LoginRequest<'_>,
db: impl Database + Clone + Send + Sync,
) -> JSONResult<ErrorResponseStatus<'_>> {
login: Json<LoginRequest>,
db: Extension<Postgres>,
) -> Result<Json<LoginResponse>, ErrorResponseStatus<'static>> {
let user = match db.get_user(login.username.borrow()).await {
Ok(u) => u,
Err(e) => {
error!("failed to get user {}: {}", login.username.clone(), e);
return reply_error(
ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND),
);
return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND));
}
};
@ -114,21 +108,17 @@ pub async fn login(
Err(e) => {
error!("failed to get session for {}: {}", login.username, e);
return reply_error(
ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND),
);
return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND));
}
};
let verified = verify_str(user.password.as_str(), login.password.borrow());
if !verified {
return reply_error(
ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND),
);
return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND));
}
reply_json(LoginResponse {
session: session.token.into(),
})
Ok(Json(LoginResponse {
session: session.token,
}))
}

View file

@ -1,8 +1,10 @@
#![forbid(unsafe_code)]
use std::net::IpAddr;
use std::net::{IpAddr, SocketAddr};
use eyre::Result;
use axum::Server;
use database::Postgres;
use eyre::{Context, Result};
use crate::settings::Settings;
@ -19,14 +21,18 @@ pub mod models;
pub mod router;
pub mod settings;
pub async fn launch(settings: &Settings, host: String, port: u16) -> Result<()> {
// routes to run:
// index, register, add_history, login, get_user, sync_count, sync_list
pub async fn launch(settings: Settings, host: String, port: u16) -> Result<()> {
let host = host.parse::<IpAddr>()?;
let r = router::router(settings).await?;
let postgres = Postgres::new(settings.db_uri.as_str())
.await
.wrap_err_with(|| format!("failed to connect to db: {}", settings.db_uri))?;
warp::serve(r).run((host, port)).await;
let r = router::router(postgres, settings);
Server::bind(&SocketAddr::new(host, port))
.serve(r.into_make_service())
.await?;
Ok(())
}

View file

@ -1,5 +1,3 @@
use std::borrow::Cow;
use chrono::prelude::*;
#[derive(sqlx::FromRow)]
@ -15,13 +13,13 @@ pub struct History {
pub created_at: NaiveDateTime,
}
pub struct NewHistory<'a> {
pub client_id: Cow<'a, str>,
pub struct NewHistory {
pub client_id: String,
pub user_id: i64,
pub hostname: Cow<'a, str>,
pub hostname: String,
pub timestamp: chrono::NaiveDateTime,
pub data: Cow<'a, str>,
pub data: String,
}
#[derive(sqlx::FromRow)]
@ -39,13 +37,13 @@ pub struct Session {
pub token: String,
}
pub struct NewUser<'a> {
pub username: Cow<'a, str>,
pub email: Cow<'a, str>,
pub password: Cow<'a, str>,
pub struct NewUser {
pub username: String,
pub email: String,
pub password: String,
}
pub struct NewSession<'a> {
pub struct NewSession {
pub user_id: i64,
pub token: Cow<'a, str>,
pub token: String,
}

View file

@ -1,9 +1,12 @@
use std::convert::Infallible;
use async_trait::async_trait;
use axum::{
extract::{FromRequest, RequestParts},
handler::Handler,
response::IntoResponse,
routing::{get, post},
Extension, Router,
};
use eyre::Result;
use warp::{hyper::StatusCode, Filter};
use atuin_common::api::SyncHistoryRequest;
use super::{
database::{Database, Postgres},
@ -11,119 +14,57 @@ use super::{
};
use crate::{models::User, settings::Settings};
fn with_settings(
settings: Settings,
) -> impl Filter<Extract = (Settings,), Error = Infallible> + Clone {
warp::any().map(move || settings.clone())
}
#[async_trait]
impl<B> FromRequest<B> for User
where
B: Send,
{
type Rejection = http::StatusCode;
fn with_db(
db: impl Database + Clone + Send + Sync,
) -> impl Filter<Extract = (impl Database + Clone,), Error = Infallible> + Clone {
warp::any().map(move || db.clone())
}
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let postgres = req
.extensions()
.get::<Postgres>()
.ok_or(http::StatusCode::INTERNAL_SERVER_ERROR)?;
fn with_user(
postgres: Postgres,
) -> impl Filter<Extract = (User,), Error = warp::Rejection> + Clone {
warp::header::<String>("authorization").and_then(move |header: String| {
// async closures are still buggy :(
let postgres = postgres.clone();
let auth_header = req
.headers()
.get(http::header::AUTHORIZATION)
.ok_or(http::StatusCode::FORBIDDEN)?;
let auth_header = auth_header
.to_str()
.map_err(|_| http::StatusCode::FORBIDDEN)?;
let (typ, token) = auth_header
.split_once(' ')
.ok_or(http::StatusCode::FORBIDDEN)?;
async move {
let header: Vec<&str> = header.split(' ').collect();
let token = if header.len() == 2 {
if header[0] != "Token" {
return Err(warp::reject());
}
header[1]
} else {
return Err(warp::reject());
};
let user = postgres
.get_session_user(token)
.await
.map_err(|_| warp::reject())?;
Ok(user)
if typ != "Token" {
return Err(http::StatusCode::FORBIDDEN);
}
})
let user = postgres
.get_session_user(token)
.await
.map_err(|_| http::StatusCode::FORBIDDEN)?;
Ok(user)
}
}
pub async fn router(
settings: &Settings,
) -> Result<impl Filter<Extract = impl warp::Reply, Error = Infallible> + Clone> {
let postgres = Postgres::new(settings.db_uri.as_str()).await?;
let index = warp::get().and(warp::path::end()).map(handlers::index);
let count = warp::get()
.and(warp::path("sync"))
.and(warp::path("count"))
.and(warp::path::end())
.and(with_user(postgres.clone()))
.and(with_db(postgres.clone()))
.and_then(handlers::history::count)
.boxed();
let sync = warp::get()
.and(warp::path("sync"))
.and(warp::path("history"))
.and(warp::query::<SyncHistoryRequest>())
.and(warp::path::end())
.and(with_user(postgres.clone()))
.and(with_db(postgres.clone()))
.and_then(handlers::history::list)
.boxed();
let add_history = warp::post()
.and(warp::path("history"))
.and(warp::path::end())
.and(warp::body::json())
.and(with_user(postgres.clone()))
.and(with_db(postgres.clone()))
.and_then(handlers::history::add)
.boxed();
let user = warp::get()
.and(warp::path("user"))
.and(warp::path::param::<String>())
.and(warp::path::end())
.and(with_db(postgres.clone()))
.and_then(handlers::user::get)
.boxed();
let register = warp::post()
.and(warp::path("register"))
.and(warp::path::end())
.and(warp::body::json())
.and(with_settings(settings.clone()))
.and(with_db(postgres.clone()))
.and_then(handlers::user::register)
.boxed();
let login = warp::post()
.and(warp::path("login"))
.and(warp::path::end())
.and(warp::body::json())
.and(with_db(postgres))
.and_then(handlers::user::login)
.boxed();
let r = warp::any()
.and(
index
.or(count)
.or(sync)
.or(add_history)
.or(user)
.or(register)
.or(login)
.or(warp::any().map(|| warp::reply::with_status("", StatusCode::IM_A_TEAPOT))),
)
.with(warp::filters::log::log("atuin::api"));
Ok(r)
async fn teapot() -> impl IntoResponse {
(http::StatusCode::IM_A_TEAPOT, "")
}
pub fn router(postgres: Postgres, settings: Settings) -> Router {
Router::new()
.route("/", get(handlers::index))
.route("/sync/count", get(handlers::history::count))
.route("/sync/history", get(handlers::history::list))
.route("/history", post(handlers::history::add))
.route("/user/:username", get(handlers::user::get))
.route("/register", post(handlers::user::register))
.route("/login", post(handlers::user::login))
.fallback(teapot.into_service())
.layer(Extension(postgres))
.layer(Extension(settings))
}

View file

@ -1,4 +1,3 @@
use std::borrow::Cow;
use std::io;
use atuin_common::api::LoginRequest;
@ -66,10 +65,8 @@ impl Cmd {
}
}
pub(super) fn or_user_input<'a>(value: &'a Option<String>, name: &'static str) -> Cow<'a, str> {
value
.as_deref()
.map_or_else(|| Cow::Owned(read_user_input(name)), Cow::Borrowed)
pub(super) fn or_user_input(value: &'_ Option<String>, name: &'static str) -> String {
value.clone().unwrap_or_else(|| read_user_input(name))
}
fn read_user_input(name: &'static str) -> String {

View file

@ -133,7 +133,7 @@ impl AtuinCmd {
match self {
Self::History(history) => history.run(&client_settings, &mut db).await,
Self::Import(import) => import.run(&mut db).await,
Self::Server(server) => server.run(&server_settings).await,
Self::Server(server) => server.run(server_settings).await,
Self::Stats(stats) => stats.run(&mut db, &client_settings).await,
Self::Init(init) => {
init.run();

View file

@ -10,7 +10,7 @@ pub enum Cmd {
/// Start the server
Start {
/// The host address to bind
#[clap(long, short)]
#[clap(long)]
host: Option<String>,
/// The port to bind
@ -20,7 +20,7 @@ pub enum Cmd {
}
impl Cmd {
pub async fn run(&self, settings: &Settings) -> Result<()> {
pub async fn run(&self, settings: Settings) -> Result<()> {
match self {
Self::Start { host, port } => {
let host = host