support timezones in calendar (#1259)

This commit is contained in:
Conrad Ludgate 2023-09-29 17:49:38 +01:00 committed by GitHub
parent a195c389b6
commit b4428c27c6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 109 additions and 153 deletions

View file

@ -1,11 +1,12 @@
// Calendar data // Calendar data
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use time::Month;
pub enum TimePeriod { pub enum TimePeriod {
YEAR, Year,
MONTH, Month { year: i32 },
DAY, Day { year: i32, month: Month },
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]

View file

@ -6,6 +6,7 @@ pub mod models;
use std::{ use std::{
collections::HashMap, collections::HashMap,
fmt::{Debug, Display}, fmt::{Debug, Display},
ops::Range,
}; };
use self::{ use self::{
@ -15,7 +16,7 @@ use self::{
use async_trait::async_trait; use async_trait::async_trait;
use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex}; use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex};
use serde::{de::DeserializeOwned, Serialize}; use serde::{de::DeserializeOwned, Serialize};
use time::{Date, Duration, Month, OffsetDateTime, PrimitiveDateTime, Time}; use time::{Date, Duration, Month, OffsetDateTime, Time, UtcOffset};
use tracing::instrument; use tracing::instrument;
#[derive(Debug)] #[derive(Debug)]
@ -74,12 +75,8 @@ pub trait Database: Sized + Clone + Send + Sync + 'static {
// Return the tail record ID for each store, so (HostID, Tag, TailRecordID) // Return the tail record ID for each store, so (HostID, Tag, TailRecordID)
async fn tail_records(&self, user: &User) -> DbResult<RecordIndex>; async fn tail_records(&self, user: &User) -> DbResult<RecordIndex>;
async fn count_history_range( async fn count_history_range(&self, user: &User, range: Range<OffsetDateTime>)
&self, -> DbResult<i64>;
user: &User,
start: PrimitiveDateTime,
end: PrimitiveDateTime,
) -> DbResult<i64>;
async fn list_history( async fn list_history(
&self, &self,
@ -94,107 +91,74 @@ pub trait Database: Sized + Clone + Send + Sync + 'static {
async fn oldest_history(&self, user: &User) -> DbResult<History>; async fn oldest_history(&self, user: &User) -> DbResult<History>;
/// Count the history for a given year
#[instrument(skip_all)]
async fn count_history_year(&self, user: &User, year: i32) -> DbResult<i64> {
let start = Date::from_calendar_date(year, time::Month::January, 1)?;
let end = Date::from_calendar_date(year + 1, time::Month::January, 1)?;
let res = self
.count_history_range(
user,
start.with_time(Time::MIDNIGHT),
end.with_time(Time::MIDNIGHT),
)
.await?;
Ok(res)
}
/// Count the history for a given month
#[instrument(skip_all)]
async fn count_history_month(&self, user: &User, year: i32, month: Month) -> DbResult<i64> {
let start = Date::from_calendar_date(year, month, 1)?;
let days = time::util::days_in_year_month(year, month);
let end = start + Duration::days(days as i64);
tracing::debug!("start: {}, end: {}", start, end);
let res = self
.count_history_range(
user,
start.with_time(Time::MIDNIGHT),
end.with_time(Time::MIDNIGHT),
)
.await?;
Ok(res)
}
/// Count the history for a given day
#[instrument(skip_all)]
async fn count_history_day(&self, user: &User, day: Date) -> DbResult<i64> {
let end = day
.next_day()
.ok_or_else(|| DbError::Other(eyre::eyre!("no next day?")))?;
let res = self
.count_history_range(
user,
day.with_time(Time::MIDNIGHT),
end.with_time(Time::MIDNIGHT),
)
.await?;
Ok(res)
}
#[instrument(skip_all)] #[instrument(skip_all)]
async fn calendar( async fn calendar(
&self, &self,
user: &User, user: &User,
period: TimePeriod, period: TimePeriod,
year: u64, tz: UtcOffset,
month: Month,
) -> DbResult<HashMap<u64, TimePeriodInfo>> { ) -> DbResult<HashMap<u64, TimePeriodInfo>> {
// TODO: Support different timezones. Right now we assume UTC and
// everything is stored as such. But it _should_ be possible to
// interpret the stored date with a different TZ
match period {
TimePeriod::YEAR => {
let mut ret = HashMap::new(); let mut ret = HashMap::new();
let iter: Box<dyn Iterator<Item = DbResult<(u64, Range<Date>)>> + Send> = match period {
TimePeriod::Year => {
// First we need to work out how far back to calculate. Get the // First we need to work out how far back to calculate. Get the
// oldest history item // oldest history item
let oldest = self.oldest_history(user).await?.timestamp.year(); let oldest = self
let current_year = OffsetDateTime::now_utc().year(); .oldest_history(user)
.await?
.timestamp
.to_offset(tz)
.year();
let current_year = OffsetDateTime::now_utc().to_offset(tz).year();
// All the years we need to get data for // All the years we need to get data for
// The upper bound is exclusive, so include current +1 // The upper bound is exclusive, so include current +1
let years = oldest..current_year + 1; let years = oldest..current_year + 1;
for year in years { Box::new(years.map(|year| {
let count = self.count_history_year(user, year).await?; let start = Date::from_calendar_date(year, time::Month::January, 1)?;
let end = Date::from_calendar_date(year + 1, time::Month::January, 1)?;
ret.insert( Ok((year as u64, start..end))
year as u64, }))
TimePeriodInfo {
count: count as u64,
hash: "".to_string(),
},
);
} }
Ok(ret) TimePeriod::Month { year } => {
}
TimePeriod::MONTH => {
let mut ret = HashMap::new();
let months = let months =
std::iter::successors(Some(Month::January), |m| Some(m.next())).take(12); std::iter::successors(Some(Month::January), |m| Some(m.next())).take(12);
for month in months {
let count = self.count_history_month(user, year as i32, month).await?; Box::new(months.map(move |month| {
let start = Date::from_calendar_date(year, month, 1)?;
let days = time::util::days_in_year_month(year, month);
let end = start + Duration::days(days as i64);
Ok((month as u64, start..end))
}))
}
TimePeriod::Day { year, month } => {
let days = 1..time::util::days_in_year_month(year, month);
Box::new(days.map(move |day| {
let start = Date::from_calendar_date(year, month, day)?;
let end = start
.next_day()
.ok_or_else(|| DbError::Other(eyre::eyre!("no next day?")))?;
Ok((day as u64, start..end))
}))
}
};
for x in iter {
let (index, range) = x?;
let start = range.start.with_time(Time::MIDNIGHT).assume_offset(tz);
let end = range.end.with_time(Time::MIDNIGHT).assume_offset(tz);
let count = self.count_history_range(user, start..end).await?;
ret.insert( ret.insert(
month as u64, index,
TimePeriodInfo { TimePeriodInfo {
count: count as u64, count: count as u64,
hash: "".to_string(), hash: "".to_string(),
@ -204,26 +168,4 @@ pub trait Database: Sized + Clone + Send + Sync + 'static {
Ok(ret) Ok(ret)
} }
TimePeriod::DAY => {
let mut ret = HashMap::new();
for day in 1..time::util::days_in_year_month(year as i32, month) {
let count = self
.count_history_day(user, Date::from_calendar_date(year as i32, month, day)?)
.await?;
ret.insert(
day as u64,
TimePeriodInfo {
count: count as u64,
hash: "".to_string(),
},
);
}
Ok(ret)
}
}
}
} }

View file

@ -1,3 +1,5 @@
use std::ops::Range;
use async_trait::async_trait; use async_trait::async_trait;
use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex}; use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex};
use atuin_server_database::models::{History, NewHistory, NewSession, NewUser, Session, User}; use atuin_server_database::models::{History, NewHistory, NewSession, NewUser, Session, User};
@ -176,8 +178,7 @@ impl Database for Postgres {
async fn count_history_range( async fn count_history_range(
&self, &self,
user: &User, user: &User,
start: PrimitiveDateTime, range: Range<OffsetDateTime>,
end: PrimitiveDateTime,
) -> DbResult<i64> { ) -> DbResult<i64> {
let res: (i64,) = sqlx::query_as( let res: (i64,) = sqlx::query_as(
"select count(1) from history "select count(1) from history
@ -186,8 +187,8 @@ impl Database for Postgres {
and timestamp < $3::date", and timestamp < $3::date",
) )
.bind(user.id) .bind(user.id)
.bind(start) .bind(into_utc(range.start))
.bind(end) .bind(into_utc(range.end))
.fetch_one(&self.pool) .fetch_one(&self.pool)
.await .await
.map_err(fix_error)?; .map_err(fix_error)?;

View file

@ -6,7 +6,7 @@ use axum::{
Json, Json,
}; };
use http::StatusCode; use http::StatusCode;
use time::Month; use time::{Month, UtcOffset};
use tracing::{debug, error, instrument}; use tracing::{debug, error, instrument};
use super::{ErrorResponse, ErrorResponseStatus, RespExt}; use super::{ErrorResponse, ErrorResponseStatus, RespExt};
@ -166,53 +166,65 @@ pub async fn add<DB: Database>(
Ok(()) Ok(())
} }
#[derive(serde::Deserialize, Debug)]
pub struct CalendarQuery {
#[serde(default = "serde_calendar::zero")]
year: i32,
#[serde(default = "serde_calendar::one")]
month: u8,
#[serde(default = "serde_calendar::utc")]
tz: UtcOffset,
}
mod serde_calendar {
use time::UtcOffset;
pub fn zero() -> i32 {
0
}
pub fn one() -> u8 {
1
}
pub fn utc() -> UtcOffset {
UtcOffset::UTC
}
}
#[instrument(skip_all, fields(user.id = user.id))] #[instrument(skip_all, fields(user.id = user.id))]
pub async fn calendar<DB: Database>( pub async fn calendar<DB: Database>(
Path(focus): Path<String>, Path(focus): Path<String>,
Query(params): Query<HashMap<String, u64>>, Query(params): Query<CalendarQuery>,
UserAuth(user): UserAuth, UserAuth(user): UserAuth,
state: State<AppState<DB>>, state: State<AppState<DB>>,
) -> Result<Json<HashMap<u64, TimePeriodInfo>>, ErrorResponseStatus<'static>> { ) -> Result<Json<HashMap<u64, TimePeriodInfo>>, ErrorResponseStatus<'static>> {
let focus = focus.as_str(); let focus = focus.as_str();
let year = params.get("year").unwrap_or(&0); let year = params.year;
let month = params.get("month").unwrap_or(&1); let month = Month::try_from(params.month).map_err(|e| ErrorResponseStatus {
let month = Month::try_from(*month as u8).map_err(|e| ErrorResponseStatus {
error: ErrorResponse { error: ErrorResponse {
reason: e.to_string().into(), reason: e.to_string().into(),
}, },
status: http::StatusCode::BAD_REQUEST, status: http::StatusCode::BAD_REQUEST,
})?; })?;
let period = match focus {
"year" => TimePeriod::Year,
"month" => TimePeriod::Month { year },
"day" => TimePeriod::Day { year, month },
_ => {
return Err(ErrorResponse::reply("invalid focus: use year/month/day")
.with_status(StatusCode::BAD_REQUEST))
}
};
let db = &state.0.database; let db = &state.0.database;
let focus = match focus { let focus = db.calendar(&user, period, params.tz).await.map_err(|_| {
"year" => db
.calendar(&user, TimePeriod::YEAR, *year, month)
.await
.map_err(|_| {
ErrorResponse::reply("failed to query calendar") ErrorResponse::reply("failed to query calendar")
.with_status(StatusCode::INTERNAL_SERVER_ERROR) .with_status(StatusCode::INTERNAL_SERVER_ERROR)
}), })?;
"month" => db
.calendar(&user, TimePeriod::MONTH, *year, month)
.await
.map_err(|_| {
ErrorResponse::reply("failed to query calendar")
.with_status(StatusCode::INTERNAL_SERVER_ERROR)
}),
"day" => db
.calendar(&user, TimePeriod::DAY, *year, month)
.await
.map_err(|_| {
ErrorResponse::reply("failed to query calendar")
.with_status(StatusCode::INTERNAL_SERVER_ERROR)
}),
_ => Err(ErrorResponse::reply("invalid focus: use year/month/day")
.with_status(StatusCode::BAD_REQUEST)),
}?;
Ok(Json(focus)) Ok(Json(focus))
} }