Skip to content

Commit

Permalink
[mlir][sparse] introduce sparse_tensor.extract_value operation. (ll…
Browse files Browse the repository at this point in the history
  • Loading branch information
Peiming Liu authored Jul 30, 2024
1 parent 99fb40d commit 12189f8
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 0 deletions.
25 changes: 25 additions & 0 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1531,6 +1531,31 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
let hasVerifier = 1;
}

def ExtractValOp : SparseTensor_Op<"extract_value", [
Pure,
TypesMatchWith<"result type matches element type of tensor",
"tensor", "result",
"::llvm::cast<TensorType>($_self).getElementType()">]> {
let summary = "Extracts a value from a sparse tensor using an iterator.";
let description = [{
The `sparse_tensor.extract_value` operation extracts the value
pointed to by a sparse iterator from a sparse tensor.

Example:

```mlir
%val = sparse_tensor.extract_value %sp at %it
: tensor<?x?xf32, #CSR>, !sparse_tensor.iterator<#CSR, lvl = 1>
```
}];

let arguments = (ins AnySparseTensor:$tensor, AnySparseIterator:$iterator);
let results = (outs AnyType:$result);

let assemblyFormat = "$tensor `at` $iterator attr-dict `:` type($tensor)`,` qualified(type($iterator))";
let hasVerifier = 1;
}

def IterateOp : SparseTensor_Op<"iterate",
[RecursiveMemoryEffects, RecursivelySpeculatable,
DeclareOpInterfaceMethods<LoopLikeOpInterface,
Expand Down
13 changes: 13 additions & 0 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2267,6 +2267,19 @@ LogicalResult ExtractIterSpaceOp::verify() {
return success();
}

LogicalResult ExtractValOp::verify() {
auto stt = getSparseTensorType(getTensor());
auto itTp = getIterator().getType();

if (stt.getEncoding() != itTp.getEncoding())
return emitOpError("mismatch in tensor encoding and iterator encoding.");

if (stt.getLvlRank() != itTp.getHiLvl())
return emitOpError("must use last-level iterator to extract values. ");

return success();
}

struct RemoveUnusedLvlCrds : public OpRewritePattern<IterateOp> {
using OpRewritePattern::OpRewritePattern;

Expand Down
36 changes: 36 additions & 0 deletions mlir/test/Dialect/SparseTensor/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1099,6 +1099,42 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
return
}

// -----

#COO = #sparse_tensor.encoding<{
map = (i, j) -> (
i : compressed(nonunique),
j : singleton(soa)
)
}>

#CSR = #sparse_tensor.encoding<{
map = (i, j) -> (
i : dense,
j : compressed
)
}>

func.func @sparse_extract_value(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#CSR, lvls = 1>) -> f32 {
// expected-error@+1 {{'sparse_tensor.extract_value' op mismatch in tensor encoding and iterator encoding.}}
%f = sparse_tensor.extract_value %sp at %it1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#CSR, lvls = 1>
return %f : f32
}

// -----

#COO = #sparse_tensor.encoding<{
map = (i, j) -> (
i : compressed(nonunique),
j : singleton(soa)
)
}>

func.func @sparse_extract_value(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) -> f32 {
// expected-error@+1 {{'sparse_tensor.extract_value' op must use last-level iterator to extract values.}}
%f = sparse_tensor.extract_value %sp at %it1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
return %f : f32
}

// -----

Expand Down
21 changes: 21 additions & 0 deletions mlir/test/Dialect/SparseTensor/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,27 @@ func.func @sparse_has_runtime() -> i1 {
return %has_runtime : i1
}

// -----

#COO = #sparse_tensor.encoding<{
map = (i, j) -> (
i : compressed(nonunique),
j : singleton(soa)
)
}>

// CHECK-LABEL: func.func @sparse_extract_value(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x8xf32, #sparse>,
// CHECK-SAME: %[[VAL_1:.*]]: !sparse_tensor.iterator<#sparse, lvls = 1>) -> f32 {
// CHECK: %[[VAL_2:.*]] = sparse_tensor.extract_value %[[VAL_0]] at %[[VAL_1]] : tensor<4x8xf32, #sparse>, !sparse_tensor.iterator<#sparse, lvls = 1>
// CHECK: return %[[VAL_2]] : f32
// CHECK: }
func.func @sparse_extract_value(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 1>) -> f32 {
%f = sparse_tensor.extract_value %sp at %it1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 1>
return %f : f32
}


// -----

#COO = #sparse_tensor.encoding<{
Expand Down

0 comments on commit 12189f8

Please sign in to comment.