Skip to content

Commit

Permalink
[export] Add support for serialization and deserialization of Exported
Browse files Browse the repository at this point in the history
At the moment we can export a JAX function into an Exported and we can invoke an Exported from another JAX function, but there is no way to serialize an Exported to be able to use it in another process.

Exported now has all the features we had in mind, so it is a reasonable time to add a serialization method. The intention is for the serialization to have backwards compatibility guarantees, meaning that we can deserialize an Exported that has been serialized with an older version of JAX. This PR does not add explicit support for versioning, nor backwards compatibility tests. Those will follow.

Here we add serialization and deserialization to bytearray, using the flatbuffers package. We use flatbuffers because it is simple, it has backwards and forwards compatibility guarantees, it is a lightweight dependency that does not require additional build steps, and it is fast (deserialization simply indexes into the bytearray rather than creating a Python structure).

In the process of implementing this we have done some small cleanup of the Exported structure:

  * renamed serialization_version to mlir_module_serialization_version
  * renamed disabled_checks to disabled_safety_checks

This code is tested by changing export_test.py to interpose a serialization followed by a deserialization every time we export.export.

There is a known bug with the serialization of effects, so I disabled one of the export tests. Will fix in a subsequent PR.

PiperOrigin-RevId: 590078785
  • Loading branch information
gnecula authored and jax authors committed Dec 12, 2023
1 parent 4347950 commit b077483
Show file tree
Hide file tree
Showing 11 changed files with 1,504 additions and 78 deletions.
1 change: 1 addition & 0 deletions build/test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ absl-py
build
cloudpickle
colorama>=0.4.4
flatbuffers
hypothesis
numpy>=1.22
pillow>=9.1.0
Expand Down
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ sphinx-design
myst-nb>=1.0.0

# Packages used for CI tests.
flatbuffers
pytest
pytest-xdist

Expand Down
6 changes: 3 additions & 3 deletions jax/_src/internal_test_util/export_back_compat_test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def serialize(self,

module_str = str(exported.mlir_module())
serialized = exported.mlir_module_serialized
module_version = exported.serialization_version
module_version = exported.mlir_module_serialization_version
nr_devices = exported.nr_devices
return serialized, module_str, module_version, nr_devices

Expand Down Expand Up @@ -330,9 +330,9 @@ def _get_vjp(_):
lowering_platforms=(data.platform,),
ordered_effects=(),
unordered_effects=(),
disabled_checks=(),
disabled_safety_checks=(),
mlir_module_serialized=data.mlir_module_serialized,
serialization_version=data.xla_call_module_version,
mlir_module_serialization_version=data.xla_call_module_version,
nr_devices=data.nr_devices,
module_kept_var_idx=tuple(range(len(in_avals))),
uses_shape_polymorphism=any(not core.is_constant_shape(a.shape)
Expand Down
4 changes: 3 additions & 1 deletion jax/experimental/export/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ py_library(
name = "export",
srcs = [
"export.py",
"serialization.py",
"serialization_generated.py",
"shape_poly.py",
],
srcs_version = "PY3",
Expand All @@ -40,5 +42,5 @@ py_library(
visibility = ["//visibility:public"],
deps = [
"//jax",
] + py_deps("numpy"),
] + py_deps("numpy") + py_deps("flatbuffers"),
)
35 changes: 16 additions & 19 deletions jax/experimental/export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@
import functools
import itertools
import re
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, TypeVar, Union

from absl import logging

import numpy as np

import jax
Expand Down Expand Up @@ -156,7 +155,7 @@ class Exported:
unordered_effects: the unordered effects present in the serialized module.
This is present from serialization version 9.
mlir_module_serialized: the serialized lowered VHLO module.
serialization_version: a version number for the serialized module.
mlir_module_serialization_version: a version number for the serialized module.
See more versioning details at https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions.
module_kept_var_idx: the sorted indices of the arguments among `in_avals` that
must be passed to the module. The other arguments have been dropped
Expand All @@ -166,7 +165,7 @@ class Exported:
variables, or due to inner calls of Exported modules that have
dimension variables or platform index arguments. Such modules need
shape refinement before XLA compilation.
disabled_checks: a list of descriptors of safety checks that have been
disabled_safety_checks: a list of descriptors of safety checks that have been
disabled at export time. See docstring for `DisabledSafetyCheck`.
_get_vjp: an optional function that takes the current exported function and
returns the exported VJP function.
Expand Down Expand Up @@ -282,10 +281,10 @@ class Exported:
lowering_platforms: tuple[str, ...]
ordered_effects: tuple[effects.Effect, ...]
unordered_effects: tuple[effects.Effect, ...]
disabled_checks: Sequence[DisabledSafetyCheck]
disabled_safety_checks: Sequence[DisabledSafetyCheck]

mlir_module_serialized: bytes
serialization_version: int
mlir_module_serialization_version: int
module_kept_var_idx: tuple[int, ...]
uses_shape_polymorphism: bool

Expand All @@ -299,6 +298,9 @@ def __str__(self):
# do not want the entire serialized module to end up in locations.
return f"Exported(fun_name={self.fun_name}, ...)"

def has_vjp(self) -> bool:
return self._get_vjp is not None

def vjp(self) -> "Exported":
"""Gets the exported VJP.
Expand Down Expand Up @@ -496,7 +498,7 @@ def do_export(*args_specs, **kwargs_specs) -> Exported:
mlir_module_attrs["jax.uses_shape_polymorphism"] = (
mlir.ir.BoolAttr.get(shape_poly_state.uses_dim_vars))

mlir_module_serialized = _serialize_module(mlir_module)
mlir_module_serialized = _module_to_bytecode(mlir_module)

# Figure out the result types and shapes
if "global_out_avals" in lowering.compile_args:
Expand Down Expand Up @@ -554,17 +556,17 @@ def export_sharding(s: LoweringSharding,
lowering_platforms=actual_lowering_platforms,
ordered_effects=ordered_effects,
unordered_effects=unordered_effects,
disabled_checks=tuple(disabled_checks),
disabled_safety_checks=tuple(disabled_checks),
mlir_module_serialized=mlir_module_serialized,
module_kept_var_idx=module_kept_var_idx,
uses_shape_polymorphism=shape_poly_state.uses_dim_vars,
serialization_version=version, # type: ignore
mlir_module_serialization_version=version, # type: ignore
_get_vjp=lambda exported: _export_native_vjp(fun_jax, exported))

return do_export


def _serialize_module(module: ir.Module) -> bytes:
def _module_to_bytecode(module: ir.Module) -> bytes:
mlir_str = mlir.module_to_bytecode(module)
if hlo.get_api_version() < 4:
target_version = hlo.get_earliest_forward_compatible_version()
Expand Down Expand Up @@ -1042,9 +1044,9 @@ def _export_native_vjp(primal_fun, primal: Exported) -> Exported:
apply_jit=True)
return export(fun_vjp_jax,
lowering_platforms=primal.lowering_platforms,
disabled_checks=primal.disabled_checks)(*vjp_in_avals)
disabled_checks=primal.disabled_safety_checks)(*vjp_in_avals)

### Importing
### Calling the exported function

def call_exported(exported: Exported) -> Callable[..., jax.Array]:
if not isinstance(exported, Exported):
Expand Down Expand Up @@ -1215,7 +1217,7 @@ def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.Abstra
if platform in exported.lowering_platforms:
callee_lowering_platform_index.append(
exported.lowering_platforms.index(platform))
elif DisabledSafetyCheck.platform() in exported.disabled_checks:
elif DisabledSafetyCheck.platform() in exported.disabled_safety_checks:
callee_lowering_platform_index.append(0)
else:
raise ValueError(
Expand Down Expand Up @@ -1249,7 +1251,7 @@ def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.Abstra
else:
assert len(lowering_platforms) == 1

if _keep_main_tokens(exported.serialization_version):
if _keep_main_tokens(exported.mlir_module_serialization_version):
ordered_effects = exported.ordered_effects
else:
ordered_effects = ()
Expand Down Expand Up @@ -1282,11 +1284,6 @@ def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.Abstra
)
return results


# for _p in ("cpu", "tpu", "cuda", "rocm"):
# mlir.register_lowering(call_exported_p,
# functools.partial(_call_exported_lowering, platform=_p),
# platform=_p)
mlir.register_lowering(call_exported_p, _call_exported_lowering)

def wrap_with_sharding(ctx: mlir.LoweringRuleContext,
Expand Down
129 changes: 129 additions & 0 deletions jax/experimental/export/serialization.fbs
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
// Copyright 2023 The JAX Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// To regenerate the serialization_generated.py, install flatc (e.g.,
// from Homebrew) and then:
//
// 1. Run flatc --python --gen-onefile serialization.fbs
// 2. Delete the trailing newlines at the end
// 3. Add back the licence comment at the start
//

namespace jax.experimental.export.serialization;

enum PyTreeDefKind: byte {
leaf = 0,
none = 1,
tuple = 2,
list = 3,
dict = 4,
}

table PyTreeDef {
kind: PyTreeDefKind;
children: [PyTreeDef];
children_names: [string]; // only for "dict"
}

enum AbstractValueKind: byte {
shapedArray = 0,
abstractToken = 1,
}

enum DType: byte {
bool,
i8,
i16,
i32,
i64,
ui8,
ui16,
ui32,
ui64,
f16,
f32,
f64,
c64,
c128,

bf16,

i4,
ui4,

f8_e4m3b11fnuz,
f8_e4m3fn,
f8_e4m3fnuz,
f8_e5m2,
f8_e5m2fnuz,
}

table AbstractValue {
kind: AbstractValueKind;
shape: [string]; // we support shape polymorphism
dtype: DType;
}

enum ShardingKind: byte {
unspecified,
hlo_sharding,
}

table Sharding {
kind: ShardingKind;
hlo_sharding_proto: [byte];
}

table Effect {
type_name: string;
}

enum DisabledSafetyCheckKind: byte {
platform,
custom_call,
shape_assertions,
}

table DisabledSafetyCheck {
kind: DisabledSafetyCheckKind;
custom_call_target: string;
}

table Exported {
serialization_version: uint16;

function_name: string;
in_tree: PyTreeDef;
in_avals: [AbstractValue];
out_tree: PyTreeDef;
out_avals: [AbstractValue];
nr_devices: short;
in_shardings: [Sharding];
out_shardings: [Sharding];

lowering_platforms: [string];

ordered_effects: [Effect];
unordered_effects: [Effect];
disabled_checks: [DisabledSafetyCheck];

mlir_module_serialized: [byte];
mlir_module_serialization_version: uint16;
module_kept_var_idx: [uint16];
uses_shape_polymorphism: bool;

vjp: Exported;
}

root_type Exported;
Loading

0 comments on commit b077483

Please sign in to comment.