diff --git a/onnxmltools/convert/sparkml/utils.py b/onnxmltools/convert/sparkml/utils.py index 58ccee36..c64986bb 100644 --- a/onnxmltools/convert/sparkml/utils.py +++ b/onnxmltools/convert/sparkml/utils.py @@ -14,16 +14,16 @@ def buildInitialTypesSimple(dataframe): def getTensorTypeFromSpark(sparktype): - if sparktype == 'StringType': + if sparktype == 'StringType' or sparktype == 'StringType()': return StringTensorType([1, 1]) - elif sparktype == 'DecimalType' \ - or sparktype == 'DoubleType' \ - or sparktype == 'FloatType' \ - or sparktype == 'LongType' \ - or sparktype == 'IntegerType' \ - or sparktype == 'ShortType' \ - or sparktype == 'ByteType' \ - or sparktype == 'BooleanType': + elif sparktype == 'DecimalType' or sparktype == 'DecimalType()' \ + or sparktype == 'DoubleType' or sparktype == 'DoubleType()' \ + or sparktype == 'FloatType' or sparktype == 'FloatType()' \ + or sparktype == 'LongType' or sparktype == 'LongType()' \ + or sparktype == 'IntegerType' or sparktype == 'IntegerType()' \ + or sparktype == 'ShortType' or sparktype == 'ShortType()' \ + or sparktype == 'ByteType' or sparktype == 'ByteType()' \ + or sparktype == 'BooleanType' or sparktype == 'BooleanType()': return FloatTensorType([1, 1]) else: raise TypeError("Cannot map this type to Onnx types: " + sparktype)