Skip to content

Commit

Permalink
Merge pull request #16802 from gnecula:ser_versions
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 549839838
  • Loading branch information
jax authors committed Jul 21, 2023
2 parents 3d556b7 + 71e2d28 commit 8f9a34a
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 4 deletions.
2 changes: 1 addition & 1 deletion jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ def update_thread_local_jit_state(**kw):
'The version number to use for native serialization. This must be '
'within the range of versions supported by the tf.XlaCallModule '
'used in your deployment environment. '
'See https://github.com/search?q=repo%3Atensorflow%2Ftensorflow+path%3Axla_call_module+%22int+VERSION_MAXIMUM_SUPPORTED%22&type=code.'
'See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions.'
)
)

Expand Down
67 changes: 66 additions & 1 deletion jax/experimental/jax2tf/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ invoked with actual argument `arg`:

* `arg.shape[0] >= 1`
* `arg.shape[1] == arg.shape[0]`
* `arg.shape[2] % 2 == 0` and `arg.shape[0] // 2 >= 1`
* `arg.shape[2] % 2 == 0` and `arg.shape[2] // 2 >= 1`

When using native serialization these are checked by the `tf.XlaCallModule`
op (starting with serialization
Expand Down Expand Up @@ -738,6 +738,71 @@ polymorphic_shapes = ["a, 2*a, b"]
polymorphic_shapes = ["a * a, a"]
```

## Native serialization versions

We use a serialization version number to help evolve the serialization
mechanism while allowing serialized artifacts to be used by consumers built
at different code versions.

If consumers use the `tf.XlaCallModule` op, e.g. when using the TensorFlow
SavedModel, then they support a range of serialization versions.
See [tf.XlaCallModule code](https://github.com/search?q=repo%3Atensorflow%2Ftensorflow+path%3Axla_call_module+%22int+kVersionMaximumSupported%22&type=code).
There is also an API to get the maximum version number supported by your
installed version of TensorFlow:

```
from tensorflow.compiler.tf2xla.python import xla as tfxla
tfxla.call_module_maximum_supported_version()
```

For **backward compatibility**, we want to allow a freshly built consumer
to load artifacts that have been serialized in the past 6 months
(by a serializer using the latest version supported at the time). Thus,
the minimum supported version number should match the maximum supported
version number from 6 months in the past.

The serialization version used by JAX is determined by the
`--jax_serialization_version` flag, or if missing, the
`JAX_SERIALIZATION_VERSION` environment variable. The default value is
specified in the [`config.py` file](https://github.com/search?q=repo%3Agoogle%2Fjax+path%3Aconfig.py+JAX_SERIALIZATION_VERSION&type=code).

For **forward compatibility**, we want freshly serialized artifacts to be
loadable by consumers that have been built in the last 1 month.
Thus, we bump the default serialization version
number about 1 month after the `tf.XlaCallModule` is upgraded to a
given version number.

You can use `--jax_serialization_version` to adjust the serialization version
to your deployed consumer. We reserve the right to remove support for
generating or consuming old serialization versions, e.g., older than 6 months.


## Serialization version numbers

We list here a history of the serialization version numbers:

* Version 1 used MHLO & CHLO to serialize the code, not supported anymore.
* Version 2 supports StableHLO & CHLO. Used from October 2022. Not supported
anymore.
* Version 3 supports platform checking and multiple platforms.
Used from February 2023. Not supported anymore.
* Version 4 supports StableHLO with compatibility guarantees.
This is the earliest version at the time of the JAX native serialization
launch.
Used in JAX from March 15, 2023 (cl/516885716). Starting with
March 28th, 2023 we stopped using `dim_args_spec` (cl/520033493).
* Version 5 adds support for `call_tf_graph`. This is currently used
for some specialized use cases. Used in JAX from May 3rd, 2023
(cl/529106145).
* Version 6 adds support for the `disabled_checks` attribute. This version
mandates a non-empty `platforms` attribute.
Used in JAX since June 13th, 2023 (JAX 0.4.13).
* Version 7 adds support for `stablehlo.shape_assertion` operations and
for `shape_assertions` specified in `disabled_checks`.
See [Errors in presence of shape polymorphism](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism).
Used in JAX serialization since July 20th, 2023 (JAX 0.4.14).


## Known issues

`jax2tf` has been in use since 2020 and the vast majority of users encounter
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/jax2tf/jax_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class Exported:
lowering_platform: one of 'tpu', 'cpu', 'cuda', 'rocm'
mlir_module_serialized: the serialized lowered VHLO module.
xla_call_module_version: a version number for the serialized module.
See more versioning details at https://github.com/search?q=repo%3Atensorflow%2Ftensorflow+path%3Axla_call_module+%22int+VERSION_MAXIMUM_SUPPORTED%22&type=code
See more versioning details at https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions.
module_kept_var_idx: the sorted indices of the arguments among `in_avals` that
must be passed to the module. The other arguments have been dropped
because they are not used. Same length as `in_shardings`.
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/jax2tf/shape_poly.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ def eq(self, other: DimSize) -> bool:
def inconclusive_comparison(self, operation: str, op: Any) -> Exception:
return InconclusiveDimensionOperation(
f"Symbolic dimension comparison '{self}' {operation} '{op}' is inconclusive.\n"
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#comparison-of-symbolic0dimensions-is-partially-supported.")
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#comparison-of-symbolic-dimensions-is-partially-supported.")

def ge(self, other: DimSize) -> bool:
lb, ub = _ensure_poly(self - other, "ge").bounds()
Expand Down

0 comments on commit 8f9a34a

Please sign in to comment.