Skip to content

Commit

Permalink
local_sec_info_t needs n_idx and d_idx
Browse files Browse the repository at this point in the history
Change-Id: I40cac7d890169ab65b6b265c19175003187d990b
  • Loading branch information
Watesoyan committed Sep 18, 2024
1 parent f0fa53c commit 98b84cf
Show file tree
Hide file tree
Showing 18 changed files with 44 additions and 32 deletions.
10 changes: 7 additions & 3 deletions include/tpu_mlir/Interfaces/LocalGenInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,16 @@ typedef struct local_sec_info {
int32_t c_idx;
int32_t c_slice;

int32_t n_idx;
int32_t d_idx;

void print() {
printf("group_type:%d, n_slice:%d, out_n_slice:%d, d_slice:%d, \
printf("group_type:%d, n_idx: %d, n_slice:%d, out_n_slice:%d, d_idx:%d, d_slice:%d, \
>>>>>>>>is_h_split:%d, h_idx:%d, h_slice:%d, out_h_idx:%d, out_h_slice:%d, \
>>>>>>>>is_w_split:%d, w_idx:%d, w_slice:%d, out_w_idx:%d, out_w_slice:%d, \
>>>>>>>>is_c_split:%d, c_idx:%d, c_slice:%d\n", group_type, n_slice, out_n_slice, \
d_slice, is_h_split, h_idx, h_slice, out_h_idx, out_h_slice, is_w_split, \
>>>>>>>>is_c_split:%d, c_idx:%d, c_slice:%d\n",
group_type, n_idx, n_slice, out_n_slice, d_idx, d_slice, \
is_h_split, h_idx, h_slice, out_h_idx, out_h_slice, is_w_split, \
w_idx, w_slice, out_w_idx, out_w_slice, is_c_split, c_idx, c_slice);
}
} local_sec_info_t;
Expand Down
2 changes: 2 additions & 0 deletions include/tpu_mlir/Interfaces/LocalGenInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def LocalGenInterface : OpInterface<"LocalGenInterface"> {
sec_info.h_slice = in_gi.h_slice;
sec_info.w_slice = in_gi.w_slice;
sec_info.c_slice = gi.c_slice;
sec_info.n_idx = in_gi.n_idx;
sec_info.d_idx = in_gi.d_idx;
sec_info.h_idx = in_gi.h_idx;
sec_info.is_h_split = !(in_gi.h_idx == 0 && in_gi.h_slice == h);
sec_info.w_idx = in_gi.w_idx;
Expand Down
27 changes: 0 additions & 27 deletions lib/Dialect/Tpu/Interfaces/BM1684X/Add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,33 +92,6 @@ void tpu::AddOp::codegen_local_bm1684x(int64_t n_step, int64_t c_step,
local_sec_info_t &sec_info) {
auto op = getOperation();
auto input_spec = BM168x::get_input_spec(op, group_type);
auto in0_gi = //¡ä¨®add¦Ì?in_hslice_offset??¨¨??y¨¨¡¤¦Ì?hslice??¨°?
LocalGenInterface::getGroupInfo(getOperand(0), n_step, h_step, d_step, w_step, c_step, op);
auto in1_gi =
LocalGenInterface::getGroupInfo(getOperand(1), n_step, h_step, d_step, w_step, c_step, op);
auto in0_type = module::getStorageType(getInputs()[0]);
auto in1_type = module::getStorageType(getInputs()[1]);
if (in0_gi.h_idx_offset > 0) {
int bytes = 4;
if (in0_type.isInteger(8)) {
bytes = 1;
} else if (in0_type.isF16()) {
bytes = 2;
}
llvm::errs() <<"add in0_gi.h_idx_offset:"<<in0_gi.h_idx_offset<<", old addr:"<<(*input_spec)[0].addr<<"\n";
(*input_spec)[0].addr += in0_gi.h_idx_offset*in0_gi.w_slice*bytes;
}
if (in1_gi.h_idx_offset > 0) {
int bytes = 4;
if (in1_type.isInteger(8)) {
bytes = 1;
} else if (in1_type.isF16()) {
bytes = 2;
}
llvm::errs() <<"add in1_gi.h_idx_offset:"<<in1_gi.h_idx_offset<<", old addr:"<<(*input_spec)[1].addr<<"\n";
(*input_spec)[1].addr += in1_gi.h_idx_offset*in1_gi.w_slice*bytes;
}

auto output_spec = BM168x::get_output_spec(op, group_type);
auto gi = getGroupInfo(n_step, h_step, d_step, w_step, c_step);
bcbinary_local_param_t param = {0};
Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/Tpu/Interfaces/Common/Add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,8 @@ void tpu::AddOp::assign_sec_info(int64_t n_step, int64_t c_step, int64_t h_step,
sec_info.d_slice = std::max(in0_gi.d_slice, in1_gi.d_slice);
sec_info.h_slice = std::max(in0_gi.h_slice, in1_gi.h_slice);
sec_info.w_slice = std::max(in0_gi.w_slice, in1_gi.w_slice);
sec_info.n_idx = std::max(in0_gi.n_idx, in1_gi.n_idx);
sec_info.d_idx = std::max(in0_gi.d_idx, in1_gi.d_idx);
sec_info.c_idx = std::max(in0_gi.c_idx, in1_gi.c_idx);
sec_info.is_c_split =
!(std::max(in0_gi.c_idx, in1_gi.c_idx) == 0 &&
Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/Tpu/Interfaces/Common/Conv2D.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,8 @@ void tpu::Conv2DOp::assign_sec_info(int64_t n_step, int64_t c_step,
sec_info.d_slice = in_gi.d_slice;
sec_info.h_slice = in_gi.h_slice;
sec_info.w_slice = in_gi.w_slice;
sec_info.n_idx = in_gi.n_idx;
sec_info.c_idx = in_gi.c_idx;
sec_info.h_idx = in_gi.h_idx;
sec_info.w_idx = in_gi.w_idx;
sec_info.is_h_split = !(in_gi.h_idx == 0 && in_gi.h_slice == attr.ih);
Expand Down
3 changes: 3 additions & 0 deletions lib/Dialect/Tpu/Interfaces/Common/Conv3D.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@ void tpu::Conv3DOp::assign_sec_info(int64_t n_step, int64_t c_step,
sec_info.d_slice = in_gi.d_slice;
sec_info.h_slice = in_gi.h_slice;
sec_info.w_slice = in_gi.w_slice;
sec_info.n_idx = in_gi.n_idx;
sec_info.c_idx = in_gi.c_idx;
sec_info.d_idx = in_gi.d_idx;
sec_info.h_idx = in_gi.h_idx;
sec_info.is_h_split = !(in_gi.h_idx == 0 && in_gi.h_slice == attr.ih);
sec_info.w_idx = in_gi.w_idx;
Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/Tpu/Interfaces/Common/Div.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ void tpu::DivOp::assign_sec_info(int64_t n_step, int64_t c_step, int64_t h_step,
sec_info.d_slice = std::max(in0_gi.d_slice, in1_gi.d_slice);
sec_info.h_slice = std::max(in0_gi.h_slice, in1_gi.h_slice);
sec_info.w_slice = std::max(in0_gi.w_slice, in1_gi.w_slice);
sec_info.n_idx = std::max(in0_gi.n_idx, in1_gi.n_idx);
sec_info.d_idx = std::max(in0_gi.d_idx, in1_gi.d_idx);
sec_info.c_idx = std::max(in0_gi.c_idx, in1_gi.c_idx);
sec_info.is_c_split =
!(std::max(in0_gi.c_idx, in1_gi.c_idx) == 0 &&
Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/Tpu/Interfaces/Common/Load.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ void tpu::LoadOp::assign_sec_info(int64_t n_step, int64_t c_step,
sec_info.d_slice = gi.d_slice;
sec_info.h_slice = gi.h_slice;
sec_info.w_slice = gi.w_slice;
sec_info.n_idx = gi.n_idx;
sec_info.d_idx = gi.d_idx;
sec_info.c_idx = gi.c_idx;
sec_info.is_c_split = !(gi.c_idx == 0 && gi.c_slice == c);
sec_info.h_idx = gi.h_idx;
Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/Tpu/Interfaces/Common/Max.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ void tpu::MaxOp::assign_sec_info(int64_t n_step, int64_t c_step, int64_t h_step,
sec_info.d_slice = std::max(in0_gi.d_slice, in1_gi.d_slice);
sec_info.h_slice = std::max(in0_gi.h_slice, in1_gi.h_slice);
sec_info.w_slice = std::max(in0_gi.w_slice, in1_gi.w_slice);
sec_info.n_idx = std::max(in0_gi.n_idx, in1_gi.n_idx);
sec_info.d_idx = std::max(in0_gi.d_idx, in1_gi.d_idx);
sec_info.c_idx = std::max(in0_gi.c_idx, in1_gi.c_idx);
sec_info.is_c_split =
!(std::max(in0_gi.c_idx, in1_gi.c_idx) == 0 &&
Expand Down
3 changes: 3 additions & 0 deletions lib/Dialect/Tpu/Interfaces/Common/MaxPoolWithMask.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ void tpu::MaxPoolWithMaskOp::assign_sec_info(int64_t n_step, int64_t c_step,
sec_info.d_slice = in_gi.d_slice;
sec_info.h_slice = in_gi.h_slice;
sec_info.w_slice = in_gi.w_slice;
sec_info.n_idx = in_gi.n_idx;
sec_info.c_idx = in_gi.c_idx;
sec_info.d_idx = in_gi.d_idx;
sec_info.h_idx = in_gi.h_idx;
sec_info.w_idx = in_gi.w_idx;
sec_info.is_h_split = !(in_gi.h_idx == 0 && in_gi.h_slice == attr.ih);
Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/Tpu/Interfaces/Common/Min.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ void tpu::MinOp::assign_sec_info(int64_t n_step, int64_t c_step, int64_t h_step,
sec_info.d_slice = std::max(in0_gi.d_slice, in1_gi.d_slice);
sec_info.h_slice = std::max(in0_gi.h_slice, in1_gi.h_slice);
sec_info.w_slice = std::max(in0_gi.w_slice, in1_gi.w_slice);
sec_info.n_idx = std::max(in0_gi.n_idx, in1_gi.n_idx);
sec_info.d_idx = std::max(in0_gi.d_idx, in1_gi.d_idx);
sec_info.c_idx = std::max(in0_gi.c_idx, in1_gi.c_idx);
sec_info.is_c_split =
!(std::max(in0_gi.c_idx, in1_gi.c_idx) == 0 &&
Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/Tpu/Interfaces/Common/Mul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ void tpu::MulOp::assign_sec_info(int64_t n_step, int64_t c_step, int64_t h_step,
sec_info.d_slice = std::max(in0_gi.d_slice, in1_gi.d_slice);
sec_info.h_slice = std::max(in0_gi.h_slice, in1_gi.h_slice);
sec_info.w_slice = std::max(in0_gi.w_slice, in1_gi.w_slice);
sec_info.n_idx = std::max(in0_gi.n_idx, in1_gi.n_idx);
sec_info.d_idx = std::max(in0_gi.d_idx, in1_gi.d_idx);
sec_info.c_idx = std::max(in0_gi.c_idx, in1_gi.c_idx);
sec_info.is_c_split =
!(std::max(in0_gi.c_idx, in1_gi.c_idx) == 0 &&
Expand Down
3 changes: 3 additions & 0 deletions lib/Dialect/Tpu/Interfaces/Common/Pool1D.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ void tpu::Pool1DOp::assign_sec_info(int64_t n_step, int64_t c_step,
sec_info.d_slice = in_gi.d_slice;
sec_info.h_slice = in_gi.h_slice;
sec_info.w_slice = in_gi.w_slice;
sec_info.n_idx = in_gi.n_idx;
sec_info.c_idx = in_gi.c_idx;
sec_info.d_idx = in_gi.d_idx;
sec_info.h_idx = in_gi.h_idx;
sec_info.w_idx = in_gi.w_idx;
sec_info.is_h_split = !(in_gi.h_idx == 0 && in_gi.h_slice == attr.ih);
Expand Down
3 changes: 3 additions & 0 deletions lib/Dialect/Tpu/Interfaces/Common/Pool2D.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,9 @@ void tpu::Pool2DOp::assign_sec_info(int64_t n_step, int64_t c_step,
sec_info.d_slice = in_gi.d_slice;
sec_info.h_slice = in_gi.h_slice;
sec_info.w_slice = in_gi.w_slice;
sec_info.n_idx = in_gi.n_idx;
sec_info.c_idx = in_gi.c_idx;
sec_info.d_idx = in_gi.d_idx;
sec_info.h_idx = in_gi.h_idx;
sec_info.w_idx = in_gi.w_idx;
sec_info.is_h_split = !(in_gi.h_idx == 0 && in_gi.h_slice == attr.ih);
Expand Down
3 changes: 3 additions & 0 deletions lib/Dialect/Tpu/Interfaces/Common/Pool3D.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ void tpu::Pool3DOp::assign_sec_info(int64_t n_step, int64_t c_step,
sec_info.d_slice = in_gi.d_slice;
sec_info.h_slice = in_gi.h_slice;
sec_info.w_slice = in_gi.w_slice;
sec_info.n_idx = in_gi.n_idx;
sec_info.c_idx = in_gi.c_idx;
sec_info.d_idx = in_gi.d_idx;
sec_info.h_idx = in_gi.h_idx;
sec_info.is_h_split = !(in_gi.h_idx == 0 && in_gi.h_slice == attr.ih);
sec_info.w_idx = in_gi.w_idx;
Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/Tpu/Interfaces/Common/Store.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ void tpu::StoreOp::assign_sec_info(int64_t n_step, int64_t c_step,
sec_info.d_slice = gi.d_slice;
sec_info.h_slice = gi.h_slice;
sec_info.w_slice = gi.w_slice;
sec_info.n_idx = gi.n_idx;
sec_info.d_idx = gi.d_idx;
sec_info.c_idx = gi.c_idx;
sec_info.is_c_split = !(gi.c_idx == 0 && gi.c_slice == c);
sec_info.h_idx = gi.h_idx;
Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/Tpu/Interfaces/Common/Sub.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@ void tpu::SubOp::assign_sec_info(int64_t n_step, int64_t c_step, int64_t h_step,
sec_info.d_slice = std::max(in0_gi.d_slice, in1_gi.d_slice);
sec_info.h_slice = std::max(in0_gi.h_slice, in1_gi.h_slice);
sec_info.w_slice = std::max(in0_gi.w_slice, in1_gi.w_slice);
sec_info.n_idx = std::max(in0_gi.n_idx, in1_gi.n_idx);
sec_info.d_idx = std::max(in0_gi.d_idx, in1_gi.d_idx);
sec_info.c_idx = std::max(in0_gi.c_idx, in1_gi.c_idx);
sec_info.is_c_split =
!(std::max(in0_gi.c_idx, in1_gi.c_idx) == 0 &&
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/Tpu/Transforms/Codegen/BM168xEvaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,8 +447,8 @@ void BM168xEvaluator::staging_results(LocalGenInterface& op, local_sec_info_t se
break;
}

const int nidx = ginfo.n_idx;
const int didx = ginfo.d_idx;
const int nidx = sec_info.n_idx;
const int didx = sec_info.d_idx;
const int cidx = sec_info.is_c_split ? sec_info.c_idx : 0;
const int hidx = sec_info.is_h_split ? sec_info.out_h_idx : 0;
const int widx = sec_info.is_w_split ? sec_info.out_w_idx : 0;
Expand Down

0 comments on commit 98b84cf

Please sign in to comment.