diff --git a/.idea/sqldialects.xml b/.idea/sqldialects.xml
index 604ccd2..325e5ee 100644
--- a/.idea/sqldialects.xml
+++ b/.idea/sqldialects.xml
@@ -2,7 +2,6 @@
-
\ No newline at end of file
diff --git a/migrations/0003_tasks.sql b/migrations/0003_tasks.sql
index 02a1c26..0b7cb54 100644
--- a/migrations/0003_tasks.sql
+++ b/migrations/0003_tasks.sql
@@ -5,7 +5,7 @@ create table tasks
primary key autoincrement,
payload_key ANY not null
constraint tasks_payload_key
- unique,
+ unique on conflict ignore,
payload TEXT not null,
status_id integer not null
constraint tasks_task_statuses_id_fk
diff --git a/src/main.rs b/src/main.rs
index 4129435..2760c1c 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -2,7 +2,7 @@ use std::fs::File;
use clap::Parser;
use readwise_bulk_upload::config::Args;
use readwise_bulk_upload::readwise::DocumentPayload;
-use readwise_bulk_upload::sql::get_database;
+use readwise_bulk_upload::sql::{TaskManager};
use readwise_bulk_upload::{Error, Result};
#[tokio::main]
@@ -17,7 +17,9 @@ async fn main() -> Result<()> {
let documents: Vec = serde_json::from_reader(file)?;
- let db = get_database().await?;
+ let task_manager = TaskManager::new().await?;
+
+ task_manager.load_tasks(documents).await?;
Ok(())
}
diff --git a/src/readwise.rs b/src/readwise.rs
index e188c21..33bf9ed 100644
--- a/src/readwise.rs
+++ b/src/readwise.rs
@@ -1,8 +1,9 @@
use chrono::{DateTime, Local};
-use serde::{Deserialize, Deserializer, de};
+use serde::{Deserialize, Deserializer, de, Serialize};
use serde_json::Value;
+use crate::sql::TaskPayload;
-#[derive(Deserialize)]
+#[derive(Deserialize, Serialize, Debug)]
pub struct DocumentPayload {
title: String,
summary: Option,
@@ -13,20 +14,10 @@ pub struct DocumentPayload {
location: String,
}
-fn str_to_int<'de, D: Deserializer<'de>>(deserializer: D) -> Result {
- Ok(match Value::deserialize(deserializer)? {
- Value::String(s) => s.parse().map_err(de::Error::custom)?,
- Value::Number(num) => num.as_u64().ok_or(de::Error::custom("Invalid number"))?,
- _ => return Err(de::Error::custom("wrong type")),
- })
-}
-
-fn str_to_bool<'de, D: Deserializer<'de>>(deserializer: D) -> Result {
- Ok(match Value::deserialize(deserializer)? {
- Value::String(s) => s.parse().map_err(de::Error::custom)?,
- Value::Bool(b) => b,
- _ => return Err(de::Error::custom("wrong type")),
- })
+impl TaskPayload for DocumentPayload {
+ fn get_key(&self) -> String {
+ self.url.clone()
+ }
}
fn single_or_vec<'de, D: Deserializer<'de>>(deserializer: D) -> Result, D::Error> {
diff --git a/src/sql.rs b/src/sql.rs
index 3c2d1c7..580431e 100644
--- a/src/sql.rs
+++ b/src/sql.rs
@@ -1,25 +1,84 @@
use crate::Error;
use directories::ProjectDirs;
+use serde::Serialize;
use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode};
-use sqlx::SqlitePool;
+use sqlx::{QueryBuilder, Sqlite, SqlitePool};
use tokio::fs;
-pub async fn get_database() -> crate::Result {
- let project_dir = ProjectDirs::from("", "", env!("CARGO_PKG_NAME"))
- .ok_or(Error::Unhandled("Could not get standard directories"))?;
+static SQLITE_QUERY_LIMIT: usize = 32766;
- let database_file_path = project_dir.data_dir().join("db.sql");
-
- fs::create_dir_all(project_dir.data_dir()).await?;
-
- let opts = SqliteConnectOptions::new()
- .filename(database_file_path)
- .create_if_missing(true)
- .journal_mode(SqliteJournalMode::Wal);
-
- let pool = SqlitePool::connect_with(opts).await?;
-
- sqlx::migrate!("./migrations").run(&pool).await?;
-
- Ok(pool)
+#[derive(sqlx::Type)]
+#[repr(u8)]
+pub enum TaskStatus {
+ Pending = 1,
+ InProgress = 2,
+ Completed = 3,
+ Failed = 4,
+}
+
+pub trait TaskPayload {
+ fn get_key(&self) -> String;
+}
+
+pub struct TaskManager {
+ pool: SqlitePool,
+}
+
+impl TaskManager {
+ pub async fn new() -> Result {
+ Ok(Self {
+ pool: Self::connect_database().await?,
+ })
+ }
+
+ async fn connect_database() -> crate::Result {
+ let project_dir = ProjectDirs::from("", "", env!("CARGO_PKG_NAME"))
+ .ok_or(Error::Unhandled("Could not get standard directories"))?;
+
+ let database_file_path = project_dir.data_dir().join("db.sql");
+
+ fs::create_dir_all(project_dir.data_dir()).await?;
+
+ let opts = SqliteConnectOptions::new()
+ .filename(database_file_path)
+ .create_if_missing(true)
+ .journal_mode(SqliteJournalMode::Wal);
+
+ let pool = SqlitePool::connect_with(opts).await?;
+
+ sqlx::migrate!("./migrations").run(&pool).await?;
+
+ Ok(pool)
+ }
+
+ pub async fn load_tasks(&self, values: Vec) -> crate::Result<()>
+ where
+ T: TaskPayload + Serialize,
+ {
+ let mut builder: QueryBuilder<'_, Sqlite> =
+ QueryBuilder::new("insert into tasks(payload_key, payload, status_id)");
+
+ let args: crate::Result> = values
+ .iter()
+ .map(|value| Ok((value.get_key(), serde_json::to_string(value)?)))
+ .collect();
+
+ builder.push_values(
+ args?.into_iter().take(SQLITE_QUERY_LIMIT / 3),
+ |mut builder, item| {
+ builder
+ .push_bind(item.0)
+ .push_bind(item.1)
+ .push_bind(TaskStatus::Pending);
+ },
+ );
+
+ builder.push("ON conflict (payload_key) DO NOTHING");
+
+ let query = builder.build();
+
+ query.execute(&self.pool).await?;
+
+ Ok(())
+ }
}