-
Notifications
You must be signed in to change notification settings - Fork 11.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[mlir][vector] Refine vectorisation of tensor.extract
This PR fixes a bug in `isLoopInvariantIdx`. It makes sure that the following case is vectorised as `vector.gather` (as opposed to attempting a contiguous load): ```mlir func.func @index_from_output_column_vector_gather_load(%src: tensor<8x128xf32>) -> tensor<8x1xf32> { %c0 = arith.constant 0 : index %0 = tensor.empty() : tensor<8x1xf32> %res = linalg.generic { indexing_maps = [#map], iterator_types = ["parallel", "parallel"] } outs(%0 : tensor<8x1xf32>) { ^bb0(%arg1: f32): %1 = linalg.index 0 : index %extracted = tensor.extract %src[%1, %c0] : tensor<8x128xf32> linalg.yield %extracted : f32 } -> tensor<8x1xf32> return %res : tensor<8x1xf32> } ``` Specifically, when looking for loop-invariant indices in `tensor.extract` Ops, any `linalg.index` Op that's used in address colcluation should only access loop dims that are == 1. In the example above, the following does not meet that criteria: ```mlir %1 = linalg.index 0 : index ``` Note that this PR also effectively addresses the issue fixed in #107922, i.e. exercised by: * `@vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load` `getNonUnitLoopDim` introduced in #107922 is still valid though. In fact, it is required to identify that the following case is a contiguous load: ```mlir func.func @index_from_output_column_vector_contiguous_load(%src: tensor<8x128xf32>) -> tensor<8x1xf32> { %c0 = arith.constant 0 : index %0 = tensor.empty() : tensor<8x1xf32> %res = linalg.generic { indexing_maps = [#map], iterator_types = ["parallel", "parallel"] } outs(%0 : tensor<8x1xf32>) { ^bb0(%arg1: f32): %1 = linalg.index 0 : index %extracted = tensor.extract %src[%c0, %1] : tensor<8x128xf32> linalg.yield %extracted : f32 } -> tensor<8x1xf32> return %res : tensor<8x1xf32> } ``` Some logic is still missing to lower the above to `vector.transfer_read`, so it is conservatively lowered to `vector.gather` (see TODO in `getTensorExtractMemoryAccessPattern`). There's a few additional changes: * `getNonUnitLoopDim` is simplified and renamed as `getTrailingNonUnitLoopDimIdx`, additional comments are added (note that the functionality didn't change); * extra comments in a few places, variable names in comments update to use Markdown (which is the preferred approach in MLIR). This is a follow-on for: * #107922 * #102321
- Loading branch information
1 parent
2f1e04f
commit 32c39dd
Showing
2 changed files
with
128 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters