Skip to content

Commit

Permalink
fixing unit test breaks
Browse files Browse the repository at this point in the history
Signed-off-by: Jason Wang <jasowang@microsoft.com>
  • Loading branch information
memoryz committed Jun 1, 2022
1 parent 9b98131 commit 81e003a
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion onnxmltools/convert/sparkml/operator_converters/k_means.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,13 @@ def convert_sparkml_k_means_model(scope: Scope, operator: Operator, container: M
)

# input * Transpose(Center): [N x K]
zeros_variable_name = scope.get_unique_variable_name("zeros")
container.add_initializer(
zeros_variable_name,
onnx_proto.TensorProto.FLOAT,
[1, K],
np.zeros([1, K]).flatten().astype(np.float32)
)
gemm_output_variable_name = scope.get_unique_variable_name("gemm_output")
gemm_attrs = {
"name": scope.get_unique_operator_name("GeMM"),
Expand All @@ -109,7 +116,7 @@ def convert_sparkml_k_means_model(scope: Scope, operator: Operator, container: M
}
container.add_node(
op_type="Gemm",
inputs=[operator.inputs[0].full_name, centers_variable_name],
inputs=[operator.inputs[0].full_name, centers_variable_name, zeros_variable_name],
outputs=[gemm_output_variable_name],
**gemm_attrs
)
Expand Down

0 comments on commit 81e003a

Please sign in to comment.