Skip to content

Commit

Permalink
fix: getTensorTypeFromSpark fails for Spark 3.3.0+
Browse files Browse the repository at this point in the history
  • Loading branch information
memoryz committed Feb 23, 2023
1 parent c29abfd commit 10b7a1b
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions onnxmltools/convert/sparkml/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@ def buildInitialTypesSimple(dataframe):


def getTensorTypeFromSpark(sparktype):
if sparktype == 'StringType':
if sparktype == 'StringType' or sparktype == 'StringType()':
return StringTensorType([1, 1])
elif sparktype == 'DecimalType' \
or sparktype == 'DoubleType' \
or sparktype == 'FloatType' \
or sparktype == 'LongType' \
or sparktype == 'IntegerType' \
or sparktype == 'ShortType' \
or sparktype == 'ByteType' \
or sparktype == 'BooleanType':
elif sparktype == 'DecimalType' or sparktype == 'DecimalType()' \
or sparktype == 'DoubleType' or sparktype == 'DoubleType()' \
or sparktype == 'FloatType' or sparktype == 'FloatType()' \
or sparktype == 'LongType' or sparktype == 'LongType()' \
or sparktype == 'IntegerType' or sparktype == 'IntegerType()' \
or sparktype == 'ShortType' or sparktype == 'ShortType()' \
or sparktype == 'ByteType' or sparktype == 'ByteType()' \
or sparktype == 'BooleanType' or sparktype == 'BooleanType()':
return FloatTensorType([1, 1])
else:
raise TypeError("Cannot map this type to Onnx types: " + sparktype)
Expand Down

0 comments on commit 10b7a1b

Please sign in to comment.