diff --git a/.idea/sqldialects.xml b/.idea/sqldialects.xml index 604ccd2..325e5ee 100644 --- a/.idea/sqldialects.xml +++ b/.idea/sqldialects.xml @@ -2,7 +2,6 @@ - \ No newline at end of file diff --git a/migrations/0003_tasks.sql b/migrations/0003_tasks.sql index 02a1c26..0b7cb54 100644 --- a/migrations/0003_tasks.sql +++ b/migrations/0003_tasks.sql @@ -5,7 +5,7 @@ create table tasks primary key autoincrement, payload_key ANY not null constraint tasks_payload_key - unique, + unique on conflict ignore, payload TEXT not null, status_id integer not null constraint tasks_task_statuses_id_fk diff --git a/src/main.rs b/src/main.rs index 4129435..2760c1c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,7 +2,7 @@ use std::fs::File; use clap::Parser; use readwise_bulk_upload::config::Args; use readwise_bulk_upload::readwise::DocumentPayload; -use readwise_bulk_upload::sql::get_database; +use readwise_bulk_upload::sql::{TaskManager}; use readwise_bulk_upload::{Error, Result}; #[tokio::main] @@ -17,7 +17,9 @@ async fn main() -> Result<()> { let documents: Vec = serde_json::from_reader(file)?; - let db = get_database().await?; + let task_manager = TaskManager::new().await?; + + task_manager.load_tasks(documents).await?; Ok(()) } diff --git a/src/readwise.rs b/src/readwise.rs index e188c21..33bf9ed 100644 --- a/src/readwise.rs +++ b/src/readwise.rs @@ -1,8 +1,9 @@ use chrono::{DateTime, Local}; -use serde::{Deserialize, Deserializer, de}; +use serde::{Deserialize, Deserializer, de, Serialize}; use serde_json::Value; +use crate::sql::TaskPayload; -#[derive(Deserialize)] +#[derive(Deserialize, Serialize, Debug)] pub struct DocumentPayload { title: String, summary: Option, @@ -13,20 +14,10 @@ pub struct DocumentPayload { location: String, } -fn str_to_int<'de, D: Deserializer<'de>>(deserializer: D) -> Result { - Ok(match Value::deserialize(deserializer)? { - Value::String(s) => s.parse().map_err(de::Error::custom)?, - Value::Number(num) => num.as_u64().ok_or(de::Error::custom("Invalid number"))?, - _ => return Err(de::Error::custom("wrong type")), - }) -} - -fn str_to_bool<'de, D: Deserializer<'de>>(deserializer: D) -> Result { - Ok(match Value::deserialize(deserializer)? { - Value::String(s) => s.parse().map_err(de::Error::custom)?, - Value::Bool(b) => b, - _ => return Err(de::Error::custom("wrong type")), - }) +impl TaskPayload for DocumentPayload { + fn get_key(&self) -> String { + self.url.clone() + } } fn single_or_vec<'de, D: Deserializer<'de>>(deserializer: D) -> Result, D::Error> { diff --git a/src/sql.rs b/src/sql.rs index 3c2d1c7..580431e 100644 --- a/src/sql.rs +++ b/src/sql.rs @@ -1,25 +1,84 @@ use crate::Error; use directories::ProjectDirs; +use serde::Serialize; use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode}; -use sqlx::SqlitePool; +use sqlx::{QueryBuilder, Sqlite, SqlitePool}; use tokio::fs; -pub async fn get_database() -> crate::Result { - let project_dir = ProjectDirs::from("", "", env!("CARGO_PKG_NAME")) - .ok_or(Error::Unhandled("Could not get standard directories"))?; +static SQLITE_QUERY_LIMIT: usize = 32766; - let database_file_path = project_dir.data_dir().join("db.sql"); - - fs::create_dir_all(project_dir.data_dir()).await?; - - let opts = SqliteConnectOptions::new() - .filename(database_file_path) - .create_if_missing(true) - .journal_mode(SqliteJournalMode::Wal); - - let pool = SqlitePool::connect_with(opts).await?; - - sqlx::migrate!("./migrations").run(&pool).await?; - - Ok(pool) +#[derive(sqlx::Type)] +#[repr(u8)] +pub enum TaskStatus { + Pending = 1, + InProgress = 2, + Completed = 3, + Failed = 4, +} + +pub trait TaskPayload { + fn get_key(&self) -> String; +} + +pub struct TaskManager { + pool: SqlitePool, +} + +impl TaskManager { + pub async fn new() -> Result { + Ok(Self { + pool: Self::connect_database().await?, + }) + } + + async fn connect_database() -> crate::Result { + let project_dir = ProjectDirs::from("", "", env!("CARGO_PKG_NAME")) + .ok_or(Error::Unhandled("Could not get standard directories"))?; + + let database_file_path = project_dir.data_dir().join("db.sql"); + + fs::create_dir_all(project_dir.data_dir()).await?; + + let opts = SqliteConnectOptions::new() + .filename(database_file_path) + .create_if_missing(true) + .journal_mode(SqliteJournalMode::Wal); + + let pool = SqlitePool::connect_with(opts).await?; + + sqlx::migrate!("./migrations").run(&pool).await?; + + Ok(pool) + } + + pub async fn load_tasks(&self, values: Vec) -> crate::Result<()> + where + T: TaskPayload + Serialize, + { + let mut builder: QueryBuilder<'_, Sqlite> = + QueryBuilder::new("insert into tasks(payload_key, payload, status_id)"); + + let args: crate::Result> = values + .iter() + .map(|value| Ok((value.get_key(), serde_json::to_string(value)?))) + .collect(); + + builder.push_values( + args?.into_iter().take(SQLITE_QUERY_LIMIT / 3), + |mut builder, item| { + builder + .push_bind(item.0) + .push_bind(item.1) + .push_bind(TaskStatus::Pending); + }, + ); + + builder.push("ON conflict (payload_key) DO NOTHING"); + + let query = builder.build(); + + query.execute(&self.pool).await?; + + Ok(()) + } }