Skip to content

Commit

Permalink
Integrate StableHLO at openxla/stablehlo@ab709fe4
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 589908773
  • Loading branch information
GleasonK authored and jax authors committed Dec 11, 2023
1 parent 384e29e commit 184e3a8
Show file tree
Hide file tree
Showing 9 changed files with 37 additions and 25 deletions.
19 changes: 12 additions & 7 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@
def dense_int_elements(xs) -> ir.DenseIntElementsAttr:
return ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64))

def dense_int_array(xs) -> Union[ir.DenseIntElementsAttr, ir.DenseI64ArrayAttr]:
if hlo.get_api_version() < 5:
return dense_int_elements(xs)
return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64))

def dense_bool_elements(xs: Sequence[bool]) -> ir.DenseElementsAttr:
a = np.packbits(np.array(xs, np.bool_), bitorder='little')
# TODO(b/209005197): Work around for MLIR crash for non-splat single element
Expand Down Expand Up @@ -1844,9 +1849,9 @@ def slice_op(ctx: LoweringRuleContext, x, aval_out, *,
x, start_indices, limit_indices, strides)
else:
return hlo.slice(x,
dense_int_elements(start_indices),
dense_int_elements(limit_indices),
dense_int_elements(strides))
dense_int_array(start_indices),
dense_int_array(limit_indices),
dense_int_array(strides))

def dynamic_slice(ctx: LoweringRuleContext, aval_out, x, *,
start_indices) -> ir.Value:
Expand Down Expand Up @@ -1881,7 +1886,7 @@ def dynamic_slice(ctx: LoweringRuleContext, aval_out, x, *,
shape_tensor([1] * len(start_indices))
)
else:
return hlo.dynamic_slice(x, start_indices, dense_int_elements(slice_sizes))
return hlo.dynamic_slice(x, start_indices, dense_int_array(slice_sizes))

def dynamic_update_slice(ctx: LoweringRuleContext, aval_out, x, update, *,
start_indices) -> ir.Value:
Expand All @@ -1906,9 +1911,9 @@ def pad(ctx: LoweringRuleContext, aval_out,
if all(core.is_constant_shape(s) for s in (padding_low,
padding_high, padding_interior)):
return hlo.pad(x, padding_value,
dense_int_elements(padding_low),
dense_int_elements(padding_high),
dense_int_elements(padding_interior))
dense_int_array(padding_low),
dense_int_array(padding_high),
dense_int_array(padding_interior))
else:
padding_low = eval_dynamic_shape_as_tensor(ctx, padding_low)
padding_high = eval_dynamic_shape_as_tensor(ctx, padding_high)
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -1284,7 +1284,7 @@ def _hlo_shard(aval, axis_env, xs, in_axis):
dims_unsqueezed = dims.copy()
dims_unsqueezed.insert(in_axis, 1)
dynamic_slice_result = hlo.dynamic_slice(
x, idxs, mlir.dense_int_elements(dims_unsqueezed))
x, idxs, mlir.dense_int_array(dims_unsqueezed))
return [
hlo.reshape(mlir.aval_to_ir_type(aval), dynamic_slice_result)
]
Expand Down Expand Up @@ -1335,7 +1335,7 @@ def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs):
padded = mlir.full_like_aval(ctx, 0, padded_aval)
zero = mlir.ir_constant(np.zeros((), dtype=np.uint32))
idxs = [_unravel_index_hlo(axis_env)] + [zero] * len(dims)
broadcast_result = hlo.broadcast(x, mlir.dense_int_elements([1]))
broadcast_result = hlo.broadcast(x, mlir.dense_int_array([1]))
padded = hlo.dynamic_update_slice(padded, broadcast_result, idxs)
replica_groups = mlir.dense_int_elements(
axis_groups(axis_env, axis_env.names[-1]))
Expand All @@ -1346,7 +1346,7 @@ def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs):
perm.insert(out_axis, 0)
transposed_dims = list(dims)
transposed_dims.insert(out_axis, axis_env.sizes[-1])
out = hlo.transpose(out, mlir.dense_int_elements(perm))
out = hlo.transpose(out, mlir.dense_int_array(perm))

return out
else:
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lax/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _fft_lowering(ctx, x, *, fft_type, fft_lengths):
raise NotImplementedError("Shape polymorphism for FFT with non-constant fft_length is not implemented for TPU and GPU")
return [
hlo.FftOp(x, hlo.FftTypeAttr.get(fft_type.name),
mlir.dense_int_elements(fft_lengths)).result
mlir.dense_int_array(fft_lengths)).result
]


Expand Down
6 changes: 3 additions & 3 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3423,7 +3423,7 @@ def _reshape_batch_rule(batched_args, batch_dims, *, new_sizes, dimensions):
def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions):
aval_out, = ctx.avals_out
if dimensions is not None:
x = hlo.transpose(x, mlir.dense_int_elements(dimensions))
x = hlo.transpose(x, mlir.dense_int_array(dimensions))
if dyn_shape:
aval_out = aval_out.update(shape=_merge_dyn_shape(new_sizes, dyn_shape))
return [mlir.reshape(ctx, x, aval_out)]
Expand Down Expand Up @@ -3467,7 +3467,7 @@ def _rev_batch_rule(batched_args, batch_dims, *, dimensions):
batching.primitive_batchers[rev_p] = _rev_batch_rule

def _rev_lower(ctx, x, *, dimensions):
return [hlo.reverse(x, mlir.dense_int_elements(dimensions))]
return [hlo.reverse(x, mlir.dense_int_array(dimensions))]
mlir.register_lowering(rev_p, _rev_lower)


Expand Down Expand Up @@ -3499,7 +3499,7 @@ def _transpose_lower(ctx, x, *, permutation):
aval_out.dtype).shape
trailing_dims = [aval_out.ndim + i for i in range(len(elt_shape))]
permutation = [*permutation, *trailing_dims]
return [hlo.transpose(x, mlir.dense_int_elements(permutation))]
return [hlo.transpose(x, mlir.dense_int_array(permutation))]

transpose_p = standard_primitive(_transpose_shape_rule, _input_dtype,
'transpose')
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/sparse/bcsr.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@ def _bcsr_dot_general_gpu_lowering(
dot_general_fn = csr_matmat_lowering
x_dtype = 'B_dtype'
if rhs_contract[0] == 1:
rhs = hlo.transpose(rhs, permutation=mlir.dense_int_elements([1, 0]))
rhs = hlo.transpose(rhs, permutation=mlir.dense_int_array([1, 0]))
else:
raise ValueError(f"rhs has to be 1d or 2d; get {rhs_aval.ndim}d.")

Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/sparse/coo.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def _coo_todense_gpu_lowering(coo_todense_hlo, ctx, data, row, col, *, spinfo):
result = coo_todense_hlo(
data, row, col, shape=shape, data_dtype=dtype, index_dtype=row_aval.dtype)
return (
[hlo.transpose(result, mlir.dense_int_elements([1, 0]))]
[hlo.transpose(result, mlir.dense_int_array([1, 0]))]
if transpose else [result])


Expand Down
1 change: 1 addition & 0 deletions jax/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
core_call_lowering as core_call_lowering,
custom_call as custom_call,
dense_bool_elements as dense_bool_elements,
dense_int_array as dense_int_array,
dense_int_elements as dense_int_elements,
dtype_to_ir_type as dtype_to_ir_type,
emit_python_callback as emit_python_callback,
Expand Down
20 changes: 11 additions & 9 deletions jaxlib/gpu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from .hlo_helpers import (
DimensionSize, ShapeTypePair, mk_result_types_and_shapes,
custom_call, ensure_hlo_s32, hlo_s32)
custom_call, ensure_hlo_s32, hlo_s32, dense_int_array)

try:
from .cuda import _blas as _cublas # pytype: disable=import-error
Expand Down Expand Up @@ -408,20 +408,20 @@ def _gesvd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
operand_output_aliases={0: 0}).results
vt = hlo.transpose(
v,
ir.DenseIntElementsAttr.get(np.array(tuple(range(num_bd)) + (num_bd + 1, num_bd))))
dense_int_array(np.array(tuple(range(num_bd)) + (num_bd + 1, num_bd))))
if np.issubdtype(dtype, np.complexfloating):
vt = hlo.complex(hlo.real(vt), hlo.negate(hlo.imag(vt)))
if not full_matrices and not econ:
u = hlo.slice(
u,
ir.DenseIntElementsAttr.get(np.zeros([len(dims)], np.int64)),
ir.DenseIntElementsAttr.get(np.array(batch_dims + (m, min(m, n)))),
ir.DenseIntElementsAttr.get(np.ones([len(dims)], np.int64)))
dense_int_array(np.zeros([len(dims)], np.int64)),
dense_int_array(np.array(batch_dims + (m, min(m, n)))),
dense_int_array(np.ones([len(dims)], np.int64)))
vt = hlo.slice(
vt,
ir.DenseIntElementsAttr.get(np.zeros([len(dims)], np.int64)),
ir.DenseIntElementsAttr.get(np.array(batch_dims + (min(m, n), n))),
ir.DenseIntElementsAttr.get(np.ones([len(dims)], np.int64)))
dense_int_array(np.zeros([len(dims)], np.int64)),
dense_int_array(np.array(batch_dims + (min(m, n), n))),
dense_int_array(np.ones([len(dims)], np.int64)))
elif m < n:
lwork, opaque = gpu_solver.build_gesvd_descriptor(
np.dtype(dtype), b, n, m, compute_uv, full_matrices)
Expand Down Expand Up @@ -535,10 +535,12 @@ def _sytrd_hlo(platform, gpu_solver, dtype, a, *, lower):
# lower=False case. The correct result is returned in the `e` vector so we can
# simply copy it back to where it needs to be:
intattr = lambda xs: ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64))
intarrattr = lambda xs: dense_int_array(np.asarray(xs, np.int64))
if not lower and platform == "cu" and m > 1:
start = (0,) * len(batch_dims) + (0,)
end = batch_dims + (1,)
s = hlo.slice(e, intattr(start), intattr(end), intattr([1] * len(start)))
s = hlo.slice(
e, intarrattr(start), intarrattr(end),intarrattr([1] * len(start)))
s_type = ir.RankedTensorType.get(batch_dims + (1, 1), diag_type)
s = hlo.broadcast_in_dim(s_type, s, intattr(range(len(dims) - 1)))
# The diagonals are always real; convert to complex if needed.
Expand Down
4 changes: 4 additions & 0 deletions jaxlib/hlo_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ def hlo_s32(x: int):
def ensure_hlo_s32(x: DimensionSize):
return hlo_s32(x) if isinstance(x, int) else x

def dense_int_array(xs) -> Union[ir.DenseIntElementsAttr, ir.DenseI64ArrayAttr]:
if hlo.get_api_version() < 5:
return ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64))
return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64))

def hlo_min(x: DimensionSize, y: DimensionSize) -> DimensionSize:
if type(x) is int:
Expand Down

0 comments on commit 184e3a8

Please sign in to comment.