Skip to content

Commit

Permalink
Shuffle elements with the same review length randomly
Browse files Browse the repository at this point in the history
  • Loading branch information
dae committed Aug 26, 2023
1 parent 6974454 commit 682579a
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 5 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ log = "0.4"
rusqlite = { version = "0.29.0" }
chrono = "0.4.26"
chrono-tz = "0.8.3"
itertools = "0.11.0"
itertools = "0.11.0"
rand = "0.8.5"
15 changes: 14 additions & 1 deletion src/convertor.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use chrono::prelude::*;
use chrono_tz::Tz;
use itertools::Itertools;
use rand::Rng;
use rusqlite::{Connection, Result, Row};
use std::collections::HashMap;

Expand Down Expand Up @@ -198,7 +199,19 @@ pub fn anki_to_fsrs() -> Vec<FSRSItem> {
.filter_map(|entries| convert_to_fsrs_items(entries, 4, Tz::Asia__Shanghai))
.flatten()
.collect_vec();
revlogs.sort_by_key(|r| r.reviews.len());
let mut rng = rand::thread_rng();
// each time we see a new size, we assign it a random number
let mut size_to_random_map = HashMap::new();
let mut size_to_random = |r: &FSRSItem| -> i32 {
*size_to_random_map
.entry(r.reviews.len())
.or_insert_with(|| rng.gen::<i32>())
};
revlogs.sort_unstable_by(|a, b| {
let a = size_to_random(a);
let b = size_to_random(b);
a.cmp(&b)
});
revlogs
}

Expand Down
1 change: 0 additions & 1 deletion src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ fn test_from_anki() {
use burn::data::dataloader::DataLoaderBuilder;
let dataloader = DataLoaderBuilder::new(batcher)
.batch_size(1)
.shuffle(42)
.num_workers(4)
.build(dataset);
dbg!(
Expand Down
2 changes: 0 additions & 2 deletions src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,11 @@ pub fn train<B: ADBackend<FloatElem = f32>>(

let dataloader_train = DataLoaderBuilder::new(batcher_train)
.batch_size(config.batch_size)
// .shuffle(config.seed)
.num_workers(config.num_workers)
.build(FSRSDataset::train());

let dataloader_test = DataLoaderBuilder::new(batcher_valid)
.batch_size(config.batch_size)
// .shuffle(config.seed)
.num_workers(config.num_workers)
.build(FSRSDataset::test());

Expand Down

0 comments on commit 682579a

Please sign in to comment.