some changes 🤷 (#83)

* make everything a cow

* fmt + clippy
This commit is contained in:
Conrad Ludgate 2021-05-09 21:17:24 +01:00 committed by GitHub
parent e43e5ce74a
commit de2e34ac50
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 134 additions and 156 deletions

View file

@ -7,21 +7,21 @@ use reqwest::{StatusCode, Url};
use sodiumoxide::crypto::secretbox;
use atuin_common::api::{
AddHistoryRequest, CountResponse, LoginResponse, RegisterResponse, SyncHistoryResponse,
AddHistoryRequest, CountResponse, LoginRequest, LoginResponse, RegisterResponse,
SyncHistoryResponse,
};
use atuin_common::utils::hash_str;
use crate::encryption::{decode_key, decrypt};
use crate::encryption::{decode_key, decrypt, EncryptedHistory};
use crate::history::History;
const VERSION: &str = env!("CARGO_PKG_VERSION");
static APP_USER_AGENT: &str = concat!("atuin/", env!("CARGO_PKG_VERSION"),);
// TODO: remove all references to the encryption key from this
// It should be handled *elsewhere*
pub struct Client<'a> {
sync_addr: &'a str,
token: &'a str,
key: secretbox::Key,
client: reqwest::Client,
}
@ -31,7 +31,7 @@ pub fn register(
username: &str,
email: &str,
password: &str,
) -> Result<RegisterResponse> {
) -> Result<RegisterResponse<'static>> {
let mut map = HashMap::new();
map.insert("username", username);
map.insert("email", email);
@ -48,7 +48,7 @@ pub fn register(
let client = reqwest::blocking::Client::new();
let resp = client
.post(url)
.header(USER_AGENT, format!("atuin/{}", VERSION))
.header(USER_AGENT, APP_USER_AGENT)
.json(&map)
.send()?;
@ -60,18 +60,14 @@ pub fn register(
Ok(session)
}
pub fn login(address: &str, username: &str, password: &str) -> Result<LoginResponse> {
let mut map = HashMap::new();
map.insert("username", username);
map.insert("password", password);
pub fn login(address: &str, req: LoginRequest) -> Result<LoginResponse<'static>> {
let url = format!("{}/login", address);
let client = reqwest::blocking::Client::new();
let resp = client
.post(url)
.header(USER_AGENT, format!("atuin/{}", VERSION))
.json(&map)
.header(USER_AGENT, APP_USER_AGENT)
.json(&req)
.send()?;
if resp.status() != reqwest::StatusCode::OK {
@ -83,31 +79,25 @@ pub fn login(address: &str, username: &str, password: &str) -> Result<LoginRespo
}
impl<'a> Client<'a> {
pub fn new(sync_addr: &'a str, token: &'a str, key: String) -> Result<Self> {
pub fn new(sync_addr: &'a str, session_token: &'a str, key: String) -> Result<Self> {
let mut headers = HeaderMap::new();
headers.insert(AUTHORIZATION, format!("Token {}", session_token).parse()?);
Ok(Client {
sync_addr,
token,
key: decode_key(key)?,
client: reqwest::Client::new(),
client: reqwest::Client::builder()
.user_agent(APP_USER_AGENT)
.default_headers(headers)
.build()?,
})
}
pub async fn count(&self) -> Result<i64> {
let url = format!("{}/sync/count", self.sync_addr);
let url = Url::parse(url.as_str())?;
let token = format!("Token {}", self.token);
let token = token.parse()?;
let mut headers = HeaderMap::new();
headers.insert(AUTHORIZATION, token);
let resp = self
.client
.get(url)
.header(USER_AGENT, format!("atuin/{}", VERSION))
.headers(headers)
.send()
.await?;
let resp = self.client.get(url).send().await?;
if resp.status() != StatusCode::OK {
return Err(eyre!("failed to get count (are you logged in?)"));
@ -137,13 +127,7 @@ impl<'a> Client<'a> {
host,
);
let resp = self
.client
.get(url)
.header(AUTHORIZATION, format!("Token {}", self.token))
.header(USER_AGENT, format!("atuin/{}", VERSION))
.send()
.await?;
let resp = self.client.get(url).send().await?;
let history = resp.json::<SyncHistoryResponse>().await?;
let history = history
@ -156,41 +140,15 @@ impl<'a> Client<'a> {
Ok(history)
}
pub async fn post_history(&self, history: &[AddHistoryRequest]) -> Result<()> {
pub async fn post_history(
&self,
history: &[AddHistoryRequest<'_, EncryptedHistory>],
) -> Result<()> {
let url = format!("{}/history", self.sync_addr);
let url = Url::parse(url.as_str())?;
self.client
.post(url)
.json(history)
.header(AUTHORIZATION, format!("Token {}", self.token))
.header(USER_AGENT, format!("atuin/{}", VERSION))
.send()
.await?;
self.client.post(url).json(history).send().await?;
Ok(())
}
pub async fn login(&self, username: &str, password: &str) -> Result<LoginResponse> {
let mut map = HashMap::new();
map.insert("username", username);
map.insert("password", password);
let url = format!("{}/login", self.sync_addr);
let resp = self
.client
.post(url)
.json(&map)
.header(USER_AGENT, format!("atuin/{}", VERSION))
.send()
.await?;
if resp.status() != reqwest::StatusCode::OK {
return Err(eyre!("invalid login details"));
}
let session = resp.json::<LoginResponse>().await?;
Ok(session)
}
}

View file

@ -1,3 +1,5 @@
#![forbid(unsafe_code)]
#[macro_use]
extern crate log;

View file

@ -99,7 +99,7 @@ async fn sync_upload(
while local_count > remote_count {
let last = db.before(cursor, HISTORY_PAGE_SIZE).await?;
let mut buffer = Vec::<AddHistoryRequest>::new();
let mut buffer = Vec::new();
if last.is_empty() {
break;
@ -107,13 +107,11 @@ async fn sync_upload(
for i in last {
let data = encrypt(&i, &key)?;
let data = serde_json::to_string(&data)?;
let add_hist = AddHistoryRequest {
id: i.id,
id: i.id.into(),
timestamp: i.timestamp,
data,
hostname: hash_str(i.hostname.as_str()),
hostname: hash_str(&i.hostname).into(),
};
buffer.push(add_hist);
@ -132,8 +130,8 @@ async fn sync_upload(
pub async fn sync(settings: &Settings, force: bool, db: &mut (impl Database + Send)) -> Result<()> {
let client = api_client::Client::new(
settings.sync_address.as_str(),
settings.session_token.as_str(),
&settings.sync_address,
&settings.session_token,
load_encoded_key(settings)?,
)?;

View file

@ -1,43 +1,43 @@
use std::convert::Infallible;
use std::{borrow::Cow, convert::Infallible};
use chrono::Utc;
use serde::Serialize;
use serde::{Deserialize, Serialize};
use warp::{reply::Response, Reply};
#[derive(Debug, Serialize, Deserialize)]
pub struct UserResponse {
pub username: String,
pub struct UserResponse<'a> {
pub username: Cow<'a, str>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct RegisterRequest {
pub email: String,
pub username: String,
pub password: String,
pub struct RegisterRequest<'a> {
pub email: Cow<'a, str>,
pub username: Cow<'a, str>,
pub password: Cow<'a, str>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct RegisterResponse {
pub session: String,
pub struct RegisterResponse<'a> {
pub session: Cow<'a, str>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct LoginRequest {
pub username: String,
pub password: String,
pub struct LoginRequest<'a> {
pub username: Cow<'a, str>,
pub password: Cow<'a, str>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct LoginResponse {
pub session: String,
pub struct LoginResponse<'a> {
pub session: Cow<'a, str>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct AddHistoryRequest {
pub id: String,
pub struct AddHistoryRequest<'a, D> {
pub id: Cow<'a, str>,
pub timestamp: chrono::DateTime<Utc>,
pub data: String,
pub hostname: String,
pub data: D,
pub hostname: Cow<'a, str>,
}
#[derive(Debug, Serialize, Deserialize)]
@ -46,10 +46,10 @@ pub struct CountResponse {
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SyncHistoryRequest {
pub struct SyncHistoryRequest<'a> {
pub sync_ts: chrono::DateTime<chrono::FixedOffset>,
pub history_ts: chrono::DateTime<chrono::FixedOffset>,
pub host: String,
pub host: Cow<'a, str>,
}
#[derive(Debug, Serialize, Deserialize)]
@ -58,38 +58,38 @@ pub struct SyncHistoryResponse {
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ErrorResponse {
pub reason: String,
pub struct ErrorResponse<'a> {
pub reason: Cow<'a, str>,
}
impl Reply for ErrorResponse {
impl Reply for ErrorResponse<'_> {
fn into_response(self) -> Response {
warp::reply::json(&self).into_response()
}
}
pub struct ErrorResponseStatus {
pub error: ErrorResponse,
pub struct ErrorResponseStatus<'a> {
pub error: ErrorResponse<'a>,
pub status: warp::http::StatusCode,
}
impl Reply for ErrorResponseStatus {
impl Reply for ErrorResponseStatus<'_> {
fn into_response(self) -> Response {
warp::reply::with_status(self.error, self.status).into_response()
}
}
impl ErrorResponse {
pub fn with_status(self, status: warp::http::StatusCode) -> ErrorResponseStatus {
impl<'a> ErrorResponse<'a> {
pub fn with_status(self, status: warp::http::StatusCode) -> ErrorResponseStatus<'a> {
ErrorResponseStatus {
error: self,
status,
}
}
pub fn reply(reason: &str) -> ErrorResponse {
pub fn reply(reason: &'a str) -> ErrorResponse {
Self {
reason: reason.to_string(),
reason: reason.into(),
}
}
}

View file

@ -1,5 +1,4 @@
#[macro_use]
extern crate serde_derive;
#![forbid(unsafe_code)]
pub mod api;
pub mod utils;

View file

@ -13,9 +13,9 @@ pub trait Database {
async fn get_session_user(&self, token: &str) -> Result<User>;
async fn add_session(&self, session: &NewSession) -> Result<()>;
async fn get_user(&self, username: String) -> Result<User>;
async fn get_user(&self, username: &str) -> Result<User>;
async fn get_user_session(&self, u: &User) -> Result<Session>;
async fn add_user(&self, user: NewUser) -> Result<i64>;
async fn add_user(&self, user: &NewUser) -> Result<i64>;
async fn count_history(&self, user: &User) -> Result<i64>;
async fn list_history(
@ -23,7 +23,7 @@ pub trait Database {
user: &User,
created_since: chrono::NaiveDateTime,
since: chrono::NaiveDateTime,
host: String,
host: &str,
) -> Result<Vec<History>>;
async fn add_history(&self, history: &[NewHistory]) -> Result<()>;
}
@ -62,7 +62,7 @@ impl Database for Postgres {
}
}
async fn get_user(&self, username: String) -> Result<User> {
async fn get_user(&self, username: &str) -> Result<User> {
let res: Option<User> =
sqlx::query_as::<_, User>("select * from users where username = $1")
.bind(username)
@ -111,7 +111,7 @@ impl Database for Postgres {
user: &User,
created_since: chrono::NaiveDateTime,
since: chrono::NaiveDateTime,
host: String,
host: &str,
) -> Result<Vec<History>> {
let res = sqlx::query_as::<_, History>(
"select * from history
@ -137,6 +137,10 @@ impl Database for Postgres {
let mut tx = self.pool.begin().await?;
for i in history {
let client_id: &str = &i.client_id;
let hostname: &str = &i.hostname;
let data: &str = &i.data;
sqlx::query(
"insert into history
(client_id, user_id, hostname, timestamp, data)
@ -144,11 +148,11 @@ impl Database for Postgres {
on conflict do nothing
",
)
.bind(i.client_id)
.bind(client_id)
.bind(i.user_id)
.bind(i.hostname)
.bind(hostname)
.bind(i.timestamp)
.bind(i.data)
.bind(data)
.execute(&mut tx)
.await?;
}
@ -158,16 +162,20 @@ impl Database for Postgres {
Ok(())
}
async fn add_user(&self, user: NewUser) -> Result<i64> {
async fn add_user(&self, user: &NewUser) -> Result<i64> {
let email: &str = &user.email;
let username: &str = &user.username;
let password: &str = &user.password;
let res: (i64,) = sqlx::query_as(
"insert into users
(username, email, password)
values($1, $2, $3)
returning id",
)
.bind(user.username.as_str())
.bind(user.email.as_str())
.bind(user.password)
.bind(username)
.bind(email)
.bind(password)
.fetch_one(&self.pool)
.await?;
@ -175,13 +183,15 @@ impl Database for Postgres {
}
async fn add_session(&self, session: &NewSession) -> Result<()> {
let token: &str = &session.token;
sqlx::query(
"insert into sessions
(user_id, token)
values($1, $2)",
)
.bind(session.user_id)
.bind(session.token)
.bind(token)
.execute(&self.pool)
.await?;

View file

@ -6,7 +6,7 @@ use atuin_common::api::*;
pub async fn count(
user: User,
db: impl Database + Clone + Send + Sync,
) -> JSONResult<ErrorResponseStatus> {
) -> JSONResult<ErrorResponseStatus<'static>> {
db.count_history(&user).await.map_or(
reply_error(
ErrorResponse::reply("failed to query history count")
@ -17,16 +17,16 @@ pub async fn count(
}
pub async fn list(
req: SyncHistoryRequest,
req: SyncHistoryRequest<'_>,
user: User,
db: impl Database + Clone + Send + Sync,
) -> JSONResult<ErrorResponseStatus> {
) -> JSONResult<ErrorResponseStatus<'static>> {
let history = db
.list_history(
&user,
req.sync_ts.naive_utc(),
req.history_ts.naive_utc(),
req.host,
&req.host,
)
.await;
@ -54,20 +54,20 @@ pub async fn list(
}
pub async fn add(
req: Vec<AddHistoryRequest>,
req: Vec<AddHistoryRequest<'_, String>>,
user: User,
db: impl Database + Clone + Send + Sync,
) -> ReplyResult<impl Reply, ErrorResponseStatus> {
) -> ReplyResult<impl Reply, ErrorResponseStatus<'_>> {
debug!("request to add {} history items", req.len());
let history: Vec<NewHistory> = req
.iter()
.into_iter()
.map(|h| NewHistory {
client_id: h.id.as_str(),
client_id: h.id,
user_id: user.id,
hostname: h.hostname.as_str(),
hostname: h.hostname,
timestamp: h.timestamp.naive_utc(),
data: h.data.as_str(),
data: h.data.into(),
})
.collect();

View file

@ -1,3 +1,5 @@
use std::borrow::Borrow;
use atuin_common::api::*;
use atuin_common::utils::hash_secret;
use sodiumoxide::crypto::pwhash::argon2id13;
@ -23,10 +25,10 @@ pub fn verify_str(secret: &str, verify: &str) -> bool {
}
pub async fn get(
username: String,
username: impl AsRef<str>,
db: impl Database + Clone + Send + Sync,
) -> JSONResult<ErrorResponseStatus> {
let user = match db.get_user(username).await {
) -> JSONResult<ErrorResponseStatus<'static>> {
let user = match db.get_user(username.as_ref()).await {
Ok(user) => user,
Err(e) => {
debug!("user not found: {}", e);
@ -37,15 +39,15 @@ pub async fn get(
};
reply_json(UserResponse {
username: user.username,
username: user.username.into(),
})
}
pub async fn register(
register: RegisterRequest,
register: RegisterRequest<'_>,
settings: Settings,
db: impl Database + Clone + Send + Sync,
) -> JSONResult<ErrorResponseStatus> {
) -> JSONResult<ErrorResponseStatus<'static>> {
if !settings.open_registration {
return reply_error(
ErrorResponse::reply("this server is not open for registrations")
@ -53,15 +55,15 @@ pub async fn register(
);
}
let hashed = hash_secret(register.password.as_str());
let hashed = hash_secret(&register.password);
let new_user = NewUser {
email: register.email,
username: register.username,
password: hashed,
password: hashed.into(),
};
let user_id = match db.add_user(new_user).await {
let user_id = match db.add_user(&new_user).await {
Ok(id) => id,
Err(e) => {
error!("failed to add user: {}", e);
@ -75,11 +77,13 @@ pub async fn register(
let new_session = NewSession {
user_id,
token: token.as_str(),
token: (&token).into(),
};
match db.add_session(&new_session).await {
Ok(_) => reply_json(RegisterResponse { session: token }),
Ok(_) => reply_json(RegisterResponse {
session: token.into(),
}),
Err(e) => {
error!("failed to add session: {}", e);
reply_error(
@ -91,10 +95,10 @@ pub async fn register(
}
pub async fn login(
login: LoginRequest,
login: LoginRequest<'_>,
db: impl Database + Clone + Send + Sync,
) -> JSONResult<ErrorResponseStatus> {
let user = match db.get_user(login.username.clone()).await {
) -> JSONResult<ErrorResponseStatus<'_>> {
let user = match db.get_user(login.username.borrow()).await {
Ok(u) => u,
Err(e) => {
error!("failed to get user {}: {}", login.username.clone(), e);
@ -116,7 +120,7 @@ pub async fn login(
}
};
let verified = verify_str(user.password.as_str(), login.password.as_str());
let verified = verify_str(user.password.as_str(), login.password.borrow());
if !verified {
return reply_error(
@ -125,6 +129,6 @@ pub async fn login(
}
reply_json(LoginResponse {
session: session.token,
session: session.token.into(),
})
}

View file

@ -1,3 +1,5 @@
#![forbid(unsafe_code)]
use std::net::IpAddr;
use eyre::Result;

View file

@ -1,3 +1,5 @@
use std::borrow::Cow;
use chrono::prelude::*;
#[derive(sqlx::FromRow)]
@ -14,12 +16,12 @@ pub struct History {
}
pub struct NewHistory<'a> {
pub client_id: &'a str,
pub client_id: Cow<'a, str>,
pub user_id: i64,
pub hostname: &'a str,
pub hostname: Cow<'a, str>,
pub timestamp: chrono::NaiveDateTime,
pub data: &'a str,
pub data: Cow<'a, str>,
}
#[derive(sqlx::FromRow)]
@ -37,13 +39,13 @@ pub struct Session {
pub token: String,
}
pub struct NewUser {
pub username: String,
pub email: String,
pub password: String,
pub struct NewUser<'a> {
pub username: Cow<'a, str>,
pub email: Cow<'a, str>,
pub password: Cow<'a, str>,
}
pub struct NewSession<'a> {
pub user_id: i64,
pub token: &'a str,
pub token: Cow<'a, str>,
}

View file

@ -1,6 +1,7 @@
use std::fs::File;
use std::io::prelude::*;
use std::{borrow::Cow, fs::File};
use atuin_common::api::LoginRequest;
use eyre::Result;
use structopt::StructOpt;
@ -34,8 +35,10 @@ impl Cmd {
let session = api_client::login(
settings.sync_address.as_str(),
self.username.as_str(),
self.password.as_str(),
LoginRequest {
username: Cow::Borrowed(&self.username),
password: Cow::Borrowed(&self.password),
},
)?;
let session_path = settings.session_path.as_str();