Skip to content
This repository has been archived by the owner on Jan 20, 2024. It is now read-only.

Commit

Permalink
[mlir][Tosa]: Add folder to ReciprocalOp of splat constant inputs (#7…
Browse files Browse the repository at this point in the history
…8137)
  • Loading branch information
AviadCo committed Jan 17, 2024
1 parent e8af89e commit d89a0a6
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 16 deletions.
11 changes: 11 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1114,6 +1114,17 @@ def Tosa_ReciprocalOp : Tosa_ElementwiseOp<"reciprocal",
let results = (outs
Tosa_Tensor:$output
);

let extraClassDeclaration = [{
/// Return the reciprocal result on the operand.
static inline APFloat calcOneElement(const APFloat &operand) {
APFloat recip = APFloat(operand.getSemantics(), 1);
recip.divide(operand, APFloat::rmNearestTiesToEven);
return recip;
}
}];

let hasFolder = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
20 changes: 20 additions & 0 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Matchers.h"
Expand All @@ -25,6 +26,7 @@
#include "mlir/Transforms/InliningUtils.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"

Expand Down Expand Up @@ -1036,3 +1038,21 @@ OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
getOperation()->setOperands(concatOperands);
return getResult();
}

OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
auto input = adaptor.getInput1();

auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
// Fold splat inputs only.
if (!inputAttr || !inputAttr.isSplat())
return {};

auto shapeType = llvm::cast<ShapedType>(getType());
if (auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
auto floatVal = inputAttr.getSplatValue<APFloat>();
return DenseElementsAttr::get(shapeType,
ReciprocalOp::calcOneElement(floatVal));
}

return {};
}
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"

Expand Down
20 changes: 4 additions & 16 deletions mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@ using namespace mlir::tosa;

namespace {

/// Rounding mode to be used on floating point operations that require rounding.
static constexpr llvm::RoundingMode tosaRoundingMode =
llvm::APFloat::rmNearestTiesToEven;

/// Apply the given transformation \p toApply to every element of the tensor to
/// be transformed \p toTransform.
///
Expand All @@ -44,14 +40,14 @@ static constexpr llvm::RoundingMode tosaRoundingMode =
template <class SrcValType, class TargetValType, class TargetType>
DenseElementsAttr applyElementWise(
const DenseElementsAttr &toTransform,
const std::function<TargetValType(const SrcValType &, TargetType)> &toApply,
const std::function<TargetValType(const SrcValType &)> &toApply,
TargetType targetType) {
SmallVector<TargetValType> transformedValues;
// We already know the amount of values we will insert, reserve space for
// all of them to avoid dynamic resizing
transformedValues.reserve(toTransform.getNumElements());
for (auto val : toTransform.getValues<SrcValType>()) {
auto transformedVal = toApply(val, targetType);
auto transformedVal = toApply(val);
transformedValues.push_back(transformedVal);
}

Expand All @@ -64,7 +60,7 @@ DenseElementsAttr applyElementWise(

template DenseElementsAttr applyElementWise<APFloat, APFloat, FloatType>(
const DenseElementsAttr &toTransform,
const std::function<APFloat(const APFloat &, FloatType)> &toApply,
const std::function<APFloat(const APFloat &)> &toApply,
FloatType targetType);

/// Function that checks if the type contained in \p toCheck is float.
Expand Down Expand Up @@ -249,14 +245,6 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {

using OpRewritePattern::OpRewritePattern;

static APFloat computeReciprocal(const APFloat &floatVal, FloatType floatTy) {
auto recipAttr = FloatAttr::get(floatTy, 1.0);
APFloat recip = recipAttr.getValue();
recip.divide(floatVal, tosaRoundingMode);

return recip;
}

LogicalResult matchAndRewrite(ReciprocalOp recip,
PatternRewriter &rewriter) const override {
auto inputTensor = recip.getInput1();
Expand All @@ -281,7 +269,7 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {

// Create a new tensor with the updated values
auto newTensor = applyElementWise<APFloat, APFloat, FloatType>(
inputValues, &computeReciprocal,
inputValues, &ReciprocalOp::calcOneElement,
cast<FloatType>(inputValues.getElementType()));

// Replace the use of the reciprocal with the transformed tensor
Expand Down
24 changes: 24 additions & 0 deletions mlir/test/Dialect/Tosa/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -613,3 +613,27 @@ func.func nested @fold_tile_rank_zero() -> tensor<i32> {
%1 = tosa.tile %0 {multiples = array<i64>} : (tensor<i32>) -> tensor<i32>
return %1 : tensor<i32>
}

// -----

// CHECK-LABEL: @fold_reciprocal
func.func nested @fold_reciprocal() -> tensor<3x600x1200xf32> {
// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<8.620690e-03> : tensor<3x600x1200xf32>}> : () -> tensor<3x600x1200xf32>
// CHECK: return %[[VAL_0]] : tensor<3x600x1200xf32>
// CHECK: }
%0 = "tosa.const"(){ value = dense<116.0>: tensor<f32> }: () -> tensor<f32>
%1 = "tosa.cast"(%0) : (tensor<f32>) -> tensor<3x600x1200xf32>
%2 = "tosa.reciprocal"(%1): (tensor<3x600x1200xf32>) -> tensor<3x600x1200xf32>
return %2 : tensor<3x600x1200xf32>
}

// -----

// CHECK-LABEL: @do_not_fold_reciprocal_int
func.func nested @do_not_fold_reciprocal_int() -> tensor<3x600x1200xi32> {
// CHECK: tosa.reciprocal
%0 = "tosa.const"(){ value = dense<11>: tensor<i32> }: () -> tensor<i32>
%1 = "tosa.cast"(%0) : (tensor<i32>) -> tensor<3x600x1200xi32>
%2 = "tosa.reciprocal"(%1): (tensor<3x600x1200xi32>) -> tensor<3x600x1200xi32>
return %2 : tensor<3x600x1200xi32>
}

0 comments on commit d89a0a6

Please sign in to comment.