diff --git a/src/enzyme_ad/jax/__init__.py b/src/enzyme_ad/jax/__init__.py index 6a5b459e..4f40ee0b 100644 --- a/src/enzyme_ad/jax/__init__.py +++ b/src/enzyme_ad/jax/__init__.py @@ -6,4 +6,5 @@ JaXPipeline, optimize_module, export, + hlo_opts ) diff --git a/src/enzyme_ad/jax/enzyme_call.cc b/src/enzyme_ad/jax/enzyme_call.cc index 08b484eb..c7d3689e 100644 --- a/src/enzyme_ad/jax/enzyme_call.cc +++ b/src/enzyme_ad/jax/enzyme_call.cc @@ -78,6 +78,8 @@ class CpuKernel { uint64_t addr; public: + static constexpr size_t UNKNOWN_PLATFORM = 0x1000000000; + CpuKernel(int64_t identifier, size_t num_out, uint64_t addr) : identifier(identifier), num_out(num_out), addr(addr) {} @@ -893,7 +895,10 @@ class CpuKernel { llvm::ArrayRef out_names, llvm::ArrayRef> in_shapes, llvm::ArrayRef in_names, PyObject *pyargv, ABI mode, - Language lang, bool xla_runtime, const std::string &pass_pipeline) { + Language lang, bool xla_runtime, const std::string &pass_pipeline, + const std::string &platform) { + if (platform != "cpu") + return std::make_tuple(UNKNOWN_PLATFORM, 0); llvm::sys::SmartScopedWriter lock(kernel_mutex); size_t identifier = last_identifier++; @@ -993,10 +998,13 @@ std::unique_ptr CpuKernel::JIT = nullptr; // CpuKernel::ES(std::move(*llvm::orc::SelfExecutorProcessControl::Create())); } // namespace -void CpuCallback(void *out, void **ins) { +void Callback(void *out, void **ins) { int64_t identifier = *reinterpret_cast(ins[0]); CpuKernel *kernel = CpuKernel::get(identifier); if (!kernel) { + if (identifier == CpuKernel::UNKNOWN_PLATFORM) { + throw pybind11::value_error("Unknown platform callback could not be executed"); + } // TODO: find a way to fail more gracefully. llvm::report_fatal_error("couldn't find enzyme kernel"); } @@ -1047,12 +1055,13 @@ PYBIND11_MODULE(enzyme_call, m) { .value("Reverse", ABI::Reverse) .value("Tape", ABI::Tape); - m.def("create_enzyme_cpu_kernel", + m.def("create_enzyme_kernel", [](const std::string &source, const std::string &fn, const pybind11::list &py_out_shapes, const pybind11::list &py_in_shapes, pybind11::object pyargv, ABI mode, Language lang, bool xla_runtime, - const std::string &pass_pipeline) -> std::tuple { + const std::string &pass_pipeline, + const std::string &platform) -> std::tuple { llvm::SmallVector> out_shapes; out_shapes.reserve(pybind11::len(py_out_shapes)); llvm::SmallVector> in_shapes; @@ -1088,7 +1097,7 @@ PYBIND11_MODULE(enzyme_call, m) { } return CpuKernel::create(fn, source, out_shapes, out_types, in_shapes, in_types, pyargv.ptr(), mode, (Language)lang, - xla_runtime, pass_pipeline); + xla_runtime, pass_pipeline, platform); }); m.def("tmp_size", @@ -1193,8 +1202,8 @@ PYBIND11_MODULE(enzyme_call, m) { pyargv.ptr(), (Language)lang, xla_runtime, pass_pipeline); }); - m.def("get_cpu_callback", []() { - return pybind11::capsule(reinterpret_cast(&CpuCallback), + m.def("get_callback", []() { + return pybind11::capsule(reinterpret_cast(&Callback), "xla._CUSTOM_CALL_TARGET"); }); diff --git a/src/enzyme_ad/jax/primitives.py b/src/enzyme_ad/jax/primitives.py index ca0a659e..d2e9679e 100644 --- a/src/enzyme_ad/jax/primitives.py +++ b/src/enzyme_ad/jax/primitives.py @@ -573,17 +573,13 @@ def absmaketup(ty): def lower(fn, vals, parameters=None): if hasattr(fn, "trace"): - if parameters is not None: - return fn.trace(*vals).lower(_private_parameters=parameters) - else: - return fn.trace(*vals).lower() + return fn.trace(*vals).lower(_private_parameters=parameters) else: if parameters is not None: return fn.lower(*vals, _experimental_lowering_parameters=parameters) else: return fn.lower(*vals) - def _enzyme_aug_abstract_eval( *args_flat: jax.core.ShapedArray, source, @@ -829,7 +825,8 @@ def _enzyme_primal_lowering( print(out_shapes, "\n", results, "\n", nmod) assert len(results) == len(out_shapes) else: - identifier, tmpBuf = enzyme_call.create_enzyme_cpu_kernel( + assert len(ctx.module_context.platforms) == 1 + identifier, tmpBuf = enzyme_call.create_enzyme_kernel( source, fn, out_shapes, @@ -839,6 +836,7 @@ def _enzyme_primal_lowering( lang, pipeline_options.xla_runtime(), pass_pipeline, + ctx.module_context.platforms[0] ) identifier_attr = jax_mlir.dense_int_elements([identifier]) identifier_op = stablehlo.ConstantOp(identifier_attr) @@ -874,7 +872,8 @@ def _enzyme_primal_lowering( results = tuple(results2) else: - identifier, tmpBuf = enzyme_call.create_enzyme_cpu_kernel( + assert len(ctx.module_context.platforms) == 1 + identifier, tmpBuf = enzyme_call.create_enzyme_kernel( source, fn, out_shapes, @@ -884,6 +883,7 @@ def _enzyme_primal_lowering( lang, pipeline_options.xla_runtime(), pass_pipeline, + ctx.module_context.platforms[0] ) identifier_attr = jax_mlir.dense_int_elements([identifier]) identifier_op = stablehlo.ConstantOp(identifier_attr) @@ -940,7 +940,8 @@ def _enzyme_fwd_lowering( in_shapes = [shape for (i, shape) in enumerate(in_shapes) if i in kept] argv = argv + ("-resource-dir", resource_dir()) + cflags() - identifier, tmpBuf = enzyme_call.create_enzyme_cpu_kernel( + assert len(ctx.module_context.platforms) == 1 + identifier, tmpBuf = enzyme_call.create_enzyme_kernel( source, fn, out_shapes, @@ -950,6 +951,7 @@ def _enzyme_fwd_lowering( lang, pipeline_options.xla_runtime(), pipeline_options.pass_pipeline(), + ctx.module_context.platforms[0] ) identifier_attr = jax_mlir.dense_int_elements([identifier]) identifier_op = stablehlo.ConstantOp(identifier_attr) @@ -1004,7 +1006,8 @@ def _enzyme_aug_lowering( in_shapes = [shape for (i, shape) in enumerate(in_shapes) if i in kept] argv = argv + ("-resource-dir", resource_dir()) + cflags() - identifier, tmpBuf = enzyme_call.create_enzyme_cpu_kernel( + assert len(ctx.module_context.platforms) == 1 + identifier, tmpBuf = enzyme_call.create_enzyme_kernel( source, fn, out_shapes, @@ -1014,6 +1017,7 @@ def _enzyme_aug_lowering( lang, pipeline_options.xla_runtime(), pipeline_options.pass_pipeline(), + ctx.module_context.platforms[0] ) identifier_attr = jax_mlir.dense_int_elements([identifier]) identifier_op = stablehlo.ConstantOp(identifier_attr) @@ -1075,7 +1079,8 @@ def _enzyme_rev_lowering( ) argv = tuple(argv) + ("-resource-dir", resource_dir()) + cflags() - identifier, tmpBuf = enzyme_call.create_enzyme_cpu_kernel( + assert len(ctx.module_context.platforms) == 1 + identifier, tmpBuf = enzyme_call.create_enzyme_kernel( source, fn, out_shapes, @@ -1085,6 +1090,7 @@ def _enzyme_rev_lowering( lang, pipeline_options.xla_runtime(), pipeline_options.pass_pipeline(), + ctx.module_context.platforms[0] ) identifier_attr = jax_mlir.dense_int_elements([identifier]) identifier_op = stablehlo.ConstantOp(identifier_attr) @@ -1160,7 +1166,7 @@ def cpp_call( jax_mlir.register_lowering(_enzyme_primal_p, _enzyme_primal_lowering) xla_client.register_custom_call_target( - "jaxzyme.primal", enzyme_call.get_cpu_callback(), platform="cpu" + "jaxzyme.primal", enzyme_call.get_callback() ) _enzyme_fwd_p = jax.core.Primitive("enzyme_fwd") @@ -1170,7 +1176,7 @@ def cpp_call( jax_mlir.register_lowering(_enzyme_fwd_p, _enzyme_fwd_lowering) xla_client.register_custom_call_target( - "jaxzyme.fwd", enzyme_call.get_cpu_callback(), platform="cpu" + "jaxzyme.fwd", enzyme_call.get_callback() ) @@ -1282,7 +1288,16 @@ def dejaxify(x): jax_mlir.register_lowering(_enzyme_aug_p, _enzyme_aug_lowering) xla_client.register_custom_call_target( - "jaxzyme.aug", enzyme_call.get_cpu_callback(), platform="cpu" + "jaxzyme.aug", enzyme_call.get_callback(), platform="cpu" +) +xla_client.register_custom_call_target( + "jaxzyme.aug", enzyme_call.get_callback(), platform="CUDA" +) +xla_client.register_custom_call_target( + "jaxzyme.aug", enzyme_call.get_callback(), platform="ROCM" +) +xla_client.register_custom_call_target( + "jaxzyme.aug", enzyme_call.get_callback(), platform="tpu" ) _enzyme_shadow_aug_p = jax.core.Primitive("enzyme_shadow_aug") @@ -1297,7 +1312,16 @@ def dejaxify(x): jax_mlir.register_lowering(_enzyme_rev_p, _enzyme_rev_lowering) xla_client.register_custom_call_target( - "jaxzyme.rev", enzyme_call.get_cpu_callback(), platform="cpu" + "jaxzyme.rev", enzyme_call.get_callback(), platform="cpu" +) +xla_client.register_custom_call_target( + "jaxzyme.rev", enzyme_call.get_callback(), platform="CUDA" +) +xla_client.register_custom_call_target( + "jaxzyme.rev", enzyme_call.get_callback(), platform="ROCM" +) +xla_client.register_custom_call_target( + "jaxzyme.rev", enzyme_call.get_callback(), platform="tpu" ) diff --git a/test/BUILD b/test/BUILD index d162e27e..6558078e 100644 --- a/test/BUILD +++ b/test/BUILD @@ -104,7 +104,9 @@ py_test( name = "bench_vs_xla", srcs = [ "bench_vs_xla.py", + "test_utils.py" ], + imports = ["."], deps = TEST_DEPS, ) @@ -112,6 +114,9 @@ py_test( name = "llama", srcs = [ "llama.py", + "test_utils.py", ], + imports = ["."], deps = TEST_DEPS, + timeout='long' ) diff --git a/test/bench_vs_xla.py b/test/bench_vs_xla.py index d08cc284..943945ac 100644 --- a/test/bench_vs_xla.py +++ b/test/bench_vs_xla.py @@ -3,242 +3,7 @@ from enzyme_ad.jax import enzyme_jax_ir, NewXLAPipeline, OldXLAPipeline, JaXPipeline from absl.testing import absltest import timeit - -argv = ("-I/usr/include/c++/11", "-I/usr/include/x86_64-linux-gnu/c++/11") -number = 1000 - -devices = [] -if jax.default_backend() != "cpu": - devices = [jax.default_backend()] - -AllBackends = ["cpu"] + devices - -AllPipelines = [ - ("JaX", None, AllBackends), - ("JaXPipeline", JaXPipeline(), AllBackends), - # ("NewXLAMLIR", NewXLAPipeline(mlirad=True)), - # ("NewXLA", NewXLAPipeline()), - ("OldXLA", OldXLAPipeline(), ["cpu"]), -] -PrimalPipelines = AllPipelines[:] -FwdPipelines = AllPipelines[:-1] -RevPipelines = AllPipelines[:-1] - - -def no_newxla(x): - return [(name, a, b) for (name, a, b) in x if name != "NewXLAMLIR" and name != "NewXLA"] - - -def no_newxlamlir(x): - return [(name, a, b) for (name, a, b) in x if name != "NewXLAMLIR"] - -def nomlir(x): - return [ - (name, a, b) - for (name, a, b) in x - if name != "NewXLAMLIR" and name != "NewXLA" # and name != "OldXLA" - ] - -def justjax(x): - return [ - (name, a, b) - for (name, a, b) in x - if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA" - ] - - -# @jax.jit -# def fwd_jax(in0, in1, din0, din1): -# . return jax.jvp(add_one_jax, (in0, in1), (din0, din1)) -def splatjvp(in_fn): - def fwd(*args): - assert len(args) % 2 == 0 - return jax.jvp( - in_fn, tuple(args[: len(args) // 2]), tuple(args[len(args) // 2 :]) - ) - - return fwd - - -# @jax.jit -# def rev_jax(dout, in0, in1): -# primals, f_vjp = jax.vjp(add_one_jax, in0, in1) -# grads = f_vjp(dout) -# return primals, grads -def splatvjp(in_fn): - def rev(dout, *args): - primals, f_vjp = jax.vjp(in_fn, *args) - grads = f_vjp(dout) - return primals, grads - - return rev - -def to_backend(x, backend): - return jax.device_put(x, jax.local_devices(backend=backend)[0]) - -class EnzymeJaxTest(absltest.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.primfilter = lambda x: x - self.fwdfilter = lambda x: x - self.revfilter = lambda x: x - - def setUp(self): - self.name = None - - def test(self): - if self.name is None: - return - self.harness(self.name, self.fn, self.ins, self.dins, self.douts) - - def harness(self, name, in_fn, ins, dins, douts): - assert len(ins) == len(dins) - - assert 1 == len(douts) - - primalstr = "fn(" + (", ".join(["in" + str(i) for i in range(len(ins))])) + ")" - - fwdstr = ( - "fwd(" - + (", ".join(["in" + str(i) for i in range(len(ins))])) - + ", " - + (", ".join(["din" + str(i) for i in range(len(dins))])) - + ")" - ) - - revstr = ( - "rev(dout, " + (", ".join(["in" + str(i) for i in range(len(ins))])) + ")" - ) - - for backend in AllBackends: - ins_backend = [to_backend(x, backend) for x in ins] - dins_backend = [to_backend(x, backend) for x in dins] - douts_backend = [to_backend(x, backend) for x in douts] - - primalins = {("in" + str(i)): ins_backend[i] for i in range(len(ins))} - fwdins = primalins | {("din" + str(i)): dins_backend[i] for i in range(len(dins))} - revins = primalins | {"dout": douts_backend[0]} - - primres = None - - for (pname, pipeline, pbackends) in self.primfilter(PrimalPipelines): - if backend in pbackends: - rfn_enzyme = jax.jit( - in_fn if pipeline is None else enzyme_jax_ir(pipeline_options=pipeline, argv=argv)(in_fn), - #backend=backend - ) - ao = rfn_enzyme(*ins_backend) - if primres is None: - primres = ao - else: - self.assertTrue((jnp.abs(ao - primres) < 1e-6).all()) - - print( - name, - ",", - pname, - ",", - backend, - ",", - "Primal,", - timeit.Timer( - primalstr, - globals={ - "fn": rfn_enzyme, - } - | primalins, - ).timeit(number) - / number, - ) - - assert primres is not None - fwdres = None - - for (pname, pipeline, pbackends) in self.fwdfilter(FwdPipelines): - if backend in pbackends: - rfn_enzyme = in_fn if pipeline is None else jax.jit( - enzyme_jax_ir(pipeline_options=pipeline, argv=argv)(in_fn), - #backend=backend - ) - fwd_enzyme = jax.jit(splatjvp(rfn_enzyme), - #backend=backend - ) - - primals, tangents = fwd_enzyme(*(ins_backend + dins_backend)) - - self.assertTrue((jnp.abs(primals - primres) < 1e-6).all()) - - if fwdres is None: - fwdres = tangents - else: - if len(tangents.shape) == 0: - self.assertTrue((jnp.abs(tangents - fwdres) < 1e-6).all()) - else: - for t, t_p in zip(tangents, fwdres): - self.assertTrue((jnp.abs(t - t_p) < 1e-6).all()) - - print( - name, - ",", - pname, - ",", - backend, - ",", - "Fwd", - ",", - timeit.Timer( - fwdstr, - globals={ - "fwd": fwd_enzyme, - } - | fwdins, - ).timeit(number) - / number, - ) - - assert fwdres is not None - - revres = None - - for (pname, pipeline, pbackends) in self.revfilter(RevPipelines): - if backend in pbackends: - rfn_enzyme = in_fn if pipeline is None else jax.jit( - enzyme_jax_ir(pipeline_options=pipeline, argv=argv)(in_fn), - #backend=backend - ) - rev_enzyme = jax.jit(splatvjp(rfn_enzyme), - #backend=backend - ) - - primals, grads = rev_enzyme(*douts_backend, *ins_backend) - self.assertTrue((jnp.abs(primals - primres) < 1e-6).all()) - - if revres is None: - revres = grads - else: - for i, (g, g_p) in enumerate(zip(grads, revres)): - self.assertTrue((jnp.abs(g - g_p) < 1e-6).all()) - - print( - name, - ",", - pname, - ",", - backend, - ",", - "Rev", - ",", - timeit.Timer( - revstr, - globals={ - "rev": rev_enzyme, - } - | revins, - ).timeit(number) - / number, - ) - assert revres is not None - +from test_utils import * class AddOne(EnzymeJaxTest): def setUp(self): @@ -263,7 +28,6 @@ def add_one(x, y): self.name = "add_one" - class AddTwo(EnzymeJaxTest): def setUp(self): self.ins = [ @@ -295,7 +59,7 @@ def sum(x): return jnp.sum(x) self.fn = sum - self.name = "sum" + self.name = "sum " class Cache(EnzymeJaxTest): @@ -360,7 +124,7 @@ def f(x): return kcl self.fn = f - self.name = "activitymismatch" + self.name = "actmtch" class GenDot(EnzymeJaxTest): @@ -374,13 +138,6 @@ def setUp(self): ) ] - def nomlir(x): - return [ - (name, a) - for (name, a) in x - if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA" - ] - self.primfilter = no_newxla self.fwdfilter = no_newxla # No new xla runs but gets wrong answer @@ -417,7 +174,8 @@ def setUp(self): ] self.douts = [jnp.array([i * i for i in range(2 * dim)], dtype=jnp.float32)] - self.revfilter = nomlir + self.revfilter = justjax + # self.revfilter = nomlir def f(x, y): return jnp.concat([x, y], axis=None) @@ -469,6 +227,5 @@ def f(x, y): self.assertTrue((jnp.abs(res[1] - to_backend(prevres[1], backend)) < 1e-6).all()) - if __name__ == "__main__": absltest.main() diff --git a/test/llama.py b/test/llama.py index 204a8f16..b0041c1d 100644 --- a/test/llama.py +++ b/test/llama.py @@ -3,8 +3,10 @@ import jax.random import jax.lax import enzyme_ad.jax as enzyme_jax +from enzyme_ad.jax import enzyme_jax_ir, NewXLAPipeline, OldXLAPipeline, JaXPipeline, hlo_opts import numpy as np import timeit +from test_utils import * argv = ("-I/usr/include/c++/11", "-I/usr/include/x86_64-linux-gnu/c++/11") @@ -239,200 +241,7 @@ def forward(x, config, weights, key_cache, value_cache): return x - -class Llama(absltest.TestCase): - def test_llama_random(self): - config = { - "dim": 288, - "hidden_dim": 768, - "n_layers": 6, - "n_heads": 6, - "n_kv_heads": 6, - "vocab_size": 32000, - "seq_len": 256, - } - - n_layers = config["n_layers"] - seq_len = config["seq_len"] - n_heads = config["n_heads"] - dim = config["dim"] - n_kv_heads = config["n_kv_heads"] - vocab_size = config["vocab_size"] - hidden_dim = config["hidden_dim"] - kv_dim = dim // n_heads * n_kv_heads - head_size = dim // n_heads - - key = jax.random.PRNGKey(0) - weights = {} - dweights = {} - - for name, shape in [ - ("rms_att_weight", (n_layers, dim)), - ("wq", (n_layers, dim, n_heads * head_size)), - ("wk", (n_layers, dim, n_kv_heads * head_size)), - ("wv", (n_layers, dim, n_kv_heads * head_size)), - ("wo", (n_layers, dim, dim)), - ("rms_ffn_weight", (n_layers, dim)), - ("w1", (n_layers, hidden_dim, dim)), - ("w2", (n_layers, dim, hidden_dim)), - ("w3", (n_layers, hidden_dim, dim)), - ("rms_final_weight", (dim,)), - ("wcls", (vocab_size, dim)), - ]: - key, subkey = jax.random.split(key) - key, subkey2 = jax.random.split(key) - weights[name] = jax.random.uniform(subkey, shape=shape) - dweights[name] = jax.random.uniform(subkey2, shape=shape) - - key, subkey = jax.random.split(key) - x = jax.random.uniform(subkey, shape=(dim,)) - key, subkey = jax.random.split(key) - dx = jax.random.uniform(subkey, shape=(dim,)) - - def partial(func, config): - def sfn(x, weights, key_cache, value_cache): - return func(x, config, weights, key_cache, value_cache) - - return sfn - - pos = 1 - key_cache = jnp.zeros((n_layers, pos, kv_dim)) - value_cache = jnp.zeros((n_layers, pos, kv_dim)) - - key, subkey = jax.random.split(key) - dkc = jax.random.uniform(subkey, shape=(n_layers, pos + 1, kv_dim)) - key, subkey = jax.random.split(key) - dvc = jax.random.uniform(subkey, shape=(n_layers, pos + 1, kv_dim)) - - func = partial(forward, config) - - jfunc = jax.jit(func) - - efunc = jax.jit( - enzyme_jax.enzyme_jax_ir(argv=argv, pipeline_options=pipeline)(func) - ) - - number = 1000 - if False: - eres = efunc(x, weights, key_cache, value_cache) - print("Enzyme primal", eres) - res = jfunc(x, weights, key_cache, value_cache) - print("Jax primal", res) - print(" max error", jnp.max(jnp.abs(eres - res))) - assert (jnp.abs(eres - res) < 1e-3).all() - - print( - "Enzyme primal", - timeit.Timer( - "efunc(x, weights, key_cache, value_cache)", - globals={ - "efunc": efunc, - "x": x, - "weights": weights, - "key_cache": key_cache, - "value_cache": value_cache, - }, - ).timeit(number), - ) - print( - "JaX primal", - timeit.Timer( - "jfunc(x, weights, key_cache, value_cache)", - globals={ - "jfunc": jfunc, - "x": x, - "weights": weights, - "key_cache": key_cache, - "value_cache": value_cache, - }, - ).timeit(number), - ) - # jfunc = jax.jit(partial(forward, config)) - # mlir = jax.jit(partial(forward, config)).lower(1, weights, key_cache, value_cache).compiler_ir(dialect="mhlo") - - if False: - - @jax.jit - def jfwd(x, dx, weights, dweights, kc, dkc, vc, dvc): - return jax.jvp(jfunc, (x, weights, kc, vc), (x, weights, dkc, dvc)) - - @jax.jit - def efwd(x, dx, weights, dweights, kc, dkc, vc, dvc): - return jax.jvp(efunc, (x, weights, kc, vc), (x, weights, dkc, dvc)) - - eres = efwd( - x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache - ) - print("Enzyme fwd", eres) - jres = jfwd( - x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache - ) - print("Jax fwd", jres) - print( - "Enzyme fwd", - timeit.Timer( - "efwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache)", - globals={ - "efwd": efwd, - "x": x, - "dx": dx, - "weights": weights, - "dweights": dweights, - "key_cache": key_cache, - "value_cache": value_cache, - }, - ).timeit(number), - ) - print( - "JaX fwd", - timeit.Timer( - "jfwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache)", - globals={ - "jfwd": jfwd, - "x": x, - "dx": dx, - "weights": weights, - "dweights": dweights, - "key_cache": key_cache, - "value_cache": value_cache, - }, - ).timeit(number), - ) - - @jax.jit - def jrev(x, weights, kc, vc, dx, dkc, dvc): - primals, f_vjp = jax.vjp(jfunc, x, weights, kc, vc) - return f_vjp(dx) # , dkc, dvc) - - @jax.jit - def erev(x, weights, kc, vc, dx, dkc, dvc): - primals, f_vjp = jax.vjp(efunc, x, weights, kc, vc) - return f_vjp(dx) # , dkc, dvc) - - eres = erev(x, weights, key_cache, value_cache, dx, dkc, dvc) - # print("Enzyme rev", eres) - jres = jrev(x, weights, key_cache, value_cache, dx, dkc, dvc) - # print("Jax rev", jres) - - jrev2 = jax.jit( - enzyme_jax.enzyme_jax_ir( - argv=argv, - pipeline_options=enzyme_jax.JaXPipeline( - "inline{default-pipeline=canonicalize max-iterations=4}," - + "canonicalize,cse,enzyme-hlo-opt,cse" - ), - )(jrev) - ) - - jres2 = jrev2(x, weights, key_cache, value_cache, dx, dkc, dvc) - # print("Jax2 rev", jres2) - - jrev3 = jax.jit( - enzyme_jax.enzyme_jax_ir( - argv=argv, - pipeline_options=enzyme_jax.JaXPipeline( - "inline{default-pipeline=canonicalize max-iterations=4}," - + """canonicalize,cse, +partialopt = "inline{default-pipeline=canonicalize max-iterations=4}," + """canonicalize,cse, enzyme-hlo-generate-td{ patterns=compare_op_canon<16>; transpose_transpose<16>; @@ -534,139 +343,104 @@ def erev(x, weights, kc, vc, dx, dkc, dvc): }, transform-interpreter, enzyme-hlo-remove-transform,cse""" - ), - )(jrev) - ) - unused = """ - - - -reshape_iota<16>; -slice_reshape_slice<1>; -dot_general_simplify<16>; -transpose_simplify<16>; -reshape_empty_broadcast<1>; -add_pad_pad_to_concat<1>; -broadcast_reshape<1>; - -slice_reshape_concat<1>; -slice_reshape_elementwise<1>; -slice_reshape_transpose<1>; -slice_reshape_dot_general<1>; -concat_pad<1>; -reduce_pad<1>; -broadcast_pad<1>; - -zero_product_reshape_pad<1>; -mul_zero_pad<1>; -div_zero_pad<1>; - -binop_const_reshape_pad<1>; -binop_const_pad_add<1>; -binop_const_pad_subtract<1>; -binop_const_pad_mul<1>; -binop_const_pad_div<1>; - -slice_reshape_pad<1>; -binop_binop_pad_pad_add<1>; -binop_binop_pad_pad_mul<1>; -binop_pad_pad_add<1>; -binop_pad_pad_subtract<1>; -binop_pad_pad_mul<1>; -binop_pad_pad_div<1>; -binop_pad_pad_min<1>; -binop_pad_pad_max<1>; - -unary_pad_push_convert<1>; -unary_pad_push_tanh<1>; -unary_pad_push_exp<1>; -transpose_pad<1>; - -transpose_dot_reorder<1>; -dot_transpose<1>; -convert_convert_float<1>; -concat_to_pad<1>; -concat_appending_reshape<1>; -reshape_iota<1>; -broadcast_reduce<1>; -slice_dot_general<1>; +pipelines = [ + ("JaX ", None, CurBackends), + ("JaXPipe", JaXPipeline(), CurBackends), + ("HLOOpt", JaXPipeline( + "inline{default-pipeline=canonicalize max-iterations=4}," + + "canonicalize,cse,enzyme-hlo-opt,cse" + ), CurBackends), + ("PartOpt", JaXPipeline(partialopt), CurBackends), + ("DefOpt", JaXPipeline(hlo_opts()), CurBackends), +] -dot_reshape_pad<1>; -pad_dot_general<1>(0); +class Llama(EnzymeJaxTest): + def setUp(self): + config = { + "dim": 288, + "hidden_dim": 768, + "n_layers": 6, + "n_heads": 6, + "n_kv_heads": 6, + "vocab_size": 32000, + "seq_len": 256, + } -dot_reshape_pad<1>; -pad_dot_general<1>(1); -""" - - jres3 = jrev3(x, weights, key_cache, value_cache, dx, dkc, dvc) - # print("Jax3 rev", jres3) - - print( - "Enzyme rev", - timeit.Timer( - "erev(x, weights, key_cache, value_cache, dx, dkc, dvc)", - globals={ - "erev": erev, - "x": x, - "weights": weights, - "key_cache": key_cache, - "value_cache": value_cache, - "dx": dx, - "dkc": dkc, - "dvc": dvc, - }, - ).timeit(number), - ) - print( - "JaX rev", - timeit.Timer( - "jrev(x, weights, key_cache, value_cache, dx, dkc, dvc)", - globals={ - "jrev": jrev, - "x": x, - "weights": weights, - "key_cache": key_cache, - "value_cache": value_cache, - "dx": dx, - "dkc": dkc, - "dvc": dvc, - }, - ).timeit(number), - ) - print( - "JaX2 rev", - timeit.Timer( - "jrev2(x, weights, key_cache, value_cache, dx, dkc, dvc)", - globals={ - "jrev2": jrev2, - "x": x, - "weights": weights, - "key_cache": key_cache, - "value_cache": value_cache, - "dx": dx, - "dkc": dkc, - "dvc": dvc, - }, - ).timeit(number), - ) - print( - "JaX3 rev", - timeit.Timer( - "jrev3(x, weights, key_cache, value_cache, dx, dkc, dvc)", - globals={ - "jrev3": jrev3, - "x": x, - "weights": weights, - "key_cache": key_cache, - "value_cache": value_cache, - "dx": dx, - "dkc": dkc, - "dvc": dvc, - }, - ).timeit(number), - ) + n_layers = config["n_layers"] + seq_len = config["seq_len"] + n_heads = config["n_heads"] + dim = config["dim"] + n_kv_heads = config["n_kv_heads"] + vocab_size = config["vocab_size"] + hidden_dim = config["hidden_dim"] + kv_dim = dim // n_heads * n_kv_heads + head_size = dim // n_heads + + key = jax.random.PRNGKey(0) + weights = {} + dweights = {} + + for name, shape in [ + ("rms_att_weight", (n_layers, dim)), + ("wq", (n_layers, dim, n_heads * head_size)), + ("wk", (n_layers, dim, n_kv_heads * head_size)), + ("wv", (n_layers, dim, n_kv_heads * head_size)), + ("wo", (n_layers, dim, dim)), + ("rms_ffn_weight", (n_layers, dim)), + ("w1", (n_layers, hidden_dim, dim)), + ("w2", (n_layers, dim, hidden_dim)), + ("w3", (n_layers, hidden_dim, dim)), + ("rms_final_weight", (dim,)), + ("wcls", (vocab_size, dim)), + ]: + key, subkey = jax.random.split(key) + key, subkey2 = jax.random.split(key) + weights[name] = jax.random.uniform(subkey, shape=shape) + dweights[name] = jax.random.uniform(subkey2, shape=shape) + + key, subkey = jax.random.split(key) + x = jax.random.uniform(subkey, shape=(dim,)) + key, subkey = jax.random.split(key) + dx = jax.random.uniform(subkey, shape=(dim,)) + + def partial(func, config): + def sfn(x, weights, key_cache, value_cache): + return func(x, config, weights, key_cache, value_cache) + + return sfn + + pos = 1 + key_cache = jnp.zeros((n_layers, pos, kv_dim)) + value_cache = jnp.zeros((n_layers, pos, kv_dim)) + + key, subkey = jax.random.split(key) + dkc = jax.random.uniform(subkey, shape=(n_layers, pos + 1, kv_dim)) + key, subkey = jax.random.split(key) + dvc = jax.random.uniform(subkey, shape=(n_layers, pos + 1, kv_dim)) + self.fn = partial(forward, config) + self.name = "llama" + self.count = 1000 + self.revprimal = False + self.AllPipelines = pipelines + self.AllBackends = CurBackends + + self.ins = [ + x, + weights, + key_cache, + value_cache + ] + self.dins = [ + dx, + weights, + key_cache, + value_cache + ] + self.douts = [ + dx + ] if __name__ == "__main__": absltest.main()