196 lines
5.4 KiB
Rust
196 lines
5.4 KiB
Rust
use crate::Error;
|
|
use chrono::Utc;
|
|
use directories::ProjectDirs;
|
|
use futures::{StreamExt, TryStreamExt};
|
|
use serde::de::DeserializeOwned;
|
|
use serde::Serialize;
|
|
use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode};
|
|
use sqlx::{QueryBuilder, Sqlite, SqlitePool};
|
|
use std::fmt::Display;
|
|
use tabled::Tabled;
|
|
use tokio::fs;
|
|
use tracing::{info, instrument};
|
|
|
|
static SQLITE_BIND_LIMIT: usize = 32766;
|
|
|
|
#[derive(sqlx::Type, Debug)]
|
|
#[repr(u8)]
|
|
pub enum TaskStatus {
|
|
Pending = 1,
|
|
InProgress = 2,
|
|
Completed = 3,
|
|
Failed = 4,
|
|
}
|
|
|
|
impl Display for TaskStatus {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
match self {
|
|
TaskStatus::Pending => {
|
|
write!(f, "Pending")
|
|
}
|
|
TaskStatus::InProgress => {
|
|
write!(f, "In Progress")
|
|
}
|
|
TaskStatus::Completed => {
|
|
write!(f, "Completed")
|
|
}
|
|
TaskStatus::Failed => {
|
|
write!(f, "Failed")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
pub trait TaskPayload {
|
|
fn get_key(&self) -> String;
|
|
}
|
|
|
|
pub type TaskJob<T: _Task> = fn(&Task<T>) -> TaskStatus;
|
|
|
|
#[derive(sqlx::FromRow, Tabled, Debug)]
|
|
pub struct Task<T: DeserializeOwned + std::fmt::Display> {
|
|
id: u32,
|
|
payload_key: String,
|
|
#[sqlx(json)]
|
|
#[tabled(skip)]
|
|
payload: T,
|
|
#[sqlx(rename = "status_id")]
|
|
status: TaskStatus,
|
|
created_at: chrono::DateTime<Utc>,
|
|
#[tabled(display = "display_option_date")]
|
|
updated_at: Option<chrono::DateTime<Utc>>,
|
|
}
|
|
|
|
impl<T: DeserializeOwned + std::fmt::Display> Task<T> {
|
|
pub fn get_key(&self) -> String {
|
|
self.payload_key.clone()
|
|
}
|
|
}
|
|
|
|
fn display_option_date(o: &Option<chrono::DateTime<Utc>>) -> String {
|
|
match o {
|
|
Some(s) => format!("{}", s),
|
|
None => String::from(""),
|
|
}
|
|
}
|
|
|
|
pub trait _Task: DeserializeOwned + Send + Unpin + 'static + Display {}
|
|
impl<T: DeserializeOwned + Send + Unpin + 'static + Display> _Task for T {}
|
|
|
|
#[derive(Debug)]
|
|
pub struct TaskManager {
|
|
pool: SqlitePool,
|
|
}
|
|
|
|
impl TaskManager {
|
|
pub async fn new() -> Result<TaskManager, Error> {
|
|
Ok(Self {
|
|
pool: Self::connect_database().await?,
|
|
})
|
|
}
|
|
|
|
async fn connect_database() -> crate::Result<SqlitePool> {
|
|
let project_dir = ProjectDirs::from("", "", env!("CARGO_PKG_NAME"))
|
|
.ok_or(Error::Unhandled("Could not get standard directories"))?;
|
|
|
|
let database_file_path = project_dir.data_dir().join("db.sql");
|
|
|
|
fs::create_dir_all(project_dir.data_dir()).await?;
|
|
|
|
let opts = SqliteConnectOptions::new()
|
|
.filename(database_file_path)
|
|
.create_if_missing(true)
|
|
.journal_mode(SqliteJournalMode::Wal);
|
|
|
|
let pool = SqlitePool::connect_with(opts).await?;
|
|
|
|
sqlx::migrate!("./migrations").run(&pool).await?;
|
|
|
|
Ok(pool)
|
|
}
|
|
|
|
fn get_task_builder(
|
|
status: Option<TaskStatus>,
|
|
limit: Option<u16>,
|
|
) -> QueryBuilder<'static, Sqlite> {
|
|
let mut builder: QueryBuilder<'_, Sqlite> = QueryBuilder::new(
|
|
"select id, payload_key, payload, status_id, created_at, updated_at from tasks ",
|
|
);
|
|
|
|
if let Some(status) = status {
|
|
builder.push("where status_id = ").push_bind(status);
|
|
}
|
|
|
|
builder.push("ORDER BY created_at DESC ");
|
|
|
|
if let Some(limit) = limit {
|
|
builder.push("LIMIT ").push_bind(limit);
|
|
}
|
|
builder
|
|
}
|
|
|
|
pub async fn get_tasks<T: _Task>(
|
|
&self,
|
|
status: Option<TaskStatus>,
|
|
limit: Option<u16>,
|
|
) -> crate::Result<Vec<Task<T>>> {
|
|
let mut builder = Self::get_task_builder(status, limit);
|
|
|
|
let tasks: Vec<Task<T>> = builder.build_query_as().fetch_all(&self.pool).await?;
|
|
|
|
Ok(tasks)
|
|
}
|
|
|
|
#[instrument(skip(self, values))]
|
|
pub async fn load_tasks<T>(&self, values: Vec<T>) -> crate::Result<()>
|
|
where
|
|
T: TaskPayload + Serialize + std::fmt::Debug,
|
|
{
|
|
let mut tx = self.pool.begin().await?;
|
|
let mut builder: QueryBuilder<'_, Sqlite> =
|
|
QueryBuilder::new("insert into tasks(payload_key, payload, status_id)");
|
|
|
|
let args: crate::Result<Vec<(String, String)>> = values
|
|
.iter()
|
|
.map(|value| Ok((value.get_key(), serde_json::to_string(value)?)))
|
|
.collect();
|
|
|
|
let mut affected_rows = 0;
|
|
// Chunk the query by the size limit of bind params
|
|
for chunk in args?.chunks(SQLITE_BIND_LIMIT / 3) {
|
|
builder.push_values(chunk, |mut builder, item| {
|
|
builder
|
|
.push_bind(&item.0)
|
|
.push_bind(&item.1)
|
|
.push_bind(TaskStatus::Pending);
|
|
});
|
|
builder.push("ON conflict (payload_key) DO NOTHING");
|
|
|
|
let query = builder.build();
|
|
|
|
affected_rows += query.execute(&mut *tx).await?.rows_affected();
|
|
builder.reset();
|
|
}
|
|
|
|
tx.commit().await?;
|
|
|
|
info!("{} rows inserted.", affected_rows);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn run_tasks<T: _Task>(&self, func: TaskJob<T>) -> crate::Result<()> {
|
|
let mut builder = Self::get_task_builder(Some(TaskStatus::Pending), None);
|
|
|
|
let rows = builder.build_query_as::<Task<T>>().fetch(&self.pool);
|
|
|
|
let result: Vec<(Task<T>, TaskStatus)> = rows.map(|x| {
|
|
let task = x.unwrap();
|
|
let status = func(&task);
|
|
|
|
(task, status)
|
|
}).collect().await;
|
|
|
|
Ok(())
|
|
}
|
|
}
|