Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:Mosaic] Support tpu.bitcast for i16, i8. #18556

Merged
merged 1 commit into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3115,9 +3115,10 @@ Value selectTilesFromRotatedRowVregs(

const IntegerType i1 = builder.getI1Type();
const auto mask_vreg_ty =
dst_layout.packing() == 2
? VectorType::get(
ArrayRef<int64_t>{target_shape[0], target_shape[1], 2}, i1)
dst_layout.packing() > 1
? VectorType::get(ArrayRef<int64_t>{target_shape[0], target_shape[1],
dst_layout.packing()},
i1)
: VectorType::get(target_shape, i1);

auto boundIdxConst = std::bind(IdxConst, std::placeholders::_1, builder,
Expand Down Expand Up @@ -3528,12 +3529,9 @@ FailureOr<Value> relayout(OpBuilder &builder, Value v, VectorLayout src,
// (8,128) -> (16,128) tiling change for packed 16-bit types.
src.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
dst.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
vty.getElementType() == builder.getBF16Type() &&
src.offsets() == dst.offsets() &&
vty.getElementTypeBitWidth() == 16 && src.offsets() == dst.offsets() &&
src.tiling() == std::array<int64_t, 2>{8, 128} &&
dst.tiling() == std::array<int64_t, 2>{16, 128}) {
VectorType vreg_f32 =
VectorType::get(target_shape, builder.getF32Type());
const VectorLayout new_src(src.bitwidth(), src.offsets(), dst.tiling());
xla::Array<Value> src_tiles_retiled(
new_src.tileArrayShape(vty.getShape(), target_shape));
Expand All @@ -3548,10 +3546,15 @@ FailureOr<Value> relayout(OpBuilder &builder, Value v, VectorLayout src,
}
Value src_row2 = src_tiles(src_idx);
const int vreg_part = idx[idx.size() - 1] % 2;

VectorType vreg_x32 =
vty.getElementType().isSignlessInteger()
? VectorType::get(target_shape, builder.getI32Type())
: VectorType::get(target_shape, builder.getF32Type());
auto half_row1 = builder.create<tpu::UnpackSubelementsOp>(
v.getLoc(), vreg_f32, src_row1, vreg_part);
v.getLoc(), vreg_x32, src_row1, vreg_part);
auto half_row2 = builder.create<tpu::UnpackSubelementsOp>(
v.getLoc(), vreg_f32, src_row2, vreg_part);
v.getLoc(), vreg_x32, src_row2, vreg_part);
*tile = builder.create<tpu::PackSubelementsOp>(
v.getLoc(), src_row1.getType(), ValueRange{half_row1, half_row2});
});
Expand Down
19 changes: 10 additions & 9 deletions jaxlib/mosaic/python/apply_vector_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,8 +890,8 @@ def select_tiles_from_rotated_row_vregs(

i1 = ir.IntegerType.get_signless(1)
mask_vreg_ty = (
ir.VectorType.get((*TARGET_SHAPE, 2), i1)
if dst_layout.packing == 2
ir.VectorType.get((*TARGET_SHAPE, dst_layout.packing), i1)
if dst_layout.packing > 1
else ir.VectorType.get(TARGET_SHAPE, i1)
)

Expand Down Expand Up @@ -1287,13 +1287,10 @@ def relayout(
src = dst
src_tiles = src_tiles_retiled

# TODO(apaszke): Generalize retiling to general 16-bit types (might need to
# use a different unpacking op).
# (8,128) -> (16,128) tiling change for packed 16-bit types.
elif (
src.implicit_dim is None
and dst.implicit_dim is None
and ir.BF16Type.isinstance(vty.element_type)
and src.bitwidth == 16
and src.offsets == dst.offsets
and src.tiling == (8, 128)
and dst.tiling == (16, 128)
Expand All @@ -1307,9 +1304,13 @@ def relayout(
src_row2 = src_tiles[(*batch_idx, src_row2_row, dst_col // 2)]

vreg_part = dst_col % 2
vreg_f32 = ir.VectorType.get(TARGET_SHAPE, ir.F32Type.get())
half_row1 = tpu.UnpackSubelementsOp(vreg_f32, src_row1, vreg_part)
half_row2 = tpu.UnpackSubelementsOp(vreg_f32, src_row2, vreg_part)
if ir.IntegerType.isinstance(vty.element_type):
unpacked_ty = ir.IntegerType.get_signless(32)
else:
unpacked_ty = ir.F32Type.get()
vreg_x32 = ir.VectorType.get(TARGET_SHAPE, unpacked_ty)
half_row1 = tpu.UnpackSubelementsOp(vreg_x32, src_row1, vreg_part)
half_row2 = tpu.UnpackSubelementsOp(vreg_x32, src_row2, vreg_part)
src_tiles_retiled[(*batch_idx, dst_row, dst_col)] = tpu.PackSubelementsOp(
src_row1.type, [half_row1, half_row2]
)
Expand Down