Skip to content

Commit

Permalink
Test stablehlo.multiply on complex data
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Jun 20, 2024
1 parent f24303c commit 1443121
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions test/lit_tests/diffrules/stablehlo/multiply.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,11 @@ func.func @main(%a : tensor<2xf32>, %b : tensor<2xf32>) -> tensor<2xf32> {
// REVERSE-NEXT: %4 = arith.addf %3, %cst : tensor<2xf32>
// REVERSE-NEXT: return %2, %4 : tensor<2xf32>, tensor<2xf32>
// REVERSE-NEXT: }

// RUN: enzymexlamlir-opt %s --enzyme-wrap="infn=mul_complex outfn= retTys=enzyme_dup argTys=enzyme_dup,enzyme_dup mode=ForwardMode" | FileCheck %s --check-prefix=FORWARD-COMPLEX
// RUN: enzymexlamlir-opt %s --enzyme-wrap="infn=mul_complex outfn= retTys=enzyme_active argTys=enzyme_active,enzyme_active mode=ReverseModeCombined" --canonicalize --remove-unnecessary-enzyme-ops | FileCheck %s --check-prefix=REVERSE-COMPLEX

func.func @mul_complex(%a : tensor<2xcomplex<f32>>, %b : tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> {
%c = stablehlo.multiply %a, %b : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>>
func.return %c : tensor<2xcomplex<f32>>
}

0 comments on commit 1443121

Please sign in to comment.