diff --git a/.idea/runConfigurations/Load_Tasks.xml b/.idea/runConfigurations/Load_Tasks.xml deleted file mode 100644 index 41f2816..0000000 --- a/.idea/runConfigurations/Load_Tasks.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/Query_Tasks.xml b/.idea/runConfigurations/Query_Tasks.xml deleted file mode 100644 index ef4b918..0000000 --- a/.idea/runConfigurations/Query_Tasks.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/Run_Tasks.xml b/.idea/runConfigurations/Run_Tasks.xml deleted file mode 100644 index ef08a15..0000000 --- a/.idea/runConfigurations/Run_Tasks.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml index 8d3e42f..35eb1dd 100644 --- a/.idea/vcs.xml +++ b/.idea/vcs.xml @@ -1,10 +1,5 @@ - - - diff --git a/Cargo.lock b/Cargo.lock index dfdfe03..87f2582 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,19 +17,6 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" -[[package]] -name = "ahash" -version = "0.8.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" -dependencies = [ - "cfg-if", - "getrandom 0.3.2", - "once_cell", - "version_check", - "zerocopy", -] - [[package]] name = "aho-corasick" version = "1.1.3" @@ -185,12 +172,6 @@ version = "3.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf" -[[package]] -name = "bytecount" -version = "0.6.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ce89b21cab1437276d2650d57e971f9d548a2d9037cc231abdc0562b97498ce" - [[package]] name = "bytemuck" version = "1.23.0" @@ -508,21 +489,6 @@ dependencies = [ "percent-encoding", ] -[[package]] -name = "futures" -version = "0.3.31" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" -dependencies = [ - "futures-channel", - "futures-core", - "futures-executor", - "futures-io", - "futures-sink", - "futures-task", - "futures-util", -] - [[package]] name = "futures-channel" version = "0.3.31" @@ -567,17 +533,6 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" -[[package]] -name = "futures-macro" -version = "0.3.31" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "futures-sink" version = "0.3.31" @@ -596,10 +551,8 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ - "futures-channel", "futures-core", "futures-io", - "futures-macro", "futures-sink", "futures-task", "memchr", @@ -1106,17 +1059,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" -[[package]] -name = "papergrid" -version = "0.15.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30268a8d20c2c0d126b2b6610ab405f16517f6ba9f244d8c59ac2c512a8a1ce7" -dependencies = [ - "ahash", - "bytecount", - "unicode-width", -] - [[package]] name = "parking" version = "2.2.1" @@ -1232,28 +1174,6 @@ dependencies = [ "zerocopy", ] -[[package]] -name = "proc-macro-error-attr2" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96de42df36bb9bba5542fe9f1a054b8cc87e172759a1868aa05c1f3acc89dfc5" -dependencies = [ - "proc-macro2", - "quote", -] - -[[package]] -name = "proc-macro-error2" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11ec05c52be0a07b08061f7dd003e7d7092e0472bc731b4af7bb1ef876109802" -dependencies = [ - "proc-macro-error-attr2", - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "proc-macro2" version = "1.0.95" @@ -1329,11 +1249,9 @@ dependencies = [ "clap", "directories", "figment", - "futures", "serde", "serde_json", "sqlx", - "tabled", "thiserror", "tokio", "tracing", @@ -1844,30 +1762,6 @@ dependencies = [ "syn", ] -[[package]] -name = "tabled" -version = "0.19.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "228d124371171cd39f0f454b58f73ddebeeef3cef3207a82ffea1c29465aea43" -dependencies = [ - "papergrid", - "tabled_derive", - "testing_table", -] - -[[package]] -name = "tabled_derive" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ea5d1b13ca6cff1f9231ffd62f15eefd72543dab5e468735f1a456728a02846" -dependencies = [ - "heck", - "proc-macro-error2", - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "tempfile" version = "3.19.1" @@ -1881,15 +1775,6 @@ dependencies = [ "windows-sys 0.59.0", ] -[[package]] -name = "testing_table" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f8daae29995a24f65619e19d8d31dea5b389f3d853d8bf297bbf607cd0014cc" -dependencies = [ - "unicode-width", -] - [[package]] name = "thiserror" version = "2.0.12" @@ -2087,12 +1972,6 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e70f2a8b45122e719eb623c01822704c4e0907e7e426a05927e1a1cfff5b75d0" -[[package]] -name = "unicode-width" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" - [[package]] name = "url" version = "2.5.4" diff --git a/Cargo.toml b/Cargo.toml index 3f51301..da95ddc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,5 +16,3 @@ tracing = "0.1.41" tracing-subscriber = { version = "0.3.19" , features = ["env-filter"]} figment = { version = "0.10.19", features = ["env"] } tracing-core = "0.1.33" -tabled = "0.19.0" -futures = "0.3.31" \ No newline at end of file diff --git a/src/config.rs b/src/config.rs index 04bdbff..680fdf7 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,4 +1,4 @@ -use clap::{Parser, Subcommand, ValueEnum}; +use clap::{Parser, ValueEnum}; use serde::{Deserialize, Serialize}; use std::fmt; use std::path::PathBuf; @@ -57,32 +57,9 @@ impl Into for VerbosityLevel { } } -#[derive(Debug, Subcommand)] -#[clap(rename_all = "snake_case")] -pub enum Command { - /// Load task into the database from [path] - LoadTasks{ - /// Path to the file - path: PathBuf, - }, - Query, - Run, - #[clap(skip)] - None, -} - -impl Default for Command { - fn default() -> Self { - Command::None - } -} - #[derive(Debug, Parser, Serialize, Deserialize)] pub struct Config { - #[command(subcommand)] - #[serde(skip)] - pub command: Command, - + path: PathBuf, #[arg( long, short = 'v', @@ -94,6 +71,9 @@ pub struct Config { } impl Config { + pub fn path(&self) -> &PathBuf { + &self.path + } pub fn log_level(&self) -> LevelFilter { self.log_level.clone().into() diff --git a/src/lib.rs b/src/lib.rs index f63d5eb..dabd6d3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,5 @@ mod error; -pub mod task_manager; +pub mod sql; pub mod config; pub mod readwise; diff --git a/src/main.rs b/src/main.rs index ff1f2ac..985c0e4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,21 +1,16 @@ -use clap::{CommandFactory, Parser}; -use figment::{ - providers::{Env, Serialized}, - Figment, -}; -use readwise_bulk_upload::config::{Command, Config}; +use clap::Parser; +use readwise_bulk_upload::config::Config; use readwise_bulk_upload::readwise::DocumentPayload; -use readwise_bulk_upload::task_manager::{TaskManager, TaskStatus}; +use readwise_bulk_upload::sql::TaskManager; use readwise_bulk_upload::{Error, Result}; use std::fs::File; -use tabled::Table; use tracing_subscriber; +use figment::{Figment, providers::{Serialized, Env}}; #[tokio::main] async fn main() -> Result<()> { - let cli = Config::parse(); let args: Config = Figment::new() - .merge(Serialized::defaults(&cli)) + .merge(Serialized::defaults(Config::parse())) .merge(Env::prefixed("APP_")) .extract()?; @@ -23,43 +18,18 @@ async fn main() -> Result<()> { .with_max_level(args.log_level()) .init(); - run(&cli.command).await?; + let file = File::open(args.path()).map_err(|_| { + Error::Runtime(format!( + r#"The file "{}" could not be open"#, + args.path().display() + )) + })?; - Ok(()) -} + let documents: Vec = serde_json::from_reader(file)?; -async fn run(command: &Command) -> Result<()> { let task_manager = TaskManager::new().await?; - match command { - Command::LoadTasks { path } => { - let file = File::open(path).map_err(|_| { - Error::Runtime(format!( - r#"The file "{}" could not be open"#, - path.display() - )) - })?; - let documents: Vec = serde_json::from_reader(file)?; - - - task_manager.load_tasks(documents).await?; - } - Command::Query => { - let tasks = task_manager.get_tasks::(None, Some(25)).await?; - - println!("{}", Table::new(tasks)); - } - Command::Run => { - task_manager.run_tasks::(|task| { - println!("{}", task.get_key()); - - TaskStatus::Completed - }).await?; - } - Command::None => { - Config::command().print_help()?; - } - } + task_manager.load_tasks(documents).await?; Ok(()) } diff --git a/src/readwise.rs b/src/readwise.rs index e3d7a42..33bf9ed 100644 --- a/src/readwise.rs +++ b/src/readwise.rs @@ -1,8 +1,7 @@ -use crate::task_manager::TaskPayload; use chrono::{DateTime, Local}; -use serde::{de, Deserialize, Deserializer, Serialize}; +use serde::{Deserialize, Deserializer, de, Serialize}; use serde_json::Value; -use std::fmt::Display; +use crate::sql::TaskPayload; #[derive(Deserialize, Serialize, Debug)] pub struct DocumentPayload { @@ -15,16 +14,6 @@ pub struct DocumentPayload { location: String, } -impl Display for DocumentPayload { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}", - serde_json::to_string_pretty(self).map_err(|_| std::fmt::Error)? - ) - } -} - impl TaskPayload for DocumentPayload { fn get_key(&self) -> String { self.url.clone() diff --git a/src/sql.rs b/src/sql.rs new file mode 100644 index 0000000..7d2a442 --- /dev/null +++ b/src/sql.rs @@ -0,0 +1,94 @@ +use crate::Error; +use directories::ProjectDirs; +use serde::Serialize; +use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode}; +use sqlx::{QueryBuilder, Sqlite, SqlitePool}; +use tokio::fs; +use tracing::{info, instrument}; + +static SQLITE_BIND_LIMIT: usize = 32766; + +#[derive(sqlx::Type)] +#[repr(u8)] +pub enum TaskStatus { + Pending = 1, + InProgress = 2, + Completed = 3, + Failed = 4, +} + +pub trait TaskPayload { + fn get_key(&self) -> String; +} + +#[derive(Debug)] +pub struct TaskManager { + pool: SqlitePool, +} + +impl TaskManager { + pub async fn new() -> Result { + Ok(Self { + pool: Self::connect_database().await?, + }) + } + + async fn connect_database() -> crate::Result { + let project_dir = ProjectDirs::from("", "", env!("CARGO_PKG_NAME")) + .ok_or(Error::Unhandled("Could not get standard directories"))?; + + let database_file_path = project_dir.data_dir().join("db.sql"); + + fs::create_dir_all(project_dir.data_dir()).await?; + + let opts = SqliteConnectOptions::new() + .filename(database_file_path) + .create_if_missing(true) + .journal_mode(SqliteJournalMode::Wal); + + let pool = SqlitePool::connect_with(opts).await?; + + sqlx::migrate!("./migrations").run(&pool).await?; + + Ok(pool) + } + + #[instrument(skip(self, values))] + pub async fn load_tasks(&self, values: Vec) -> crate::Result<()> + where + T: TaskPayload + Serialize + std::fmt::Debug, + { + let mut tx = self.pool.begin().await?; + let mut builder: QueryBuilder<'_, Sqlite> = + QueryBuilder::new("insert into tasks(payload_key, payload, status_id)"); + + let args: crate::Result> = values + .iter() + .map(|value| Ok((value.get_key(), serde_json::to_string(value)?))) + .collect(); + + + let mut affected_rows = 0; + // Chunk the query by the size limit of bind params + for chunk in args?.chunks(SQLITE_BIND_LIMIT / 3) { + builder.push_values(chunk, |mut builder, item| { + builder + .push_bind(&item.0) + .push_bind(&item.1) + .push_bind(TaskStatus::Pending); + }); + builder.push("ON conflict (payload_key) DO NOTHING"); + + let query = builder.build(); + + affected_rows += query.execute(&mut *tx).await?.rows_affected(); + builder.reset(); + } + + tx.commit().await?; + + info!("{} rows inserted.", affected_rows); + + Ok(()) + } +} diff --git a/src/task_manager.rs b/src/task_manager.rs deleted file mode 100644 index 11401e2..0000000 --- a/src/task_manager.rs +++ /dev/null @@ -1,196 +0,0 @@ -use crate::Error; -use chrono::Utc; -use directories::ProjectDirs; -use futures::{StreamExt, TryStreamExt}; -use serde::de::DeserializeOwned; -use serde::Serialize; -use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode}; -use sqlx::{QueryBuilder, Sqlite, SqlitePool}; -use std::fmt::Display; -use tabled::Tabled; -use tokio::fs; -use tracing::{info, instrument}; - -static SQLITE_BIND_LIMIT: usize = 32766; - -#[derive(sqlx::Type, Debug)] -#[repr(u8)] -pub enum TaskStatus { - Pending = 1, - InProgress = 2, - Completed = 3, - Failed = 4, -} - -impl Display for TaskStatus { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - TaskStatus::Pending => { - write!(f, "Pending") - } - TaskStatus::InProgress => { - write!(f, "In Progress") - } - TaskStatus::Completed => { - write!(f, "Completed") - } - TaskStatus::Failed => { - write!(f, "Failed") - } - } - } -} - -pub trait TaskPayload { - fn get_key(&self) -> String; -} - -pub type TaskJob = fn(&Task) -> TaskStatus; - -#[derive(sqlx::FromRow, Tabled, Debug)] -pub struct Task { - id: u32, - payload_key: String, - #[sqlx(json)] - #[tabled(skip)] - payload: T, - #[sqlx(rename = "status_id")] - status: TaskStatus, - created_at: chrono::DateTime, - #[tabled(display = "display_option_date")] - updated_at: Option>, -} - -impl Task { - pub fn get_key(&self) -> String { - self.payload_key.clone() - } -} - -fn display_option_date(o: &Option>) -> String { - match o { - Some(s) => format!("{}", s), - None => String::from(""), - } -} - -pub trait _Task: DeserializeOwned + Send + Unpin + 'static + Display {} -impl _Task for T {} - -#[derive(Debug)] -pub struct TaskManager { - pool: SqlitePool, -} - -impl TaskManager { - pub async fn new() -> Result { - Ok(Self { - pool: Self::connect_database().await?, - }) - } - - async fn connect_database() -> crate::Result { - let project_dir = ProjectDirs::from("", "", env!("CARGO_PKG_NAME")) - .ok_or(Error::Unhandled("Could not get standard directories"))?; - - let database_file_path = project_dir.data_dir().join("db.sql"); - - fs::create_dir_all(project_dir.data_dir()).await?; - - let opts = SqliteConnectOptions::new() - .filename(database_file_path) - .create_if_missing(true) - .journal_mode(SqliteJournalMode::Wal); - - let pool = SqlitePool::connect_with(opts).await?; - - sqlx::migrate!("./migrations").run(&pool).await?; - - Ok(pool) - } - - fn get_task_builder( - status: Option, - limit: Option, - ) -> QueryBuilder<'static, Sqlite> { - let mut builder: QueryBuilder<'_, Sqlite> = QueryBuilder::new( - "select id, payload_key, payload, status_id, created_at, updated_at from tasks ", - ); - - if let Some(status) = status { - builder.push("where status_id = ").push_bind(status); - } - - builder.push("ORDER BY created_at DESC "); - - if let Some(limit) = limit { - builder.push("LIMIT ").push_bind(limit); - } - builder - } - - pub async fn get_tasks( - &self, - status: Option, - limit: Option, - ) -> crate::Result>> { - let mut builder = Self::get_task_builder(status, limit); - - let tasks: Vec> = builder.build_query_as().fetch_all(&self.pool).await?; - - Ok(tasks) - } - - #[instrument(skip(self, values))] - pub async fn load_tasks(&self, values: Vec) -> crate::Result<()> - where - T: TaskPayload + Serialize + std::fmt::Debug, - { - let mut tx = self.pool.begin().await?; - let mut builder: QueryBuilder<'_, Sqlite> = - QueryBuilder::new("insert into tasks(payload_key, payload, status_id)"); - - let args: crate::Result> = values - .iter() - .map(|value| Ok((value.get_key(), serde_json::to_string(value)?))) - .collect(); - - let mut affected_rows = 0; - // Chunk the query by the size limit of bind params - for chunk in args?.chunks(SQLITE_BIND_LIMIT / 3) { - builder.push_values(chunk, |mut builder, item| { - builder - .push_bind(&item.0) - .push_bind(&item.1) - .push_bind(TaskStatus::Pending); - }); - builder.push("ON conflict (payload_key) DO NOTHING"); - - let query = builder.build(); - - affected_rows += query.execute(&mut *tx).await?.rows_affected(); - builder.reset(); - } - - tx.commit().await?; - - info!("{} rows inserted.", affected_rows); - - Ok(()) - } - - pub async fn run_tasks(&self, func: TaskJob) -> crate::Result<()> { - let mut builder = Self::get_task_builder(Some(TaskStatus::Pending), None); - - let rows = builder.build_query_as::>().fetch(&self.pool); - - let result: Vec<(Task, TaskStatus)> = rows.map(|x| { - let task = x.unwrap(); - let status = func(&task); - - (task, status) - }).collect().await; - - Ok(()) - } -}