Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][vector] Refine vectorisation of tensor.extract #109580

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

banach-space
Copy link
Contributor

This PR fixes a bug in isLoopInvariantIdx. It makes sure that the
following case is vectorised as vector.gather (as opposed to
attempting a contiguous load):

  func.func @index_from_output_column_vector_gather_load(%src: tensor<8x128xf32>) -> tensor<8x1xf32> {
    %c0 = arith.constant 0 : index
    %0 = tensor.empty() : tensor<8x1xf32>
    %res = linalg.generic {
      indexing_maps = [#map],
      iterator_types = ["parallel", "parallel"]
    } outs(%0 : tensor<8x1xf32>) {
    ^bb0(%arg1: f32):
        %1 = linalg.index 0 : index
      %extracted = tensor.extract %src[%1, %c0] : tensor<8x128xf32>
        linalg.yield %extracted : f32
    } -> tensor<8x1xf32>
    return %res : tensor<8x1xf32>
  }

Specifically, when looking for loop-invariant indices in
tensor.extract Ops, any linalg.index Op that's used in address
colcluation should only access loop dims that are == 1. In the example
above, the following does not meet that criteria:

  %1 = linalg.index 0 : index

Note that this PR also effectively addresses the issue fixed in #107922, i.e.
exercised by:
*
@vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load

getNonUnitLoopDim introduced in #107922 is still valid though. In
fact, it is required to identify that the following case is a contiguous
load:

  func.func @index_from_output_column_vector_contiguous_load(%src: tensor<8x128xf32>) -> tensor<8x1xf32> {
    %c0 = arith.constant 0 : index
    %0 = tensor.empty() : tensor<8x1xf32>
    %res = linalg.generic {
      indexing_maps = [#map],
      iterator_types = ["parallel", "parallel"]
    } outs(%0 : tensor<8x1xf32>) {
    ^bb0(%arg1: f32):
        %1 = linalg.index 0 : index
      %extracted = tensor.extract %src[%c0, %1] : tensor<8x128xf32>
        linalg.yield %extracted : f32
    } -> tensor<8x1xf32>
    return %res : tensor<8x1xf32>
  }

Some logic is still missing to lower the above to
vector.transfer_read, so it is conservatively lowered to
vector.gather (see TODO in getTensorExtractMemoryAccessPattern).

There's a few additional changes:

  • getNonUnitLoopDim is simplified and renamed as
    getTrailingNonUnitLoopDimIdx, additional comments are added (note
    that the functionality didn't change);
  • extra comments in a few places, variable names in comments update to
    use Markdown (which is the preferred approach in MLIR).

This is a follow-on for:

This PR fixes a bug in `isLoopInvariantIdx`. It makes sure that the
following case is vectorised as `vector.gather` (as opposed to
attempting a contiguous load):
```mlir
  func.func @index_from_output_column_vector_gather_load(%src: tensor<8x128xf32>) -> tensor<8x1xf32> {
    %c0 = arith.constant 0 : index
    %0 = tensor.empty() : tensor<8x1xf32>
    %res = linalg.generic {
      indexing_maps = [#map],
      iterator_types = ["parallel", "parallel"]
    } outs(%0 : tensor<8x1xf32>) {
    ^bb0(%arg1: f32):
        %1 = linalg.index 0 : index
      %extracted = tensor.extract %src[%1, %c0] : tensor<8x128xf32>
        linalg.yield %extracted : f32
    } -> tensor<8x1xf32>
    return %res : tensor<8x1xf32>
  }
```

Specifically, when looking for loop-invariant indices in
`tensor.extract` Ops, any `linalg.index` Op that's used in address
colcluation should only access loop dims that are == 1. In the example
above, the following does not meet that criteria:
```mlir
  %1 = linalg.index 0 : index
```

Note that this PR also effectively addresses the issue fixed in llvm#107922, i.e.
exercised by:
  *
  `@vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load`

`getNonUnitLoopDim` introduced in llvm#107922 is still valid though. In
fact, it is required to identify that the following case is a contiguous
load:
```mlir
  func.func @index_from_output_column_vector_contiguous_load(%src: tensor<8x128xf32>) -> tensor<8x1xf32> {
    %c0 = arith.constant 0 : index
    %0 = tensor.empty() : tensor<8x1xf32>
    %res = linalg.generic {
      indexing_maps = [#map],
      iterator_types = ["parallel", "parallel"]
    } outs(%0 : tensor<8x1xf32>) {
    ^bb0(%arg1: f32):
        %1 = linalg.index 0 : index
      %extracted = tensor.extract %src[%c0, %1] : tensor<8x128xf32>
        linalg.yield %extracted : f32
    } -> tensor<8x1xf32>
    return %res : tensor<8x1xf32>
  }
```
Some logic is still missing to lower the above to
`vector.transfer_read`, so it is conservatively lowered to
`vector.gather` (see TODO in `getTensorExtractMemoryAccessPattern`).

There's a few additional changes:
  * `getNonUnitLoopDim` is simplified and renamed as
    `getTrailingNonUnitLoopDimIdx`, additional comments are added (note
    that the functionality didn't change);
  * extra comments in a few places, variable names in comments update to
    use Markdown (which is the preferred approach in MLIR).

This is a follow-on for:
  * llvm#107922
  * llvm#102321
@llvmbot
Copy link
Collaborator

llvmbot commented Sep 22, 2024

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes

This PR fixes a bug in isLoopInvariantIdx. It makes sure that the
following case is vectorised as vector.gather (as opposed to
attempting a contiguous load):

  func.func @<!-- -->index_from_output_column_vector_gather_load(%src: tensor&lt;8x128xf32&gt;) -&gt; tensor&lt;8x1xf32&gt; {
    %c0 = arith.constant 0 : index
    %0 = tensor.empty() : tensor&lt;8x1xf32&gt;
    %res = linalg.generic {
      indexing_maps = [#map],
      iterator_types = ["parallel", "parallel"]
    } outs(%0 : tensor&lt;8x1xf32&gt;) {
    ^bb0(%arg1: f32):
        %1 = linalg.index 0 : index
      %extracted = tensor.extract %src[%1, %c0] : tensor&lt;8x128xf32&gt;
        linalg.yield %extracted : f32
    } -&gt; tensor&lt;8x1xf32&gt;
    return %res : tensor&lt;8x1xf32&gt;
  }

Specifically, when looking for loop-invariant indices in
tensor.extract Ops, any linalg.index Op that's used in address
colcluation should only access loop dims that are == 1. In the example
above, the following does not meet that criteria:

  %1 = linalg.index 0 : index

Note that this PR also effectively addresses the issue fixed in #107922, i.e.
exercised by:
*
@<!-- -->vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load

getNonUnitLoopDim introduced in #107922 is still valid though. In
fact, it is required to identify that the following case is a contiguous
load:

  func.func @<!-- -->index_from_output_column_vector_contiguous_load(%src: tensor&lt;8x128xf32&gt;) -&gt; tensor&lt;8x1xf32&gt; {
    %c0 = arith.constant 0 : index
    %0 = tensor.empty() : tensor&lt;8x1xf32&gt;
    %res = linalg.generic {
      indexing_maps = [#map],
      iterator_types = ["parallel", "parallel"]
    } outs(%0 : tensor&lt;8x1xf32&gt;) {
    ^bb0(%arg1: f32):
        %1 = linalg.index 0 : index
      %extracted = tensor.extract %src[%c0, %1] : tensor&lt;8x128xf32&gt;
        linalg.yield %extracted : f32
    } -&gt; tensor&lt;8x1xf32&gt;
    return %res : tensor&lt;8x1xf32&gt;
  }

Some logic is still missing to lower the above to
vector.transfer_read, so it is conservatively lowered to
vector.gather (see TODO in getTensorExtractMemoryAccessPattern).

There's a few additional changes:

  • getNonUnitLoopDim is simplified and renamed as
    getTrailingNonUnitLoopDimIdx, additional comments are added (note
    that the functionality didn't change);
  • extra comments in a few places, variable names in comments update to
    use Markdown (which is the preferred approach in MLIR).

This is a follow-on for:


Full diff: https://github.com/llvm/llvm-project/pull/109580.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+50-35)
  • (modified) mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir (+78)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 6800a0fec278c6..8891299c4a099b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -810,27 +810,35 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
 
 enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
 
-/// Find the non-unit dim in a linalgOp.
-/// When executing this hook, it is expected that only one dim will be non-unit.
-/// Other cases (i.e. reading n-D vectors) should've been labelled as gather
-/// loads before calling this method. This is used for finding contiguous loads
-/// (represented as `tensor.extract`) within `linalg.generic` Ops. Note that
-/// this condition is expected to hold for statically shaped Linalg Ops only.
-static uint64_t getNonUnitLoopDim(LinalgOp linalgOp) {
-  uint64_t nonUnitDim = 0;
-  uint64_t countNonUnitDim = 0;
-  for (auto tripCount : llvm::enumerate(linalgOp.getStaticLoopRanges())) {
-    if (tripCount.value() != 1) {
-      nonUnitDim = tripCount.index();
-      countNonUnitDim++;
-    }
-  }
-
+/// Find the index of the trailing non-unit dim in linalgOp. This hook is used
+/// when checking whether `tensor.extract` Op (within a `linalg.generic` Op)
+/// represents a contiguous load operaiton.
+///
+/// Note that when calling this hook, it is assumed that the output vector is
+/// effectively 1D. Other cases (i.e. reading n-D vectors) should've been
+/// labelled as a gather load before entering this method.
+///
+/// Following on from the above, it is assumed that:
+///   * for statically shaped loops, when no masks are used, only one dim is !=
+///   1 (that's what the shape of the output vector is based on).
+///   * for dynamically shaped loops, there might be more non-unit dims
+///   as the output vector type is user-specified.
+///
+/// TODO: Statically shaped loops + vector masking
+static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp) {
+  SmallVector<int64_t> loopRanges = linalgOp.getStaticLoopRanges();
   assert(linalgOp.hasDynamicShape() ||
-         countNonUnitDim == 1 && "For statically shaped Linalg Ops, only one "
-                                 "non-unit loop dim is expected");
-  (void)countNonUnitDim;
-  return nonUnitDim;
+         llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) ==
+                 1 &&
+             "For statically shaped Linalg Ops, only one "
+             "non-unit loop dim is expected");
+
+  size_t idx = loopRanges.size() - 1;
+  for (; idx >= 0; idx--)
+    if (loopRanges[idx] != 1)
+      break;
+
+  return idx;
 }
 
 /// Checks whether `val` can be used for calculating a loop invariant index.
@@ -854,11 +862,11 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
   assert(defOp && "This is neither a block argument nor an operation result");
 
   // IndexOp is loop invariant as long as its result remains constant across
-  // iterations. Given the assumptions on the loop ranges above, only the
-  // trailing loop dim ever changes.
-  auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
-  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
-    return (indexOp.getDim() != trailingLoopDim);
+  // iterations. Note that for dynamic shapes, the corresponding dim will also
+  // be conservatively treated as != 1.
+  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
+    return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
+  }
 
   auto *ancestor = block->findAncestorOpInBlock(*defOp);
 
@@ -877,7 +885,7 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
   return result;
 }
 
-/// Check whether \p val could be used for calculating the trailing index for a
+/// Check whether `val` could be used for calculating the trailing index for a
 /// contiguous load operation.
 ///
 /// There are currently 3 types of values that are allowed here:
@@ -886,13 +894,14 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
 ///   3. results of basic arithmetic operations (linear and continuous)
 ///      involving 1., 2. and 3.
 /// This method returns True if indeed only such values are used in calculating
-/// \p val.
+/// `val.`
 ///
 /// Additionally, the trailing index for a contiguous load operation should
 /// increment by 1 with every loop iteration, i.e. be based on:
 ///   * `linalg.index <dim>` ,
-/// where <dim> is the trailing dim of the iteration space. \p foundIndexOp is
-/// updated to `true` when such an op is found.
+/// where <dim> is the trailing non-unit dim of the iteration space (this way,
+/// `linalg.index <dim>` increments by 1 with every loop iteration).
+/// `foundIndexOp` is updated to `true` when such Op is found.
 static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
                                 bool &foundIndexOp, VectorType resType) {
 
@@ -912,12 +921,10 @@ static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
   Operation *defOp = val.getDefiningOp();
   assert(defOp && "This is neither a block argument nor an operation result");
 
-  // Given the assumption on the loop ranges above, we expect only 1 non-unit
-  // loop dim.
-  auto nonUnitLoopDim = getNonUnitLoopDim(linalgOp);
-
   if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
-    foundIndexOp = (indexOp.getDim() == nonUnitLoopDim);
+    auto loopDimThatIncrementsByOne = getTrailingNonUnitLoopDimIdx(linalgOp);
+
+    foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
     return true;
   }
 
@@ -1012,7 +1019,10 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
   bool foundIndexOp = false;
   bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx,
                                               foundIndexOp, resType);
-  isContiguousLoad &= foundIndexOp;
+  // TODO: Support generating contiguous loads for column vectors - that will
+  // require adding a permutation map to tranfer_read Ops.
+  bool isRowVector = resType.getShape().back() != 1;
+  isContiguousLoad &= (foundIndexOp && isRowVector);
 
   if (isContiguousLoad) {
     LDBG("Found contigous load: " << extractOp);
@@ -1073,6 +1083,11 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
   //  b. contiguous loads.
   // Both cases use vector.transfer_read.
 
+  assert(llvm::count_if(resultType.getShape(),
+                        [](uint64_t dim) { return dim != 1; }) &&
+         "Contiguous loads and scalar loads + broadcast only support 1-D "
+         "vectors ATM!");
+
   // Collect indices for `vector.transfer_read`. At this point, the indices will
   // either be scalars or would have been broadcast to vectors matching the
   // result type. For indices that are vectors, there are two options:
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index ad3a8d9f926082..a140ae616bb6c2 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -307,6 +307,84 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// Reading a 1D column vector (hence a candidate for a contiguous load), but given
+// %1, it's a gather load.
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @index_from_output_column_vector_gather_load(%src: tensor<8x128xf32>) -> tensor<8x1xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = tensor.empty() : tensor<8x1xf32>
+  %res = linalg.generic {
+    indexing_maps = [#map],
+    iterator_types = ["parallel", "parallel"]
+  } outs(%0 : tensor<8x1xf32>) {
+  ^bb0(%arg1: f32):
+      %1 = linalg.index 0 : index
+    %extracted = tensor.extract %src[%1, %c0] : tensor<8x128xf32>
+      linalg.yield %extracted : f32
+  } -> tensor<8x1xf32>
+  return %res : tensor<8x1xf32>
+}
+
+// CHECK-LABEL:   func.func @index_from_output_column_vector_gather_load(
+// CHECK-SAME:      %[[SRC:.*]]: tensor<8x128xf32>) -> tensor<8x1xf32> {
+// CHECK:           %[[C128:.*]] = arith.constant dense<128> : vector<1x8xindex>
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[PASS_THRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
+// CHECK:           %[[MASK:.*]] = arith.constant dense<true> : vector<8x1xi1>
+// CHECK:           %[[IDX_VEC:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
+// CHECK:           %[[OUT:.*]] = tensor.empty() : tensor<8x1xf32>
+// CHECK:           %[[B:.*]] = vector.broadcast %[[IDX_VEC]] : vector<8xindex> to vector<1x8xindex>
+// CHECK:           %[[MUL:.*]] = arith.muli %[[B]], %[[C128]] : vector<1x8xindex>
+// CHECK:           %[[TR:.*]] = vector.transpose %[[MUL]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
+// CHECK:           %[[GATHER:.*]] = vector.gather %[[SRC]]{{\[}}%[[C0]], %[[C0]]] {{\[}}%[[TR]]], %[[MASK]], %[[PASS_THRU]] : tensor<8x128xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
+// CHECK:           %[[RES:.*]] = vector.transfer_write %[[GATHER]], %[[OUT]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
+// CHECK:           return %[[RES]] : tensor<8x1xf32>
+
+// Same as above, but the access indices have been swapped and hence this is
+// contiguous load. Currently not supported and lowered as vector.gather
+// instead.
+// TODO: Make sure that this is lowered as a contiguous load.
+
+func.func @index_from_output_column_vector_contiguous_load(%src: tensor<8x128xf32>) -> tensor<8x1xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = tensor.empty() : tensor<8x1xf32>
+  %res = linalg.generic {
+    indexing_maps = [#map],
+    iterator_types = ["parallel", "parallel"]
+  } outs(%0 : tensor<8x1xf32>) {
+  ^bb0(%arg1: f32):
+      %1 = linalg.index 0 : index
+    %extracted = tensor.extract %src[%c0, %1] : tensor<8x128xf32>
+      linalg.yield %extracted : f32
+  } -> tensor<8x1xf32>
+  return %res : tensor<8x1xf32>
+}
+
+// CHECK-LABEL:   func.func @index_from_output_column_vector_contiguous_load(
+// CHECK-SAME:      %[[SRC:.*]]: tensor<8x128xf32>) -> tensor<8x1xf32> {
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[PASS_THRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
+// CHECK:           %[[MASK:.*]] = arith.constant dense<true> : vector<8x1xi1>
+// CHECK:           %[[IDX_VEC:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
+// CHECK:           %[[OUT:.*]] = tensor.empty() : tensor<8x1xf32>
+// CHECK:           %[[B:.*]] = vector.broadcast %[[IDX_VEC]] : vector<8xindex> to vector<1x8xindex>
+// CHECK:           %[[TR:.*]] = vector.transpose %[[B]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
+// CHECK:           %[[GATHER:.*]] = vector.gather %[[SRC]]{{\[}}%[[C0]], %[[C0]]] {{\[}}%[[TR]]], %[[MASK]], %[[PASS_THRU]] : tensor<8x128xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
+// CHECK:           %[[RES:.*]] = vector.transfer_write %[[GATHER]], %[[OUT]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
+// CHECK:           return %[[RES]] : tensor<8x1xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg2: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg2 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 {vectorize_nd_extract} : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 #map = affine_map<(d0) -> (d0)>
 func.func @vectorize_nd_tensor_extract_contiguous_and_gather(%arg0: tensor<6xf32>, %arg1: tensor<5xi32>) -> tensor<5xf32> {
  %c5 = arith.constant 5 : index

@banach-space
Copy link
Contributor Author

@nirvedhmeshram, given that you touched this code recently, would you have the bandwidth to review my PR?

The actual functional changes are small, but I appreciate that I mix a few separate changes. I've tried to capture that in the summary, but am also happy to split this patch if that makes reviewing easier 😅 There's more to come :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants