-
Notifications
You must be signed in to change notification settings - Fork 574
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
IREE Custom tilable op (iree_linalg_ext.custom_op
).
#18555
base: main
Are you sure you want to change the base?
IREE Custom tilable op (iree_linalg_ext.custom_op
).
#18555
Conversation
This operation is meant to allow users/front ends to specify computations that can be fused at a tile granularity. IREE by default fuses certain operations (like linalg/linalg_ext ops) at tile granularity, but there might be certain sequences that IREE cannot/is not able to fuse. Previously such custom fusions could only be implemented by using a "black-box" approach where it would by-pass the entire compilation stack and either inject a manually generated binary or use a transform dialect script that implements a custom lowering sequence. This operation allows front ends/users to specify such a fusions within the region of the operation. The `indexing_maps` and `iterator_types` capture all the information necessary for IREE to distribute these operations to tile/distribute this operation, as well as fuse with other operations (like elementwise operations). The operation is meant to implement all interfaces necessary to be able to compiled/executed with IREE. Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
8e3c3b1
to
a165bfd
Compare
return emitOpError("expected number of indexing maps to be same as the " | ||
"number of input/output operands"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: would be nice to also print the values in the error message here, e.g.:
expected number of indexing maps (4) to be same as the number of input/output operands (5)
// Enum definitions | ||
//===----------------------------------------------------------------------===// | ||
|
||
def IREELinalgExt_IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to reuse the iterator types from linalg? Or is this to avoid this dependency?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, the iterator type enum attr is not reusable last I tried, but the utils::IteratorType
enum itself should be reusable similar to here:
def IREEGPU_IteratorTypeEnum |
Thank you for this. |
// Enum definitions | ||
//===----------------------------------------------------------------------===// | ||
|
||
def IREELinalgExt_IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, the iterator type enum attr is not reusable last I tried, but the utils::IteratorType
enum itself should be reusable similar to here:
def IREEGPU_IteratorTypeEnum |
The region of the operation represents the computation at a tile granularity. | ||
The basic block arguments correspond to tiles of inputs/outputs. The region | ||
yields tiled results of the same type as the corresponding basic block | ||
arguments for the outputs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the block arguments could use more description. Are statically shaped block arguments allowed? I would assume no, as it would be unclear how to tile along the statically-casted tile dimensions.
The operation has operands similar to `LinalgOp`s | ||
1. Indexing maps : Modification on top of Linalg Op is that the indexing | ||
map can be `affine_map<() -> ()>` indicating that operand cannot | ||
be fused along and that this operand is not tiled |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel that we need more control than per-operand. Take a similar problem to attention that motivated the addition of this op.
iree_linalg_ext.custom_op ins(%q: tensor<?x?>, %k: tensor<?x?>, %v: tensor<?x?>) {
^bb0(%q_tile: tensor<?x?>, %k_tile: tensor<?x?>, %v_tile: tensor<?x?>)
%0 = linalg.matmul_transpose_b ins(%q_tile, %k_tile)
%1 = linalg.matmul_transpose_b ins(%0, %v_tile)
linalg_ext.yield %1
}
The question here is how to design the indexing maps for this op. This op effectively has 4
loops hidden inside of it, the M
and N
loops of the second matmul, and the reduction loop for each matmul.
iterator_types = [parallel, parallel, reduction, reduction]
There is an issue here though, where the first reduction loop is not tilable as a part of this operation. Tiling the first reduction loop would need to produce a loop around only the first matmul, which is incompatible with the tiling interface which forms the loop then invokes the interface on a clone of the full operation.
So if we can't tile that reduction loop, how do we design affine maps for the Q
operand. Well the Q
operation is shaped tensor<M x K1>
, which is a tilable parallel loop and untilable dimension so there isn't a clear answer to what the indexing map should be.
I think we should just make the iteration space on the op explicit to mitigate issues like this. The status of the K1
loop as a reduction loop would be recovered after decomposing (inlining) this custom op, at which point we can do further lowerings/tilings if needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also had the same concern which I pointed out to Mahesh. One suggestion I had was representing these non-tilable dimensions with a special iterator type that explicitly states that these dimensions are non-tilable. These dimensions would also need to be the inner-most dimensions of the operation, otherwise you cannot tile any dimension that occurs after these non-tilable dimensions (you cannot move a loop out of a non-tilable loop, because it is non-perfectly nested).
I do think that there is something hacky going on here. I'm going to write out more examples related to why I think there will be problems with attention and hopefully present them next week when we sync.
This operation is meant to allow users/front ends to specify computations that can be fused at a tile granularity. IREE by default fuses certain operations (like linalg/linalg_ext ops) at tile granularity, but there might be certain sequences that IREE cannot/is not able to fuse. Previously such custom fusions could only be implemented by using a "black-box" approach where it would by-pass the entire compilation stack and either inject a manually generated binary or use a transform dialect script that implements a custom lowering sequence.
This operation allows front ends/users to specify such a fusions within the region of the operation. The
indexing_maps
anditerator_types
capture all the information necessary for IREE to distribute these operations to tile/distribute this operation, as well as fuse with other operations (like elementwise operations).The operation is meant to implement all interfaces necessary to be able to compiled/executed with IREE.