Switch to Warp + SQLx, use async, switch to Rust stable (#36)
* Switch to warp + sql, use async and stable rust * Update CI to use stable
This commit is contained in:
parent
f6de558070
commit
34888827f8
32 changed files with 1520 additions and 1324 deletions
16
.github/workflows/rust.yml
vendored
16
.github/workflows/rust.yml
vendored
|
@ -16,10 +16,10 @@ jobs:
|
|||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- name: Install latest nightly
|
||||
- name: Install rust
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly
|
||||
toolchain: stable
|
||||
override: true
|
||||
|
||||
- name: Run cargo build
|
||||
|
@ -31,10 +31,10 @@ jobs:
|
|||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- name: Install latest nightly
|
||||
- name: Install rust
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly
|
||||
toolchain: stable
|
||||
override: true
|
||||
|
||||
- name: Run cargo test
|
||||
|
@ -46,10 +46,10 @@ jobs:
|
|||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- name: Install latest nightly
|
||||
- name: Install latest rust
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly
|
||||
toolchain: stable
|
||||
override: true
|
||||
components: clippy
|
||||
|
||||
|
@ -62,10 +62,10 @@ jobs:
|
|||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- name: Install latest nightly
|
||||
- name: Install latest rust
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly
|
||||
toolchain: stable
|
||||
override: true
|
||||
components: rustfmt
|
||||
|
||||
|
|
1507
Cargo.lock
generated
1507
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
16
Cargo.toml
16
Cargo.toml
|
@ -8,7 +8,7 @@ description = "atuin - magical shell history"
|
|||
|
||||
[dependencies]
|
||||
log = "0.4"
|
||||
fern = "0.6.0"
|
||||
fern = {version = "0.6.0", features = ["colored"] }
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
eyre = "0.6"
|
||||
shellexpand = "2"
|
||||
|
@ -17,7 +17,6 @@ directories = "3"
|
|||
uuid = { version = "0.8", features = ["v4"] }
|
||||
indicatif = "0.15.0"
|
||||
whoami = "1.1.2"
|
||||
rocket = "0.4.7"
|
||||
chrono-english = "0.1.4"
|
||||
cli-table = "0.4"
|
||||
config = "0.11"
|
||||
|
@ -29,8 +28,6 @@ tui = "0.14"
|
|||
termion = "1.5"
|
||||
unicode-width = "0.1"
|
||||
itertools = "0.10.0"
|
||||
diesel = { version = "1.4.4", features = ["postgres", "chrono"] }
|
||||
diesel_migrations = "1.4.0"
|
||||
dotenv = "0.15.0"
|
||||
sodiumoxide = "0.2.6"
|
||||
reqwest = { version = "0.11", features = ["blocking", "json"] }
|
||||
|
@ -40,12 +37,13 @@ parse_duration = "2.1.1"
|
|||
rand = "0.8.3"
|
||||
rust-crypto = "^0.2"
|
||||
human-panic = "1.0.3"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
warp = "0.3"
|
||||
sqlx = { version = "0.5", features = [ "runtime-tokio-native-tls", "uuid", "chrono", "postgres" ] }
|
||||
async-trait = "0.1.49"
|
||||
urlencoding = "1.1.1"
|
||||
humantime = "2.1.0"
|
||||
|
||||
[dependencies.rusqlite]
|
||||
version = "0.25"
|
||||
features = ["bundled"]
|
||||
|
||||
[dependencies.rocket_contrib]
|
||||
version = "0.4.7"
|
||||
default-features = false
|
||||
features = ["diesel_postgres_pool", "json"]
|
||||
|
|
|
@ -1,7 +1,4 @@
|
|||
FROM rust as builder
|
||||
|
||||
RUN rustup default nightly
|
||||
|
||||
FROM rust:1.51-buster as builder
|
||||
|
||||
RUN cargo new --bin atuin
|
||||
WORKDIR /atuin
|
||||
|
|
|
@ -29,7 +29,7 @@
|
|||
# sync_address = "https://api.atuin.sh"
|
||||
|
||||
# This section configures the sync server, if you decide to host your own
|
||||
[remote]
|
||||
[server]
|
||||
## host to bind, can also be passed via CLI args
|
||||
# host = "127.0.0.1"
|
||||
|
||||
|
|
42
src/api.rs
42
src/api.rs
|
@ -1,8 +1,9 @@
|
|||
use chrono::Utc;
|
||||
|
||||
// This is shared between the client and the server, and has the data structures
|
||||
// representing the requests/responses for each method.
|
||||
// TODO: Properly define responses rather than using json!
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct UserResponse {
|
||||
pub username: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct RegisterRequest {
|
||||
|
@ -11,12 +12,22 @@ pub struct RegisterRequest {
|
|||
pub password: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct RegisterResponse {
|
||||
pub session: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct LoginRequest {
|
||||
pub username: String,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct LoginResponse {
|
||||
pub session: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct AddHistoryRequest {
|
||||
pub id: String,
|
||||
|
@ -31,6 +42,29 @@ pub struct CountResponse {
|
|||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ListHistoryResponse {
|
||||
pub struct SyncHistoryRequest {
|
||||
pub sync_ts: chrono::DateTime<chrono::FixedOffset>,
|
||||
pub history_ts: chrono::DateTime<chrono::FixedOffset>,
|
||||
pub host: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct SyncHistoryResponse {
|
||||
pub history: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ErrorResponse {
|
||||
pub reason: String,
|
||||
}
|
||||
|
||||
impl ErrorResponse {
|
||||
pub fn reply(reason: &str, status: warp::http::StatusCode) -> impl warp::Reply {
|
||||
warp::reply::with_status(
|
||||
warp::reply::json(&ErrorResponse {
|
||||
reason: String::from(reason),
|
||||
}),
|
||||
status,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -53,7 +53,7 @@ fn print_list(h: &[History]) {
|
|||
}
|
||||
|
||||
impl Cmd {
|
||||
pub fn run(&self, settings: &Settings, db: &mut impl Database) -> Result<()> {
|
||||
pub async fn run(&self, settings: &Settings, db: &mut (impl Database + Send)) -> Result<()> {
|
||||
match self {
|
||||
Self::Start { command: words } => {
|
||||
let command = words.join(" ");
|
||||
|
@ -69,6 +69,10 @@ impl Cmd {
|
|||
}
|
||||
|
||||
Self::End { id, exit } => {
|
||||
if id.trim() == "" {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut h = db.load(id)?;
|
||||
h.exit = *exit;
|
||||
h.duration = chrono::Utc::now().timestamp_nanos() - h.timestamp.timestamp_nanos();
|
||||
|
@ -82,7 +86,7 @@ impl Cmd {
|
|||
}
|
||||
Ok(Fork::Child) => {
|
||||
debug!("running periodic background sync");
|
||||
sync::sync(settings, false, db)?;
|
||||
sync::sync(settings, false, db).await?;
|
||||
}
|
||||
Err(_) => println!("Fork failed"),
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@ use std::collections::HashMap;
|
|||
use std::fs::File;
|
||||
use std::io::prelude::*;
|
||||
|
||||
use eyre::Result;
|
||||
use eyre::{eyre, Result};
|
||||
use structopt::StructOpt;
|
||||
|
||||
use crate::settings::Settings;
|
||||
|
@ -28,8 +28,13 @@ impl Cmd {
|
|||
|
||||
let url = format!("{}/login", settings.local.sync_address);
|
||||
let client = reqwest::blocking::Client::new();
|
||||
|
||||
let resp = client.post(url).json(&map).send()?;
|
||||
|
||||
if resp.status() != reqwest::StatusCode::OK {
|
||||
return Err(eyre!("invalid login details"));
|
||||
}
|
||||
|
||||
let session = resp.json::<HashMap<String, String>>()?;
|
||||
let session = session["session"].clone();
|
||||
|
||||
|
|
|
@ -63,16 +63,16 @@ pub fn uuid_v4() -> String {
|
|||
}
|
||||
|
||||
impl AtuinCmd {
|
||||
pub fn run(self, db: &mut impl Database, settings: &Settings) -> Result<()> {
|
||||
pub async fn run<T: Database + Send>(self, db: &mut T, settings: &Settings) -> Result<()> {
|
||||
match self {
|
||||
Self::History(history) => history.run(settings, db),
|
||||
Self::History(history) => history.run(settings, db).await,
|
||||
Self::Import(import) => import.run(db),
|
||||
Self::Server(server) => server.run(settings),
|
||||
Self::Server(server) => server.run(settings).await,
|
||||
Self::Stats(stats) => stats.run(db, settings),
|
||||
Self::Init => init::init(),
|
||||
Self::Search { query } => search::run(&query, db),
|
||||
|
||||
Self::Sync { force } => sync::run(settings, force, db),
|
||||
Self::Sync { force } => sync::run(settings, force, db).await,
|
||||
Self::Login(l) => l.run(settings),
|
||||
Self::Register(r) => register::run(
|
||||
settings,
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
use eyre::Result;
|
||||
use itertools::Itertools;
|
||||
use std::io::stdout;
|
||||
use std::time::Duration;
|
||||
|
||||
use termion::{event::Key, input::MouseTerminal, raw::IntoRawMode, screen::AlternateScreen};
|
||||
use tui::{
|
||||
backend::TermionBackend,
|
||||
|
@ -26,6 +28,78 @@ struct State {
|
|||
results_state: ListState,
|
||||
}
|
||||
|
||||
#[allow(clippy::clippy::cast_sign_loss)]
|
||||
impl State {
|
||||
fn durations(&self) -> Vec<String> {
|
||||
self.results
|
||||
.iter()
|
||||
.map(|h| {
|
||||
let duration =
|
||||
Duration::from_millis(std::cmp::max(h.duration, 0) as u64 / 1_000_000);
|
||||
let duration = humantime::format_duration(duration).to_string();
|
||||
let duration: Vec<&str> = duration.split(' ').collect();
|
||||
|
||||
duration[0].to_string()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn render_results<T: tui::backend::Backend>(
|
||||
&mut self,
|
||||
f: &mut tui::Frame<T>,
|
||||
r: tui::layout::Rect,
|
||||
) {
|
||||
let durations = self.durations();
|
||||
let max_length = durations
|
||||
.iter()
|
||||
.fold(0, |largest, i| std::cmp::max(largest, i.len()));
|
||||
|
||||
let results: Vec<ListItem> = self
|
||||
.results
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, m)| {
|
||||
let command = m.command.to_string().replace("\n", " ").replace("\t", " ");
|
||||
|
||||
let mut command = Span::raw(command);
|
||||
|
||||
let mut duration = durations[i].clone();
|
||||
|
||||
while duration.len() < max_length {
|
||||
duration.push(' ');
|
||||
}
|
||||
|
||||
let duration = Span::styled(
|
||||
duration,
|
||||
Style::default().fg(if m.exit == 0 || m.duration == -1 {
|
||||
Color::Green
|
||||
} else {
|
||||
Color::Red
|
||||
}),
|
||||
);
|
||||
|
||||
if let Some(selected) = self.results_state.selected() {
|
||||
if selected == i {
|
||||
command.style =
|
||||
Style::default().fg(Color::Red).add_modifier(Modifier::BOLD);
|
||||
}
|
||||
}
|
||||
|
||||
let spans = Spans::from(vec![duration, Span::raw(" "), command]);
|
||||
|
||||
ListItem::new(spans)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let results = List::new(results)
|
||||
.block(Block::default().borders(Borders::ALL).title("History"))
|
||||
.start_corner(Corner::BottomLeft)
|
||||
.highlight_symbol(">> ");
|
||||
|
||||
f.render_stateful_widget(results, r, &mut self.results_state);
|
||||
}
|
||||
}
|
||||
|
||||
fn query_results(app: &mut State, db: &mut impl Database) {
|
||||
let results = match app.input.as_str() {
|
||||
"" => db.list(),
|
||||
|
@ -48,7 +122,11 @@ fn key_handler(input: Key, db: &mut impl Database, app: &mut State) -> Option<St
|
|||
Key::Esc | Key::Char('\n') => {
|
||||
let i = app.results_state.selected().unwrap_or(0);
|
||||
|
||||
return Some(app.results.get(i).unwrap().command.clone());
|
||||
return Some(
|
||||
app.results
|
||||
.get(i)
|
||||
.map_or("".to_string(), |h| h.command.clone()),
|
||||
);
|
||||
}
|
||||
Key::Char(c) => {
|
||||
app.input.push(c);
|
||||
|
@ -163,32 +241,8 @@ fn select_history(query: &[String], db: &mut impl Database) -> Result<String> {
|
|||
let help = Text::from(Spans::from(help));
|
||||
let help = Paragraph::new(help);
|
||||
|
||||
let input = Paragraph::new(app.input.as_ref())
|
||||
.block(Block::default().borders(Borders::ALL).title("Search"));
|
||||
|
||||
let results: Vec<ListItem> = app
|
||||
.results
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, m)| {
|
||||
let mut content =
|
||||
Span::raw(m.command.to_string().replace("\n", " ").replace("\t", " "));
|
||||
|
||||
if let Some(selected) = app.results_state.selected() {
|
||||
if selected == i {
|
||||
content.style =
|
||||
Style::default().fg(Color::Red).add_modifier(Modifier::BOLD);
|
||||
}
|
||||
}
|
||||
|
||||
ListItem::new(content)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let results = List::new(results)
|
||||
.block(Block::default().borders(Borders::ALL).title("History"))
|
||||
.start_corner(Corner::BottomLeft)
|
||||
.highlight_symbol(">> ");
|
||||
let input = Paragraph::new(app.input.clone())
|
||||
.block(Block::default().borders(Borders::ALL).title("Query"));
|
||||
|
||||
let stats = Paragraph::new(Text::from(Span::raw(format!(
|
||||
"history count: {}",
|
||||
|
@ -199,8 +253,8 @@ fn select_history(query: &[String], db: &mut impl Database) -> Result<String> {
|
|||
f.render_widget(title, top_left_chunks[0]);
|
||||
f.render_widget(help, top_left_chunks[1]);
|
||||
|
||||
app.render_results(f, chunks[1]);
|
||||
f.render_widget(stats, top_right_chunks[0]);
|
||||
f.render_stateful_widget(results, chunks[1], &mut app.results_state);
|
||||
f.render_widget(input, chunks[2]);
|
||||
|
||||
f.set_cursor(
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use eyre::Result;
|
||||
use structopt::StructOpt;
|
||||
|
||||
use crate::remote::server;
|
||||
use crate::server;
|
||||
use crate::settings::Settings;
|
||||
|
||||
#[derive(StructOpt)]
|
||||
|
@ -20,7 +20,7 @@ pub enum Cmd {
|
|||
}
|
||||
|
||||
impl Cmd {
|
||||
pub fn run(&self, settings: &Settings) -> Result<()> {
|
||||
pub async fn run(&self, settings: &Settings) -> Result<()> {
|
||||
match self {
|
||||
Self::Start { host, port } => {
|
||||
let host = host.as_ref().map_or(
|
||||
|
@ -29,7 +29,7 @@ impl Cmd {
|
|||
);
|
||||
let port = port.map_or(settings.server.port, |p| p);
|
||||
|
||||
server::launch(settings, host, port)
|
||||
server::launch(settings, host, port).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,8 +4,8 @@ use crate::local::database::Database;
|
|||
use crate::local::sync;
|
||||
use crate::settings::Settings;
|
||||
|
||||
pub fn run(settings: &Settings, force: bool, db: &mut impl Database) -> Result<()> {
|
||||
sync::sync(settings, force, db)?;
|
||||
pub async fn run(settings: &Settings, force: bool, db: &mut (impl Database + Send)) -> Result<()> {
|
||||
sync::sync(settings, force, db).await?;
|
||||
println!(
|
||||
"Sync complete! {} items in database, force: {}",
|
||||
db.history_count()?,
|
||||
|
|
|
@ -1,93 +1,94 @@
|
|||
use chrono::Utc;
|
||||
use eyre::Result;
|
||||
use reqwest::header::AUTHORIZATION;
|
||||
use reqwest::header::{HeaderMap, AUTHORIZATION};
|
||||
use reqwest::Url;
|
||||
use sodiumoxide::crypto::secretbox;
|
||||
|
||||
use crate::api::{AddHistoryRequest, CountResponse, ListHistoryResponse};
|
||||
use crate::local::encryption::{decrypt, load_key};
|
||||
use crate::api::{AddHistoryRequest, CountResponse, SyncHistoryResponse};
|
||||
use crate::local::encryption::decrypt;
|
||||
use crate::local::history::History;
|
||||
use crate::settings::Settings;
|
||||
use crate::utils::hash_str;
|
||||
|
||||
pub struct Client<'a> {
|
||||
settings: &'a Settings,
|
||||
sync_addr: &'a str,
|
||||
token: &'a str,
|
||||
key: secretbox::Key,
|
||||
client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl<'a> Client<'a> {
|
||||
pub const fn new(settings: &'a Settings) -> Self {
|
||||
Client { settings }
|
||||
pub fn new(sync_addr: &'a str, token: &'a str, key: secretbox::Key) -> Self {
|
||||
Client {
|
||||
sync_addr,
|
||||
token,
|
||||
key,
|
||||
client: reqwest::Client::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn count(&self) -> Result<i64> {
|
||||
let url = format!("{}/sync/count", self.settings.local.sync_address);
|
||||
let client = reqwest::blocking::Client::new();
|
||||
pub async fn count(&self) -> Result<i64> {
|
||||
let url = format!("{}/sync/count", self.sync_addr);
|
||||
let url = Url::parse(url.as_str())?;
|
||||
let token = format!("Token {}", self.token);
|
||||
let token = token.parse()?;
|
||||
|
||||
let resp = client
|
||||
.get(url)
|
||||
.header(
|
||||
AUTHORIZATION,
|
||||
format!("Token {}", self.settings.local.session_token),
|
||||
)
|
||||
.send()?;
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(AUTHORIZATION, token);
|
||||
|
||||
let count = resp.json::<CountResponse>()?;
|
||||
let resp = self.client.get(url).headers(headers).send().await?;
|
||||
|
||||
let count = resp.json::<CountResponse>().await?;
|
||||
|
||||
Ok(count.count)
|
||||
}
|
||||
|
||||
pub fn get_history(
|
||||
pub async fn get_history(
|
||||
&self,
|
||||
sync_ts: chrono::DateTime<Utc>,
|
||||
history_ts: chrono::DateTime<Utc>,
|
||||
host: Option<String>,
|
||||
) -> Result<Vec<History>> {
|
||||
let key = load_key(self.settings)?;
|
||||
|
||||
let host = match host {
|
||||
None => hash_str(&format!("{}:{}", whoami::hostname(), whoami::username())),
|
||||
Some(h) => h,
|
||||
};
|
||||
|
||||
// this allows for syncing between users on the same machine
|
||||
let url = format!(
|
||||
"{}/sync/history?sync_ts={}&history_ts={}&host={}",
|
||||
self.settings.local.sync_address,
|
||||
sync_ts.to_rfc3339(),
|
||||
history_ts.to_rfc3339(),
|
||||
self.sync_addr,
|
||||
urlencoding::encode(sync_ts.to_rfc3339().as_str()),
|
||||
urlencoding::encode(history_ts.to_rfc3339().as_str()),
|
||||
host,
|
||||
);
|
||||
let client = reqwest::blocking::Client::new();
|
||||
|
||||
let resp = client
|
||||
let resp = self
|
||||
.client
|
||||
.get(url)
|
||||
.header(
|
||||
AUTHORIZATION,
|
||||
format!("Token {}", self.settings.local.session_token),
|
||||
)
|
||||
.send()?;
|
||||
.header(AUTHORIZATION, format!("Token {}", self.token))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let history = resp.json::<ListHistoryResponse>()?;
|
||||
let history = resp.json::<SyncHistoryResponse>().await?;
|
||||
let history = history
|
||||
.history
|
||||
.iter()
|
||||
.map(|h| serde_json::from_str(h).expect("invalid base64"))
|
||||
.map(|h| decrypt(&h, &key).expect("failed to decrypt history! check your key"))
|
||||
.map(|h| decrypt(&h, &self.key).expect("failed to decrypt history! check your key"))
|
||||
.collect();
|
||||
|
||||
Ok(history)
|
||||
}
|
||||
|
||||
pub fn post_history(&self, history: &[AddHistoryRequest]) -> Result<()> {
|
||||
let client = reqwest::blocking::Client::new();
|
||||
pub async fn post_history(&self, history: &[AddHistoryRequest]) -> Result<()> {
|
||||
let url = format!("{}/history", self.sync_addr);
|
||||
let url = Url::parse(url.as_str())?;
|
||||
|
||||
let url = format!("{}/history", self.settings.local.sync_address);
|
||||
client
|
||||
self.client
|
||||
.post(url)
|
||||
.json(history)
|
||||
.header(
|
||||
AUTHORIZATION,
|
||||
format!("Token {}", self.settings.local.session_token),
|
||||
)
|
||||
.send()?;
|
||||
.header(AUTHORIZATION, format!("Token {}", self.token))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -215,9 +215,9 @@ impl Database for Sqlite {
|
|||
}
|
||||
|
||||
fn before(&self, timestamp: chrono::DateTime<Utc>, count: i64) -> Result<Vec<History>> {
|
||||
let mut stmt = self.conn.prepare(
|
||||
"SELECT * FROM history where timestamp <= ? order by timestamp desc limit ?",
|
||||
)?;
|
||||
let mut stmt = self
|
||||
.conn
|
||||
.prepare("SELECT * FROM history where timestamp < ? order by timestamp desc limit ?")?;
|
||||
|
||||
let history_iter = stmt.query_map(params![timestamp.timestamp_nanos(), count], |row| {
|
||||
history_from_sqlite_row(None, row)
|
||||
|
@ -236,7 +236,7 @@ impl Database for Sqlite {
|
|||
|
||||
fn prefix_search(&self, query: &str) -> Result<Vec<History>> {
|
||||
self.query(
|
||||
"select * from history where command like ?1 || '%' order by timestamp asc",
|
||||
"select * from history where command like ?1 || '%' order by timestamp asc limit 1000",
|
||||
&[query],
|
||||
)
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ use std::{fs::File, path::Path};
|
|||
use chrono::prelude::*;
|
||||
use chrono::Utc;
|
||||
use eyre::{eyre, Result};
|
||||
use itertools::Itertools;
|
||||
|
||||
use super::history::History;
|
||||
|
||||
|
@ -42,8 +43,8 @@ impl Zsh {
|
|||
|
||||
fn parse_extended(line: &str, counter: i64) -> History {
|
||||
let line = line.replacen(": ", "", 2);
|
||||
let (time, duration) = line.split_once(':').unwrap();
|
||||
let (duration, command) = duration.split_once(';').unwrap();
|
||||
let (time, duration) = line.splitn(2, ':').collect_tuple().unwrap();
|
||||
let (duration, command) = duration.splitn(2, ';').collect_tuple().unwrap();
|
||||
|
||||
let time = time
|
||||
.parse::<i64>()
|
||||
|
@ -60,7 +61,7 @@ fn parse_extended(line: &str, counter: i64) -> History {
|
|||
time,
|
||||
command.trim_end().to_string(),
|
||||
String::from("unknown"),
|
||||
-1,
|
||||
0, // assume 0, we have no way of knowing :(
|
||||
duration,
|
||||
None,
|
||||
None,
|
||||
|
|
|
@ -20,12 +20,12 @@ use crate::{api::AddHistoryRequest, utils::hash_str};
|
|||
|
||||
// Check if remote has things we don't, and if so, download them.
|
||||
// Returns (num downloaded, total local)
|
||||
fn sync_download(
|
||||
async fn sync_download(
|
||||
force: bool,
|
||||
client: &api_client::Client,
|
||||
db: &mut impl Database,
|
||||
client: &api_client::Client<'_>,
|
||||
db: &mut (impl Database + Send),
|
||||
) -> Result<(i64, i64)> {
|
||||
let remote_count = client.count()?;
|
||||
let remote_count = client.count().await?;
|
||||
|
||||
let initial_local = db.history_count()?;
|
||||
let mut local_count = initial_local;
|
||||
|
@ -41,7 +41,9 @@ fn sync_download(
|
|||
let host = if force { Some(String::from("")) } else { None };
|
||||
|
||||
while remote_count > local_count {
|
||||
let page = client.get_history(last_sync, last_timestamp, host.clone())?;
|
||||
let page = client
|
||||
.get_history(last_sync, last_timestamp, host.clone())
|
||||
.await?;
|
||||
|
||||
if page.len() < HISTORY_PAGE_SIZE.try_into().unwrap() {
|
||||
break;
|
||||
|
@ -71,13 +73,13 @@ fn sync_download(
|
|||
}
|
||||
|
||||
// Check if we have things remote doesn't, and if so, upload them
|
||||
fn sync_upload(
|
||||
async fn sync_upload(
|
||||
settings: &Settings,
|
||||
_force: bool,
|
||||
client: &api_client::Client,
|
||||
db: &mut impl Database,
|
||||
client: &api_client::Client<'_>,
|
||||
db: &mut (impl Database + Send),
|
||||
) -> Result<()> {
|
||||
let initial_remote_count = client.count()?;
|
||||
let initial_remote_count = client.count().await?;
|
||||
let mut remote_count = initial_remote_count;
|
||||
|
||||
let local_count = db.history_count()?;
|
||||
|
@ -111,21 +113,25 @@ fn sync_upload(
|
|||
}
|
||||
|
||||
// anything left over outside of the 100 block size
|
||||
client.post_history(&buffer)?;
|
||||
client.post_history(&buffer).await?;
|
||||
cursor = buffer.last().unwrap().timestamp;
|
||||
|
||||
remote_count = client.count()?;
|
||||
remote_count = client.count().await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn sync(settings: &Settings, force: bool, db: &mut impl Database) -> Result<()> {
|
||||
let client = api_client::Client::new(settings);
|
||||
pub async fn sync(settings: &Settings, force: bool, db: &mut (impl Database + Send)) -> Result<()> {
|
||||
let client = api_client::Client::new(
|
||||
settings.local.sync_address.as_str(),
|
||||
settings.local.session_token.as_str(),
|
||||
load_key(settings)?,
|
||||
);
|
||||
|
||||
sync_upload(settings, force, &client, db)?;
|
||||
sync_upload(settings, force, &client, db).await?;
|
||||
|
||||
let download = sync_download(force, &client, db)?;
|
||||
let download = sync_download(force, &client, db).await?;
|
||||
|
||||
debug!("sync downloaded {}", download.0);
|
||||
|
||||
|
|
43
src/main.rs
43
src/main.rs
|
@ -1,32 +1,19 @@
|
|||
#![feature(proc_macro_hygiene)]
|
||||
#![feature(decl_macro)]
|
||||
#![warn(clippy::pedantic, clippy::nursery)]
|
||||
#![allow(clippy::use_self)] // not 100% reliable
|
||||
|
||||
use std::path::PathBuf;
|
||||
|
||||
use eyre::{eyre, Result};
|
||||
use fern::colors::{Color, ColoredLevelConfig};
|
||||
use human_panic::setup_panic;
|
||||
use structopt::{clap::AppSettings, StructOpt};
|
||||
|
||||
#[macro_use]
|
||||
extern crate log;
|
||||
|
||||
#[macro_use]
|
||||
extern crate rocket;
|
||||
|
||||
#[macro_use]
|
||||
extern crate serde_derive;
|
||||
|
||||
#[macro_use]
|
||||
extern crate diesel;
|
||||
|
||||
#[macro_use]
|
||||
extern crate diesel_migrations;
|
||||
|
||||
#[macro_use]
|
||||
extern crate rocket_contrib;
|
||||
|
||||
use command::AtuinCmd;
|
||||
use local::database::Sqlite;
|
||||
use settings::Settings;
|
||||
|
@ -34,12 +21,10 @@ use settings::Settings;
|
|||
mod api;
|
||||
mod command;
|
||||
mod local;
|
||||
mod remote;
|
||||
mod server;
|
||||
mod settings;
|
||||
mod utils;
|
||||
|
||||
pub mod schema;
|
||||
|
||||
#[derive(StructOpt)]
|
||||
#[structopt(
|
||||
author = "Ellie Huxtable <e@elm.sh>",
|
||||
|
@ -56,7 +41,7 @@ struct Atuin {
|
|||
}
|
||||
|
||||
impl Atuin {
|
||||
fn run(self, settings: &Settings) -> Result<()> {
|
||||
async fn run(self, settings: &Settings) -> Result<()> {
|
||||
let db_path = if let Some(db_path) = self.db {
|
||||
let path = db_path
|
||||
.to_str()
|
||||
|
@ -69,26 +54,32 @@ impl Atuin {
|
|||
|
||||
let mut db = Sqlite::new(db_path)?;
|
||||
|
||||
self.atuin.run(&mut db, settings)
|
||||
self.atuin.run(&mut db, settings).await
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
setup_panic!();
|
||||
let settings = Settings::new()?;
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
let colors = ColoredLevelConfig::new()
|
||||
.warn(Color::Yellow)
|
||||
.error(Color::Red);
|
||||
|
||||
fern::Dispatch::new()
|
||||
.format(|out, message, record| {
|
||||
.format(move |out, message, record| {
|
||||
out.finish(format_args!(
|
||||
"{} [{}] {}",
|
||||
chrono::Local::now().format("[%Y-%m-%d][%H:%M:%S]"),
|
||||
record.level(),
|
||||
chrono::Local::now().to_rfc3339(),
|
||||
colors.color(record.level()),
|
||||
message
|
||||
))
|
||||
})
|
||||
.level(log::LevelFilter::Info)
|
||||
.level_for("sqlx", log::LevelFilter::Warn)
|
||||
.chain(std::io::stdout())
|
||||
.apply()?;
|
||||
|
||||
Atuin::from_args().run(&settings)
|
||||
let settings = Settings::new()?;
|
||||
setup_panic!();
|
||||
|
||||
Atuin::from_args().run(&settings).await
|
||||
}
|
||||
|
|
|
@ -1,22 +0,0 @@
|
|||
use diesel::pg::PgConnection;
|
||||
use diesel::prelude::*;
|
||||
use eyre::{eyre, Result};
|
||||
|
||||
use crate::settings::Settings;
|
||||
|
||||
#[database("atuin")]
|
||||
pub struct AtuinDbConn(diesel::PgConnection);
|
||||
|
||||
// TODO: connection pooling
|
||||
pub fn establish_connection(settings: &Settings) -> Result<PgConnection> {
|
||||
if settings.server.db_uri == "default_uri" {
|
||||
Err(eyre!(
|
||||
"Please configure your database! Set db_uri in config.toml"
|
||||
))
|
||||
} else {
|
||||
let database_url = &settings.server.db_uri;
|
||||
let conn = PgConnection::establish(database_url)?;
|
||||
|
||||
Ok(conn)
|
||||
}
|
||||
}
|
|
@ -1,5 +0,0 @@
|
|||
pub mod auth;
|
||||
pub mod database;
|
||||
pub mod models;
|
||||
pub mod server;
|
||||
pub mod views;
|
|
@ -1,61 +0,0 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use crate::remote::database::establish_connection;
|
||||
use crate::settings::Settings;
|
||||
|
||||
use super::database::AtuinDbConn;
|
||||
|
||||
use eyre::Result;
|
||||
use rocket::config::{Config, Environment, LoggingLevel, Value};
|
||||
|
||||
// a bunch of these imports are generated by macros, it's easier to wildcard
|
||||
#[allow(clippy::clippy::wildcard_imports)]
|
||||
use super::views::*;
|
||||
|
||||
#[allow(clippy::clippy::wildcard_imports)]
|
||||
use super::auth::*;
|
||||
|
||||
embed_migrations!("migrations");
|
||||
|
||||
pub fn launch(settings: &Settings, host: String, port: u16) -> Result<()> {
|
||||
let settings: Settings = settings.clone(); // clone so rocket can manage it
|
||||
|
||||
let mut database_config = HashMap::new();
|
||||
let mut databases = HashMap::new();
|
||||
|
||||
database_config.insert("url", Value::from(settings.server.db_uri.clone()));
|
||||
databases.insert("atuin", Value::from(database_config));
|
||||
|
||||
let connection = establish_connection(&settings)?;
|
||||
|
||||
embedded_migrations::run(&connection).expect("failed to run migrations");
|
||||
|
||||
let config = Config::build(Environment::Production)
|
||||
.address(host)
|
||||
.log_level(LoggingLevel::Normal)
|
||||
.port(port)
|
||||
.extra("databases", databases)
|
||||
.finalize()
|
||||
.unwrap();
|
||||
|
||||
let app = rocket::custom(config);
|
||||
|
||||
app.mount(
|
||||
"/",
|
||||
routes![
|
||||
index,
|
||||
register,
|
||||
add_history,
|
||||
login,
|
||||
get_user,
|
||||
sync_count,
|
||||
sync_list
|
||||
],
|
||||
)
|
||||
.manage(settings)
|
||||
.attach(AtuinDbConn::fairing())
|
||||
.register(catchers![internal_error, bad_request])
|
||||
.launch();
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -1,185 +0,0 @@
|
|||
use chrono::Utc;
|
||||
use rocket::http::uri::Uri;
|
||||
use rocket::http::RawStr;
|
||||
use rocket::http::{ContentType, Status};
|
||||
use rocket::request::FromFormValue;
|
||||
use rocket::request::Request;
|
||||
use rocket::response;
|
||||
use rocket::response::{Responder, Response};
|
||||
use rocket_contrib::databases::diesel;
|
||||
use rocket_contrib::json::{Json, JsonValue};
|
||||
|
||||
use self::diesel::prelude::*;
|
||||
|
||||
use crate::api::AddHistoryRequest;
|
||||
use crate::schema::history;
|
||||
use crate::settings::HISTORY_PAGE_SIZE;
|
||||
|
||||
use super::database::AtuinDbConn;
|
||||
use super::models::{History, NewHistory, User};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ApiResponse {
|
||||
pub json: JsonValue,
|
||||
pub status: Status,
|
||||
}
|
||||
|
||||
impl<'r> Responder<'r> for ApiResponse {
|
||||
fn respond_to(self, req: &Request) -> response::Result<'r> {
|
||||
Response::build_from(self.json.respond_to(req).unwrap())
|
||||
.status(self.status)
|
||||
.header(ContentType::JSON)
|
||||
.ok()
|
||||
}
|
||||
}
|
||||
|
||||
#[get("/")]
|
||||
pub const fn index() -> &'static str {
|
||||
"\"Through the fathomless deeps of space swims the star turtle Great A\u{2019}Tuin, bearing on its back the four giant elephants who carry on their shoulders the mass of the Discworld.\"\n\t-- Sir Terry Pratchett"
|
||||
}
|
||||
|
||||
#[catch(500)]
|
||||
pub fn internal_error(_req: &Request) -> ApiResponse {
|
||||
ApiResponse {
|
||||
status: Status::InternalServerError,
|
||||
json: json!({"status": "error", "message": "an internal server error has occured"}),
|
||||
}
|
||||
}
|
||||
|
||||
#[catch(400)]
|
||||
pub fn bad_request(_req: &Request) -> ApiResponse {
|
||||
ApiResponse {
|
||||
status: Status::InternalServerError,
|
||||
json: json!({"status": "error", "message": "bad request. don't do that."}),
|
||||
}
|
||||
}
|
||||
|
||||
#[post("/history", data = "<add_history>")]
|
||||
#[allow(
|
||||
clippy::clippy::cast_sign_loss,
|
||||
clippy::cast_possible_truncation,
|
||||
clippy::clippy::needless_pass_by_value
|
||||
)]
|
||||
pub fn add_history(
|
||||
conn: AtuinDbConn,
|
||||
user: User,
|
||||
add_history: Json<Vec<AddHistoryRequest>>,
|
||||
) -> ApiResponse {
|
||||
let new_history: Vec<NewHistory> = add_history
|
||||
.iter()
|
||||
.map(|h| NewHistory {
|
||||
client_id: h.id.as_str(),
|
||||
hostname: h.hostname.to_string(),
|
||||
user_id: user.id,
|
||||
timestamp: h.timestamp.naive_utc(),
|
||||
data: h.data.as_str(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
match diesel::insert_into(history::table)
|
||||
.values(&new_history)
|
||||
.on_conflict_do_nothing()
|
||||
.execute(&*conn)
|
||||
{
|
||||
Ok(_) => ApiResponse {
|
||||
status: Status::Ok,
|
||||
json: json!({"status": "ok", "message": "history added"}),
|
||||
},
|
||||
Err(_) => ApiResponse {
|
||||
status: Status::BadRequest,
|
||||
json: json!({"status": "error", "message": "failed to add history"}),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[get("/sync/count")]
|
||||
#[allow(clippy::wildcard_imports, clippy::needless_pass_by_value)]
|
||||
pub fn sync_count(conn: AtuinDbConn, user: User) -> ApiResponse {
|
||||
use crate::schema::history::dsl::*;
|
||||
|
||||
// we need to return the number of history items we have for this user
|
||||
// in the future I'd like to use something like a merkel tree to calculate
|
||||
// which day specifically needs syncing
|
||||
let count = history
|
||||
.filter(user_id.eq(user.id))
|
||||
.count()
|
||||
.first::<i64>(&*conn);
|
||||
|
||||
if count.is_err() {
|
||||
error!("failed to count: {}", count.err().unwrap());
|
||||
|
||||
return ApiResponse {
|
||||
json: json!({"message": "internal server error"}),
|
||||
status: Status::InternalServerError,
|
||||
};
|
||||
}
|
||||
|
||||
ApiResponse {
|
||||
status: Status::Ok,
|
||||
json: json!({"count": count.ok()}),
|
||||
}
|
||||
}
|
||||
|
||||
pub struct UtcDateTime(chrono::DateTime<Utc>);
|
||||
|
||||
impl<'v> FromFormValue<'v> for UtcDateTime {
|
||||
type Error = &'v RawStr;
|
||||
|
||||
fn from_form_value(form_value: &'v RawStr) -> Result<UtcDateTime, &'v RawStr> {
|
||||
let time = Uri::percent_decode(form_value.as_bytes()).map_err(|_| form_value)?;
|
||||
let time = time.to_string();
|
||||
|
||||
match chrono::DateTime::parse_from_rfc3339(time.as_str()) {
|
||||
Ok(t) => Ok(UtcDateTime(t.with_timezone(&Utc))),
|
||||
Err(e) => {
|
||||
error!("failed to parse time {}, got: {}", time, e);
|
||||
Err(form_value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Request a list of all history items added to the DB after a given timestamp.
|
||||
// Provide the current hostname, so that we don't send the client data that
|
||||
// originated from them
|
||||
#[get("/sync/history?<sync_ts>&<history_ts>&<host>")]
|
||||
#[allow(clippy::wildcard_imports, clippy::needless_pass_by_value)]
|
||||
pub fn sync_list(
|
||||
conn: AtuinDbConn,
|
||||
user: User,
|
||||
sync_ts: UtcDateTime,
|
||||
history_ts: UtcDateTime,
|
||||
host: String,
|
||||
) -> ApiResponse {
|
||||
use crate::schema::history::dsl::*;
|
||||
|
||||
// we need to return the number of history items we have for this user
|
||||
// in the future I'd like to use something like a merkel tree to calculate
|
||||
// which day specifically needs syncing
|
||||
// TODO: Allow for configuring the page size, both from params, and setting
|
||||
// the max in config. 100 is fine for now.
|
||||
let h = history
|
||||
.filter(user_id.eq(user.id))
|
||||
.filter(hostname.ne(host))
|
||||
.filter(created_at.ge(sync_ts.0.naive_utc()))
|
||||
.filter(timestamp.ge(history_ts.0.naive_utc()))
|
||||
.order(timestamp.asc())
|
||||
.limit(HISTORY_PAGE_SIZE)
|
||||
.load::<History>(&*conn);
|
||||
|
||||
if let Err(e) = h {
|
||||
error!("failed to load history: {}", e);
|
||||
|
||||
return ApiResponse {
|
||||
json: json!({"message": "internal server error"}),
|
||||
status: Status::InternalServerError,
|
||||
};
|
||||
}
|
||||
|
||||
let user_data: Vec<String> = h.unwrap().iter().map(|i| i.data.to_string()).collect();
|
||||
|
||||
ApiResponse {
|
||||
status: Status::Ok,
|
||||
json: json!({ "history": user_data }),
|
||||
}
|
||||
}
|
|
@ -1,30 +0,0 @@
|
|||
table! {
|
||||
history (id) {
|
||||
id -> Int8,
|
||||
client_id -> Text,
|
||||
user_id -> Int8,
|
||||
hostname -> Text,
|
||||
timestamp -> Timestamp,
|
||||
data -> Varchar,
|
||||
created_at -> Timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
table! {
|
||||
sessions (id) {
|
||||
id -> Int8,
|
||||
user_id -> Int8,
|
||||
token -> Varchar,
|
||||
}
|
||||
}
|
||||
|
||||
table! {
|
||||
users (id) {
|
||||
id -> Int8,
|
||||
username -> Varchar,
|
||||
email -> Varchar,
|
||||
password -> Varchar,
|
||||
}
|
||||
}
|
||||
|
||||
allow_tables_to_appear_in_same_query!(history, sessions, users,);
|
|
@ -1,3 +1,4 @@
|
|||
/*
|
||||
use self::diesel::prelude::*;
|
||||
use eyre::Result;
|
||||
use rocket::http::Status;
|
||||
|
@ -218,3 +219,4 @@ pub fn login(conn: AtuinDbConn, login: Json<LoginRequest>) -> ApiResponse {
|
|||
json: json!({"session": session.token}),
|
||||
}
|
||||
}
|
||||
*/
|
202
src/server/database.rs
Normal file
202
src/server/database.rs
Normal file
|
@ -0,0 +1,202 @@
|
|||
use async_trait::async_trait;
|
||||
|
||||
use eyre::{eyre, Result};
|
||||
use sqlx::postgres::PgPoolOptions;
|
||||
|
||||
use crate::settings::HISTORY_PAGE_SIZE;
|
||||
|
||||
use super::models::{History, NewHistory, NewSession, NewUser, Session, User};
|
||||
|
||||
#[async_trait]
|
||||
pub trait Database {
|
||||
async fn get_session(&self, token: &str) -> Result<Session>;
|
||||
async fn get_session_user(&self, token: &str) -> Result<User>;
|
||||
async fn add_session(&self, session: &NewSession) -> Result<()>;
|
||||
|
||||
async fn get_user(&self, username: String) -> Result<User>;
|
||||
async fn get_user_session(&self, u: &User) -> Result<Session>;
|
||||
async fn add_user(&self, user: NewUser) -> Result<i64>;
|
||||
|
||||
async fn count_history(&self, user: &User) -> Result<i64>;
|
||||
async fn list_history(
|
||||
&self,
|
||||
user: &User,
|
||||
created_since: chrono::NaiveDateTime,
|
||||
since: chrono::NaiveDateTime,
|
||||
host: String,
|
||||
) -> Result<Vec<History>>;
|
||||
async fn add_history(&self, history: &[NewHistory]) -> Result<()>;
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Postgres {
|
||||
pool: sqlx::Pool<sqlx::postgres::Postgres>,
|
||||
}
|
||||
|
||||
impl Postgres {
|
||||
pub async fn new(uri: &str) -> Result<Self, sqlx::Error> {
|
||||
let pool = PgPoolOptions::new()
|
||||
.max_connections(100)
|
||||
.connect(uri)
|
||||
.await?;
|
||||
|
||||
Ok(Self { pool })
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Database for Postgres {
|
||||
async fn get_session(&self, token: &str) -> Result<Session> {
|
||||
let res: Option<Session> =
|
||||
sqlx::query_as::<_, Session>("select * from sessions where token = $1")
|
||||
.bind(token)
|
||||
.fetch_optional(&self.pool)
|
||||
.await?;
|
||||
|
||||
if let Some(s) = res {
|
||||
Ok(s)
|
||||
} else {
|
||||
Err(eyre!("could not find session"))
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_user(&self, username: String) -> Result<User> {
|
||||
let res: Option<User> =
|
||||
sqlx::query_as::<_, User>("select * from users where username = $1")
|
||||
.bind(username)
|
||||
.fetch_optional(&self.pool)
|
||||
.await?;
|
||||
|
||||
if let Some(u) = res {
|
||||
Ok(u)
|
||||
} else {
|
||||
Err(eyre!("could not find user"))
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_session_user(&self, token: &str) -> Result<User> {
|
||||
let res: Option<User> = sqlx::query_as::<_, User>(
|
||||
"select * from users
|
||||
inner join sessions
|
||||
on users.id = sessions.user_id
|
||||
and sessions.token = $1",
|
||||
)
|
||||
.bind(token)
|
||||
.fetch_optional(&self.pool)
|
||||
.await?;
|
||||
|
||||
if let Some(u) = res {
|
||||
Ok(u)
|
||||
} else {
|
||||
Err(eyre!("could not find user"))
|
||||
}
|
||||
}
|
||||
|
||||
async fn count_history(&self, user: &User) -> Result<i64> {
|
||||
let res: (i64,) = sqlx::query_as(
|
||||
"select count(1) from history
|
||||
where user_id = $1",
|
||||
)
|
||||
.bind(user.id)
|
||||
.fetch_one(&self.pool)
|
||||
.await?;
|
||||
|
||||
Ok(res.0)
|
||||
}
|
||||
|
||||
async fn list_history(
|
||||
&self,
|
||||
user: &User,
|
||||
created_since: chrono::NaiveDateTime,
|
||||
since: chrono::NaiveDateTime,
|
||||
host: String,
|
||||
) -> Result<Vec<History>> {
|
||||
let res = sqlx::query_as::<_, History>(
|
||||
"select * from history
|
||||
where user_id = $1
|
||||
and hostname != $2
|
||||
and created_at >= $3
|
||||
and timestamp >= $4
|
||||
order by timestamp asc
|
||||
limit $5",
|
||||
)
|
||||
.bind(user.id)
|
||||
.bind(host)
|
||||
.bind(created_since)
|
||||
.bind(since)
|
||||
.bind(HISTORY_PAGE_SIZE)
|
||||
.fetch_all(&self.pool)
|
||||
.await?;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
async fn add_history(&self, history: &[NewHistory]) -> Result<()> {
|
||||
let mut tx = self.pool.begin().await?;
|
||||
|
||||
for i in history {
|
||||
sqlx::query(
|
||||
"insert into history
|
||||
(client_id, user_id, hostname, timestamp, data)
|
||||
values ($1, $2, $3, $4, $5)
|
||||
on conflict do nothing
|
||||
",
|
||||
)
|
||||
.bind(i.client_id)
|
||||
.bind(i.user_id)
|
||||
.bind(i.hostname)
|
||||
.bind(i.timestamp)
|
||||
.bind(i.data)
|
||||
.execute(&mut tx)
|
||||
.await?;
|
||||
}
|
||||
|
||||
tx.commit().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn add_user(&self, user: NewUser) -> Result<i64> {
|
||||
let res: (i64,) = sqlx::query_as(
|
||||
"insert into users
|
||||
(username, email, password)
|
||||
values($1, $2, $3)
|
||||
returning id",
|
||||
)
|
||||
.bind(user.username.as_str())
|
||||
.bind(user.email.as_str())
|
||||
.bind(user.password)
|
||||
.fetch_one(&self.pool)
|
||||
.await?;
|
||||
|
||||
Ok(res.0)
|
||||
}
|
||||
|
||||
async fn add_session(&self, session: &NewSession) -> Result<()> {
|
||||
sqlx::query(
|
||||
"insert into sessions
|
||||
(user_id, token)
|
||||
values($1, $2)",
|
||||
)
|
||||
.bind(session.user_id)
|
||||
.bind(session.token)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_user_session(&self, u: &User) -> Result<Session> {
|
||||
let res: Option<Session> =
|
||||
sqlx::query_as::<_, Session>("select * from sessions where user_id = $1")
|
||||
.bind(u.id)
|
||||
.fetch_optional(&self.pool)
|
||||
.await?;
|
||||
|
||||
if let Some(s) = res {
|
||||
Ok(s)
|
||||
} else {
|
||||
Err(eyre!("could not find session"))
|
||||
}
|
||||
}
|
||||
}
|
89
src/server/handlers/history.rs
Normal file
89
src/server/handlers/history.rs
Normal file
|
@ -0,0 +1,89 @@
|
|||
use std::convert::Infallible;
|
||||
|
||||
use warp::{http::StatusCode, reply::json};
|
||||
|
||||
use crate::api::{
|
||||
AddHistoryRequest, CountResponse, ErrorResponse, SyncHistoryRequest, SyncHistoryResponse,
|
||||
};
|
||||
use crate::server::database::Database;
|
||||
use crate::server::models::{NewHistory, User};
|
||||
|
||||
pub async fn count(
|
||||
user: User,
|
||||
db: impl Database + Clone + Send + Sync,
|
||||
) -> Result<Box<dyn warp::Reply>, Infallible> {
|
||||
db.count_history(&user).await.map_or(
|
||||
Ok(Box::new(ErrorResponse::reply(
|
||||
"failed to query history count",
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
))),
|
||||
|count| Ok(Box::new(json(&CountResponse { count }))),
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn list(
|
||||
req: SyncHistoryRequest,
|
||||
user: User,
|
||||
db: impl Database + Clone + Send + Sync,
|
||||
) -> Result<Box<dyn warp::Reply>, Infallible> {
|
||||
let history = db
|
||||
.list_history(
|
||||
&user,
|
||||
req.sync_ts.naive_utc(),
|
||||
req.history_ts.naive_utc(),
|
||||
req.host,
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Err(e) = history {
|
||||
error!("failed to load history: {}", e);
|
||||
let resp =
|
||||
ErrorResponse::reply("failed to load history", StatusCode::INTERNAL_SERVER_ERROR);
|
||||
let resp = Box::new(resp);
|
||||
return Ok(resp);
|
||||
}
|
||||
|
||||
let history: Vec<String> = history
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|i| i.data.to_string())
|
||||
.collect();
|
||||
|
||||
debug!(
|
||||
"loaded {} items of history for user {}",
|
||||
history.len(),
|
||||
user.id
|
||||
);
|
||||
|
||||
Ok(Box::new(json(&SyncHistoryResponse { history })))
|
||||
}
|
||||
|
||||
pub async fn add(
|
||||
req: Vec<AddHistoryRequest>,
|
||||
user: User,
|
||||
db: impl Database + Clone + Send + Sync,
|
||||
) -> Result<Box<dyn warp::Reply>, Infallible> {
|
||||
debug!("request to add {} history items", req.len());
|
||||
|
||||
let history: Vec<NewHistory> = req
|
||||
.iter()
|
||||
.map(|h| NewHistory {
|
||||
client_id: h.id.as_str(),
|
||||
user_id: user.id,
|
||||
hostname: h.hostname.as_str(),
|
||||
timestamp: h.timestamp.naive_utc(),
|
||||
data: h.data.as_str(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
if let Err(e) = db.add_history(&history).await {
|
||||
error!("failed to add history: {}", e);
|
||||
|
||||
return Ok(Box::new(ErrorResponse::reply(
|
||||
"failed to add history",
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
)));
|
||||
};
|
||||
|
||||
Ok(Box::new(warp::reply()))
|
||||
}
|
6
src/server/handlers/mod.rs
Normal file
6
src/server/handlers/mod.rs
Normal file
|
@ -0,0 +1,6 @@
|
|||
pub mod history;
|
||||
pub mod user;
|
||||
|
||||
pub const fn index() -> &'static str {
|
||||
"\"Through the fathomless deeps of space swims the star turtle Great A\u{2019}Tuin, bearing on its back the four giant elephants who carry on their shoulders the mass of the Discworld.\"\n\t-- Sir Terry Pratchett"
|
||||
}
|
140
src/server/handlers/user.rs
Normal file
140
src/server/handlers/user.rs
Normal file
|
@ -0,0 +1,140 @@
|
|||
use std::convert::Infallible;
|
||||
|
||||
use sodiumoxide::crypto::pwhash::argon2id13;
|
||||
use uuid::Uuid;
|
||||
use warp::http::StatusCode;
|
||||
use warp::reply::json;
|
||||
|
||||
use crate::api::{
|
||||
ErrorResponse, LoginRequest, LoginResponse, RegisterRequest, RegisterResponse, UserResponse,
|
||||
};
|
||||
use crate::server::database::Database;
|
||||
use crate::server::models::{NewSession, NewUser};
|
||||
use crate::settings::Settings;
|
||||
use crate::utils::hash_secret;
|
||||
|
||||
pub fn verify_str(secret: &str, verify: &str) -> bool {
|
||||
sodiumoxide::init().unwrap();
|
||||
|
||||
let mut padded = [0_u8; 128];
|
||||
secret.as_bytes().iter().enumerate().for_each(|(i, val)| {
|
||||
padded[i] = *val;
|
||||
});
|
||||
|
||||
match argon2id13::HashedPassword::from_slice(&padded) {
|
||||
Some(hp) => argon2id13::pwhash_verify(&hp, verify.as_bytes()),
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get(
|
||||
username: String,
|
||||
db: impl Database + Clone + Send + Sync,
|
||||
) -> Result<Box<dyn warp::Reply>, Infallible> {
|
||||
let user = match db.get_user(username).await {
|
||||
Ok(user) => user,
|
||||
Err(e) => {
|
||||
debug!("user not found: {}", e);
|
||||
return Ok(Box::new(ErrorResponse::reply(
|
||||
"user not found",
|
||||
StatusCode::NOT_FOUND,
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Box::new(warp::reply::json(&UserResponse {
|
||||
username: user.username,
|
||||
})))
|
||||
}
|
||||
|
||||
pub async fn register(
|
||||
register: RegisterRequest,
|
||||
settings: Settings,
|
||||
db: impl Database + Clone + Send + Sync,
|
||||
) -> Result<Box<dyn warp::Reply>, Infallible> {
|
||||
if !settings.server.open_registration {
|
||||
return Ok(Box::new(ErrorResponse::reply(
|
||||
"this server is not open for registrations",
|
||||
StatusCode::BAD_REQUEST,
|
||||
)));
|
||||
}
|
||||
|
||||
let hashed = hash_secret(register.password.as_str());
|
||||
|
||||
let new_user = NewUser {
|
||||
email: register.email,
|
||||
username: register.username,
|
||||
password: hashed,
|
||||
};
|
||||
|
||||
let user_id = match db.add_user(new_user).await {
|
||||
Ok(id) => id,
|
||||
Err(e) => {
|
||||
error!("failed to add user: {}", e);
|
||||
return Ok(Box::new(ErrorResponse::reply(
|
||||
"failed to add user",
|
||||
StatusCode::BAD_REQUEST,
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
let token = Uuid::new_v4().to_simple().to_string();
|
||||
|
||||
let new_session = NewSession {
|
||||
user_id,
|
||||
token: token.as_str(),
|
||||
};
|
||||
|
||||
match db.add_session(&new_session).await {
|
||||
Ok(_) => Ok(Box::new(json(&RegisterResponse { session: token }))),
|
||||
Err(e) => {
|
||||
error!("failed to add session: {}", e);
|
||||
Ok(Box::new(ErrorResponse::reply(
|
||||
"failed to register user",
|
||||
StatusCode::BAD_REQUEST,
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn login(
|
||||
login: LoginRequest,
|
||||
db: impl Database + Clone + Send + Sync,
|
||||
) -> Result<Box<dyn warp::Reply>, Infallible> {
|
||||
let user = match db.get_user(login.username.clone()).await {
|
||||
Ok(u) => u,
|
||||
Err(e) => {
|
||||
error!("failed to get user {}: {}", login.username.clone(), e);
|
||||
|
||||
return Ok(Box::new(ErrorResponse::reply(
|
||||
"user not found",
|
||||
StatusCode::NOT_FOUND,
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
let session = match db.get_user_session(&user).await {
|
||||
Ok(u) => u,
|
||||
Err(e) => {
|
||||
error!("failed to get session for {}: {}", login.username, e);
|
||||
|
||||
return Ok(Box::new(ErrorResponse::reply(
|
||||
"user not found",
|
||||
StatusCode::NOT_FOUND,
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
let verified = verify_str(user.password.as_str(), login.password.as_str());
|
||||
|
||||
if !verified {
|
||||
return Ok(Box::new(ErrorResponse::reply(
|
||||
"user not found",
|
||||
StatusCode::NOT_FOUND,
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(Box::new(warp::reply::json(&LoginResponse {
|
||||
session: session.token,
|
||||
})))
|
||||
}
|
23
src/server/mod.rs
Normal file
23
src/server/mod.rs
Normal file
|
@ -0,0 +1,23 @@
|
|||
use std::net::IpAddr;
|
||||
|
||||
use eyre::Result;
|
||||
|
||||
use crate::settings::Settings;
|
||||
|
||||
pub mod auth;
|
||||
pub mod database;
|
||||
pub mod handlers;
|
||||
pub mod models;
|
||||
pub mod router;
|
||||
|
||||
pub async fn launch(settings: &Settings, host: String, port: u16) -> Result<()> {
|
||||
// routes to run:
|
||||
// index, register, add_history, login, get_user, sync_count, sync_list
|
||||
let host = host.parse::<IpAddr>()?;
|
||||
|
||||
let r = router::router(settings).await?;
|
||||
|
||||
warp::serve(r).run((host, port)).await;
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -1,10 +1,6 @@
|
|||
use chrono::prelude::*;
|
||||
|
||||
use crate::schema::{history, sessions, users};
|
||||
|
||||
#[derive(Deserialize, Serialize, Identifiable, Queryable, Associations)]
|
||||
#[table_name = "history"]
|
||||
#[belongs_to(User)]
|
||||
#[derive(sqlx::FromRow)]
|
||||
pub struct History {
|
||||
pub id: i64,
|
||||
pub client_id: String, // a client generated ID
|
||||
|
@ -17,7 +13,16 @@ pub struct History {
|
|||
pub created_at: NaiveDateTime,
|
||||
}
|
||||
|
||||
#[derive(Identifiable, Queryable, Associations)]
|
||||
pub struct NewHistory<'a> {
|
||||
pub client_id: &'a str,
|
||||
pub user_id: i64,
|
||||
pub hostname: &'a str,
|
||||
pub timestamp: chrono::NaiveDateTime,
|
||||
|
||||
pub data: &'a str,
|
||||
}
|
||||
|
||||
#[derive(sqlx::FromRow)]
|
||||
pub struct User {
|
||||
pub id: i64,
|
||||
pub username: String,
|
||||
|
@ -25,35 +30,19 @@ pub struct User {
|
|||
pub password: String,
|
||||
}
|
||||
|
||||
#[derive(Queryable, Identifiable, Associations)]
|
||||
#[belongs_to(User)]
|
||||
#[derive(sqlx::FromRow)]
|
||||
pub struct Session {
|
||||
pub id: i64,
|
||||
pub user_id: i64,
|
||||
pub token: String,
|
||||
}
|
||||
|
||||
#[derive(Insertable)]
|
||||
#[table_name = "history"]
|
||||
pub struct NewHistory<'a> {
|
||||
pub client_id: &'a str,
|
||||
pub user_id: i64,
|
||||
pub hostname: String,
|
||||
pub timestamp: chrono::NaiveDateTime,
|
||||
|
||||
pub data: &'a str,
|
||||
pub struct NewUser {
|
||||
pub username: String,
|
||||
pub email: String,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
#[derive(Insertable)]
|
||||
#[table_name = "users"]
|
||||
pub struct NewUser<'a> {
|
||||
pub username: &'a str,
|
||||
pub email: &'a str,
|
||||
pub password: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Insertable)]
|
||||
#[table_name = "sessions"]
|
||||
pub struct NewSession<'a> {
|
||||
pub user_id: i64,
|
||||
pub token: &'a str,
|
121
src/server/router.rs
Normal file
121
src/server/router.rs
Normal file
|
@ -0,0 +1,121 @@
|
|||
use std::convert::Infallible;
|
||||
|
||||
use eyre::Result;
|
||||
use warp::Filter;
|
||||
|
||||
use super::handlers;
|
||||
use super::{database::Database, database::Postgres};
|
||||
use crate::server::models::User;
|
||||
use crate::{api::SyncHistoryRequest, settings::Settings};
|
||||
|
||||
fn with_settings(
|
||||
settings: Settings,
|
||||
) -> impl Filter<Extract = (Settings,), Error = Infallible> + Clone {
|
||||
warp::any().map(move || settings.clone())
|
||||
}
|
||||
|
||||
fn with_db(
|
||||
db: impl Database + Clone + Send + Sync,
|
||||
) -> impl Filter<Extract = (impl Database + Clone,), Error = Infallible> + Clone {
|
||||
warp::any().map(move || db.clone())
|
||||
}
|
||||
|
||||
fn with_user(
|
||||
postgres: Postgres,
|
||||
) -> impl Filter<Extract = (User,), Error = warp::Rejection> + Clone {
|
||||
warp::header::<String>("authorization").and_then(move |header: String| {
|
||||
// async closures are still buggy :(
|
||||
let postgres = postgres.clone();
|
||||
|
||||
async move {
|
||||
let header: Vec<&str> = header.split(' ').collect();
|
||||
|
||||
let token;
|
||||
|
||||
if header.len() == 2 {
|
||||
if header[0] != "Token" {
|
||||
return Err(warp::reject());
|
||||
}
|
||||
|
||||
token = header[1];
|
||||
} else {
|
||||
return Err(warp::reject());
|
||||
}
|
||||
|
||||
let user = postgres
|
||||
.get_session_user(token)
|
||||
.await
|
||||
.map_err(|_| warp::reject())?;
|
||||
|
||||
Ok(user)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn router(
|
||||
settings: &Settings,
|
||||
) -> Result<impl Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone> {
|
||||
let postgres = Postgres::new(settings.server.db_uri.as_str()).await?;
|
||||
let index = warp::get().and(warp::path::end()).map(handlers::index);
|
||||
|
||||
let count = warp::get()
|
||||
.and(warp::path("sync"))
|
||||
.and(warp::path("count"))
|
||||
.and(warp::path::end())
|
||||
.and(with_user(postgres.clone()))
|
||||
.and(with_db(postgres.clone()))
|
||||
.and_then(handlers::history::count);
|
||||
|
||||
let sync = warp::get()
|
||||
.and(warp::path("sync"))
|
||||
.and(warp::path("history"))
|
||||
.and(warp::query::<SyncHistoryRequest>())
|
||||
.and(warp::path::end())
|
||||
.and(with_user(postgres.clone()))
|
||||
.and(with_db(postgres.clone()))
|
||||
.and_then(handlers::history::list);
|
||||
|
||||
let add_history = warp::post()
|
||||
.and(warp::path("history"))
|
||||
.and(warp::path::end())
|
||||
.and(warp::body::json())
|
||||
.and(with_user(postgres.clone()))
|
||||
.and(with_db(postgres.clone()))
|
||||
.and_then(handlers::history::add);
|
||||
|
||||
let user = warp::get()
|
||||
.and(warp::path("user"))
|
||||
.and(warp::path::param::<String>())
|
||||
.and(warp::path::end())
|
||||
.and(with_db(postgres.clone()))
|
||||
.and_then(handlers::user::get);
|
||||
|
||||
let register = warp::post()
|
||||
.and(warp::path("register"))
|
||||
.and(warp::path::end())
|
||||
.and(warp::body::json())
|
||||
.and(with_settings(settings.clone()))
|
||||
.and(with_db(postgres.clone()))
|
||||
.and_then(handlers::user::register);
|
||||
|
||||
let login = warp::post()
|
||||
.and(warp::path("login"))
|
||||
.and(warp::path::end())
|
||||
.and(warp::body::json())
|
||||
.and(with_db(postgres))
|
||||
.and_then(handlers::user::login);
|
||||
|
||||
let r = warp::any()
|
||||
.and(
|
||||
index
|
||||
.or(count)
|
||||
.or(sync)
|
||||
.or(add_history)
|
||||
.or(user)
|
||||
.or(register)
|
||||
.or(login),
|
||||
)
|
||||
.with(warp::filters::log::log("atuin::api"));
|
||||
|
||||
Ok(r)
|
||||
}
|
|
@ -161,7 +161,7 @@ impl Settings {
|
|||
// Finally, set the auth token
|
||||
if Path::new(session_path.to_string().as_str()).exists() {
|
||||
let token = std::fs::read_to_string(session_path.to_string())?;
|
||||
s.set("local.session_token", token)?;
|
||||
s.set("local.session_token", token.trim())?;
|
||||
} else {
|
||||
s.set("local.session_token", "not logged in")?;
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@ _atuin_precmd(){
|
|||
[[ -z "${ATUIN_HISTORY_ID}" ]] && return
|
||||
|
||||
atuin history end $ATUIN_HISTORY_ID --exit $EXIT
|
||||
export ATUIN_HISTORY_ID=""
|
||||
}
|
||||
|
||||
_atuin_search(){
|
||||
|
|
Loading…
Reference in a new issue