Skip to content

Commit

Permalink
[CUTLASS] Conv2d activation fusion, part 2: Sigmoid fp16, SiLU and Ha…
Browse files Browse the repository at this point in the history
…rdSwish (#9795)

* [Torch] do not pad if pad widths are all zero

* silu fusion supported

* adding hardswish support

* support fast_math sigmoid op

* fixed type inference for yolov5 + silu fusion

* use include_non_call_ops=False in AnnotateTarget

* update cutlass

* revert change in build.py

* simplify codegen

* lint
  • Loading branch information
masahi committed Dec 23, 2021
1 parent b29a443 commit 1afcf36
Show file tree
Hide file tree
Showing 10 changed files with 155 additions and 33 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/cutlass
Submodule cutlass updated 163 files
30 changes: 25 additions & 5 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _get_cutlass_path():
return cutlass_path


def _get_cutlass_compile_options(sm, threads):
def _get_cutlass_compile_options(sm, threads, use_fast_math=False):
cutlass_root = _get_cutlass_path()
cutlass_include = os.path.join(cutlass_root, "include")
cutlass_util_include = os.path.join(cutlass_root, "tools/util/include")
Expand All @@ -58,6 +58,8 @@ def _get_cutlass_compile_options(sm, threads):
"-I" + cutlass_include,
"-I" + cutlass_util_include,
]
if use_fast_math:
kwargs["options"].append("-DCUTLASS_USE_TANH_FOR_SIGMOID")
cuda_path = find_cuda_path()
cuda_ver = get_cuda_version(cuda_path)
if cuda_ver >= 11.2:
Expand Down Expand Up @@ -222,6 +224,10 @@ def handle_conv2d(
cutlass_op_def = out["opdef_bias_relu"]
elif op_type == "cutlass.conv2d_bias_sigmoid":
cutlass_op_def = out["opdef_bias_sigmoid"]
elif op_type == "cutlass.conv2d_bias_silu":
cutlass_op_def = out["opdef_bias_silu"]
elif op_type == "cutlass.conv2d_bias_hardswish":
cutlass_op_def = out["opdef_bias_hardswish"]
else:
raise ValueError("%s pattern is not implemented." % op_type)

Expand Down Expand Up @@ -339,7 +345,9 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t
return mod, num_cutlass_partition


def build_cutlass_kernels(lib, sm, tmp_dir="./tmp", lib_path="compile.so", threads=-1):
def build_cutlass_kernels(
lib, sm, tmp_dir="./tmp", lib_path="compile.so", threads=-1, use_fast_math=False
):
"""Compile CUTLASS kernels in lib and return the runtime module ready to run.
Parameters
Expand All @@ -361,18 +369,27 @@ def build_cutlass_kernels(lib, sm, tmp_dir="./tmp", lib_path="compile.so", threa
The number of threads to use for compiling generated kernels. Only available for
CUDA 11.2 or later. Use all physical cores by default.
use_fast_math : bool, optional
Whether or not to use faster but less accurate math intrinsics.
Returns
-------
updated_lib : runtime.Module
The updated module with compiled cutlass kernels.
"""
kwargs = _get_cutlass_compile_options(sm, threads)
kwargs = _get_cutlass_compile_options(sm, threads, use_fast_math)
lib.export_library(lib_path, workspace_dir=tmp_dir, **kwargs)
return runtime.load_module(lib_path)


def build_cutlass_kernels_vm(
vm_exec, sm, tmp_dir="./tmp", lib_path="compile.so", vmcode_path="vmcode.ro", threads=-1
vm_exec,
sm,
tmp_dir="./tmp",
lib_path="compile.so",
vmcode_path="vmcode.ro",
threads=-1,
use_fast_math=False,
):
"""Compile CUTLASS kernels in vm_exec and return a VM executable ready to run.
Expand All @@ -398,13 +415,16 @@ def build_cutlass_kernels_vm(
The number of threads to use for compiling generated kernels. Only available for
CUDA 11.2 or later. Use all physical cores by default.
use_fast_math : bool, optional
Whether or not to use faster but less accurate math intrinsics.
Returns
-------
updated_vm_exec: vm.Executable
The updated exectuable with compiled cutlass kernels.
"""
code, lib = vm_exec.save()
kwargs = _get_cutlass_compile_options(sm, threads)
kwargs = _get_cutlass_compile_options(sm, threads, use_fast_math)
lib_path = os.path.join(tmp_dir, lib_path)
vmcode_path = os.path.join(tmp_dir, vmcode_path)
lib.export_library(lib_path, workspace_dir=tmp_dir, **kwargs)
Expand Down
12 changes: 10 additions & 2 deletions python/tvm/contrib/cutlass/gen_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,17 @@ def create_conv2d_operator(
EpilogueFunctor.LinearCombinationBias,
EpilogueFunctor.LinearCombinationRelu,
EpilogueFunctor.LinearCombinationSigmoid,
EpilogueFunctor.LinearCombinationSilu,
EpilogueFunctor.LinearCombinationHardSwish,
],
["opdef_bias", "opdef_bias_relu", "opdef_bias_sigmoid"],
[True, True, False],
[
"opdef_bias",
"opdef_bias_relu",
"opdef_bias_sigmoid",
"opdef_bias_silu",
"opdef_bias_hardswish",
],
[True, True, False, False, False],
):
op = Conv2dOperation(
ConvKind.Fprop,
Expand Down
6 changes: 5 additions & 1 deletion python/tvm/contrib/cutlass/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
# pylint: disable=invalid-name,line-too-long
"""Various type definitions to help instantiate CUTLASS kernels."""
import re
import enum
Expand Down Expand Up @@ -149,6 +149,8 @@ class EpilogueFunctor(enum.Enum):
LinearCombinationBias = enum_auto()
LinearCombinationGelu = enum_auto()
LinearCombinationSigmoid = enum_auto()
LinearCombinationSilu = enum_auto()
LinearCombinationHardSwish = enum_auto()


EpilogueFunctorTag = {
Expand All @@ -157,6 +159,8 @@ class EpilogueFunctor(enum.Enum):
EpilogueFunctor.LinearCombinationBias: "cutlass::epilogue::thread::LinearCombination",
EpilogueFunctor.LinearCombinationGelu: "cutlass::epilogue::thread::LinearCombinationGELU",
EpilogueFunctor.LinearCombinationSigmoid: "cutlass::epilogue::thread::LinearCombinationSigmoid",
EpilogueFunctor.LinearCombinationSilu: "cutlass::epilogue::thread::LinearCombinationSilu",
EpilogueFunctor.LinearCombinationHardSwish: "cutlass::epilogue::thread::LinearCombinationHardSwish",
}


Expand Down
7 changes: 6 additions & 1 deletion python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1736,14 +1736,19 @@ def pad(inputs, input_types):
paddings = [paddings[i : i + 2] for i in range(0, len(paddings), 2)]

const_paddings = []
non_zero_found = False
for pad in paddings:
const_paddings.append([])
for p in pad:
if not isinstance(p, int):
p = int(_infer_value(p, {}).numpy())
const_paddings[-1].append(p)
if p != 0:
non_zero_found = True

if mode == "constant":
if not non_zero_found:
return data
elif mode == "constant":
return _op.nn.pad(data, const_paddings, pad_value=inputs[2], pad_mode=mode)
else:
return _op.nn.pad(data, const_paddings, pad_mode=mode)
Expand Down
21 changes: 20 additions & 1 deletion python/tvm/relay/op/contrib/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,15 @@ def make_conv2d_pattern(with_bias=False, with_act=None):
return is_op("nn.relu")(conv2d_out)
if with_act == "sigmoid":
return is_op("sigmoid")(conv2d_out)
if with_act == "silu":
return is_op("multiply")(conv2d_out, is_op("sigmoid")(conv2d_out))
if with_act == "hardswish":
rhs = is_op("divide")(
is_op("clip")(is_op("add")(conv2d_out, is_constant())), is_constant()
)
return is_op("multiply")(conv2d_out, rhs)

raise ValueError("Unknown activation %s." % with_act)

return conv2d_out

Expand Down Expand Up @@ -149,6 +158,16 @@ def partition_for_cutlass(mod, params=None):
dense_bias_pat,
dense_pat,
("cutlass.batch_matmul", make_batch_matmul_pattern(), check_batch_matmul),
(
"cutlass.conv2d_bias_hardswish",
make_conv2d_pattern(with_bias=True, with_act="hardswish"),
check_conv2d,
),
(
"cutlass.conv2d_bias_silu",
make_conv2d_pattern(with_bias=True, with_act="silu"),
check_conv2d,
),
(
"cutlass.conv2d_bias_relu",
make_conv2d_pattern(with_bias=True, with_act="relu"),
Expand Down Expand Up @@ -180,7 +199,7 @@ def partition_for_cutlass(mod, params=None):
[
transform.InferType(),
transform.MergeComposite(cutlass_patterns),
transform.AnnotateTarget(["cutlass"]),
transform.AnnotateTarget(["cutlass"], include_non_call_ops=False),
transform.PartitionGraph(bind_constants=False),
]
)
Expand Down
38 changes: 24 additions & 14 deletions src/relay/backend/contrib/cutlass/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,9 @@ void AppendGemmExecute(std::ostringstream& gemm_decl, const std::string& kernel)

std::string DenseOp(std::string id, const Str2StrMap& attrs,
const std::vector<std::string>& func_args) {
bool has_bias = false;
bool has_bias = attrs.at("op_type").find("bias") != std::string::npos;
bool is_gelu =
attrs.at("op_type").find("cutlass.dense_bias_gelu") != std::string::npos; // fp32 or fp16
if (attrs.at("op_type") == "cutlass.dense_bias" ||
attrs.at("op_type") == "cutlass.dense_bias_relu" || is_gelu) {
has_bias = true;
}
std::ostringstream gemm_decl;
AppendPrologue(gemm_decl, attrs, func_args, "Gemm", has_bias, is_gelu, 0, 0, 1);

Expand Down Expand Up @@ -263,10 +259,10 @@ Str2StrMap Conv2dArgs(const Map<String, ObjectRef>& attrs) {

std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
const std::vector<std::string>& func_args) {
bool has_bias = attrs.at("op_type") == "cutlass.conv2d_bias" ||
attrs.at("op_type") == "cutlass.conv2d_bias_relu" ||
attrs.at("op_type") == "cutlass.conv2d_bias_sigmoid";
bool no_bias_scaling = attrs.at("op_type") != "cutlass.conv2d_bias_sigmoid";
bool has_bias = attrs.at("op_type").find("bias") != std::string::npos;
bool no_bias_scaling = attrs.at("op_type") != "cutlass.conv2d_bias_sigmoid" &&
attrs.at("op_type") != "cutlass.conv2d_bias_silu" &&
attrs.at("op_type") != "cutlass.conv2d_bias_hardswish";

std::ostringstream conv2d_decl;
CutlassPrint(conv2d_decl, "using ElementInputA = " + attrs.at("ElementInputA") + ";\n");
Expand Down Expand Up @@ -505,6 +501,20 @@ class CodegenCutlass : public MemoizedExprTranslator<std::vector<Output>>, publi
GetRootCall(callee->body.as<CallNode>(), 2, {"nn.conv2d", add_or_bias_add, "sigmoid"});
return GenerateBody(conv2d_call, "cutlass_conv2d_bias_sigmoid", GetArgumentNames(caller),
Conv2dArgs(std::ref(attrs_)));
} else if (pattern_name == "cutlass.conv2d_bias_silu") {
const CallNode* current_call = callee->body.as<CallNode>();
std::string add_or_bias_add = current_call->args[0].as<CallNode>()->op.as<OpNode>()->name;
const auto* conv2d_call =
GetRootCall(callee->body.as<CallNode>(), 2, {"nn.conv2d", add_or_bias_add, "multiply"});
return GenerateBody(conv2d_call, "cutlass_conv2d_bias_silu", GetArgumentNames(caller),
Conv2dArgs(std::ref(attrs_)));
} else if (pattern_name == "cutlass.conv2d_bias_hardswish") {
const CallNode* current_call = callee->body.as<CallNode>();
std::string add_or_bias_add = current_call->args[0].as<CallNode>()->op.as<OpNode>()->name;
const auto* conv2d_call =
GetRootCall(callee->body.as<CallNode>(), 2, {"nn.conv2d", add_or_bias_add, "multiply"});
return GenerateBody(conv2d_call, "cutlass_conv2d_bias_hardswish", GetArgumentNames(caller),
Conv2dArgs(std::ref(attrs_)));
}

LOG(FATAL) << "Unknown composite function: " << pattern_name;
Expand Down Expand Up @@ -546,14 +556,11 @@ class CodegenCutlass : public MemoizedExprTranslator<std::vector<Output>>, publi
ret.outputs.push_back(output);
}
decl_stream << ");";
if (func_name == "cutlass_dense" || func_name == "cutlass_dense_bias" ||
func_name == "cutlass_dense_bias_relu" || func_name == "cutlass_dense_bias_gelu") {
if (func_name.find("dense") != std::string::npos) {
ret.decl = DenseOp(ext_func_id_, attribute_args, func_args);
} else if (func_name == "cutlass_batch_matmul") {
ret.decl = BatchMatmulOp(ext_func_id_, attribute_args, func_args);
} else if (func_name == "cutlass_conv2d" || func_name == "cutlass_conv2d_bias" ||
func_name == "cutlass_conv2d_bias_relu" ||
func_name == "cutlass_conv2d_bias_sigmoid") {
} else if (func_name.find("conv2d") != std::string::npos) {
ret.decl = Conv2dOp(ext_func_id_, attribute_args, func_args);
}

Expand Down Expand Up @@ -613,6 +620,9 @@ class CutlassModuleCodegen : public CSourceModuleCodegenBase {
code_stream_ << "#include <cutlass/conv/device/implicit_gemm_convolution.h>\n";
code_stream_ << "#include <cutlass/epilogue/thread/linear_combination_bias_relu.h>\n";
code_stream_ << "#include <cutlass/epilogue/thread/linear_combination_gelu.h>\n";
code_stream_ << "#include <cutlass/epilogue/thread/linear_combination_sigmoid.h>\n";
code_stream_ << "#include <cutlass/epilogue/thread/linear_combination_silu.h>\n";
code_stream_ << "#include <cutlass/epilogue/thread/linear_combination_hardswish.h>\n";

ICHECK(ref->IsInstance<FunctionNode>());
auto res = GenCutlassFunc(Downcast<Function>(ref));
Expand Down
3 changes: 3 additions & 0 deletions src/relay/op/dyn/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,9 @@ bool StridedSliceRel(const Array<Type>& types, int num_inputs, const Attrs& attr
int64_t num_axis = dshape.size();

const auto* begin = types[1].as<TensorTypeNode>();
if (begin == nullptr) {
return false;
}
ICHECK(begin);

// calculate output shape
Expand Down
5 changes: 5 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,11 @@ bool StackRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
<< "cast: expect input type to be TupleType but get " << types[0];
return false;
}
for (auto field : tensor_tuple->fields) {
if (field.as<IncompleteTypeNode>()) {
return false;
}
}
const auto* param = attrs.as<StackAttrs>();
const auto& first = Downcast<TensorType>(tensor_tuple->fields[0]);
const int ndim = static_cast<int>(first->shape.size());
Expand Down
Loading

0 comments on commit 1afcf36

Please sign in to comment.