Skip to content

Commit

Permalink
Add support for IterableDataset to TorchDataLoaderAdapter. (#19176)
Browse files Browse the repository at this point in the history
  • Loading branch information
hertschuh committed Feb 14, 2024
1 parent e385873 commit 4917d97
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 13 deletions.
6 changes: 4 additions & 2 deletions keras/backend/jax/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def fit(
(x, y, sample_weight), validation_split=validation_split
)

if validation_data:
if validation_data is not None:
(
val_x,
val_y,
Expand Down Expand Up @@ -428,7 +428,9 @@ def fit(
epoch_logs = self.get_metrics_result()

# Run validation.
if validation_data and self._should_eval(epoch, validation_freq):
if validation_data is not None and self._should_eval(
epoch, validation_freq
):
# Create JAXEpochIterator for evaluation and cache it.
if getattr(self, "_eval_epoch_iterator", None) is None:
self._eval_epoch_iterator = JAXEpochIterator(
Expand Down
6 changes: 4 additions & 2 deletions keras/backend/tensorflow/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def fit(
(x, y, sample_weight), validation_split=validation_split
)

if validation_data:
if validation_data is not None:
(
val_x,
val_y,
Expand Down Expand Up @@ -331,7 +331,9 @@ def fit(
epoch_logs = self.get_metrics_result()

# Run validation.
if validation_data and self._should_eval(epoch, validation_freq):
if validation_data is not None and self._should_eval(
epoch, validation_freq
):
# Create EpochIterator for evaluation and cache it.
if getattr(self, "_eval_epoch_iterator", None) is None:
self._eval_epoch_iterator = TFEpochIterator(
Expand Down
6 changes: 4 additions & 2 deletions keras/backend/torch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def fit(
(x, y, sample_weight), validation_split=validation_split
)

if validation_data:
if validation_data is not None:
(
val_x,
val_y,
Expand Down Expand Up @@ -260,7 +260,9 @@ def fit(
self.eval()

# Run validation.
if validation_data and self._should_eval(epoch, validation_freq):
if validation_data is not None and self._should_eval(
epoch, validation_freq
):
# Create TorchEpochIterator for evaluation and cache it.
if getattr(self, "_eval_epoch_iterator", None) is None:
self._eval_epoch_iterator = TorchEpochIterator(
Expand Down
2 changes: 1 addition & 1 deletion keras/trainers/data_adapters/tf_dataset_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def num_batches(self):
else:
# However, in the case of `DistributedDataset`, it's a np.int64.
cardinality = int(cardinality)
# Return None for Unknown and Infiite cardinality datasets
# Return None for Unknown and Infinite cardinality datasets
if cardinality < 0:
return None
return cardinality
Expand Down
12 changes: 9 additions & 3 deletions keras/trainers/data_adapters/torch_data_loader_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,14 @@ def __init__(self, dataloader):

self._dataloader = dataloader
self._batch_size = dataloader.batch_size
self._size = len(dataloader)
self._partial_batch_size = len(dataloader.dataset) % self._batch_size
self._num_batches = None
self._partial_batch_size = None
if hasattr(dataloader.dataset, "__len__"):
self._num_batches = len(dataloader)
if self._batch_size is not None:
self._partial_batch_size = (
len(dataloader.dataset) % self._batch_size
)

def get_numpy_iterator(self):
for batch in self._dataloader:
Expand Down Expand Up @@ -70,7 +76,7 @@ def get_tensor_spec(x):

@property
def num_batches(self):
return self._size
return self._num_batches

@property
def batch_size(self):
Expand Down
94 changes: 91 additions & 3 deletions keras/trainers/data_adapters/torch_data_loader_adapter_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import math

import jax
import numpy as np
import tensorflow as tf
Expand All @@ -18,9 +20,9 @@ class TestTorchDataLoaderAdapter(testing.TestCase, parameterized.TestCase):
def test_basic_dataloader(self, iterator_type):
x = torch.normal(2, 3, size=(34, 4))
y = torch.normal(1, 3, size=(34, 2))
base_ds = torch.utils.data.TensorDataset(x, y)
base_dataloader = torch.utils.data.DataLoader(base_ds, batch_size=16)
adapter = TorchDataLoaderAdapter(base_dataloader)
ds = torch.utils.data.TensorDataset(x, y)
dataloader = torch.utils.data.DataLoader(ds, batch_size=16)
adapter = TorchDataLoaderAdapter(dataloader)

self.assertEqual(adapter.num_batches, 3)
self.assertEqual(adapter.batch_size, 16)
Expand Down Expand Up @@ -53,3 +55,89 @@ def test_basic_dataloader(self, iterator_type):
else:
self.assertEqual(bx.shape, (2, 4))
self.assertEqual(by.shape, (2, 2))

@parameterized.named_parameters(
named_product(
batch_size=[None, 3],
implements_len=[True, False],
iterator_type=["np", "tf", "jax", "torch"],
)
)
def test_dataloader_iterable_dataset(
self, batch_size, implements_len, iterator_type
):

class TestIterableDataset(torch.utils.data.IterableDataset):
def __init__(self):
self.x = torch.normal(2, 3, size=(16, 4))
self.y = torch.normal(1, 3, size=(16, 2))

def __iter__(self):
for _ in range(10):
yield (self.x, self.y)

class TestIterableDatasetWithLen(TestIterableDataset):
def __len__(self):
return 10

ds = (
TestIterableDatasetWithLen()
if implements_len
else TestIterableDataset()
)
dataloader = torch.utils.data.DataLoader(ds, batch_size=batch_size)
adapter = TorchDataLoaderAdapter(dataloader)

if implements_len and batch_size:
self.assertEqual(adapter.num_batches, math.ceil(10 / batch_size))
self.assertEqual(adapter.batch_size, batch_size)
self.assertEqual(adapter.has_partial_batch, True)
self.assertEqual(adapter.partial_batch_size, 10 % batch_size)
elif implements_len:
self.assertEqual(adapter.num_batches, 10)
self.assertEqual(adapter.batch_size, None)
self.assertEqual(adapter.has_partial_batch, None)
self.assertEqual(adapter.partial_batch_size, None)
else:
self.assertIsNone(adapter.num_batches)
self.assertEqual(adapter.batch_size, batch_size)
self.assertIsNone(adapter.has_partial_batch)
self.assertIsNone(adapter.partial_batch_size)

if iterator_type == "np":
it = adapter.get_numpy_iterator()
expected_class = np.ndarray
elif iterator_type == "tf":
it = adapter.get_tf_dataset()
expected_class = tf.Tensor
elif iterator_type == "jax":
it = adapter.get_jax_iterator()
expected_class = jax.Array
elif iterator_type == "torch":
it = adapter.get_torch_dataloader()
expected_class = torch.Tensor

batch_count = 0
for i, batch in enumerate(it):
batch_count += 1
self.assertEqual(len(batch), 2)
bx, by = batch
self.assertIsInstance(bx, expected_class)
self.assertIsInstance(by, expected_class)
self.assertEqual(bx.dtype, by.dtype)
self.assertContainsExactSubsequence(str(bx.dtype), "float32")
if batch_size:
if i < 3:
self.assertEqual(bx.shape, (batch_size, 16, 4))
self.assertEqual(by.shape, (batch_size, 16, 2))
else:
self.assertEqual(bx.shape, (10 % batch_size, 16, 4))
self.assertEqual(by.shape, (10 % batch_size, 16, 2))
else:
self.assertEqual(bx.shape, (16, 4))
self.assertEqual(by.shape, (16, 2))

if batch_size:
self.assertEqual(batch_count, math.ceil(10 / batch_size))
else:
self.assertEqual(batch_count, 10)

0 comments on commit 4917d97

Please sign in to comment.