use crate::Error; use chrono::Utc; use directories::ProjectDirs; use futures::{StreamExt, TryStreamExt}; use serde::de::DeserializeOwned; use serde::Serialize; use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode}; use sqlx::{QueryBuilder, Sqlite, SqlitePool}; use std::fmt::Display; use tabled::Tabled; use tokio::fs; use tracing::{info, instrument}; static SQLITE_BIND_LIMIT: usize = 32766; #[derive(sqlx::Type, Debug)] #[repr(u8)] pub enum TaskStatus { Pending = 1, InProgress = 2, Completed = 3, Failed = 4, } impl Display for TaskStatus { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { TaskStatus::Pending => { write!(f, "Pending") } TaskStatus::InProgress => { write!(f, "In Progress") } TaskStatus::Completed => { write!(f, "Completed") } TaskStatus::Failed => { write!(f, "Failed") } } } } pub trait TaskPayload { fn get_key(&self) -> String; } pub type TaskJob = fn(&Task) -> TaskStatus; #[derive(sqlx::FromRow, Tabled, Debug)] pub struct Task { id: u32, payload_key: String, #[sqlx(json)] #[tabled(skip)] payload: T, #[sqlx(rename = "status_id")] status: TaskStatus, created_at: chrono::DateTime, #[tabled(display = "display_option_date")] updated_at: Option>, } impl Task { pub fn get_key(&self) -> String { self.payload_key.clone() } } fn display_option_date(o: &Option>) -> String { match o { Some(s) => format!("{}", s), None => String::from(""), } } pub trait _Task: DeserializeOwned + Send + Unpin + 'static + Display {} impl _Task for T {} #[derive(Debug)] pub struct TaskManager { pool: SqlitePool, } impl TaskManager { pub async fn new() -> Result { Ok(Self { pool: Self::connect_database().await?, }) } async fn connect_database() -> crate::Result { let project_dir = ProjectDirs::from("", "", env!("CARGO_PKG_NAME")) .ok_or(Error::Unhandled("Could not get standard directories"))?; let database_file_path = project_dir.data_dir().join("db.sql"); fs::create_dir_all(project_dir.data_dir()).await?; let opts = SqliteConnectOptions::new() .filename(database_file_path) .create_if_missing(true) .journal_mode(SqliteJournalMode::Wal); let pool = SqlitePool::connect_with(opts).await?; sqlx::migrate!("./migrations").run(&pool).await?; Ok(pool) } fn get_task_builder( status: Option, limit: Option, ) -> QueryBuilder<'static, Sqlite> { let mut builder: QueryBuilder<'_, Sqlite> = QueryBuilder::new( "select id, payload_key, payload, status_id, created_at, updated_at from tasks ", ); if let Some(status) = status { builder.push("where status_id = ").push_bind(status); } builder.push("ORDER BY created_at DESC "); if let Some(limit) = limit { builder.push("LIMIT ").push_bind(limit); } builder } pub async fn get_tasks( &self, status: Option, limit: Option, ) -> crate::Result>> { let mut builder = Self::get_task_builder(status, limit); let tasks: Vec> = builder.build_query_as().fetch_all(&self.pool).await?; Ok(tasks) } #[instrument(skip(self, values))] pub async fn load_tasks(&self, values: Vec) -> crate::Result<()> where T: TaskPayload + Serialize + std::fmt::Debug, { let mut tx = self.pool.begin().await?; let mut builder: QueryBuilder<'_, Sqlite> = QueryBuilder::new("insert into tasks(payload_key, payload, status_id)"); let args: crate::Result> = values .iter() .map(|value| Ok((value.get_key(), serde_json::to_string(value)?))) .collect(); let mut affected_rows = 0; // Chunk the query by the size limit of bind params for chunk in args?.chunks(SQLITE_BIND_LIMIT / 3) { builder.push_values(chunk, |mut builder, item| { builder .push_bind(&item.0) .push_bind(&item.1) .push_bind(TaskStatus::Pending); }); builder.push("ON conflict (payload_key) DO NOTHING"); let query = builder.build(); affected_rows += query.execute(&mut *tx).await?.rows_affected(); builder.reset(); } tx.commit().await?; info!("{} rows inserted.", affected_rows); Ok(()) } pub async fn run_tasks(&self, func: TaskJob) -> crate::Result<()> { let mut builder = Self::get_task_builder(Some(TaskStatus::Pending), None); let rows = builder.build_query_as::>().fetch(&self.pool); let result: Vec<(Task, TaskStatus)> = rows.map(|x| { let task = x.unwrap(); let status = func(&task); (task, status) }).collect().await; Ok(()) } }