Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Aug 25, 2024
1 parent e8c6c16 commit 7f83dbb
Show file tree
Hide file tree
Showing 7 changed files with 542 additions and 61 deletions.
2 changes: 1 addition & 1 deletion src/enzyme_ad/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@
JaXPipeline,
optimize_module,
export,
hlo_opts
hlo_opts,
)
3 changes: 2 additions & 1 deletion src/enzyme_ad/jax/enzyme_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1003,7 +1003,8 @@ void Callback(void *out, void **ins) {
CpuKernel *kernel = CpuKernel::get(identifier);
if (!kernel) {
if (identifier == CpuKernel::UNKNOWN_PLATFORM) {
throw pybind11::value_error("Unknown platform callback could not be executed");
throw pybind11::value_error(
"Unknown platform callback could not be executed");
}
// TODO: find a way to fail more gracefully.
llvm::report_fatal_error("couldn't find enzyme kernel");
Expand Down
19 changes: 8 additions & 11 deletions src/enzyme_ad/jax/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,7 @@ def lower(fn, vals, parameters=None):
else:
return fn.lower(*vals)


def _enzyme_aug_abstract_eval(
*args_flat: jax.core.ShapedArray,
source,
Expand Down Expand Up @@ -836,7 +837,7 @@ def _enzyme_primal_lowering(
lang,
pipeline_options.xla_runtime(),
pass_pipeline,
ctx.module_context.platforms[0]
ctx.module_context.platforms[0],
)
identifier_attr = jax_mlir.dense_int_elements([identifier])
identifier_op = stablehlo.ConstantOp(identifier_attr)
Expand Down Expand Up @@ -883,7 +884,7 @@ def _enzyme_primal_lowering(
lang,
pipeline_options.xla_runtime(),
pass_pipeline,
ctx.module_context.platforms[0]
ctx.module_context.platforms[0],
)
identifier_attr = jax_mlir.dense_int_elements([identifier])
identifier_op = stablehlo.ConstantOp(identifier_attr)
Expand Down Expand Up @@ -951,7 +952,7 @@ def _enzyme_fwd_lowering(
lang,
pipeline_options.xla_runtime(),
pipeline_options.pass_pipeline(),
ctx.module_context.platforms[0]
ctx.module_context.platforms[0],
)
identifier_attr = jax_mlir.dense_int_elements([identifier])
identifier_op = stablehlo.ConstantOp(identifier_attr)
Expand Down Expand Up @@ -1017,7 +1018,7 @@ def _enzyme_aug_lowering(
lang,
pipeline_options.xla_runtime(),
pipeline_options.pass_pipeline(),
ctx.module_context.platforms[0]
ctx.module_context.platforms[0],
)
identifier_attr = jax_mlir.dense_int_elements([identifier])
identifier_op = stablehlo.ConstantOp(identifier_attr)
Expand Down Expand Up @@ -1090,7 +1091,7 @@ def _enzyme_rev_lowering(
lang,
pipeline_options.xla_runtime(),
pipeline_options.pass_pipeline(),
ctx.module_context.platforms[0]
ctx.module_context.platforms[0],
)
identifier_attr = jax_mlir.dense_int_elements([identifier])
identifier_op = stablehlo.ConstantOp(identifier_attr)
Expand Down Expand Up @@ -1165,19 +1166,15 @@ def cpp_call(
_enzyme_primal_p.def_abstract_eval(_enzyme_primal_abstract_eval)
jax_mlir.register_lowering(_enzyme_primal_p, _enzyme_primal_lowering)

xla_client.register_custom_call_target(
"jaxzyme.primal", enzyme_call.get_callback()
)
xla_client.register_custom_call_target("jaxzyme.primal", enzyme_call.get_callback())

_enzyme_fwd_p = jax.core.Primitive("enzyme_fwd")
_enzyme_fwd_p.multiple_results = True
_enzyme_fwd_p.def_impl(_enzyme_fwd_impl)
_enzyme_fwd_p.def_abstract_eval(_enzyme_fwd_abstract_eval)
jax_mlir.register_lowering(_enzyme_fwd_p, _enzyme_fwd_lowering)

xla_client.register_custom_call_target(
"jaxzyme.fwd", enzyme_call.get_callback()
)
xla_client.register_custom_call_target("jaxzyme.fwd", enzyme_call.get_callback())


def enzyme_jvp(arg_primals, arg_tangents, **kwargs):
Expand Down
43 changes: 37 additions & 6 deletions test/bench_vs_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import timeit
from test_utils import *


class AddOne(EnzymeJaxTest):
def setUp(self):
self.ins = [
Expand All @@ -28,6 +29,7 @@ def add_one(x, y):

self.name = "add_one"


class AddTwo(EnzymeJaxTest):
def setUp(self):
self.ins = [
Expand Down Expand Up @@ -207,8 +209,13 @@ def f(x, y):
)

g = jax.value_and_grad(
f if pipeline is None else jax.jit(enzyme_jax_ir(pipeline_options=pipeline, argv=argv)(f),
#backend=backend
(
f
if pipeline is None
else jax.jit(
enzyme_jax_ir(pipeline_options=pipeline, argv=argv)(f),
# backend=backend
)
),
has_aux=True,
allow_int=True,
Expand All @@ -221,11 +228,35 @@ def f(x, y):
name = "valueandgrad"
print(name + " JaX(", pname, "): ", prevres)
print(name + " EnzymeMLIR(", pname, "): ", res)
self.assertTrue((jnp.abs(res[0][0] - to_backend(prevres[0][0], backend)) < 1e-6).all())
self.assertTrue((jnp.abs(res[0][1][0] - to_backend(prevres[0][1][0], backend)) < 1e-6).all())
self.assertTrue((jnp.abs(res[0][1][1] - to_backend(prevres[0][1][1], backend)) < 1e-6).all())
self.assertTrue(
(
jnp.abs(res[0][0] - to_backend(prevres[0][0], backend))
< 1e-6
).all()
)
self.assertTrue(
(
jnp.abs(
res[0][1][0] - to_backend(prevres[0][1][0], backend)
)
< 1e-6
).all()
)
self.assertTrue(
(
jnp.abs(
res[0][1][1] - to_backend(prevres[0][1][1], backend)
)
< 1e-6
).all()
)

self.assertTrue(
(
jnp.abs(res[1] - to_backend(prevres[1], backend)) < 1e-6
).all()
)

self.assertTrue((jnp.abs(res[1] - to_backend(prevres[1], backend)) < 1e-6).all())

if __name__ == "__main__":
absltest.main()
48 changes: 26 additions & 22 deletions test/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
import jax.random
import jax.lax
import enzyme_ad.jax as enzyme_jax
from enzyme_ad.jax import enzyme_jax_ir, NewXLAPipeline, OldXLAPipeline, JaXPipeline, hlo_opts
from enzyme_ad.jax import (
enzyme_jax_ir,
NewXLAPipeline,
OldXLAPipeline,
JaXPipeline,
hlo_opts,
)
import numpy as np
import timeit
from test_utils import *
Expand Down Expand Up @@ -241,7 +247,10 @@ def forward(x, config, weights, key_cache, value_cache):

return x

partialopt = "inline{default-pipeline=canonicalize max-iterations=4}," + """canonicalize,cse,

partialopt = (
"inline{default-pipeline=canonicalize max-iterations=4},"
+ """canonicalize,cse,
enzyme-hlo-generate-td{
patterns=compare_op_canon<16>;
transpose_transpose<16>;
Expand Down Expand Up @@ -343,18 +352,24 @@ def forward(x, config, weights, key_cache, value_cache):
},
transform-interpreter,
enzyme-hlo-remove-transform,cse"""
)

pipelines = [
("JaX ", None, CurBackends),
("JaXPipe", JaXPipeline(), CurBackends),
("HLOOpt", JaXPipeline(
"inline{default-pipeline=canonicalize max-iterations=4},"
+ "canonicalize,cse,enzyme-hlo-opt,cse"
), CurBackends),
(
"HLOOpt",
JaXPipeline(
"inline{default-pipeline=canonicalize max-iterations=4},"
+ "canonicalize,cse,enzyme-hlo-opt,cse"
),
CurBackends,
),
("PartOpt", JaXPipeline(partialopt), CurBackends),
("DefOpt", JaXPipeline(hlo_opts()), CurBackends),
]


class Llama(EnzymeJaxTest):
def setUp(self):
config = {
Expand Down Expand Up @@ -425,22 +440,11 @@ def sfn(x, weights, key_cache, value_cache):
self.revprimal = False
self.AllPipelines = pipelines
self.AllBackends = CurBackends

self.ins = [
x,
weights,
key_cache,
value_cache
]
self.dins = [
dx,
weights,
key_cache,
value_cache
]
self.douts = [
dx
]

self.ins = [x, weights, key_cache, value_cache]
self.dins = [dx, weights, key_cache, value_cache]
self.douts = [dx]


if __name__ == "__main__":
absltest.main()
82 changes: 62 additions & 20 deletions test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import jax.numpy as jnp
from enzyme_ad.jax import cpp_call, enzyme_jax_ir, optimize_module

jax.config.update('jax_platform_name', 'cpu')
jax.config.update("jax_platform_name", "cpu")

argv = ("-I/usr/include/c++/11", "-I/usr/include/x86_64-linux-gnu/c++/11")

Expand Down Expand Up @@ -84,10 +84,17 @@ def do_something(ones):
self.assertTrue((primals[0] == 43).all())
self.assertTrue((primals[1] == 85).all())
self.assertTrue((primals[2][0] == 56).all())

self.assertTrue((grads[1] == jnp.array([
[128., 128., 128.],
])).all())

self.assertTrue(
(
grads[1]
== jnp.array(
[
[128.0, 128.0, 128.0],
]
)
).all()
)

def test_enzyme_mlir_jit(self):
@jax.jit
Expand All @@ -102,26 +109,61 @@ def add_one(x: jax.Array, y) -> jax.Array:
(jnp.array([1.0, 2.0, 3.0]), jnp.array([10.0, 20.0, 30.0])),
(jnp.array([0.1, 0.2, 0.3]), jnp.array([50.0, 70.0, 110.0])),
)
self.assertTrue((primals == jnp.array([
[12., 23., 34.],
])).all())
self.assertTrue((tangents == jnp.array([
[50.1, 70.2, 110.3],
])).all())
self.assertTrue(
(
primals
== jnp.array(
[
[12.0, 23.0, 34.0],
]
)
).all()
)
self.assertTrue(
(
tangents
== jnp.array(
[
[50.1, 70.2, 110.3],
]
)
).all()
)

primals, f_vjp = jax.vjp(
add_one, jnp.array([1.0, 2.0, 3.0]), jnp.array([10.0, 20.0, 30.0])
)
grads = f_vjp(jnp.array([500.0, 700.0, 110.0]))
self.assertTrue((primals == jnp.array([
[12., 23., 34.],
])).all())
self.assertTrue((grads[0] == jnp.array([
[500., 700., 110.],
])).all())
self.assertTrue((grads[1] == jnp.array([
[500., 700., 110.],
])).all())
self.assertTrue(
(
primals
== jnp.array(
[
[12.0, 23.0, 34.0],
]
)
).all()
)
self.assertTrue(
(
grads[0]
== jnp.array(
[
[500.0, 700.0, 110.0],
]
)
).all()
)
self.assertTrue(
(
grads[1]
== jnp.array(
[
[500.0, 700.0, 110.0],
]
)
).all()
)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 7f83dbb

Please sign in to comment.