Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Aug 25, 2024
1 parent 4723ae2 commit 7848d78
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 592 deletions.
1 change: 1 addition & 0 deletions src/enzyme_ad/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@
JaXPipeline,
optimize_module,
export,
hlo_opts
)
23 changes: 16 additions & 7 deletions src/enzyme_ad/jax/enzyme_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ class CpuKernel {
uint64_t addr;

public:
static constexpr size_t UNKNOWN_PLATFORM = 0x1000000000;

CpuKernel(int64_t identifier, size_t num_out, uint64_t addr)
: identifier(identifier), num_out(num_out), addr(addr) {}

Expand Down Expand Up @@ -893,7 +895,10 @@ class CpuKernel {
llvm::ArrayRef<std::string> out_names,
llvm::ArrayRef<llvm::SmallVector<int64_t>> in_shapes,
llvm::ArrayRef<std::string> in_names, PyObject *pyargv, ABI mode,
Language lang, bool xla_runtime, const std::string &pass_pipeline) {
Language lang, bool xla_runtime, const std::string &pass_pipeline,
const std::string &platform) {
if (platform != "cpu")
return std::make_tuple(UNKNOWN_PLATFORM, 0);
llvm::sys::SmartScopedWriter<true> lock(kernel_mutex);
size_t identifier = last_identifier++;

Expand Down Expand Up @@ -993,10 +998,13 @@ std::unique_ptr<llvm::orc::LLJIT> CpuKernel::JIT = nullptr;
// CpuKernel::ES(std::move(*llvm::orc::SelfExecutorProcessControl::Create()));
} // namespace

void CpuCallback(void *out, void **ins) {
void Callback(void *out, void **ins) {
int64_t identifier = *reinterpret_cast<int64_t *>(ins[0]);
CpuKernel *kernel = CpuKernel::get(identifier);
if (!kernel) {
if (identifier == CpuKernel::UNKNOWN_PLATFORM) {
throw pybind11::value_error("Unknown platform callback could not be executed");
}
// TODO: find a way to fail more gracefully.
llvm::report_fatal_error("couldn't find enzyme kernel");
}
Expand Down Expand Up @@ -1047,12 +1055,13 @@ PYBIND11_MODULE(enzyme_call, m) {
.value("Reverse", ABI::Reverse)
.value("Tape", ABI::Tape);

m.def("create_enzyme_cpu_kernel",
m.def("create_enzyme_kernel",
[](const std::string &source, const std::string &fn,
const pybind11::list &py_out_shapes,
const pybind11::list &py_in_shapes, pybind11::object pyargv,
ABI mode, Language lang, bool xla_runtime,
const std::string &pass_pipeline) -> std::tuple<size_t, size_t> {
const std::string &pass_pipeline,
const std::string &platform) -> std::tuple<size_t, size_t> {
llvm::SmallVector<llvm::SmallVector<int64_t>> out_shapes;
out_shapes.reserve(pybind11::len(py_out_shapes));
llvm::SmallVector<llvm::SmallVector<int64_t>> in_shapes;
Expand Down Expand Up @@ -1088,7 +1097,7 @@ PYBIND11_MODULE(enzyme_call, m) {
}
return CpuKernel::create(fn, source, out_shapes, out_types, in_shapes,
in_types, pyargv.ptr(), mode, (Language)lang,
xla_runtime, pass_pipeline);
xla_runtime, pass_pipeline, platform);
});

m.def("tmp_size",
Expand Down Expand Up @@ -1193,8 +1202,8 @@ PYBIND11_MODULE(enzyme_call, m) {
pyargv.ptr(), (Language)lang, xla_runtime, pass_pipeline);
});

m.def("get_cpu_callback", []() {
return pybind11::capsule(reinterpret_cast<void *>(&CpuCallback),
m.def("get_callback", []() {
return pybind11::capsule(reinterpret_cast<void *>(&Callback),
"xla._CUSTOM_CALL_TARGET");
});

Expand Down
52 changes: 38 additions & 14 deletions src/enzyme_ad/jax/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,17 +573,13 @@ def absmaketup(ty):

def lower(fn, vals, parameters=None):
if hasattr(fn, "trace"):
if parameters is not None:
return fn.trace(*vals).lower(_private_parameters=parameters)
else:
return fn.trace(*vals).lower()
return fn.trace(*vals).lower(_private_parameters=parameters)
else:
if parameters is not None:
return fn.lower(*vals, _experimental_lowering_parameters=parameters)
else:
return fn.lower(*vals)


def _enzyme_aug_abstract_eval(
*args_flat: jax.core.ShapedArray,
source,
Expand Down Expand Up @@ -829,7 +825,8 @@ def _enzyme_primal_lowering(
print(out_shapes, "\n", results, "\n", nmod)
assert len(results) == len(out_shapes)
else:
identifier, tmpBuf = enzyme_call.create_enzyme_cpu_kernel(
assert len(ctx.module_context.platforms) == 1
identifier, tmpBuf = enzyme_call.create_enzyme_kernel(
source,
fn,
out_shapes,
Expand All @@ -839,6 +836,7 @@ def _enzyme_primal_lowering(
lang,
pipeline_options.xla_runtime(),
pass_pipeline,
ctx.module_context.platforms[0]
)
identifier_attr = jax_mlir.dense_int_elements([identifier])
identifier_op = stablehlo.ConstantOp(identifier_attr)
Expand Down Expand Up @@ -874,7 +872,8 @@ def _enzyme_primal_lowering(

results = tuple(results2)
else:
identifier, tmpBuf = enzyme_call.create_enzyme_cpu_kernel(
assert len(ctx.module_context.platforms) == 1
identifier, tmpBuf = enzyme_call.create_enzyme_kernel(
source,
fn,
out_shapes,
Expand All @@ -884,6 +883,7 @@ def _enzyme_primal_lowering(
lang,
pipeline_options.xla_runtime(),
pass_pipeline,
ctx.module_context.platforms[0]
)
identifier_attr = jax_mlir.dense_int_elements([identifier])
identifier_op = stablehlo.ConstantOp(identifier_attr)
Expand Down Expand Up @@ -940,7 +940,8 @@ def _enzyme_fwd_lowering(
in_shapes = [shape for (i, shape) in enumerate(in_shapes) if i in kept]

argv = argv + ("-resource-dir", resource_dir()) + cflags()
identifier, tmpBuf = enzyme_call.create_enzyme_cpu_kernel(
assert len(ctx.module_context.platforms) == 1
identifier, tmpBuf = enzyme_call.create_enzyme_kernel(
source,
fn,
out_shapes,
Expand All @@ -950,6 +951,7 @@ def _enzyme_fwd_lowering(
lang,
pipeline_options.xla_runtime(),
pipeline_options.pass_pipeline(),
ctx.module_context.platforms[0]
)
identifier_attr = jax_mlir.dense_int_elements([identifier])
identifier_op = stablehlo.ConstantOp(identifier_attr)
Expand Down Expand Up @@ -1004,7 +1006,8 @@ def _enzyme_aug_lowering(
in_shapes = [shape for (i, shape) in enumerate(in_shapes) if i in kept]

argv = argv + ("-resource-dir", resource_dir()) + cflags()
identifier, tmpBuf = enzyme_call.create_enzyme_cpu_kernel(
assert len(ctx.module_context.platforms) == 1
identifier, tmpBuf = enzyme_call.create_enzyme_kernel(
source,
fn,
out_shapes,
Expand All @@ -1014,6 +1017,7 @@ def _enzyme_aug_lowering(
lang,
pipeline_options.xla_runtime(),
pipeline_options.pass_pipeline(),
ctx.module_context.platforms[0]
)
identifier_attr = jax_mlir.dense_int_elements([identifier])
identifier_op = stablehlo.ConstantOp(identifier_attr)
Expand Down Expand Up @@ -1075,7 +1079,8 @@ def _enzyme_rev_lowering(
)

argv = tuple(argv) + ("-resource-dir", resource_dir()) + cflags()
identifier, tmpBuf = enzyme_call.create_enzyme_cpu_kernel(
assert len(ctx.module_context.platforms) == 1
identifier, tmpBuf = enzyme_call.create_enzyme_kernel(
source,
fn,
out_shapes,
Expand All @@ -1085,6 +1090,7 @@ def _enzyme_rev_lowering(
lang,
pipeline_options.xla_runtime(),
pipeline_options.pass_pipeline(),
ctx.module_context.platforms[0]
)
identifier_attr = jax_mlir.dense_int_elements([identifier])
identifier_op = stablehlo.ConstantOp(identifier_attr)
Expand Down Expand Up @@ -1160,7 +1166,7 @@ def cpp_call(
jax_mlir.register_lowering(_enzyme_primal_p, _enzyme_primal_lowering)

xla_client.register_custom_call_target(
"jaxzyme.primal", enzyme_call.get_cpu_callback(), platform="cpu"
"jaxzyme.primal", enzyme_call.get_callback()
)

_enzyme_fwd_p = jax.core.Primitive("enzyme_fwd")
Expand All @@ -1170,7 +1176,7 @@ def cpp_call(
jax_mlir.register_lowering(_enzyme_fwd_p, _enzyme_fwd_lowering)

xla_client.register_custom_call_target(
"jaxzyme.fwd", enzyme_call.get_cpu_callback(), platform="cpu"
"jaxzyme.fwd", enzyme_call.get_callback()
)


Expand Down Expand Up @@ -1282,7 +1288,16 @@ def dejaxify(x):
jax_mlir.register_lowering(_enzyme_aug_p, _enzyme_aug_lowering)

xla_client.register_custom_call_target(
"jaxzyme.aug", enzyme_call.get_cpu_callback(), platform="cpu"
"jaxzyme.aug", enzyme_call.get_callback(), platform="cpu"
)
xla_client.register_custom_call_target(
"jaxzyme.aug", enzyme_call.get_callback(), platform="CUDA"
)
xla_client.register_custom_call_target(
"jaxzyme.aug", enzyme_call.get_callback(), platform="ROCM"
)
xla_client.register_custom_call_target(
"jaxzyme.aug", enzyme_call.get_callback(), platform="tpu"
)

_enzyme_shadow_aug_p = jax.core.Primitive("enzyme_shadow_aug")
Expand All @@ -1297,7 +1312,16 @@ def dejaxify(x):
jax_mlir.register_lowering(_enzyme_rev_p, _enzyme_rev_lowering)

xla_client.register_custom_call_target(
"jaxzyme.rev", enzyme_call.get_cpu_callback(), platform="cpu"
"jaxzyme.rev", enzyme_call.get_callback(), platform="cpu"
)
xla_client.register_custom_call_target(
"jaxzyme.rev", enzyme_call.get_callback(), platform="CUDA"
)
xla_client.register_custom_call_target(
"jaxzyme.rev", enzyme_call.get_callback(), platform="ROCM"
)
xla_client.register_custom_call_target(
"jaxzyme.rev", enzyme_call.get_callback(), platform="tpu"
)


Expand Down
5 changes: 5 additions & 0 deletions test/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,19 @@ py_test(
name = "bench_vs_xla",
srcs = [
"bench_vs_xla.py",
"test_utils.py"
],
imports = ["."],
deps = TEST_DEPS,
)

py_test(
name = "llama",
srcs = [
"llama.py",
"test_utils.py",
],
imports = ["."],
deps = TEST_DEPS,
timeout='long'
)
Loading

0 comments on commit 7848d78

Please sign in to comment.