atuin/atuin-server/src/database.rs
Ellie Huxtable f4240aa62b
Initial implementation of calendar API (#298)
This can be used in the future for sync so that we can be more
intelligent with what we're doing, and only sync up what's needed

I'd like to eventually replace this with something more like a merkle
tree, hence the hash field I've exposed, but that can come later

Although this does include a much larger number of count queries, it
should also be significantly more cache-able. I'll follow up with that
later, and also follow up with using this for sync :)
2022-04-13 18:29:18 +01:00

417 lines
12 KiB
Rust

use async_trait::async_trait;
use std::collections::HashMap;
use eyre::{eyre, Result};
use sqlx::postgres::PgPoolOptions;
use crate::settings::HISTORY_PAGE_SIZE;
use super::calendar::{TimePeriod, TimePeriodInfo};
use super::models::{History, NewHistory, NewSession, NewUser, Session, User};
use chrono::{Datelike, TimeZone};
use chronoutil::RelativeDuration;
use atuin_common::utils::get_days_from_month;
#[async_trait]
pub trait Database {
async fn get_session(&self, token: &str) -> Result<Session>;
async fn get_session_user(&self, token: &str) -> Result<User>;
async fn add_session(&self, session: &NewSession) -> Result<()>;
async fn get_user(&self, username: &str) -> Result<User>;
async fn get_user_session(&self, u: &User) -> Result<Session>;
async fn add_user(&self, user: &NewUser) -> Result<i64>;
async fn count_history(&self, user: &User) -> Result<i64>;
async fn count_history_range(
&self,
user: &User,
start: chrono::NaiveDateTime,
end: chrono::NaiveDateTime,
) -> Result<i64>;
async fn count_history_day(&self, user: &User, date: chrono::NaiveDate) -> Result<i64>;
async fn count_history_month(&self, user: &User, date: chrono::NaiveDate) -> Result<i64>;
async fn count_history_year(&self, user: &User, year: i32) -> Result<i64>;
async fn list_history(
&self,
user: &User,
created_after: chrono::NaiveDateTime,
since: chrono::NaiveDateTime,
host: &str,
) -> Result<Vec<History>>;
async fn add_history(&self, history: &[NewHistory]) -> Result<()>;
async fn oldest_history(&self, user: &User) -> Result<History>;
async fn calendar(
&self,
user: &User,
period: TimePeriod,
year: u64,
month: u64,
) -> Result<HashMap<u64, TimePeriodInfo>>;
}
#[derive(Clone)]
pub struct Postgres {
pool: sqlx::Pool<sqlx::postgres::Postgres>,
}
impl Postgres {
pub async fn new(uri: &str) -> Result<Self, sqlx::Error> {
let pool = PgPoolOptions::new()
.max_connections(100)
.connect(uri)
.await?;
sqlx::migrate!("./migrations").run(&pool).await?;
Ok(Self { pool })
}
}
#[async_trait]
impl Database for Postgres {
async fn get_session(&self, token: &str) -> Result<Session> {
let res: Option<Session> =
sqlx::query_as::<_, Session>("select * from sessions where token = $1")
.bind(token)
.fetch_optional(&self.pool)
.await?;
if let Some(s) = res {
Ok(s)
} else {
Err(eyre!("could not find session"))
}
}
async fn get_user(&self, username: &str) -> Result<User> {
let res: Option<User> =
sqlx::query_as::<_, User>("select * from users where username = $1")
.bind(username)
.fetch_optional(&self.pool)
.await?;
if let Some(u) = res {
Ok(u)
} else {
Err(eyre!("could not find user"))
}
}
async fn get_session_user(&self, token: &str) -> Result<User> {
let res: Option<User> = sqlx::query_as::<_, User>(
"select * from users
inner join sessions
on users.id = sessions.user_id
and sessions.token = $1",
)
.bind(token)
.fetch_optional(&self.pool)
.await?;
if let Some(u) = res {
Ok(u)
} else {
Err(eyre!("could not find user"))
}
}
async fn count_history(&self, user: &User) -> Result<i64> {
let res: (i64,) = sqlx::query_as(
"select count(1) from history
where user_id = $1",
)
.bind(user.id)
.fetch_one(&self.pool)
.await?;
Ok(res.0)
}
async fn count_history_range(
&self,
user: &User,
start: chrono::NaiveDateTime,
end: chrono::NaiveDateTime,
) -> Result<i64> {
let res: (i64,) = sqlx::query_as(
"select count(1) from history
where user_id = $1
and timestamp >= $2::date
and timestamp < $3::date",
)
.bind(user.id)
.bind(start)
.bind(end)
.fetch_one(&self.pool)
.await?;
Ok(res.0)
}
// Count the history for a given year
async fn count_history_year(&self, user: &User, year: i32) -> Result<i64> {
let start = chrono::Utc.ymd(year, 1, 1).and_hms_nano(0, 0, 0, 0);
let end = start + RelativeDuration::years(1);
let res = self
.count_history_range(user, start.naive_utc(), end.naive_utc())
.await?;
Ok(res)
}
// Count the history for a given month
async fn count_history_month(&self, user: &User, month: chrono::NaiveDate) -> Result<i64> {
let start = chrono::Utc
.ymd(month.year(), month.month(), 1)
.and_hms_nano(0, 0, 0, 0);
// ofc...
let end = if month.month() < 12 {
chrono::Utc
.ymd(month.year(), month.month() + 1, 1)
.and_hms_nano(0, 0, 0, 0)
} else {
chrono::Utc
.ymd(month.year() + 1, 1, 1)
.and_hms_nano(0, 0, 0, 0)
};
debug!("start: {}, end: {}", start, end);
let res = self
.count_history_range(user, start.naive_utc(), end.naive_utc())
.await?;
Ok(res)
}
// Count the history for a given day
async fn count_history_day(&self, user: &User, day: chrono::NaiveDate) -> Result<i64> {
let start = chrono::Utc
.ymd(day.year(), day.month(), day.day())
.and_hms_nano(0, 0, 0, 0);
let end = chrono::Utc
.ymd(day.year(), day.month(), day.day() + 1)
.and_hms_nano(0, 0, 0, 0);
let res = self
.count_history_range(user, start.naive_utc(), end.naive_utc())
.await?;
Ok(res)
}
async fn list_history(
&self,
user: &User,
created_after: chrono::NaiveDateTime,
since: chrono::NaiveDateTime,
host: &str,
) -> Result<Vec<History>> {
let res = sqlx::query_as::<_, History>(
"select * from history
where user_id = $1
and hostname != $2
and created_at >= $3
and timestamp >= $4
order by timestamp asc
limit $5",
)
.bind(user.id)
.bind(host)
.bind(created_after)
.bind(since)
.bind(HISTORY_PAGE_SIZE)
.fetch_all(&self.pool)
.await?;
Ok(res)
}
async fn add_history(&self, history: &[NewHistory]) -> Result<()> {
let mut tx = self.pool.begin().await?;
for i in history {
let client_id: &str = &i.client_id;
let hostname: &str = &i.hostname;
let data: &str = &i.data;
sqlx::query(
"insert into history
(client_id, user_id, hostname, timestamp, data)
values ($1, $2, $3, $4, $5)
on conflict do nothing
",
)
.bind(client_id)
.bind(i.user_id)
.bind(hostname)
.bind(i.timestamp)
.bind(data)
.execute(&mut tx)
.await?;
}
tx.commit().await?;
Ok(())
}
async fn add_user(&self, user: &NewUser) -> Result<i64> {
let email: &str = &user.email;
let username: &str = &user.username;
let password: &str = &user.password;
let res: (i64,) = sqlx::query_as(
"insert into users
(username, email, password)
values($1, $2, $3)
returning id",
)
.bind(username)
.bind(email)
.bind(password)
.fetch_one(&self.pool)
.await?;
Ok(res.0)
}
async fn add_session(&self, session: &NewSession) -> Result<()> {
let token: &str = &session.token;
sqlx::query(
"insert into sessions
(user_id, token)
values($1, $2)",
)
.bind(session.user_id)
.bind(token)
.execute(&self.pool)
.await?;
Ok(())
}
async fn get_user_session(&self, u: &User) -> Result<Session> {
let res: Option<Session> =
sqlx::query_as::<_, Session>("select * from sessions where user_id = $1")
.bind(u.id)
.fetch_optional(&self.pool)
.await?;
if let Some(s) = res {
Ok(s)
} else {
Err(eyre!("could not find session"))
}
}
async fn oldest_history(&self, user: &User) -> Result<History> {
let res = sqlx::query_as::<_, History>(
"select * from history
where user_id = $1
order by timestamp asc
limit 1",
)
.bind(user.id)
.fetch_one(&self.pool)
.await?;
Ok(res)
}
async fn calendar(
&self,
user: &User,
period: TimePeriod,
year: u64,
month: u64,
) -> Result<HashMap<u64, TimePeriodInfo>> {
// TODO: Support different timezones. Right now we assume UTC and
// everything is stored as such. But it _should_ be possible to
// interpret the stored date with a different TZ
match period {
TimePeriod::YEAR => {
let mut ret = HashMap::new();
// First we need to work out how far back to calculate. Get the
// oldest history item
let oldest = self.oldest_history(user).await?.timestamp.year();
let current_year = chrono::Utc::now().year();
// All the years we need to get data for
// The upper bound is exclusive, so include current +1
let years = oldest..current_year + 1;
for year in years {
let count = self.count_history_year(user, year).await?;
ret.insert(
year as u64,
TimePeriodInfo {
count: count as u64,
hash: "".to_string(),
},
);
}
Ok(ret)
}
TimePeriod::MONTH => {
let mut ret = HashMap::new();
for month in 1..13 {
let count = self
.count_history_month(
user,
chrono::Utc.ymd(year as i32, month, 1).naive_utc(),
)
.await?;
ret.insert(
month as u64,
TimePeriodInfo {
count: count as u64,
hash: "".to_string(),
},
);
}
Ok(ret)
}
TimePeriod::DAY => {
let mut ret = HashMap::new();
for day in 1..get_days_from_month(year as i32, month as u32) {
let count = self
.count_history_day(
user,
chrono::Utc
.ymd(year as i32, month as u32, day as u32)
.naive_utc(),
)
.await?;
ret.insert(
day as u64,
TimePeriodInfo {
count: count as u64,
hash: "".to_string(),
},
);
}
Ok(ret)
}
}
}
}