Skip to content

Commit

Permalink
Add support for passing array attributes via ffi_call
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm committed Sep 30, 2024
1 parent ff1c2ac commit d133a1e
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 5 deletions.
4 changes: 4 additions & 0 deletions examples/ffi/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,7 @@ find_package(nanobind CONFIG REQUIRED)
nanobind_add_module(_rms_norm NB_STATIC "src/jax_ffi_example/rms_norm.cc")
target_include_directories(_rms_norm PUBLIC ${XLA_DIR})
install(TARGETS _rms_norm LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})

nanobind_add_module(_attrs NB_STATIC "src/jax_ffi_example/attrs.cc")
target_include_directories(_attrs PUBLIC ${XLA_DIR})
install(TARGETS _attrs LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
66 changes: 66 additions & 0 deletions examples/ffi/src/jax_ffi_example/attrs.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <cstdint>

#include "nanobind/nanobind.h"
#include "xla/ffi/api/ffi.h"

namespace nb = nanobind;
namespace ffi = xla::ffi;

ffi::Error ArrayAttrImpl(ffi::Span<const int32_t> array,
ffi::Result<ffi::BufferR0<ffi::S32>> res) {
int64_t total = 0;
for (int32_t x : array) {
total += x;
}
res->typed_data()[0] = total;
return ffi::Error::Success();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(ArrayAttr, ArrayAttrImpl,
ffi::Ffi::Bind()
.Attr<ffi::Span<const int32_t>>("array")
.Ret<ffi::BufferR0<ffi::S32>>());

ffi::Error DictionaryAttrImpl(ffi::Dictionary attrs,
ffi::Result<ffi::BufferR0<ffi::S32>> secret,
ffi::Result<ffi::BufferR0<ffi::S32>> count) {
auto maybe_secret = attrs.get<int64_t>("secret");
if (maybe_secret.has_error()) {
return maybe_secret.error();
}
secret->typed_data()[0] = maybe_secret.value();
count->typed_data()[0] = attrs.size();
return ffi::Error::Success();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(DictionaryAttr, DictionaryAttrImpl,
ffi::Ffi::Bind()
.Attrs()
.Ret<ffi::BufferR0<ffi::S32>>()
.Ret<ffi::BufferR0<ffi::S32>>());

NB_MODULE(_attrs, m) {
m.def("registrations", []() {
nb::dict registrations;
registrations["array_attr"] =
nb::capsule(reinterpret_cast<void *>(ArrayAttr));
registrations["dictionary_attr"] =
nb::capsule(reinterpret_cast<void *>(DictionaryAttr));
return registrations;
});
}
47 changes: 47 additions & 0 deletions examples/ffi/src/jax_ffi_example/attrs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""An example demonstrating the different ways that attributes can be passed to
the FFI.
For example, we can pass arrays, variadic attributes, and user-defined types.
Full support of user-defined types isn't yet supported by XLA, so that example
will be added in the future.
"""

import numpy as np

import jax
import jax.extend as jex

from jax_ffi_example import _attrs

for name, target in _attrs.registrations().items():
jex.ffi.register_ffi_target(name, target)


def array_attr(num: int):
return jex.ffi.ffi_call(
"array_attr",
jax.ShapeDtypeStruct((), np.int32),
array=np.arange(num, dtype=np.int32),
)


def dictionary_attr(**kwargs):
return jex.ffi.ffi_call(
"dictionary_attr",
(jax.ShapeDtypeStruct((), np.int32), jax.ShapeDtypeStruct((), np.int32)),
**kwargs,
)
58 changes: 58 additions & 0 deletions examples/ffi/tests/attrs_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest

import jax
import jax.numpy as jnp
from jax._src import test_util as jtu

from jax_ffi_example import attrs

jax.config.parse_flags_with_absl()


class AttrsTests(jtu.JaxTestCase):
def test_array_attr(self):
self.assertEqual(attrs.array_attr(5), jnp.arange(5).sum())
self.assertEqual(attrs.array_attr(3), jnp.arange(3).sum())

def test_array_attr_jit_cache(self):
jit_array_attr = jax.jit(attrs.array_attr, static_argnums=(0,))
with jtu.count_jit_and_pmap_lowerings() as count:
jit_array_attr(5)
self.assertEqual(count[0], 1) # compiles once the first time
with jtu.count_jit_and_pmap_lowerings() as count:
jit_array_attr(5)
self.assertEqual(count[0], 0) # cache hit
self.assertNotIn("_HashableByObjectId", jit_array_attr.lower(5).as_text())

def test_dictionary_attr(self):
secret, count = attrs.dictionary_attr(secret=5)
self.assertEqual(secret, 5)
self.assertEqual(count, 1)

secret, count = attrs.dictionary_attr(secret=3, a_string="hello")
self.assertEqual(secret, 3)
self.assertEqual(count, 2)

with self.assertRaisesRegex(Exception, "Unexpected attribute"):
attrs.dictionary_attr()

with self.assertRaisesRegex(Exception, "Wrong attribute type"):
attrs.dictionary_attr(secret="invalid")


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())
20 changes: 18 additions & 2 deletions jax/_src/extend/ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
import os
from typing import Any

from jax._src import api
from jax._src import core
from jax._src import dispatch
from jax._src import effects
from jax._src import util
from jax._src.callback import _check_shape_dtype, callback_batching_rule
Expand Down Expand Up @@ -256,6 +256,22 @@ def ffi_call(
return results[0]


def ffi_call_impl(
*args,
result_avals: tuple[core.AbstractValue, ...],
target_name: str,
vectorized: bool,
**kwargs: Any,
):
@api.jit
def impl(*args):
return ffi_call_p.bind(
*args, result_avals=result_avals, target_name=target_name,
vectorized=vectorized, **kwargs)

return impl(*args)


class FfiEffect(effects.Effect):
def __str__(self):
return "FFI"
Expand Down Expand Up @@ -308,7 +324,7 @@ def ffi_call_lowering(

ffi_call_p = core.Primitive("ffi_call")
ffi_call_p.multiple_results = True
ffi_call_p.def_impl(functools.partial(dispatch.apply_primitive, ffi_call_p))
ffi_call_p.def_impl(ffi_call_impl)
ffi_call_p.def_effectful_abstract_eval(ffi_call_abstract_eval)
ad.primitive_jvps[ffi_call_p] = ffi_call_jvp
ad.primitive_transposes[ffi_call_p] = ffi_call_transpose
Expand Down
26 changes: 23 additions & 3 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,12 @@ def ir_constant(val: Any) -> IrValues:
raise TypeError(f"No constant handler for type: {type(val)}")

def _numpy_array_constant(x: np.ndarray | np.generic) -> IrValues:
attr = _numpy_array_attribute(x)
element_type = dtype_to_ir_type(x.dtype)
shape = x.shape
if x.dtype == np.bool_:
x = np.packbits(x, bitorder='little') # type: ignore
x = np.ascontiguousarray(x)
attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape) # type: ignore
return hlo.constant(attr)


Expand Down Expand Up @@ -359,13 +364,26 @@ def _numpy_scalar_attribute(val: Any) -> ir.Attribute:
else:
raise TypeError(f"Unsupported scalar attribute type: {type(val)}")

_dtype_to_array_attr: dict[Any, AttributeHandler] = {
np.dtype(np.bool_): ir.DenseBoolArrayAttr.get,
np.dtype(np.float32): ir.DenseF32ArrayAttr.get,
np.dtype(np.float64): ir.DenseF64ArrayAttr.get,
np.dtype(np.int32): ir.DenseI32ArrayAttr.get,
np.dtype(np.int64): ir.DenseI64ArrayAttr.get,
np.dtype(np.int8): ir.DenseI8ArrayAttr.get,
}

def _numpy_array_attribute(x: np.ndarray | np.generic) -> ir.Attribute:
element_type = dtype_to_ir_type(x.dtype)
shape = x.shape
if x.dtype == np.bool_:
x = np.packbits(x, bitorder='little') # type: ignore
x = np.ascontiguousarray(x)
return ir.DenseElementsAttr.get(x, type=element_type, shape=shape) # type: ignore
builder = _dtype_to_array_attr.get(x.dtype, None)
if builder:
return builder(x)
else:
element_type = dtype_to_ir_type(x.dtype)
return ir.DenseElementsAttr.get(x, type=element_type, shape=shape) # type: ignore

def _numpy_array_attribute_handler(val: np.ndarray | np.generic) -> ir.Attribute:
if 0 in val.strides and val.size > 0:
Expand Down Expand Up @@ -407,6 +425,8 @@ def _sequence_attribute_handler(val: Sequence[Any]) -> ir.Attribute:

register_attribute_handler(list, _sequence_attribute_handler)
register_attribute_handler(tuple, _sequence_attribute_handler)
register_attribute_handler(ir.Attribute, lambda x: x)
register_attribute_handler(ir.Type, lambda x: x)

def ir_attribute(val: Any) -> ir.Attribute:
"""Convert a Python value to an MLIR attribute."""
Expand Down

0 comments on commit d133a1e

Please sign in to comment.