From 7d80f4eeac27faa923a50750247412be41bfd310 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Fri, 25 Aug 2023 08:42:58 +0900 Subject: [PATCH] add explanation for the sql query (#23) * add explanation for sql query. --- src/convertor.rs | 60 +++++++++++++++++++------------------------ src/model.rs | 4 +-- src/weight_clipper.rs | 7 ++--- 3 files changed, 32 insertions(+), 39 deletions(-) diff --git a/src/convertor.rs b/src/convertor.rs index 8b71ae6..050a35d 100644 --- a/src/convertor.rs +++ b/src/convertor.rs @@ -5,7 +5,7 @@ use std::collections::HashMap; use crate::dataset::{FSRSItem, FSRSReview}; -#[derive(Debug, Clone)] +#[derive(Clone, Debug, Default)] struct RevlogEntry { id: i64, cid: i64, @@ -24,8 +24,8 @@ fn row_to_revlog_entry(row: &Row) -> Result { }) } -fn read_collection() -> Vec { - let db = Connection::open("tests/data/collection.anki21").unwrap(); +fn read_collection() -> Result> { + let db = Connection::open("tests/data/collection.anki21")?; let filter_out_suspended_cards = false; let filter_out_flags = vec![]; let flags_str = if !filter_out_flags.is_empty() { @@ -34,7 +34,7 @@ fn read_collection() -> Vec { filter_out_flags .iter() .map(|x: &i32| x.to_string()) - .collect::>() + .collect::>() .join(", ") ) } else { @@ -48,32 +48,26 @@ fn read_collection() -> Vec { }; let current_timestamp = Utc::now().timestamp() * 1000; - - let query = format!( - "SELECT id, cid, ease, type - FROM revlog - WHERE (type != 4 OR ivl <= 0) - AND (factor != 0 or type != 3) - AND id < {} - AND cid < {} - AND cid IN ( - SELECT id - FROM cards - WHERE queue != 0 - {} - {} - )", - current_timestamp, current_timestamp, suspended_cards_str, flags_str - ); - + // This sql query will be remove in the futrue. See https://github.com/open-spaced-repetition/fsrs-optimizer-burn/pull/14#issuecomment-1685895643 let revlogs = db - .prepare_cached(&query) - .unwrap() - .query_and_then([], row_to_revlog_entry) - .unwrap() - .collect::>>() - .unwrap(); - revlogs + .prepare_cached(&format!( + "SELECT id, cid, ease, type + FROM revlog + WHERE (type != 4 OR ivl <= 0) + AND (factor != 0 or type != 3) + AND id < ?1 + AND cid < ?2 + AND cid IN ( + SELECT id + FROM cards + WHERE queue != 0 + {suspended_cards_str} + {flags_str} + )" + ))? + .query_and_then((current_timestamp, current_timestamp), row_to_revlog_entry)? + .collect::>>()?; + Ok(revlogs) } fn group_by_cid(revlogs: Vec) -> Vec> { @@ -136,9 +130,7 @@ fn convert_to_fsrs_items( // Increment review_kind of all entries by 1 // 将所有 review_kind + 1 - for entry in &mut entries { - entry.review_kind += 1; - } + entries.iter_mut().for_each(|entry| entry.review_kind += 1); // Convert the timestamp and keep the first RevlogEntry for each date // 转换时间戳并保留每个日期的第一个 RevlogEntry @@ -198,7 +190,7 @@ fn convert_to_fsrs_items( } pub fn anki_to_fsrs() -> Vec { - let revlogs = read_collection(); + let revlogs = read_collection().expect("read error"); let revlogs_per_card = group_by_cid(revlogs); revlogs_per_card .into_iter() @@ -219,7 +211,7 @@ mod tests { // https://github.com/open-spaced-repetition/fsrs-optimizer-burn/files/12394182/collection.anki21.zip #[test] fn test() { - let revlogs = read_collection(); + let revlogs = read_collection().unwrap(); let single_card_revlog = vec![revlogs .iter() .filter(|r| r.cid == 1528947214762) diff --git a/src/model.rs b/src/model.rs index ff1bb55..89888ea 100644 --- a/src/model.rs +++ b/src/model.rs @@ -300,8 +300,8 @@ fn test_forward() { [1.0, 2.0, 3.0, 4.0, 1.0, 2.0], ]); let (stability, difficulty) = model.forward(delta_ts, ratings); - println!("stability {:?}", stability); - println!("difficulty {:?}", difficulty); + dbg!(&stability); + dbg!(&difficulty); } #[cfg(test)] diff --git a/src/weight_clipper.rs b/src/weight_clipper.rs index 51bc62c..6555d68 100644 --- a/src/weight_clipper.rs +++ b/src/weight_clipper.rs @@ -20,9 +20,10 @@ pub fn weight_clipper>(weights: Tensor) -> Ten let val: &mut Vec = &mut weights.to_data().value; - for (i, w) in val.iter_mut().skip(4).enumerate() { - *w = w.clamp(CLAMPS[i].0, CLAMPS[i].1); - } + val.iter_mut() + .skip(4) + .zip(CLAMPS) + .for_each(|(w, (low, high))| *w = w.clamp(low, high)); Tensor::from_data(Data::new(val.clone(), weights.shape())) }