Skip to content

Commit

Permalink
implement BatchShuffledDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Aug 26, 2023
1 parent d99d925 commit ac1b57b
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 36 deletions.
48 changes: 15 additions & 33 deletions src/convertor.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use chrono::prelude::*;
use chrono_tz::Tz;
use itertools::Itertools;
use rand::Rng;
use rusqlite::{Connection, Result, Row};
use std::collections::HashMap;

Expand Down Expand Up @@ -81,7 +80,7 @@ fn group_by_cid(revlogs: Vec<RevlogEntry>) -> Vec<Vec<RevlogEntry>> {
.push(revlog);
}

grouped.into_values().collect()
grouped.into_values().sorted_by_cached_key(|revlog| revlog.get(0).unwrap().cid).collect()
}

fn convert_to_date(timestamp: i64, next_day_starts_at: i64, timezone: Tz) -> chrono::NaiveDate {
Expand Down Expand Up @@ -194,40 +193,13 @@ fn convert_to_fsrs_items(
pub fn anki_to_fsrs() -> Vec<FSRSItem> {
let revlogs = read_collection().expect("read error");
let revlogs_per_card = group_by_cid(revlogs);
// collect FSRS items into a map by sequence size
let mut total_fsrs_items = 0;
let mut revlogs_by_seq_size: HashMap<usize, Vec<FSRSItem>> = HashMap::new();
revlogs_per_card
let mut revlogs = revlogs_per_card
.into_iter()
.filter_map(|entries| convert_to_fsrs_items(entries, 4, Tz::Asia__Shanghai))
.flatten()
.for_each(|r| {
total_fsrs_items += 1;
revlogs_by_seq_size
.entry(r.reviews.len())
.or_default()
.push(r)
});
let mut sizes = revlogs_by_seq_size.keys().copied().collect_vec();
let mut rng = rand::thread_rng();
let mut out: Vec<FSRSItem> = Vec::with_capacity(total_fsrs_items);
while !sizes.is_empty() {
// pick a random sequence size
let size_idx = rng.gen_range(0..sizes.len() as u32) as usize;
let size = &mut sizes[size_idx];
let items = revlogs_by_seq_size.get_mut(size).unwrap();
// add up to 512 items from it to the output vector
for _ in 0..512 {
let Some(item) = items.pop() else {
// this size has run out of items; clear it from available sizes
sizes.swap_remove(size_idx);
break;
};
out.push(item);
}
}

out
.collect_vec();
revlogs.sort_by_cached_key(|r| r.reviews.len());
revlogs
}

#[cfg(test)]
Expand Down Expand Up @@ -392,4 +364,14 @@ mod tests {
);
assert_eq!(res.labels.to_data(), Data::from([1]));
}

#[test]
fn test_order() {
let revlogs = read_collection().unwrap();
let revlogs_per_card = group_by_cid(revlogs);
assert_eq!(
revlogs_per_card.get(0).unwrap().get(0).unwrap().id,
1528956429777
);
}
}
95 changes: 95 additions & 0 deletions src/dataloader.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
use burn::data::dataset::Dataset;
use rand::{prelude::SliceRandom, rngs::StdRng, SeedableRng};
use std::marker::PhantomData;

pub struct BatchShuffledDataset<D, I> {
dataset: D,
indices: Vec<usize>,
input: PhantomData<I>,
}

impl<D, I> BatchShuffledDataset<D, I>
where
D: Dataset<I>,
{
/// Creates a new shuffled dataset.
pub fn new(dataset: D, batch_size: usize, rng: &mut StdRng) -> Self {
let len = dataset.len();

// 计算批数
let num_batches = (len + batch_size - 1) / batch_size;

// 创建一个批数索引的向量并打乱
let mut batch_indices: Vec<usize> = (0..num_batches).collect();
batch_indices.shuffle(rng);

// 为每个打乱的批次生成相应的元素索引
let mut indices: Vec<usize> = Vec::new();
for &batch_index in &batch_indices {
let start_index = batch_index * batch_size;
let end_index = std::cmp::min(start_index + batch_size, len);
indices.extend(start_index..end_index);
}

Self {
dataset,
indices,
input: PhantomData,
}
}

/// Creates a new shuffled dataset with a fixed seed.
pub fn with_seed(dataset: D, batch_size: usize, seed: u64) -> Self {
let mut rng = StdRng::seed_from_u64(seed);
Self::new(dataset, batch_size, &mut rng)
}
}

impl<D, I> Dataset<I> for BatchShuffledDataset<D, I>
where
D: Dataset<I>,
I: Clone + Send + Sync,
{
fn get(&self, index: usize) -> Option<I> {
let index = match self.indices.get(index) {
Some(index) => index,
None => return None,
};
self.dataset.get(*index)
}

fn len(&self) -> usize {
self.dataset.len()
}
}


#[test]
fn test_batch_shuffle() {
use crate::dataset::FSRSDataset;
let dataset = FSRSDataset::train();
let batch_size = 10;
let seed = 42;
let batch_shuffled_dataset: BatchShuffledDataset<FSRSDataset, crate::dataset::FSRSItem> = BatchShuffledDataset::with_seed(dataset, batch_size, seed);
for i in 0..batch_shuffled_dataset.len() {
println!("{:?}", batch_shuffled_dataset.get(i).unwrap());
if i > batch_size {
break;
}
}
}

#[test]
fn test_item_shuffle() {
use crate::dataset::FSRSDataset;
use burn::data::dataset::transform::ShuffledDataset;
let dataset = FSRSDataset::train();
let seed = 42;
let shuffled_dataset: ShuffledDataset<FSRSDataset, crate::dataset::FSRSItem> = ShuffledDataset::with_seed(dataset, seed);
for i in 0..shuffled_dataset.len() {
println!("{:?}", shuffled_dataset.get(i).unwrap());
if i > 10 {
break;
}
}
}
1 change: 1 addition & 0 deletions src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ fn test_from_anki() {
use burn::data::dataloader::DataLoaderBuilder;
let dataloader = DataLoaderBuilder::new(batcher)
.batch_size(1)
.shuffle(42)
.num_workers(4)
.build(dataset);
dbg!(
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ pub mod dataset;
pub mod model;
pub mod training;
mod weight_clipper;

mod dataloader;
5 changes: 2 additions & 3 deletions src/training.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::cosine_annealing::CosineAnnealingLR;
use crate::dataloader::BatchShuffledDataset;
use crate::dataset::{FSRSBatch, FSRSBatcher, FSRSDataset};
use crate::model::{Model, ModelConfig};
use crate::weight_clipper::weight_clipper;
Expand Down Expand Up @@ -117,12 +118,10 @@ pub fn train<B: ADBackend<FloatElem = f32>>(

let dataloader_train = DataLoaderBuilder::new(batcher_train)
.batch_size(config.batch_size)
.num_workers(config.num_workers)
.build(FSRSDataset::train());
.build(BatchShuffledDataset::with_seed(FSRSDataset::train(), config.batch_size, config.seed));

let dataloader_test = DataLoaderBuilder::new(batcher_valid)
.batch_size(config.batch_size)
.num_workers(config.num_workers)
.build(FSRSDataset::test());

let lr_scheduler = CosineAnnealingLR::init(
Expand Down

0 comments on commit ac1b57b

Please sign in to comment.