diff --git a/src/nnfusion/core/operators/generic_op/generic_op.hpp b/src/nnfusion/core/operators/generic_op/generic_op.hpp index 30ba4eeb2..9e3d616a8 100644 --- a/src/nnfusion/core/operators/generic_op/generic_op.hpp +++ b/src/nnfusion/core/operators/generic_op/generic_op.hpp @@ -264,7 +264,9 @@ namespace nnfusion std::vector shape_def; for (int d = 0; d < shape.size(); d++) { - shape_def.push_back(shape[d] == 0 ? "1" : ("N" + to_string(d))); + // Tensor with shape [0] is treated as scalar value and convert its shape to [1] + shape_def.push_back((shape.size() == 1 && shape[d] == 0) ? "1" + : ("N" + to_string(d))); } return shape_def; } diff --git a/src/nnfusion/core/operators/generic_op/generic_op_define/Concat.cpp b/src/nnfusion/core/operators/generic_op/generic_op_define/Concat.cpp index ad5336cb4..de87f56f0 100644 --- a/src/nnfusion/core/operators/generic_op/generic_op_define/Concat.cpp +++ b/src/nnfusion/core/operators/generic_op/generic_op_define/Concat.cpp @@ -46,11 +46,20 @@ REGISTER_OP(Concat) R"( @input@@input_layout@.when(@dim@ < @offset@, @recursive@) )"; auto final_input_template = R"(@input@@input_layout@)"; std::string inputs_body = R"(@recursive@)"; + + size_t num_valid_inputs = 0; + for (int in_id = 0; in_id < curr->get_input_size(); ++in_id) + if (curr->get_input_shape(in_id)[axis] > 0) + num_valid_inputs++; + + size_t processed_inputs = 0; for (int in_id = 0; in_id < curr->get_input_size(); ++in_id) { std::vector in_data_layout(data_layout); in_data_layout[axis] = in_data_layout[axis] + " - " + to_string(offset); auto dim_size = curr->get_input_shape(in_id)[axis]; + if (dim_size == 0) + continue; offset += dim_size; op::OpConfig::any in_config; @@ -58,9 +67,10 @@ REGISTER_OP(Concat) in_config["input_layout"] = vector_to_string>(in_data_layout); in_config["dim"] = data_layout[axis]; in_config["offset"] = offset; + processed_inputs++; std::string cur_body; - if (in_id != curr->get_input_size() - 1) + if (processed_inputs < num_valid_inputs) cur_body = op::create_code_from_template(recursive_input_template, in_config); else cur_body = op::create_code_from_template(final_input_template, in_config);