Skip to content

Commit

Permalink
Merge pull request #16900 from gnecula:poly_dot2
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 552684253
  • Loading branch information
jax authors committed Aug 1, 2023
2 parents f049aee + 2eaf545 commit b8019dc
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 59 deletions.
32 changes: 14 additions & 18 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2754,32 +2754,28 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
handled = lambda dt: (dtypes.issubdtype(dt, np.floating) or
dtypes.issubdtype(dt, np.integer))
if not (handled(lhs_dtype) and handled(rhs_dtype)):
dt = mlir.dtype_to_ir_type(aval_out.dtype)
lhs = hlo.ConvertOp(ir.RankedTensorType.get(lhs_aval.shape, dt), lhs
).result
rhs = hlo.ConvertOp(ir.RankedTensorType.get(rhs_aval.shape, dt), rhs
).result
lhs = mlir.convert_hlo(ctx, lhs, lhs_aval,
core.ShapedArray(lhs_aval.shape, aval_out.dtype))
rhs = mlir.convert_hlo(ctx, rhs, rhs_aval,
core.ShapedArray(rhs_aval.shape, aval_out.dtype))
lhs_dtype = rhs_dtype = aval_out.dtype
else: # cpu and gpu
dt = mlir.dtype_to_ir_type(aval_out.dtype)
lhs = hlo.ConvertOp(ir.RankedTensorType.get(lhs_aval.shape, dt), lhs
).result
rhs = hlo.ConvertOp(ir.RankedTensorType.get(rhs_aval.shape, dt), rhs
).result
lhs = mlir.convert_hlo(ctx, lhs, lhs_aval,
core.ShapedArray(lhs_aval.shape, aval_out.dtype))
rhs = mlir.convert_hlo(ctx, rhs, rhs_aval,
core.ShapedArray(rhs_aval.shape, aval_out.dtype))
lhs_dtype = rhs_dtype = aval_out.dtype

# TODO(b/195364460): Work around slow XLA/CPU implementation of float16 matmul
if ctx.module_context.platform == "cpu":
if lhs_dtype == np.float16:
f32 = mlir.dtype_to_ir_type(np.dtype(np.float32))
lhs = hlo.ConvertOp(ir.RankedTensorType.get(lhs_aval.shape, f32),
lhs).result
lhs_dtype = np.dtype('float32')
lhs = mlir.convert_hlo(ctx, lhs, lhs_aval,
core.ShapedArray(lhs_aval.shape, np.float32))

if rhs_dtype == np.float16:
f32 = mlir.dtype_to_ir_type(np.dtype(np.float32))
rhs = hlo.ConvertOp(ir.RankedTensorType.get(rhs_aval.shape, f32),
rhs).result
rhs_dtype = np.dtype('float32')
rhs = mlir.convert_hlo(ctx, rhs, rhs_aval,
core.ShapedArray(rhs_aval.shape, np.float32))


dot_dnums = hlo.DotDimensionNumbers.get(
lhs_batching_dimensions=list(lhs_batch),
Expand Down
24 changes: 14 additions & 10 deletions jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Primitives with limited JAX support

*Last generated on: 2022-11-07* (YYYY-MM-DD)
*Last generated on: 2023-07-31* (YYYY-MM-DD)

## Supported data types for primitives

We use a set of 7308 test harnesses to test
the implementation of 130 numeric JAX primitives.
We use a set of 7554 test harnesses to test
the implementation of 133 numeric JAX primitives.
We consider a JAX primitive supported for a particular data
type if it is supported on at least one device type.
The following table shows the dtypes at which primitives
Expand Down Expand Up @@ -46,6 +46,7 @@ be updated.
| add | 16 | inexact, integer | bool |
| add_any | 14 | inexact, integer | bool |
| and | 11 | bool, integer | inexact |
| approx_top_k | 24 | floating | bool, complex, integer |
| argmax | 64 | bool, floating, integer | complex |
| argmin | 64 | bool, floating, integer | complex |
| asin | 6 | inexact | bool, integer |
Expand All @@ -64,7 +65,7 @@ be updated.
| complex | 4 | float32, float64 | bfloat16, bool, complex, float16, integer |
| concatenate | 17 | all | |
| conj | 5 | complex, float32, float64 | bfloat16, bool, float16, integer |
| conv_general_dilated | 114 | inexact, int16, int32, int8 | bool, int64, unsigned |
| conv_general_dilated | 132 | inexact, signed | bool, unsigned |
| convert_element_type | 201 | all | |
| cos | 6 | inexact | bool, integer |
| cosh | 6 | inexact | bool, integer |
Expand All @@ -77,7 +78,7 @@ be updated.
| device_put | 16 | all | |
| digamma | 4 | floating | bool, complex, integer |
| div | 20 | inexact, integer | bool |
| dot_general | 245 | all | |
| dot_general | 400 | all | |
| dynamic_slice | 68 | all | |
| dynamic_update_slice | 46 | all | |
| eig | 72 | inexact | bool, integer |
Expand All @@ -88,16 +89,17 @@ be updated.
| erfc | 4 | floating | bool, complex, integer |
| exp | 6 | inexact | bool, integer |
| expm1 | 6 | inexact | bool, integer |
| fft | 20 | complex, float32, float64 | bfloat16, bool, float16, integer |
| fft | 32 | complex, float32, float64 | bfloat16, bool, float16, integer |
| floor | 4 | floating | bool, complex, integer |
| gather | 150 | all | |
| gather | 164 | all | |
| ge | 17 | all | |
| gt | 17 | all | |
| igamma | 6 | floating | bool, complex, integer |
| igammac | 6 | floating | bool, complex, integer |
| imag | 2 | complex | bool, floating, integer |
| integer_pow | 108 | inexact, integer | bool |
| iota | 16 | inexact, integer | bool |
| iota_2x32_shape | 3 | uint32 | bool, inexact, signed, uint16, uint64, uint8 |
| is_finite | 4 | floating | bool, complex, integer |
| le | 17 | all | |
| lgamma | 4 | floating | bool, complex, integer |
Expand All @@ -106,8 +108,8 @@ be updated.
| logistic | 6 | inexact | bool, integer |
| lt | 17 | all | |
| lu | 18 | inexact | bool, integer |
| max | 33 | all | |
| min | 33 | all | |
| max | 27 | all | |
| min | 27 | all | |
| mul | 16 | inexact, integer | bool |
| ne | 17 | all | |
| neg | 14 | inexact, integer | bool |
Expand All @@ -128,6 +130,7 @@ be updated.
| reduce_max | 15 | all | |
| reduce_min | 15 | all | |
| reduce_or | 1 | bool | inexact, integer |
| reduce_precision | 32 | floating | bool, complex, integer |
| reduce_prod | 14 | inexact, integer | bool |
| reduce_sum | 14 | inexact, integer | bool |
| reduce_window_add | 50 | inexact, integer | bool |
Expand Down Expand Up @@ -194,7 +197,8 @@ and search for "limitation".
| --- | --- | --- | --- |
|cholesky|unimplemented|float16|cpu, gpu|
|clamp|unimplemented|bool, complex|cpu, gpu, tpu|
|conv_general_dilated|preferred_element_type not implemented for integers|int16, int32, int8|gpu|
|conv_general_dilated|preferred_element_type not implemented for integers|signed|gpu|
|dot_general|preferred_element_type must be floating for integer dtype|integer|gpu|
|dot_general|preferred_element_type must match dtype for floating point|inexact|gpu|
|eig|only supported on CPU in JAX|all|tpu, gpu|
|eig|unimplemented|bfloat16, float16|cpu|
Expand Down
20 changes: 15 additions & 5 deletions jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Primitives with limited support for jax2tf

*Last generated on (YYYY-MM-DD): 2022-11-07*
*Last generated on (YYYY-MM-DD): 2023-07-31*

This document summarizes known limitations of the jax2tf conversion.
There are several kinds of limitations.
Expand Down Expand Up @@ -61,15 +61,21 @@ More detailed information can be found in the

| Affected primitive | Description of limitation | Affected dtypes | Affected devices | Affected compilation modes |
| --- | --- | --- | --- | --- |
| approx_top_k | TF error: compilation not supported for float64. | float64 | cpu, gpu | compiled |
| approx_top_k | TF error: op not defined for dtype | floating | cpu, gpu | eager, graph |
| bessel_i0e | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
| bessel_i1e | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
| cholesky | TF test skipped: Not implemented in JAX: unimplemented | float16 | cpu, gpu | compiled, eager, graph |
| clamp | TF test skipped: Not implemented in JAX: unimplemented | bool, complex | cpu, gpu, tpu | compiled, eager, graph |
| conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type not implemented for integers | int16, int32, int8 | gpu | compiled, eager, graph |
| conv_general_dilated | TF error: Numeric comparison disabled: Non-deterministic NaN for conv_general_dilated with preferred_element_type | int16, int32, int64 | cpu, gpu, tpu | compiled, eager, graph |
| conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type not implemented for integers | signed | gpu | compiled, eager, graph |
| digamma | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
| div | TF error: TF integer division fails if divisor contains 0; JAX returns NaN | integer | cpu, gpu, tpu | compiled, eager, graph |
| dot_general | TF test skipped: TF error: Numeric comparison disabled: Crash when lhs_dtype != rhs_dtype for non-native serialization on TPU | all | tpu | compiled, eager, graph |
| dot_general | TF error: Numeric comparison disabled: Errors when lhs_dtype != rhs_dtype for non-native serialization on CPU and GPU | all | cpu, gpu, tpu | compiled, eager, graph |
| dot_general | TF error: Numeric comparison disabled: Large tolerances when upcasting with preferred_element_type on CPU (b/241740367) | all | cpu, gpu, tpu | compiled, eager, graph |
| dot_general | TF error: Numeric comparison disabled: Non-deterministic NaN for dot_general with preferred_element_type on GPU (b/189287598) | bfloat16, complex64, float16, float32 | gpu | compiled, eager, graph |
| dot_general | TF test skipped: Not implemented in JAX: preferred_element_type must be floating for integer dtype | integer | gpu | compiled, eager, graph |
| dot_general | TF test skipped: Not implemented in JAX: preferred_element_type must match dtype for floating point | inexact | gpu | compiled, eager, graph |
| dot_general | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| eig | TF test skipped: Not implemented in JAX: only supported on CPU in JAX | all | gpu, tpu | compiled, eager, graph |
Expand All @@ -79,8 +85,8 @@ More detailed information can be found in the
| eigh | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu | compiled, eager, graph |
| eigh | TF error: op not defined for dtype | bfloat16 | tpu | compiled, eager, graph |
| erf_inv | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu | eager, graph |
| fft | TF error: TF function not compilable | float64 | cpu, gpu | compiled |
| fft | TF error: TF function not compilable for IFFT and IRFFT | complex128 | cpu, gpu | compiled |
| fft | TF error: TF function not compilableble | float64 | cpu, gpu | compiled |
| fft | TF error: TF function not compilableble for IFFT and IRFFT | complex128 | cpu, gpu | compiled |
| igamma | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu | eager, graph |
| igammac | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu | eager, graph |
| lgamma | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
Expand All @@ -91,7 +97,8 @@ More detailed information can be found in the
| reduce_max | TF error: op not defined for dtype | complex | cpu, gpu, tpu | compiled, eager, graph |
| reduce_min | TF error: op not defined for dtype | complex | cpu, gpu, tpu | compiled, eager, graph |
| regularized_incomplete_beta | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph |
| rem | TF error: TF integer division fails if divisor contains 0; JAX returns NaN | integer | cpu, gpu, tpu | compiled, eager, graph |
| rem | TF error: Numeric comparison disabled: TF division of inf by inf returns inf while in JAX returns nan | float32 | gpu | compiled, eager, graph |
| rem | TF error: Numeric comparison disabled: TF integer division fails if divisor contains 0; JAX returns NaN | integer | cpu, gpu, tpu | compiled, eager, graph |
| round | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
| scatter | TF error: Numeric comparison disabled: out-of-bounds scatters are not supported in graph and eager mode | inexact | cpu, gpu, tpu | eager, graph |
| scatter_add | TF test skipped: Not implemented in JAX: unimplemented | bool | cpu, gpu, tpu | compiled, eager, graph |
Expand Down Expand Up @@ -120,6 +127,7 @@ with jax2tf. The following table lists that cases when this does not quite hold:
| Affected primitive | Description of limitation | Affected dtypes | Affected devices | Affected compilation modes |
| --- | --- | --- | --- | --- |
| acosh | May return different but still correct results | complex | cpu, gpu, tpu | eager, graph |
| approx_top_k | custom numeric comparison | floating | cpu, gpu | eager, graph |
| argmax | Numeric comparison disabled: different results when the input contains NaN and enable_xla=False | inexact | cpu, gpu, tpu | compiled, eager, graph |
| argmin | Numeric comparison disabled: different results when the input contains NaN and enable_xla=False | inexact | cpu, gpu, tpu | compiled, eager, graph |
| asin | May return different but still correct results | complex | cpu, gpu, tpu | eager, graph |
Expand All @@ -140,7 +148,9 @@ with jax2tf. The following table lists that cases when this does not quite hold:
| integer_pow | custom numeric comparison | complex | cpu, gpu, tpu | eager, graph |
| lu | May return different, but also correct, results when the decomposition is not unique | all | cpu, gpu | compiled, eager, graph |
| max | May return different values when one of the values is NaN. JAX always returns NaN, while TF returns the value NaN is compared with. | all | cpu, gpu, tpu | compiled, eager, graph |
| max | TF and JAX use different values of the compiler flag xla_cpu_enable_fast_min_max compiler flag and therefore have different behavior of NaN propagation through min/max. | all | cpu | compiled, eager, graph |
| min | May return different values when one of the values is NaN. JAX always returns NaN, while TF returns the value NaN is compared with. | all | cpu, gpu, tpu | compiled, eager, graph |
| min | TF and JAX use different values of the compiler flag xla_cpu_enable_fast_min_max compiler flag and therefore have different behavior of NaN propagation through min/max. | all | cpu | compiled, eager, graph |
| pow | custom numeric comparison | complex | cpu, gpu, tpu | eager, graph |
| random_split | Returns JAX key arrays, so compare underlying base array | all | cpu, gpu, tpu | compiled, eager, graph |
| reduce_window_add | Numeric comparison disabled: Large deviations on TPU for enable_xla=False | float16, float32 | tpu | compiled, eager, graph |
Expand Down
13 changes: 13 additions & 0 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2049,6 +2049,19 @@ def _dot_general(lhs, rhs, *, dimension_numbers,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
"""Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""
# TODO(b/293247337): we ought to turn on this safety check, but this leads to
# failures. Since we are going to turn on native serializaton soon, wait
# until then to turn on this check.
# lhs_aval, rhs_aval = _in_avals
# if lhs_aval.dtype != rhs_aval.dtype:
# # There are multiple kinds of errors: handling jnp.bfloat16 in xla.py and
# # returning different result dtype than JAX expects for various combinations
# # of types. We ought to implement the same workarounds as in the
# # native dot_general lowering rules, but this is not a high priority now
# # that we deprecate non-native serialization.
# raise NotImplementedError(
# "dot_general with different lhs_dtype and rhs_dtype is not supported "
# "in non-native serialization")
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
dnums_proto = xla_data_pb2.DotDimensionNumbers()
dnums_proto.lhs_contracting_dimensions.extend(lhs_contracting)
Expand Down
26 changes: 26 additions & 0 deletions jax/experimental/jax2tf/tests/jax2tf_limitations.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,18 @@ def custom_assert(tst, result_jax, result_tf, *, tol, err_msg, **_):

@classmethod
def conv_general_dilated(cls, harness: primitive_harness.Harness):
prefer_elem = harness.params["preferred_element_type"]
return [
Jax2TfLimitation(
"Non-deterministic NaN for conv_general_dilated with preferred_element_type",
dtypes=[
jnp.int32, np.int16, np.int64
],
devices=["cpu", "gpu", "tpu"],
modes=("eager", "graph", "compiled"),
enabled=(prefer_elem is not None
and prefer_elem in [jnp.bfloat16, np.float16, np.float32, np.float64]),
skip_comparison=True),
# Even in compiled mode, for GPU we see a bit of discrepancy but
# very minor.
custom_numeric(dtypes=[np.float32], devices="gpu",
Expand Down Expand Up @@ -485,6 +496,21 @@ def dot_general(cls, harness: primitive_harness.Harness):
devices=["cpu", "gpu", "tpu"],
enabled=prefer_elem and np.dtype(harness.dtype) < np.dtype(prefer_elem),
skip_comparison=True),
# TODO(necula): look into this, but this is only for non-native serialization
Jax2TfLimitation(
"Errors when lhs_dtype != rhs_dtype for non-native serialization on CPU and GPU",
devices=["cpu", "gpu", "tpu"],
enabled=(harness.dtype != harness.params["rhs_dtype"]),
skip_comparison=True),
# TODO(necula): look into this, but this is only for non-native serialization
Jax2TfLimitation(
"Crash when lhs_dtype != rhs_dtype for non-native serialization on TPU",
devices=["tpu"],
enabled=(harness.dtype != harness.params["rhs_dtype"] and
(harness.dtype in [np.complex64, np.complex128] or
harness.params["rhs_dtype"] in [np.complex64, np.complex128])),
skip_comparison=True,
skip_tf_run=True),
# JAX performs float16 matmuls in float32 on CPU, so the JAX result
# may be more precise.
custom_numeric(dtypes=[np.float16], devices=["cpu"], tol=1e-2,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from absl.testing import absltest

from jax._src import test_util as jtu
from jax._src import maps # Needed for config flags.
from jax import config

import numpy as np
Expand All @@ -48,7 +49,7 @@ class JaxPrimitiveTest(jtu.JaxTestCase):
# If you want to run this test for only one harness, add parameter
# `one_containing="foo"` to parameterized below.
@primitive_harness.parameterized(primitive_harness.all_harnesses,
#one_containing="gather_from_slicing_name",
#one_containing="",
include_jax_unimpl=True)
@jtu.ignore_warning(category=UserWarning,
message="Using reduced precision for gradient.*")
Expand Down
Loading

0 comments on commit b8019dc

Please sign in to comment.