Skip to content

Commit

Permalink
add test_create_loader
Browse files Browse the repository at this point in the history
  • Loading branch information
yuedongli1 committed Jun 29, 2023
1 parent b045390 commit c52180a
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 2 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,7 @@ jobs:
pylint mindyolo --rcfile=.github/pylint.conf
- name: Test with unit test (UT) pytest
run: |
wget https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip
unzip coco128.zip
pytest tests/modules/*.py
2 changes: 1 addition & 1 deletion mindyolo/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def __init__(
x[:, 0] = 0

n = len(labels) # number of images
bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
bi = np.floor(np.arange(n) / batch_size).astype(np.int_) # batch index
nb = bi[-1] + 1 # number of batches
self.batch = bi # batch index of image

Expand Down
63 changes: 63 additions & 0 deletions tests/modules/test_create_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import sys

sys.path.append(".")

import os
import pytest

import mindspore as ms

from mindyolo.data import COCODataset, create_loader


@pytest.mark.parametrize("mode", [0, 1])
@pytest.mark.parametrize("drop_remainder", [True, False])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("batch_size", [1, 4])
def test_create_loader(mode, drop_remainder, shuffle, batch_size):
ms.set_context(mode=mode)
dataset_path = './coco128'
transforms_dict = [
{'func_name': 'mosaic', 'prob': 1.0, 'mosaic9_prob': 0.0, 'translate': 0.1, 'scale': 0.9},
{'func_name': 'mixup', 'prob': 0.1, 'alpha': 8.0, 'beta': 8.0, 'needed_mosaic': True},
{'func_name': 'hsv_augment', 'prob': 1.0, 'hgain': 0.015, 'sgain': 0.7, 'vgain': 0.4},
{'func_name': 'label_norm', 'xyxy2xywh_': True},
{'func_name': 'albumentations'},
{'func_name': 'fliplr', 'prob': 0.5},
{'func_name': 'label_pad', 'padding_size': 160, 'padding_value': -1},
{'func_name': 'image_norm', 'scale': 255.},
{'func_name': 'image_transpose', 'bgr2rgb': True, 'hwc2chw': True}
]

dataset = COCODataset(
dataset_path=dataset_path,
transforms_dict=transforms_dict,
img_size=640,
is_training=True,
augment=True,
batch_size=batch_size,
stride=64,
)
dataloader = create_loader(
dataset=dataset,
batch_collate_fn=dataset.train_collate_fn,
dataset_column_names=dataset.dataset_column_names,
batch_size=batch_size,
epoch_size=1,
shuffle=shuffle,
drop_remainder=drop_remainder,
num_parallel_workers=1,
python_multiprocessing=True,
)

out_batch_size = dataloader.get_batch_size()
out_shapes = dataloader.output_shapes()[0]
assert out_batch_size == batch_size
assert out_shapes == [batch_size, 3, 640, 640]

for data in dataset:
assert data is not None


if __name__ == '__main__':
test_create_loader()
2 changes: 1 addition & 1 deletion tests/modules/test_create_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_create_trainer(yaml_name, mode):
network.set_train(True)

# Create Dataloaders
bs = 8
bs = 6
# create data
x = np.random.randn(bs, 3, 32, 32).astype(np.float32)
y = np.random.rand(bs, 160, 6).astype(np.float32)
Expand Down

0 comments on commit c52180a

Please sign in to comment.