Skip to content

Commit

Permalink
Rewrite mosaic concat to support operand shapes that do not align wit…
Browse files Browse the repository at this point in the history
…h native shapes, Expand tests to cover multi operand, batch dim concat, etc.

PiperOrigin-RevId: 675835891
  • Loading branch information
Google-ML-Automation committed Oct 8, 2024
1 parent 2f67710 commit e998f9e
Show file tree
Hide file tree
Showing 5 changed files with 222 additions and 55 deletions.
1 change: 1 addition & 0 deletions jaxlib/mosaic/dialect/tpu/tpu.td
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def TPU_ConcatenateOp : TPU_Op<"concatenate", [Pure]> {
let assemblyFormat = [{
$sources `in` $dimension attr-dict `:` type($sources) `->` type($output)
}];
let hasVerifier = 1;
}

def TPU_BitcastOp : TPU_Op<"bitcast", [Pure]> {
Expand Down
30 changes: 30 additions & 0 deletions jaxlib/mosaic/dialect/tpu/tpu_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,36 @@ LogicalResult ShuffledStoreOp::canonicalize(ShuffledStoreOp op,
}
return success();
}

LogicalResult ConcatenateOp::verify() {
auto dimension = getDimension();
if (getOperands().size() < 2) {
return emitOpError("Expected at least 2 operands for concatenate op.");
}
auto first_type = getOperand(0).getType().cast<VectorType>();
auto first_shape = first_type.getShape();
auto first_dtype = first_type.getElementType();
for (int i = 0; i < getNumOperands(); ++i) {
auto operand = getOperand(i);
auto vty = cast<VectorType>(operand.getType());
auto shape = vty.getShape();
auto dtype = vty.getElementType();
if (dtype != first_dtype) {
return emitOpError(
"Not implemented:: Expected all operands to have the same element "
"type.");
}
for (int dim = 0; dim < shape.size(); ++dim) {
if (dim != dimension && shape[dim] != first_shape[dim]) {
return emitOpError(
"Not implemented: Expected all operands to have "
"the same shape outside of the concat dim");
}
}
}
return success();
}

} // namespace tpu
} // namespace mlir

Expand Down
169 changes: 131 additions & 38 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/types/span.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
Expand Down Expand Up @@ -2509,54 +2510,146 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op,
TPU_ASSERT_OP(
llvm::all_of(layouts_in, [](const Layout &l) { return l.has_value(); }));
TPU_ASSERT_OP(layouts_out.front().has_value());
const VectorLayout &layout = *layouts_out.front();
for (const Layout &l : layouts_in) {
if (l != layout) {
return op.emitOpError("Not implemented: Inconsistent layouts");
}
}
OpBuilder builder(&op);
auto concatenate_op = cast<tpu::ConcatenateOp>(op);
const VectorType res_ty = concatenate_op.getResult().getType();
const uint32_t dimension = concatenate_op.getDimension();
if (dimension - res_ty.getRank() >= -2) {
if (!layout.hasNativeTiling(ctx.target_shape) ||
layout.offsets() != LayoutOffsets{0, 0}) {
return op.emitOpError(
"Not implemented: Only native tiling with offset (0, 0) is supported "
"when concatenation along tiling dims.");
}
// Check if the concat dim size of src and res is aligned to native tiling.
auto check_aligned = [&](const VectorType &vty) {
auto i = dimension - res_ty.getRank();
return vty.getRank() >= 2 &&
*(vty.getShape().end() + i) % *(layout.tiling().end() + i) == 0;
};
bool is_aligned = check_aligned(res_ty);
int op_idx = 0;
while (is_aligned && op_idx < op.getNumOperands()) {
auto vty = dyn_cast<VectorType>(op.getOperand(op_idx++).getType());
is_aligned = check_aligned(vty);
uint32_t dimension = concatenate_op.getDimension();
SmallVector<xla::Array<Value>> operand_vregs;
operand_vregs.reserve(op.getNumOperands());

std::optional<int64_t> tiling_dim;
auto res_layout = layouts_out.front();

TPU_ASSERT_OP(res_layout.has_value());
if (res_layout->implicit_dim() != VectorLayout::ImplicitDim::kNone) {
return op.emitOpError("Not implemented: implicit dim");
}
auto num_untiled_dims = res_ty.getRank() - res_layout->layout_rank();

if (dimension >= num_untiled_dims) {
tiling_dim = dimension - num_untiled_dims;
}

// Op level invariants on layouts, other op level invariants are checked in
// the verifier.
for (int i = 0; i < op.getNumOperands(); ++i) {
auto operand = op.getOperand(i);
if (!layouts_in[i].has_value()) {
return op.emitOpError("Not implemented: Expected input layout");
}
if (!is_aligned) {
return op.emitOpError(
"Not implemented: Only aligned shapes are supported when "
"concatenation along tiling dims");
auto const &layout = *layouts_in[i];

if (layout.implicit_dim() != VectorLayout::ImplicitDim::kNone) {
return op.emitOpError("Not implemented: implicit dim");
}
}

SmallVector<xla::Array<Value>> tiles;
tiles.reserve(concatenate_op->getNumOperands());
for (Value operand : concatenate_op.getOperands()) {
// auto operand_shape =
// std::vector<int64_t>(cast<VectorType>(operand.getType()).getShape());
// operand.setType(VectorType::get(
// operand_shape,
// cast<VectorType>(operand.getType()).getElementType()));
FAILUREOR_ASSIGN_OR_RETURN(
xla::Array<Value> t,
xla::Array<Value> vreg_array,
disassemble(builder, layout, cast<TypedValue<VectorType>>(operand),
ctx.target_shape));
tiles.emplace_back(std::move(t));
operand_vregs.push_back(std::move(vreg_array));
}
const xla::Array<Value> res_tiles = concatenate(tiles, dimension);
op.replaceAllUsesWith(
assemble(builder, res_ty, layout, res_tiles, ctx.target_shape));

CHECK_EQ(operand_vregs.size(), op.getNumOperands());
SmallVector<int64_t> vreg_array_shape =
res_layout->tileArrayShape(res_ty.getShape(), ctx.target_shape);

// Fill out out_vregs with nulls, to avoid a problem with where we have to
// blend with a vreg that has not been written to yet.
xla::Array<Value> out_vregs(vreg_array_shape, nullptr);

auto boundIdxConst =
std::bind(IdxConst, std::placeholders::_1, builder, op.getLoc());

// Handle the untiled concatenation case.
if (!tiling_dim.has_value()) {
out_vregs = concatenate(operand_vregs, dimension);
} else {
if (res_layout->offsets()[tiling_dim.value()] != 0) {
return op.emitOpError("Not implemented: result non-zero offset.");
}
if (!res_layout->hasNativeTiling(ctx.target_shape)) {
return op.emitOpError("Not implemented: Non native tiling in concat.");
}

int64_t offset_at_dim = 0;
{
for (int i = 0; i < op.getNumOperands(); ++i) {
auto operand = op.getOperand(i);
auto const &layout = *layouts_in[i];
auto vty = cast<VectorType>(operand.getType());
auto shape = vty.getShape();

auto starting_point = offset_at_dim;
auto offset_amount =
starting_point % layout.tiling()[tiling_dim.value()];
if (offset_amount != layout.offsets()[tiling_dim.value()]) {
return op.emitOpError(
"Not implemented: Relayout not called, unaligned dims "
"concatenated "
"without proper offsets. Ensure that infer_vector_layout pass "
"was "
"called.");
}
offset_at_dim += shape[dimension];
}
}

// Tiled concatenation logic.
int64_t offset = 0;
for (size_t i = 0; i < operand_vregs.size(); ++i) {
auto &vreg = operand_vregs[i];
const auto &layout = layouts_in[i];
const int64_t operand_offset = *layout->offsets()[tiling_dim.value()];
if (operand_offset != 0) {
// We are offset, so we must blend with the previous vreg.
// Or, to frame it in an another way, the prior vreg
// stored its entire dim size in the offset, but only wrote the
// last dime partially.
offset -= 1;
}

const auto bitwidth = res_ty.getElementTypeBitWidth();
const int packing = res_layout->packing();

SmallVector<int64_t> out_idx;
vreg.Each([&](absl::Span<const int64_t> idx, Value *v) {
out_idx.assign(idx.begin(), idx.end());
out_idx[dimension] += offset;
if (idx[dimension] == 0 && operand_offset != 0) {
Value mask;
const VectorType vmask_ty = getNativeVregOrVmaskType(
builder.getI1Type(), bitwidth, ctx.target_shape);
if (tiling_dim.value() == 0) { // sublane
mask = builder.create<tpu::CreateMaskOp>(
op.getLoc(), vmask_ty,
ArrayRef<Value>{boundIdxConst(0), boundIdxConst(0)},
ArrayRef<Value>{boundIdxConst(operand_offset * packing),
boundIdxConst(layout->tiling()[1])});
} else { // lane
mask = builder.create<tpu::CreateMaskOp>(
op.getLoc(), vmask_ty,
ArrayRef<Value>{boundIdxConst(0), boundIdxConst(0)},
ArrayRef<Value>{boundIdxConst(layout->tiling()[0]),
boundIdxConst(operand_offset * packing)});
}
// Blend the current value with the existing value in the output.
*v = builder.create<arith::SelectOp>(op.getLoc(), mask,
out_vregs(out_idx), *v);
}
out_vregs(out_idx) = *v;
});
offset += vreg.dim(dimension);
}
}
auto assembled =
assemble(builder, res_ty, *res_layout, out_vregs, ctx.target_shape);
op.replaceAllUsesWith(assembled);
op.erase();
return success();
}
Expand Down
75 changes: 60 additions & 15 deletions jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -774,21 +774,66 @@ class VectorLayoutInferer {
}
auto res_ty = op.getResult().getType();
int8_t bitwidth = res_ty.getElementTypeBitWidth();
auto layout = getLayout(op.getSources().front());
// When concatenating vectors with replicated offsets, we want to reset the
// replicated offset to zero. Because we are not sure if the replicated
// value from each vector are same.
layout = VectorLayout(
layout->bitwidth(),
{layout->offsets()[0].value_or(0), layout->offsets()[1].value_or(0)},
layout->tiling(), layout->implicit_dim());
if (dimension >= res_rank - 2) {
layout = VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth),
ImplicitDim::kNone);
}
SmallVector<Layout> in_layouts(op->getNumOperands(), layout);
setLayout(op, in_layouts, layout);
return success();

std::optional<int64_t> tiling_dim;
if (dimension == res_ty.getRank() - 1) {
tiling_dim = 1;
} else if (dimension == res_ty.getRank() - 2) {
tiling_dim = 0;
}

if (tiling_dim.has_value()) {
int64_t starting_point = 0;

auto first_layout = getLayout(op.getSources().front());
auto op_layouts = getLayoutFromOperands(op);
SmallVector<Layout> in_layouts;
in_layouts.reserve(op.getSources().size());

auto native_tiling = nativeTiling(bitwidth);

for (int i = 0; i < op.getSources().size(); ++i) {
// Compute the offset per source.
// Ex: for a cat of (10, 128), (10, 128) on dim 0, where the
// vreg_sice for that dim is 8, the first source starts at
// offset 0, and overflows the vreg
// by 2, so the offset for the second input is 2.
auto op_shape =
cast<VectorType>(op.getSources()[i].getType()).getShape();
auto offset_amount = starting_point % native_tiling[tiling_dim.value()];
auto op_layout = op_layouts[i];
SmallVector<int64_t> in_idx{op_layout->offsets()[0].value_or(0),
op_layout->offsets()[1].value_or(0)};
in_idx[tiling_dim.value()] = offset_amount;
starting_point += op_shape[dimension];
in_layouts.push_back(VectorLayout(bitwidth, {in_idx[0], in_idx[1]},
native_tiling, ImplicitDim::kNone));
}
auto res_layout_offsets =
std::vector<int64_t>({first_layout->offsets()[0].value_or(0),
first_layout->offsets()[1].value_or(0)});
res_layout_offsets[tiling_dim.value()] = 0;
// TODO(mvoz): A tiny optimization we could do here later is to
// no-op setting tiling when sublane dim size is aligned to sublane
// tiling.
auto res_layout =
VectorLayout(bitwidth, {res_layout_offsets[0], res_layout_offsets[1]},
native_tiling, ImplicitDim::kNone);
setLayout(op, in_layouts, res_layout);
return success();
} else {
auto layout = getLayout(op.getSources().front());
// When concatenating vectors with replicated offsets, we want to reset
// the replicated offset to zero. Because we are not sure if the
// replicated value from each vector are same.
layout = VectorLayout(
layout->bitwidth(),
{layout->offsets()[0].value_or(0), layout->offsets()[1].value_or(0)},
layout->tiling(), layout->implicit_dim());
SmallVector<Layout> in_layouts(op->getNumOperands(), layout);
setLayout(op, in_layouts, layout);
return success();
}
}

LogicalResult infer(tpu::LoadOp op) {
Expand Down
2 changes: 0 additions & 2 deletions tests/pallas/tpu_pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2279,9 +2279,7 @@ def wrapper(self):
class MiscellaneousTest(PallasBaseTest):
"""Tests for reported bugs. Only pass in interpret mode unless fixed."""

@only_passes_in_interpret()
def test_float32_stack(self):
"""b/347761105"""
x = np.arange(128, dtype=jnp.float32).reshape(1, 128)
y = x + 128

Expand Down

0 comments on commit e998f9e

Please sign in to comment.