diff --git a/python/orca/src/bigdl/orca/learn/pytorch/estimator.py b/python/orca/src/bigdl/orca/learn/pytorch/estimator.py index 0e36e11c7d3..1aad55557a8 100644 --- a/python/orca/src/bigdl/orca/learn/pytorch/estimator.py +++ b/python/orca/src/bigdl/orca/learn/pytorch/estimator.py @@ -229,7 +229,7 @@ def evaluate(self, data, validation_methods=None, batch_size=32): val_feature_set = FeatureSet.sample_rdd(data.rdd.flatMap(to_sample)) return self.estimator.evaluate(val_feature_set, validation_methods, batch_size) elif isinstance(data, DataLoader) or callable(data): - val_feature_set = FeatureSet.pytorch_dataloader(data, "", "") + val_feature_set = FeatureSet.pytorch_dataloader(data) return self.estimator.evaluate_minibatch(val_feature_set, validation_methods) else: raise ValueError("Data should be a SparkXShards, a DataLoader or a callable "