Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a bunch of rules for scalar and non-differentiable functions #90

Merged
merged 75 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
75 commits
Select commit Hold shift + click to select a range
8261a8a
Add rule for `stablehlo.min`
mofeing Jun 1, 2024
edfa5b7
Mark non-differentiable StableHLO ops
mofeing Jun 1, 2024
0bc17ba
Mark non-differentiable ops from CHLO
mofeing Jun 1, 2024
8b2c7ea
More diff rules
mofeing Jun 1, 2024
e9278fc
Add rule for `chlo.conj`
mofeing Jun 1, 2024
37c2c70
Format code
mofeing Jun 1, 2024
2f74175
Fix typos
mofeing Jun 1, 2024
2a69429
Unmark bit operators as inactive
mofeing Jun 1, 2024
ac571d0
Refactor diff rule of `min`
mofeing Jun 1, 2024
6fb3415
Fix type in `max`/`min`
mofeing Jun 1, 2024
15d1cc3
Unmark bit-shift ops as inactive
mofeing Jun 1, 2024
1e9ec05
Unmark `clz` as inactive
mofeing Jun 1, 2024
c7b9d9b
Move CHLO ops to new file
mofeing Jun 1, 2024
fcdaf4d
Add `ConstantOp`
mofeing Jun 1, 2024
0da91d4
Move `conj` diff rule to `CHLODerivatives.td`
mofeing Jun 1, 2024
7e3fa1f
Build CHLODerivatives.td
mofeing Jun 1, 2024
f53ef6c
Move `einsum` tests to new directory
mofeing Jun 5, 2024
781ca0f
Add forward-rule for `atan2`
mofeing Jun 5, 2024
768d5a2
Add forward-rule of `cosine`
mofeing Jun 5, 2024
1a026d0
Prototype tests
mofeing Jun 5, 2024
969eabe
Rename file
mofeing Jun 5, 2024
959724b
Import StableHLO ops to CHLO for diff rules
mofeing Jun 5, 2024
2ab2491
Fix `TorchIndexSelecOp` reference
mofeing Jun 5, 2024
ec32707
Add forward rules of `add`,`subtract`
mofeing Jun 5, 2024
a73002e
Fix some test prototypes
mofeing Jun 6, 2024
0619fbd
Test reverse AD
mofeing Jun 6, 2024
9c23a29
Test forward diff rules
mofeing Jun 6, 2024
2ca817e
Fix and test `reverse`
mofeing Jun 6, 2024
198d484
Move `stablehlo` ops diff rules test to new directory
mofeing Jun 6, 2024
6a815f5
Move `grad`,`convolution` diff tests
mofeing Jun 6, 2024
b623409
Temporary fix for CI
mofeing Jun 8, 2024
df957da
Remove tempory fix
mofeing Jun 8, 2024
d991a4f
Fix incorrent number of arg activity states in `select` tests
mofeing Jun 8, 2024
352c23b
Add diffrules for `digamma`,`polygamma` ops
mofeing Jun 9, 2024
f88e7e1
Add diffrule for `fft`
mofeing Jun 9, 2024
1db10ce
Add forward test to `fft`
mofeing Jun 10, 2024
72d24d6
Prototype tests for some CHLO ops
mofeing Jun 10, 2024
83edf46
Test forward diff rules of `fft`
mofeing Jun 19, 2024
f91acf7
Add multiplier to FFT
mofeing Jun 19, 2024
8dcbbf6
Update derivative rules for CHLO
mofeing Jun 19, 2024
7261b64
Register CHLO dialect
mofeing Jun 19, 2024
c863ec9
Fix CHLO dialect registration
mofeing Jun 20, 2024
dd15ca8
Include StableHLO ops in CHLO autodiff interface
mofeing Jun 20, 2024
c35a00b
Fix some CHLO rules
mofeing Jun 20, 2024
c8bca95
Test `stablehlo.multiply` on complex data
mofeing Jun 20, 2024
2c261cb
Remove wrong forward diff rule of `chlo.acos`
mofeing Jun 20, 2024
b345c62
Format code
mofeing Jun 20, 2024
aede53b
Fix inferred type annotation in CHLO tests
mofeing Jun 20, 2024
f7831c5
Fix typo
mofeing Jun 20, 2024
1ca0bc2
Fix `stablehlo.tanh` rule and test
mofeing Jun 20, 2024
d28a106
Fix `stablehlo.dot_general` tests
mofeing Jun 20, 2024
7579a50
Revert square with `mul` in `Tanh` diff rule
mofeing Jun 20, 2024
774dbd8
Fix `stablehlo.tanh` test
mofeing Jun 20, 2024
4ffc886
Fix reverse diff test of `stablehlo.select`
mofeing Jun 20, 2024
8d9c2a5
Fix reverse diff test of `stablehlo.log`
mofeing Jun 20, 2024
f124414
Try to remove corrupt file automatically
mofeing Jun 20, 2024
489d0ea
Fix `stablehlo.dot_general` test on forward diff with batching dim
mofeing Jun 20, 2024
6d0d50a
Fix `chlo.conj` forward diff test
mofeing Jun 20, 2024
29c58e9
Replace squaring for multiplication in CHLO diff rules
mofeing Jun 20, 2024
2ac7025
Minor fix on `stablehlo.dot_general`
mofeing Jun 20, 2024
e8418f5
Minor fix on `stablehlo.dot_general`
mofeing Jun 20, 2024
8d6fa76
fixup
wsmoses Jun 21, 2024
af7a011
Format code
mofeing Jun 21, 2024
048ec86
Test reverse diff of `stablehlo.complex`
mofeing Jun 21, 2024
20ccb25
Fix diff tests of `stablehlo.einsum` on complex data
mofeing Jun 21, 2024
c28ef6b
Fix `stablehlo.dot_general` reverse diff rule test on batch
mofeing Jun 21, 2024
f3ac252
Test `stablehlo.dot_general` diff rules on complex data
mofeing Jun 21, 2024
25762cf
Fix passes in `stablehlo.dot_general` test
mofeing Jun 21, 2024
413d603
Load StableHLO on CHLO loading
mofeing Jun 21, 2024
3c418a9
Test diff rules of CHLO scalar ops
mofeing Jun 21, 2024
a572461
Test `chlo.conj` reverse diff rule
mofeing Jun 21, 2024
e193586
Fix passes in `chlo.conj` test
mofeing Jun 21, 2024
e589cca
Try fix `stablehlo.fft` tests
mofeing Jun 21, 2024
4ce0546
Fix text prefix
mofeing Jun 21, 2024
ed850c9
Disable RFFT,IRFFT tests
mofeing Jun 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ jobs:
path: "~/.cache/bazel"
key: bazel-${{ matrix.os }}
- run: sudo find ~/.cache/bazel ~/.cache/bazelisk -iname "*.whl" -exec rm {} \;
- run: sudo find ~/.cache/bazel -name "A-server.jar" -exec rm -rf $(dirname {}) \;
- run: |
bazel build :enzyme_ad @llvm-project//llvm:FileCheck
bazel cquery "allpaths(//src/enzyme_ad/jax:enzyme_call,@xla//xla/stream_executor:executor_cache)" --notool_deps
Expand Down
1 change: 0 additions & 1 deletion WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ http_archive(
urls = ["https://github.com/EnzymeAD/Enzyme/archive/{commit}.tar.gz".format(commit = ENZYME_COMMIT)],
)


load("@xla//third_party/llvm:workspace.bzl", llvm = "repo")
llvm("llvm-raw")
load("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure")
Expand Down
17 changes: 17 additions & 0 deletions src/enzyme_ad/jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,22 @@ gentbl_cc_library(
],
)

gentbl_cc_library(
name = "chlo-derivatives",
tbl_outs = [(
["-gen-mlir-derivatives"],
"Implementations/CHLODerivatives.inc",
)],
tblgen = "@enzyme//:enzyme-tblgen",
td_file = "Implementations/CHLODerivatives.td",
td_srcs = [
"Implementations/CHLODerivatives.td",
],
deps = [
":EnzymeImplementationsCommonTdFiles",
],
)

td_library(
name = "EnzymeXLAPassesTdFiles",
srcs = [
Expand Down Expand Up @@ -228,6 +244,7 @@ cc_library(
":EnzyeHLOPatternsIncGen",
":mhlo-derivatives",
":stablehlo-derivatives",
":chlo-derivatives",
"@enzyme//:EnzymeMLIR",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
Expand Down
52 changes: 52 additions & 0 deletions src/enzyme_ad/jax/Implementations/CHLOAutoDiffOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
//===- CHLOAutoDiffOpInterfaceImpl.cpp - Interface external model --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains the external model implementation of the automatic
// differentiation op interfaces for the upstream MLIR arithmetic dialect.
//
//===----------------------------------------------------------------------===//

#include "Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h"
#include "Enzyme/MLIR/Interfaces/AutoDiffOpInterface.h"
#include "Enzyme/MLIR/Interfaces/GradientUtils.h"
#include "Enzyme/MLIR/Interfaces/GradientUtilsReverse.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/Support/LogicalResult.h"

#include "Dialect/Ops.h"
#include "mlir/IR/TypeSupport.h"

#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"

#include "src/enzyme_ad/jax/Implementations/XLADerivatives.h"

using namespace mlir;
using namespace mlir::enzyme;
using namespace mlir::chlo;
using namespace mlir::stablehlo;

static int64_t to_i64(int64_t x) { return x; }
static int64_t to_i64(llvm::APInt x) { return x.getSExtValue(); }

static mlir::DenseI64ArrayAttr getI64Attr(OpBuilder &builder,
llvm::ArrayRef<int64_t> vals) {
return builder.getDenseI64ArrayAttr(vals);
}

namespace {
#include "src/enzyme_ad/jax/Implementations/CHLODerivatives.inc"
} // namespace

void mlir::enzyme::registerCHLODialectAutoDiffInterface(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *context, chlo::ChloDialect *) {
registerInterfaces(context);
context->loadDialect<stablehlo::StablehloDialect>();
});
}
144 changes: 144 additions & 0 deletions src/enzyme_ad/jax/Implementations/CHLODerivatives.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
include "src/enzyme_ad/jax/Implementations/Common.td"

class HLODerivative<string opName_, dag patternToMatch, list<dag> resultOps, dag forwardOps=(ForwardFromSummedReverse)> : MLIRDerivative<"chlo", opName_, patternToMatch, resultOps, forwardOps>;

class HLOInst<string m, string postopt=""> : Inst<m, "chlo", postopt>;

class HLOMemoryIdentityOp<string opName_, list<int> ptrargs_, list<int> storedargs_ = [], dag patternToMatch=(Unimplemented), list<dag> reverse_ = []> : MemoryIdentityOp<"chlo", opName_, ptrargs_, storedargs_, patternToMatch, reverse_>;

class HLOReadOnlyIdentityOp<string opName_, list<int> ptrargs_ = [0], dag patternToMatch=(Unimplemented), list<dag> reverse_ = []> : ReadOnlyIdentityOp<"chlo", opName_, ptrargs_, patternToMatch, reverse_>;

class HLOControlFlowOp<string opName_, string impl_> : ControlFlowOp<"chlo", opName_, impl_>;

class HLOConstantFP<string m> : ConstantFP<m, "chlo", "ConstantOp", "mlir::ElementsAttr">;

class HLORegionTerminatorOp<string m> : RegionTerminatorOp<"chlo", m>;

class HLOInactiveOp<string m> : InactiveOp<"chlo", m>;

// Required operations from the StableHLO dialect
def Add : Inst<"AddOp", "stablehlo">;
def Sub : Inst<"SubtractOp", "stablehlo">;
def Mul : Inst<"MulOp", "stablehlo">;
def Div : Inst<"DivOp", "stablehlo">;
def Neg : Inst<"NegOp", "stablehlo">;
def Sqrt : Inst<"SqrtOp", "stablehlo">;
def Pow : Inst<"PowOp", "stablehlo">;
def Cos : Inst<"CosineOp", "stablehlo">;
def Sin : Inst<"SineOp", "stablehlo">;

// Operations
/// CHLO - binary elementwise operations
def BroadcastAdd : HLOInst<"BroadcastAddOp">;
def BroadcastAtan2 : HLOInst<"BroadcastAtan2Op">;
def BroadcastDiv : HLOInst<"BroadcastDivOp">;
def BroadcastMax : HLOInst<"BroadcastMaxOp">;
def BroadcastMin : HLOInst<"BroadcastMinOp">;
def BroadcastMul : HLOInst<"BroadcastMulOp">;
def BroadcastNextAfter : HLOInst<"BroadcastNextAfterOp">;
def BroadcastPolygamma : HLOInst<"BroadcastPolygammaOp">;
def BroadcastPow : HLOInst<"BroadcastPowOp">;
def BroadcastRem : HLOInst<"BroadcastRemOp">;
def BroadcastShiftLeft : HLOInst<"BroadcastShiftLeftOp">;
def BroadcastShiftRightArithmetic : HLOInst<"BroadcastShiftRightArithmeticOp">;
def BroadcastShiftRightLogical : HLOInst<"BroadcastShiftRightLogicalOp">;
def BroadcastSub : HLOInst<"BroadcastSubOp">;
def BroadcastZeta : HLOInst<"BroadcastZetaOp">;

/// CHLO - binary logical elementwise operations
def BroadcastAnd : HLOInst<"BroadcastAndOp">;
def BroadcastOr : HLOInst<"BroadcastOrOp">;
def BroadcastXor : HLOInst<"BroadcastXorOp">;

/// CHLO - non-broadcasting binary operations
def NextAfter : HLOInst<"NextAfterOp">;
def Polygamma : HLOInst<"PolygammaOp">;
def Zeta : HLOInst<"ZetaOp">;

/// CHLO - complex broadcasting operation
def BroadcastComplex : HLOInst<"BroadcastComplexOp">;

/// CHLO - unary elementwise operations
def Acos : HLOInst<"AcosOp">;
def Acosh : HLOInst<"AcoshOp">;
def Asin : HLOInst<"AsinOp">;
def Asinh : HLOInst<"AsinhOp">;
def Atan : HLOInst<"AtanOp">;
def Atanh : HLOInst<"AtanhOp">;
def BesselI1e : HLOInst<"BesselI1eOp">;
def Conj : HLOInst<"ConjOp">;
def Cosh : HLOInst<"CoshOp">;
def Sinh : HLOInst<"SinhOp">;
def Tan : HLOInst<"TanOp">;
def Constant : HLOInst<"ConstantOp">;
def ConstantLike : HLOInst<"ConstantLikeOp">;
def Digamma : HLOInst<"DigammaOp">;
def Erf : HLOInst<"ErfOp">;
def ErfInv : HLOInst<"ErfInvOp">;
def Erfc : HLOInst<"ErfcOp">;
def IsInf : HLOInst<"IsInfOp">;
def IsNegInf : HLOInst<"IsNegInfOp">;
def IsPosInf : HLOInst<"IsPosInfOp">;
def Lgamma : HLOInst<"LgammaOp">;

/// CHLO - broadcasting compare operation
def BroadcastCompare : HLOInst<"BroadcastCompareOp">;

/// CHLO - broadcasting select operation
def BroadcastSelect : HLOInst<"BroadcastSelectOp">;

/// CHLO - miscelaneous operations
def TopK : HLOInst<"TopKOp">;

// Derivative rules
def : HLODerivative<"AcosOp", (Op $x), [
(Neg (Div (DiffeRet), (Sqrt (Sub (HLOConstantFP<"1"> $x), (Mul $x, $x)))))
]>;

def : HLODerivative<"AcoshOp", (Op $x), [
(Div (DiffeRet), (Mul (Sqrt (Sub (Mul $x, $x), (HLOConstantFP<"1"> $x))), (Sqrt (Add (Mul $x, $x), (HLOConstantFP<"1"> $x)))))
]>;

def : HLODerivative<"AsinOp", (Op $x), [
(Div (DiffeRet), (Sqrt (Sub (HLOConstantFP<"1"> $x), (Mul $x, $x))))
]>;

def : HLODerivative<"AsinhOp", (Op $x), [
(Div (DiffeRet), (Sqrt (Add (HLOConstantFP<"1"> $x), (Mul $x, $x))))
]>;

def : HLODerivative<"AtanOp", (Op $x), [
(Div (DiffeRet), (Add (Mul $x, $x), (HLOConstantFP<"1"> $x)))
]>;

def : HLODerivative<"AtanhOp", (Op $x), [
(Div (DiffeRet), (Sub (Mul $x, $x), (HLOConstantFP<"1"> $x)))
]>;

def : HLODerivative<"ConjOp", (Op $z), [(Conj (DiffeRet))], (Conj (Shadow $z))>;

def : HLODerivative<"CoshOp", (Op $x), [(Mul (DiffeRet), (Sinh $x))]>;

def : HLODerivative<"DigammaOp", (Op $x),
[(Mul (DiffeRet), (Polygamma $x, (HLOConstantFP<"2"> $x)))],
(Mul (Shadow $x), (Polygamma $x, (HLOConstantFP<"2"> $x)))
>;

def : HLOInactiveOp<"IsInfOp">;

def : HLOInactiveOp<"IsNegInfOp">;

def : HLOInactiveOp<"IsPosInfOp">;

def : HLODerivative<"PolygammaOp", (Op $x, $n),
[
(Mul (DiffeRet), (Polygamma $x, (Add $n, (HLOConstantFP<"1"> $n)))),
(AssertingInactiveArg)
]
>;

def : HLODerivative<"SinhOp", (Op $x), [(Mul (DiffeRet), (Cosh $x))]>;

def : HLODerivative<"TanOp", (Op $x), [
(Div (DiffeRet), (Mul (Cos $x), (Cos $x)))
]>;
Loading
Loading