From ac1b57ba82f9199eb92b51238550722ee720e5fa Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Sat, 26 Aug 2023 16:18:07 +0800 Subject: [PATCH] implement BatchShuffledDataset --- src/convertor.rs | 48 ++++++++---------------- src/dataloader.rs | 95 +++++++++++++++++++++++++++++++++++++++++++++++ src/dataset.rs | 1 + src/lib.rs | 2 + src/training.rs | 5 +-- 5 files changed, 115 insertions(+), 36 deletions(-) create mode 100644 src/dataloader.rs diff --git a/src/convertor.rs b/src/convertor.rs index c49fd85..ae090c1 100644 --- a/src/convertor.rs +++ b/src/convertor.rs @@ -1,7 +1,6 @@ use chrono::prelude::*; use chrono_tz::Tz; use itertools::Itertools; -use rand::Rng; use rusqlite::{Connection, Result, Row}; use std::collections::HashMap; @@ -81,7 +80,7 @@ fn group_by_cid(revlogs: Vec) -> Vec> { .push(revlog); } - grouped.into_values().collect() + grouped.into_values().sorted_by_cached_key(|revlog| revlog.get(0).unwrap().cid).collect() } fn convert_to_date(timestamp: i64, next_day_starts_at: i64, timezone: Tz) -> chrono::NaiveDate { @@ -194,40 +193,13 @@ fn convert_to_fsrs_items( pub fn anki_to_fsrs() -> Vec { let revlogs = read_collection().expect("read error"); let revlogs_per_card = group_by_cid(revlogs); - // collect FSRS items into a map by sequence size - let mut total_fsrs_items = 0; - let mut revlogs_by_seq_size: HashMap> = HashMap::new(); - revlogs_per_card + let mut revlogs = revlogs_per_card .into_iter() .filter_map(|entries| convert_to_fsrs_items(entries, 4, Tz::Asia__Shanghai)) .flatten() - .for_each(|r| { - total_fsrs_items += 1; - revlogs_by_seq_size - .entry(r.reviews.len()) - .or_default() - .push(r) - }); - let mut sizes = revlogs_by_seq_size.keys().copied().collect_vec(); - let mut rng = rand::thread_rng(); - let mut out: Vec = Vec::with_capacity(total_fsrs_items); - while !sizes.is_empty() { - // pick a random sequence size - let size_idx = rng.gen_range(0..sizes.len() as u32) as usize; - let size = &mut sizes[size_idx]; - let items = revlogs_by_seq_size.get_mut(size).unwrap(); - // add up to 512 items from it to the output vector - for _ in 0..512 { - let Some(item) = items.pop() else { - // this size has run out of items; clear it from available sizes - sizes.swap_remove(size_idx); - break; - }; - out.push(item); - } - } - - out + .collect_vec(); + revlogs.sort_by_cached_key(|r| r.reviews.len()); + revlogs } #[cfg(test)] @@ -392,4 +364,14 @@ mod tests { ); assert_eq!(res.labels.to_data(), Data::from([1])); } + + #[test] + fn test_order() { + let revlogs = read_collection().unwrap(); + let revlogs_per_card = group_by_cid(revlogs); + assert_eq!( + revlogs_per_card.get(0).unwrap().get(0).unwrap().id, + 1528956429777 + ); + } } diff --git a/src/dataloader.rs b/src/dataloader.rs new file mode 100644 index 0000000..f3b9f8d --- /dev/null +++ b/src/dataloader.rs @@ -0,0 +1,95 @@ +use burn::data::dataset::Dataset; +use rand::{prelude::SliceRandom, rngs::StdRng, SeedableRng}; +use std::marker::PhantomData; + +pub struct BatchShuffledDataset { + dataset: D, + indices: Vec, + input: PhantomData, +} + +impl BatchShuffledDataset +where + D: Dataset, +{ + /// Creates a new shuffled dataset. + pub fn new(dataset: D, batch_size: usize, rng: &mut StdRng) -> Self { + let len = dataset.len(); + + // 计算批数 + let num_batches = (len + batch_size - 1) / batch_size; + + // 创建一个批数索引的向量并打乱 + let mut batch_indices: Vec = (0..num_batches).collect(); + batch_indices.shuffle(rng); + + // 为每个打乱的批次生成相应的元素索引 + let mut indices: Vec = Vec::new(); + for &batch_index in &batch_indices { + let start_index = batch_index * batch_size; + let end_index = std::cmp::min(start_index + batch_size, len); + indices.extend(start_index..end_index); + } + + Self { + dataset, + indices, + input: PhantomData, + } + } + + /// Creates a new shuffled dataset with a fixed seed. + pub fn with_seed(dataset: D, batch_size: usize, seed: u64) -> Self { + let mut rng = StdRng::seed_from_u64(seed); + Self::new(dataset, batch_size, &mut rng) + } +} + +impl Dataset for BatchShuffledDataset +where + D: Dataset, + I: Clone + Send + Sync, +{ + fn get(&self, index: usize) -> Option { + let index = match self.indices.get(index) { + Some(index) => index, + None => return None, + }; + self.dataset.get(*index) + } + + fn len(&self) -> usize { + self.dataset.len() + } +} + + +#[test] +fn test_batch_shuffle() { + use crate::dataset::FSRSDataset; + let dataset = FSRSDataset::train(); + let batch_size = 10; + let seed = 42; + let batch_shuffled_dataset: BatchShuffledDataset = BatchShuffledDataset::with_seed(dataset, batch_size, seed); + for i in 0..batch_shuffled_dataset.len() { + println!("{:?}", batch_shuffled_dataset.get(i).unwrap()); + if i > batch_size { + break; + } + } +} + +#[test] +fn test_item_shuffle() { + use crate::dataset::FSRSDataset; + use burn::data::dataset::transform::ShuffledDataset; + let dataset = FSRSDataset::train(); + let seed = 42; + let shuffled_dataset: ShuffledDataset = ShuffledDataset::with_seed(dataset, seed); + for i in 0..shuffled_dataset.len() { + println!("{:?}", shuffled_dataset.get(i).unwrap()); + if i > 10 { + break; + } + } +} \ No newline at end of file diff --git a/src/dataset.rs b/src/dataset.rs index 5582f77..1ec77c9 100644 --- a/src/dataset.rs +++ b/src/dataset.rs @@ -169,6 +169,7 @@ fn test_from_anki() { use burn::data::dataloader::DataLoaderBuilder; let dataloader = DataLoaderBuilder::new(batcher) .batch_size(1) + .shuffle(42) .num_workers(4) .build(dataset); dbg!( diff --git a/src/lib.rs b/src/lib.rs index 5ac6c27..52c9684 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,3 +6,5 @@ pub mod dataset; pub mod model; pub mod training; mod weight_clipper; + +mod dataloader; \ No newline at end of file diff --git a/src/training.rs b/src/training.rs index f71ac8e..8949176 100644 --- a/src/training.rs +++ b/src/training.rs @@ -1,4 +1,5 @@ use crate::cosine_annealing::CosineAnnealingLR; +use crate::dataloader::BatchShuffledDataset; use crate::dataset::{FSRSBatch, FSRSBatcher, FSRSDataset}; use crate::model::{Model, ModelConfig}; use crate::weight_clipper::weight_clipper; @@ -117,12 +118,10 @@ pub fn train>( let dataloader_train = DataLoaderBuilder::new(batcher_train) .batch_size(config.batch_size) - .num_workers(config.num_workers) - .build(FSRSDataset::train()); + .build(BatchShuffledDataset::with_seed(FSRSDataset::train(), config.batch_size, config.seed)); let dataloader_test = DataLoaderBuilder::new(batcher_valid) .batch_size(config.batch_size) - .num_workers(config.num_workers) .build(FSRSDataset::test()); let lr_scheduler = CosineAnnealingLR::init(