Skip to content

Commit

Permalink
[Codegen] Check for workgroup level tile sizes in workgroup tiling (#…
Browse files Browse the repository at this point in the history
…18538)

TileDispatchUsingForall relies on lowering configurations having
workgroup level tile sizes, so this PR adds the additional check that
the tilableOp has workgroup level tile sizes. It also adds verification
that there is only one op with a workgroup tiling level.

---------

Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
  • Loading branch information
Max191 committed Sep 19, 2024
1 parent 73ffafb commit 782f372
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,18 @@ struct TilingInfo {
static FailureOr<TilingInfo>
getTiledAndDistributionInfo(RewriterBase &rewriter,
ArrayRef<Operation *> computeOps) {
// It is expected that at most one compute op has a workgroup tiling level.
Operation *tilableOp = nullptr;
for (Operation *op : llvm::reverse(computeOps)) {
if (getLoweringConfig(op)) {
if (!getLoweringConfig(op).hasWorkgroupTilingLevel()) {
continue;
}
if (tilableOp) {
return op->emitOpError("expected only one op with a workgroup tiling"
"level.");
}
tilableOp = op;
break;
}
}
if (!tilableOp) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,10 @@ bool LoweringConfigAttr::hasTilingLevel(unsigned level) const {
return !getTileSizeVals(level).empty();
}

bool LoweringConfigAttr::hasWorkgroupTilingLevel() const {
return !getWorkgroupTileSizes().empty();
}

LogicalResult
LoweringConfigAttr::verify(function_ref<InFlightDiagnostic()> emitError,
LoweringConfigTilingLevelsAttr levels,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def IREECodegen_LoweringConfigAttr :
"getStaticTilingLevelSizes",
"getTilingLevelSizes",
"hasTilingLevel",
"hasWorkgroupTilingLevel",
]>
]> {
let mnemonic = "lowering_config";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,19 @@ def IREECodegen_LoweringConfigAttrInterface :
return false;
}]
>,
InterfaceMethod<
/*desc=*/[{
Returns true if the lowering config specifies tile sizes for the
workgroup tiling level.
}],
/*retTy=*/"bool",
/*methodName=*/"hasWorkgroupTilingLevel",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return false;
}]
>,
InterfaceMethod<
/*desc=*/[{
Returns the tile sizes for the specified tiling level. The
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1348,6 +1348,10 @@ bool LoweringConfigAttr::hasTilingLevel(unsigned level) const {
.empty();
}

bool LoweringConfigAttr::hasWorkgroupTilingLevel() const {
return !getWorkgroupTileSizes().empty();
}

constexpr StringLiteral kMmaKindName = "mma_kind";

IREE::GPU::MmaInterfaceAttr LoweringConfigAttr::getMmaKind() const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def IREEGPU_LoweringConfigAttr :
"getStaticTilingLevelSizes",
"getTilingLevelSizes",
"hasTilingLevel",
"hasWorkgroupTilingLevel",
]>
]> {
let mnemonic = "lowering_config";
Expand Down

0 comments on commit 782f372

Please sign in to comment.