use crate::Error; use directories::ProjectDirs; use serde::Serialize; use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode}; use sqlx::{QueryBuilder, Sqlite, SqlitePool}; use tokio::fs; use tracing::{info, instrument}; static SQLITE_BIND_LIMIT: usize = 32766; #[derive(sqlx::Type)] #[repr(u8)] pub enum TaskStatus { Pending = 1, InProgress = 2, Completed = 3, Failed = 4, } pub trait TaskPayload { fn get_key(&self) -> String; } #[derive(Debug)] pub struct TaskManager { pool: SqlitePool, } impl TaskManager { pub async fn new() -> Result { Ok(Self { pool: Self::connect_database().await?, }) } async fn connect_database() -> crate::Result { let project_dir = ProjectDirs::from("", "", env!("CARGO_PKG_NAME")) .ok_or(Error::Unhandled("Could not get standard directories"))?; let database_file_path = project_dir.data_dir().join("db.sql"); fs::create_dir_all(project_dir.data_dir()).await?; let opts = SqliteConnectOptions::new() .filename(database_file_path) .create_if_missing(true) .journal_mode(SqliteJournalMode::Wal); let pool = SqlitePool::connect_with(opts).await?; sqlx::migrate!("./migrations").run(&pool).await?; Ok(pool) } #[instrument(skip(self, values))] pub async fn load_tasks(&self, values: Vec) -> crate::Result<()> where T: TaskPayload + Serialize + std::fmt::Debug, { let mut tx = self.pool.begin().await?; let mut builder: QueryBuilder<'_, Sqlite> = QueryBuilder::new("insert into tasks(payload_key, payload, status_id)"); let args: crate::Result> = values .iter() .map(|value| Ok((value.get_key(), serde_json::to_string(value)?))) .collect(); let mut affected_rows = 0; // Chunk the query by the size limit of bind params for chunk in args?.chunks(SQLITE_BIND_LIMIT / 3) { builder.push_values(chunk, |mut builder, item| { builder .push_bind(&item.0) .push_bind(&item.1) .push_bind(TaskStatus::Pending); }); builder.push("ON conflict (payload_key) DO NOTHING"); let query = builder.build(); affected_rows += query.execute(&mut *tx).await?.rows_affected(); builder.reset(); } tx.commit().await?; info!("{} rows inserted.", affected_rows); Ok(()) } }