diff --git a/python/orca/src/bigdl/orca/inference/inference_model.py b/python/orca/src/bigdl/orca/inference/inference_model.py index cfe5c1d13f6..23311bdb742 100644 --- a/python/orca/src/bigdl/orca/inference/inference_model.py +++ b/python/orca/src/bigdl/orca/inference/inference_model.py @@ -146,3 +146,18 @@ def predict(self, inputs): jinputs, input_is_table) return KerasNet.convert_output(output) + + def distributed_predict(self, inputs, sc): + data_type = inputs.map(lambda x: x.__class__.__name__).first() + input_is_table = False + if data_type == "list": + input_is_table = True + jinputs = inputs.map(lambda x: Layer.check_input(x)[0]) + + output = callZooFunc(self.bigdl_type, + "inferenceModelDistriPredict", + self.value, + sc, + jinputs, + input_is_table) + return output.map(lambda x: KerasNet.convert_output(x))