diff --git a/cli/bin/readwise/external_interface.rs b/cli/bin/readwise/external_interface.rs index 13b2f65..a484ca9 100644 --- a/cli/bin/readwise/external_interface.rs +++ b/cli/bin/readwise/external_interface.rs @@ -1,4 +1,4 @@ -use lib_sync_core::task_manager::TaskPayload; +use lib_sync_core::tasks::TaskPayload; use chrono::{DateTime, Local}; use serde::{de, Deserialize, Deserializer, Serialize}; use serde_json::Value; diff --git a/cli/bin/readwise/main.rs b/cli/bin/readwise/main.rs index bdfc584..29eee78 100644 --- a/cli/bin/readwise/main.rs +++ b/cli/bin/readwise/main.rs @@ -4,7 +4,7 @@ use figment::{ Figment, providers::{Env, Serialized}, }; -use lib_sync_core::task_manager::{TaskManager, TaskStatus}; +use lib_sync_core::tasks::{TaskStatus}; use cli::config::{Command, Config}; use cli::{Error, Result}; use std::fs::File; diff --git a/lib_sync_core/src/database.rs b/lib_sync_core/src/database.rs index 0674329..f2bc877 100644 --- a/lib_sync_core/src/database.rs +++ b/lib_sync_core/src/database.rs @@ -1,14 +1,6 @@ -use crate::task_manager::{Task, TaskPayload, TaskStatus}; -use futures::stream::BoxStream; -use futures::{Stream, StreamExt, TryStreamExt}; -use serde::Serialize; -use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode}; -use sqlx::{Error, QueryBuilder, Sqlite, SqlitePool}; -use std::path::PathBuf; -use tokio::fs; -use tracing::{info, instrument}; - -static SQLITE_BIND_LIMIT: usize = 32766; +use crate::tasks::{Task, TaskPayload, TaskStatus}; +use futures::{Stream}; +mod sqlite; #[derive(Default, Clone)] pub struct TaskPagination { @@ -77,121 +69,9 @@ impl TasksPage { } pub trait TaskStorage { - fn insert_tasks(&self, tasks: Vec>) -> crate::Result<()>; + async fn insert_tasks(&self, tasks: Vec>) -> crate::Result<()>; fn get_tasks(&self, options: TaskStatus) -> impl Stream>>; async fn get_paginated_tasks(&self, page: &TaskPagination) -> crate::Result>; } -#[derive(Debug)] -pub struct Database { - pool: SqlitePool, -} - -impl Database { - pub async fn new>(base_path: P) -> crate::Result { - Ok(Self { - pool: Self::connect_database(base_path).await?, - }) - } - - async fn connect_database>(base_path: P) -> crate::Result { - let base_path = base_path.into(); - - let database_file_path = base_path.join("db.sql"); - - fs::create_dir_all(base_path).await?; - - let opts = SqliteConnectOptions::new() - .filename(database_file_path) - .create_if_missing(true) - .journal_mode(SqliteJournalMode::Wal); - - let pool = SqlitePool::connect_with(opts).await?; - - sqlx::migrate!("../migrations").run(&pool).await?; - - Ok(pool) - } - - #[instrument(skip(self, values))] - pub async fn load_tasks(&self, values: Vec) -> crate::Result<()> - where - T: TaskPayload + Serialize + std::fmt::Debug, - { - let mut tx = self.pool.begin().await?; - let mut builder: QueryBuilder<'_, Sqlite> = - QueryBuilder::new("insert into tasks(payload_key, payload, status_id)"); - - let args: crate::Result> = values - .iter() - .map(|value| Ok((value.get_key(), serde_json::to_string(value)?))) - .collect(); - - let mut affected_rows = 0; - // Chunk the query by the size limit of bind params - for chunk in args?.chunks(SQLITE_BIND_LIMIT / 3) { - builder.push_values(chunk, |mut builder, item| { - builder - .push_bind(&item.0) - .push_bind(&item.1) - .push_bind(TaskStatus::Pending); - }); - builder.push("ON conflict (payload_key) DO NOTHING"); - - let query = builder.build(); - - affected_rows += query.execute(&mut *tx).await?.rows_affected(); - builder.reset(); - } - - tx.commit().await?; - - info!("{} rows inserted.", affected_rows); - - Ok(()) - } -} - -impl TaskStorage for Database { - fn insert_tasks(&self, tasks: Vec>) -> crate::error::Result<()> { - todo!() - } - - fn get_tasks(&self, task_status: TaskStatus) -> impl Stream>> { - let query= sqlx::query_as::<_, Task>( - " - SELECT id, payload_key, payload, status_id, created_at, updated_at - FROM tasks - WHERE status_id = ? - ORDER BY created_at DESC - ", - ).bind(task_status); - - query.fetch(&self.pool).err_into::() - } - - async fn get_paginated_tasks(&self, page: &TaskPagination) -> crate::Result> { - let mut builder: QueryBuilder<'_, Sqlite> = QueryBuilder::new( - "select id, payload_key, payload, status_id, created_at, updated_at from tasks ", - ); - - if let Some(status) = &page.status { - builder.push("where status_id = ").push_bind(status); - } - - builder.push("ORDER BY created_at DESC "); - - if let Some(limit) = &page.offset { - builder.push("OFFSET ").push_bind(limit); - } - - if let Some(limit) = &page.limit { - builder.push("LIMIT ").push_bind(limit); - } - - let tasks = builder.build_query_as::>().fetch_all(&self.pool).await?; - - Ok(TasksPage::new(tasks, page.clone())) - } -} diff --git a/lib_sync_core/src/database/sqlite.rs b/lib_sync_core/src/database/sqlite.rs new file mode 100644 index 0000000..c7fc993 --- /dev/null +++ b/lib_sync_core/src/database/sqlite.rs @@ -0,0 +1,117 @@ +use sqlx::{QueryBuilder, SqlitePool}; +use std::path::PathBuf; +use tokio::fs; +use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode}; +use tracing::{info, instrument}; +use futures::{Stream, TryStreamExt}; +use crate::database::{TaskPagination, TaskStorage, TasksPage}; +use crate::tasks::{Task, TaskPayload, TaskStatus}; + +#[allow(unused)] +static SQLITE_BIND_LIMIT: usize = 32766; + +#[derive(Debug)] +pub struct Sqlite { + pool: SqlitePool, +} + +impl Sqlite { + pub async fn new>(base_path: P) -> crate::Result { + Ok(Self { + pool: Self::connect_database(base_path).await?, + }) + } + + async fn connect_database>(base_path: P) -> crate::Result { + let base_path = base_path.into(); + + let database_file_path = base_path.join("db.sql"); + + fs::create_dir_all(base_path).await?; + + let opts = SqliteConnectOptions::new() + .filename(database_file_path) + .create_if_missing(true) + .journal_mode(SqliteJournalMode::Wal); + + let pool = SqlitePool::connect_with(opts).await?; + + sqlx::migrate!("../migrations").run(&pool).await?; + + Ok(pool) + } +} + +impl TaskStorage for Sqlite { + #[instrument(skip(self, tasks))] + async fn insert_tasks(&self, tasks: Vec>) -> crate::Result<()> { + let mut tx = self.pool.begin().await?; + let mut builder: QueryBuilder<'_, sqlx::Sqlite> = + QueryBuilder::new("insert into tasks(payload_key, payload, status_id)"); + + let args: crate::Result> = tasks + .iter() + .map(|value| Ok((value.get_key(), serde_json::to_string(value.payload())?))) + .collect(); + + let mut affected_rows = 0; + // Chunk the query by the size limit of bind params + for chunk in args?.chunks(SQLITE_BIND_LIMIT / 3) { + builder.push_values(chunk, |mut builder, item| { + builder + .push_bind(&item.0) + .push_bind(&item.1) + .push_bind(TaskStatus::Pending); + }); + builder.push("ON conflict (payload_key) DO NOTHING"); + + let query = builder.build(); + + affected_rows += query.execute(&mut *tx).await?.rows_affected(); + builder.reset(); + } + + tx.commit().await?; + + info!("{} rows inserted.", affected_rows); + + Ok(()) + } + + fn get_tasks(&self, task_status: TaskStatus) -> impl Stream>> { + let query= sqlx::query_as::<_, Task>( + " + SELECT id, payload_key, payload, status_id, created_at, updated_at + FROM tasks + WHERE status_id = ? + ORDER BY created_at DESC + ", + ).bind(task_status); + + query.fetch(&self.pool).err_into::() + } + + async fn get_paginated_tasks(&self, page: &TaskPagination) -> crate::Result> { + let mut builder: QueryBuilder<'_, sqlx::Sqlite> = QueryBuilder::new( + "select id, payload_key, payload, status_id, created_at, updated_at from tasks ", + ); + + if let Some(status) = &page.status { + builder.push("where status_id = ").push_bind(status); + } + + builder.push("ORDER BY created_at DESC "); + + if let Some(limit) = &page.offset { + builder.push("OFFSET ").push_bind(limit); + } + + if let Some(limit) = &page.limit { + builder.push("LIMIT ").push_bind(limit); + } + + let tasks = builder.build_query_as::>().fetch_all(&self.pool).await?; + + Ok(TasksPage::new(tasks, page.clone())) + } +} \ No newline at end of file diff --git a/lib_sync_core/src/lib.rs b/lib_sync_core/src/lib.rs index 0feb7fb..586934c 100644 --- a/lib_sync_core/src/lib.rs +++ b/lib_sync_core/src/lib.rs @@ -1,7 +1,7 @@ pub mod error; pub(crate) use error::*; -pub mod task_manager; +pub mod tasks; mod database; pub fn add(left: u64, right: u64) -> u64 { diff --git a/lib_sync_core/src/task_manager.rs b/lib_sync_core/src/tasks.rs similarity index 53% rename from lib_sync_core/src/task_manager.rs rename to lib_sync_core/src/tasks.rs index 8d7cdd4..04c22b7 100644 --- a/lib_sync_core/src/task_manager.rs +++ b/lib_sync_core/src/tasks.rs @@ -1,17 +1,11 @@ -use crate::error::Error; use chrono::Utc; -use directories::ProjectDirs; -use futures::{StreamExt, TryStreamExt}; +use futures::StreamExt; use serde::de::DeserializeOwned; use serde::Serialize; -use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode}; -use sqlx::{QueryBuilder, Sqlite, SqlitePool}; use std::fmt::Display; -use std::path::PathBuf; -use futures::stream::BoxStream; use tabled::Tabled; -use tokio::fs; -use tracing::{info, instrument}; + +mod manager; #[derive(sqlx::Type, Debug, Clone)] #[repr(u8)] @@ -45,13 +39,19 @@ pub trait TaskPayloadKey { fn get_key(&self) -> String; } -pub trait TaskPayload: DeserializeOwned + Send + Unpin + 'static + Display + TaskPayloadKey {} -impl TaskPayload for T {} +pub trait TaskPayload: + Serialize + DeserializeOwned + Send + Unpin + 'static + Display + TaskPayloadKey +{ +} +impl + TaskPayload for T +{ +} -pub type TaskJob = fn(&Task) -> TaskStatus; +pub type TaskJob = fn(&Task) -> TaskStatus; #[derive(sqlx::FromRow, Tabled, Debug)] -pub struct Task { +pub struct Task { id: u32, payload_key: String, #[sqlx(json)] @@ -64,7 +64,13 @@ pub struct Task { updated_at: Option>, } -impl Task { +impl Task { + pub fn payload(&self) -> &T { + &self.payload + } +} + +impl Task { pub fn get_key(&self) -> String { self.payload_key.clone() } @@ -77,22 +83,3 @@ fn display_option_date(o: &Option>) -> String { } } - -struct TaskManager{} - -impl TaskManager { - // pub async fn run_tasks(&self, func: TaskJob) -> crate::Result<()> { - // let mut builder = Self::get_task_builder(Some(TaskStatus::Pending), None); - // - // let rows = builder.build_query_as::>().fetch(&self.pool); - // - // let result: Vec<(Task, TaskStatus)> = rows.map(|x| { - // let task = x.unwrap(); - // let status = func(&task); - // - // (task, status) - // }).collect().await; - // - // Ok(()) - // } -} diff --git a/lib_sync_core/src/tasks/manager.rs b/lib_sync_core/src/tasks/manager.rs new file mode 100644 index 0000000..bee3cb8 --- /dev/null +++ b/lib_sync_core/src/tasks/manager.rs @@ -0,0 +1,32 @@ +use futures::StreamExt; +use std::marker::PhantomData; +use crate::database::TaskStorage; +use crate::tasks::{Task, TaskJob, TaskPayload, TaskStatus}; + +struct TaskManager> +{ + storage: T, + _marker: PhantomData, +} + +impl> TaskManager { + pub fn new(storage: T) -> Self { + Self { + storage, + _marker: PhantomData, + } + } + + pub async fn run_tasks(&self, func: TaskJob) -> crate::Result<()> { + let rows = self.storage.get_tasks(TaskStatus::Pending); + + let result: Vec<(Task, TaskStatus)> = rows.map(|x| { + let task = x.unwrap(); + let status = func(&task); + + (task, status) + }).collect().await; + + Ok(()) + } +} \ No newline at end of file