diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td index f31df080d7811a..ff9858d5832ba8 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -1531,6 +1531,31 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space", let hasVerifier = 1; } +def ExtractValOp : SparseTensor_Op<"extract_value", [ + Pure, + TypesMatchWith<"result type matches element type of tensor", + "tensor", "result", + "::llvm::cast($_self).getElementType()">]> { + let summary = "Extracts a value from a sparse tensor using an iterator."; + let description = [{ + The `sparse_tensor.extract_value` operation extracts the value + pointed to by a sparse iterator from a sparse tensor. + + Example: + + ```mlir + %val = sparse_tensor.extract_value %sp at %it + : tensor, !sparse_tensor.iterator<#CSR, lvl = 1> + ``` + }]; + + let arguments = (ins AnySparseTensor:$tensor, AnySparseIterator:$iterator); + let results = (outs AnyType:$result); + + let assemblyFormat = "$tensor `at` $iterator attr-dict `:` type($tensor)`,` qualified(type($iterator))"; + let hasVerifier = 1; +} + def IterateOp : SparseTensor_Op<"iterate", [RecursiveMemoryEffects, RecursivelySpeculatable, DeclareOpInterfaceMethods { using OpRewritePattern::OpRewritePattern; diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir index eb0dc01be25b93..61cc9be88685cc 100644 --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -1099,6 +1099,42 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse return } +// ----- + +#COO = #sparse_tensor.encoding<{ + map = (i, j) -> ( + i : compressed(nonunique), + j : singleton(soa) + ) +}> + +#CSR = #sparse_tensor.encoding<{ + map = (i, j) -> ( + i : dense, + j : compressed + ) +}> + +func.func @sparse_extract_value(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#CSR, lvls = 1>) -> f32 { + // expected-error@+1 {{'sparse_tensor.extract_value' op mismatch in tensor encoding and iterator encoding.}} + %f = sparse_tensor.extract_value %sp at %it1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#CSR, lvls = 1> + return %f : f32 +} + +// ----- + +#COO = #sparse_tensor.encoding<{ + map = (i, j) -> ( + i : compressed(nonunique), + j : singleton(soa) + ) +}> + +func.func @sparse_extract_value(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) -> f32 { + // expected-error@+1 {{'sparse_tensor.extract_value' op must use last-level iterator to extract values.}} + %f = sparse_tensor.extract_value %sp at %it1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0> + return %f : f32 +} // ----- diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir index bce0b41a99828a..055709ee69eb71 100644 --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -739,6 +739,27 @@ func.func @sparse_has_runtime() -> i1 { return %has_runtime : i1 } +// ----- + +#COO = #sparse_tensor.encoding<{ + map = (i, j) -> ( + i : compressed(nonunique), + j : singleton(soa) + ) +}> + +// CHECK-LABEL: func.func @sparse_extract_value( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x8xf32, #sparse>, +// CHECK-SAME: %[[VAL_1:.*]]: !sparse_tensor.iterator<#sparse, lvls = 1>) -> f32 { +// CHECK: %[[VAL_2:.*]] = sparse_tensor.extract_value %[[VAL_0]] at %[[VAL_1]] : tensor<4x8xf32, #sparse>, !sparse_tensor.iterator<#sparse, lvls = 1> +// CHECK: return %[[VAL_2]] : f32 +// CHECK: } +func.func @sparse_extract_value(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 1>) -> f32 { + %f = sparse_tensor.extract_value %sp at %it1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 1> + return %f : f32 +} + + // ----- #COO = #sparse_tensor.encoding<{