Skip to content

Commit

Permalink
[XLA:Mosaic] Support tpu.bitcast for i16, i8.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 582897025
  • Loading branch information
bythew3i authored and jax authors committed Nov 16, 2023
1 parent 95de3d0 commit f44e939
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 16 deletions.
28 changes: 19 additions & 9 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3084,9 +3084,10 @@ Value selectTilesFromRotatedRowVregs(

const IntegerType i1 = builder.getI1Type();
const auto mask_vreg_ty =
dst_layout.packing() == 2
? VectorType::get(
ArrayRef<int64_t>{target_shape[0], target_shape[1], 2}, i1)
dst_layout.packing() > 1
? VectorType::get(ArrayRef<int64_t>{target_shape[0], target_shape[1],
dst_layout.packing()},
i1)
: VectorType::get(target_shape, i1);

auto boundIdxConst = std::bind(IdxConst, std::placeholders::_1, builder,
Expand Down Expand Up @@ -3461,12 +3462,10 @@ FailureOr<Value> relayout(OpBuilder &builder, Value v, VectorLayout src,
}
// TODO(b/265133506): Generalize retiling to general 16-bit types (might
// need to use a different unpacking op).
VectorType vreg_f32 = VectorType::get(target_shape, builder.getF32Type());
// (8,128) -> (16,128) tiling change for packed 16-bit types.
if (src.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
dst.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
vty.getElementType() == builder.getBF16Type() &&
src.offsets() == dst.offsets() &&
vty.getElementTypeBitWidth() == 16 && src.offsets() == dst.offsets() &&
src.tiling() == std::array<int64_t, 2>{8, 128} &&
dst.tiling() == std::array<int64_t, 2>{16, 128}) {
const VectorLayout new_src(src.bitwidth(), src.offsets(), dst.tiling());
Expand All @@ -3483,10 +3482,15 @@ FailureOr<Value> relayout(OpBuilder &builder, Value v, VectorLayout src,
}
Value src_row2 = src_tiles(src_idx);
const int vreg_part = idx[idx.size() - 1] % 2;

VectorType vreg_x32 =
vty.getElementType().isSignlessInteger()
? VectorType::get(target_shape, builder.getI32Type())
: VectorType::get(target_shape, builder.getF32Type());
auto half_row1 = builder.create<tpu::UnpackSubelementsOp>(
v.getLoc(), vreg_f32, src_row1, vreg_part);
v.getLoc(), vreg_x32, src_row1, vreg_part);
auto half_row2 = builder.create<tpu::UnpackSubelementsOp>(
v.getLoc(), vreg_f32, src_row2, vreg_part);
v.getLoc(), vreg_x32, src_row2, vreg_part);
*tile = builder.create<tpu::PackSubelementsOp>(
v.getLoc(), src_row1.getType(), ValueRange{half_row1, half_row2});
});
Expand Down Expand Up @@ -3633,8 +3637,14 @@ FailureOr<Value> relayout(OpBuilder &builder, Value v, VectorLayout src,
if (src_tiles.dimensions()[src_tiles.num_dimensions() - 1] > 1) {
auto boundIdxConst =
std::bind(IdxConst, std::placeholders::_1, builder, v.getLoc());
VectorType mask_vreg_ty =
packing > 1
? VectorType::get(ArrayRef<int64_t>{target_shape[0],
target_shape[1], packing},
builder.getI1Type())
: VectorType::get(target_shape, builder.getI1Type());
maybe_create_mask = builder.create<tpu::CreateMaskOp>(
v.getLoc(), VectorType::get(target_shape, builder.getI1Type()),
v.getLoc(), mask_vreg_ty,
ValueRange{boundIdxConst(0), boundIdxConst(0)},
ValueRange{boundIdxConst(target_shape[0]),
boundIdxConst(col_diff)});
Expand Down
25 changes: 18 additions & 7 deletions jaxlib/mosaic/python/apply_vector_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,8 +927,8 @@ def select_tiles_from_rotated_row_vregs(

i1 = ir.IntegerType.get_signless(1)
mask_vreg_ty = (
ir.VectorType.get((*TARGET_SHAPE, 2), i1)
if dst_layout.packing == 2
ir.VectorType.get((*TARGET_SHAPE, dst_layout.packing), i1)
if dst_layout.packing > 1
else ir.VectorType.get(TARGET_SHAPE, i1)
)

Expand Down Expand Up @@ -1287,7 +1287,7 @@ def relayout(
if (
src.implicit_dim is None
and dst.implicit_dim is None
and ir.BF16Type.isinstance(vty.element_type)
and src.bitwidth == 16
and src.offsets == dst.offsets
and src.tiling == (8, 128)
and dst.tiling == (16, 128)
Expand All @@ -1301,9 +1301,13 @@ def relayout(
src_row2 = src_tiles[(*batch_idx, src_row2_row, dst_col // 2)]

vreg_part = dst_col % 2
vreg_f32 = ir.VectorType.get(TARGET_SHAPE, ir.F32Type.get())
half_row1 = tpu.UnpackSubelementsOp(vreg_f32, src_row1, vreg_part)
half_row2 = tpu.UnpackSubelementsOp(vreg_f32, src_row2, vreg_part)
if ir.IntegerType.isinstance(vty.element_type):
unpacked_ty = ir.IntegerType.get_signless(32)
else:
unpacked_ty = ir.F32Type.get()
vreg_x32 = ir.VectorType.get(TARGET_SHAPE, unpacked_ty)
half_row1 = tpu.UnpackSubelementsOp(vreg_x32, src_row1, vreg_part)
half_row2 = tpu.UnpackSubelementsOp(vreg_x32, src_row2, vreg_part)
src_tiles_retiled[(*batch_idx, dst_row, dst_col)] = tpu.PackSubelementsOp(
src_row1.type, [half_row1, half_row2]
)
Expand Down Expand Up @@ -1423,9 +1427,16 @@ def relayout(
sublane_diff_attr = ir.IntegerAttr.get(
ir.IntegerType.get_signed(32), sublane_diff
)
mask_vreg_ty = (
ir.VectorType.get(
(*TARGET_SHAPE, packing), ir.IntegerType.get_signless(1)
)
if packing > 1
else ir.VectorType.get(TARGET_SHAPE, ir.IntegerType.get_signless(1))
)
if src_tiles.shape[-1] > 1:
mask = tpu.CreateMaskOp(
ir.VectorType.get(TARGET_SHAPE, ir.IntegerType.get_signless(1)),
mask_vreg_ty,
low=list(map(ix_cst, [0, 0])),
high=list(map(ix_cst, [TARGET_SHAPE[0], col_diff])))
for idx, tile in np.ndenumerate(src_tiles):
Expand Down

0 comments on commit f44e939

Please sign in to comment.