Skip to content

Commit

Permalink
cargo fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Aug 26, 2023
1 parent ac1b57b commit 6de05f8
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 10 deletions.
5 changes: 4 additions & 1 deletion src/convertor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@ fn group_by_cid(revlogs: Vec<RevlogEntry>) -> Vec<Vec<RevlogEntry>> {
.push(revlog);
}

grouped.into_values().sorted_by_cached_key(|revlog| revlog.get(0).unwrap().cid).collect()
grouped
.into_values()
.sorted_by_cached_key(|revlog| revlog.get(0).unwrap().cid)
.collect()
}

fn convert_to_date(timestamp: i64, next_day_starts_at: i64, timezone: Tz) -> chrono::NaiveDate {
Expand Down
15 changes: 8 additions & 7 deletions src/dataloader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ where
/// Creates a new shuffled dataset.
pub fn new(dataset: D, batch_size: usize, rng: &mut StdRng) -> Self {
let len = dataset.len();

// 计算批数
let num_batches = (len + batch_size - 1) / batch_size;

// 创建一个批数索引的向量并打乱
let mut batch_indices: Vec<usize> = (0..num_batches).collect();
batch_indices.shuffle(rng);

// 为每个打乱的批次生成相应的元素索引
let mut indices: Vec<usize> = Vec::new();
for &batch_index in &batch_indices {
Expand Down Expand Up @@ -63,14 +63,14 @@ where
}
}


#[test]
fn test_batch_shuffle() {
use crate::dataset::FSRSDataset;
let dataset = FSRSDataset::train();
let batch_size = 10;
let seed = 42;
let batch_shuffled_dataset: BatchShuffledDataset<FSRSDataset, crate::dataset::FSRSItem> = BatchShuffledDataset::with_seed(dataset, batch_size, seed);
let batch_shuffled_dataset: BatchShuffledDataset<FSRSDataset, crate::dataset::FSRSItem> =
BatchShuffledDataset::with_seed(dataset, batch_size, seed);
for i in 0..batch_shuffled_dataset.len() {
println!("{:?}", batch_shuffled_dataset.get(i).unwrap());
if i > batch_size {
Expand All @@ -85,11 +85,12 @@ fn test_item_shuffle() {
use burn::data::dataset::transform::ShuffledDataset;
let dataset = FSRSDataset::train();
let seed = 42;
let shuffled_dataset: ShuffledDataset<FSRSDataset, crate::dataset::FSRSItem> = ShuffledDataset::with_seed(dataset, seed);
let shuffled_dataset: ShuffledDataset<FSRSDataset, crate::dataset::FSRSItem> =
ShuffledDataset::with_seed(dataset, seed);
for i in 0..shuffled_dataset.len() {
println!("{:?}", shuffled_dataset.get(i).unwrap());
if i > 10 {
break;
}
}
}
}
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ pub mod model;
pub mod training;
mod weight_clipper;

mod dataloader;
mod dataloader;
6 changes: 5 additions & 1 deletion src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,11 @@ pub fn train<B: ADBackend<FloatElem = f32>>(

let dataloader_train = DataLoaderBuilder::new(batcher_train)
.batch_size(config.batch_size)
.build(BatchShuffledDataset::with_seed(FSRSDataset::train(), config.batch_size, config.seed));
.build(BatchShuffledDataset::with_seed(
FSRSDataset::train(),
config.batch_size,
config.seed,
));

let dataloader_test = DataLoaderBuilder::new(batcher_valid)
.batch_size(config.batch_size)
Expand Down

0 comments on commit 6de05f8

Please sign in to comment.