Skip to content

Commit

Permalink
add callZooFunc and change all callBigDlFunc to callZooFunc (intel-an…
Browse files Browse the repository at this point in the history
  • Loading branch information
qiuxin2012 committed Nov 26, 2019
1 parent 39d1c2b commit df08ffd
Showing 1 changed file with 24 additions and 3 deletions.
27 changes: 24 additions & 3 deletions pyspark/bigdl/common/zooUtils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from bigdl.util.common import Sample as BSample, JTensor as BJTensor, callBigDlFunc
from bigdl.util.common import Sample as BSample, JTensor as BJTensor,\
JavaCreator, _get_gateway, _java2py, _py2java
import numpy as np


def to_list_of_numpy(elements):

if isinstance(elements, np.ndarray):
return [elements]
elif np.isscalar(elements):
Expand All @@ -39,7 +39,28 @@ def to_list_of_numpy(elements):


def set_core_number(num):
callBigDlFunc("float", "setCoreNumber", num)
callZooFunc("float", "setCoreNumber", num)


def callZooFunc(bigdl_type, name, *args):
""" Call API in PythonBigDL """
gateway = _get_gateway()
args = [_py2java(gateway, a) for a in args]
error = Exception("Cannot find function: %s" % name)
for jinvoker in JavaCreator.instance(bigdl_type, gateway).value:
# hasattr(jinvoker, name) always return true here,
# so you need to invoke the method to check if it exist or not
try:
api = getattr(jinvoker, name)
java_result = api(*args)
result = _java2py(gateway, java_result)
except Exception as e:
error = e
if "does not exist" not in str(e):
raise e
else:
return result
raise error


class JTensor(BJTensor):
Expand Down

0 comments on commit df08ffd

Please sign in to comment.