Skip to content

Commit

Permalink
[cherry-pick] [ARM] Add int64 implement for gather and `greater_tha…
Browse files Browse the repository at this point in the history
…n` (#4444)
  • Loading branch information
DannyIsFunny committed Sep 24, 2020
1 parent 038c07f commit 4b5b540
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 15 deletions.
8 changes: 4 additions & 4 deletions lite/kernels/arm/gather_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ void GatherCompute<IndexType>::Run() {

REGISTER_LITE_KERNEL(gather,
kARM,
kAny,
kFloat,
kNCHW,
paddle::lite::kernels::arm::GatherCompute<int32_t>,
def)
int32)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindInput("Index",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
Expand All @@ -85,10 +85,10 @@ REGISTER_LITE_KERNEL(gather,

REGISTER_LITE_KERNEL(gather,
kARM,
kAny,
kFloat,
kNCHW,
paddle::lite::kernels::arm::GatherCompute<int64_t>,
def_int64_idx)
int64)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindInput("Index",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
Expand Down
2 changes: 1 addition & 1 deletion lite/kernels/arm/gather_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace kernels {
namespace arm {

template <typename IndexType>
class GatherCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> {
class GatherCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
void Run() override;

Expand Down
31 changes: 31 additions & 0 deletions lite/kernels/host/compare_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,21 @@ REGISTER_LITE_KERNEL(greater_than, kHost, kFloat, kAny, greater_than_float, def)
TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)})
.Finalize();

using greater_than_int64 = paddle::lite::kernels::host::CompareCompute<
PRECISION(kInt64),
paddle::lite::kernels::host::_GreaterThanFunctor<int64_t>>;
REGISTER_LITE_KERNEL(greater_than, kHost, kInt64, kAny, greater_than_int64, def)
.BindInput("X",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kInt64), DATALAYOUT(kAny), -1)})
.BindInput("Y",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kInt64), DATALAYOUT(kAny), -1)})
.BindOutput("Out",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)})
.Finalize();

using greater_equal_float = paddle::lite::kernels::host::CompareCompute<
PRECISION(kFloat),
paddle::lite::kernels::host::_GreaterEqualFunctor<float>>;
Expand All @@ -245,3 +260,19 @@ REGISTER_LITE_KERNEL(
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)})
.Finalize();

using greater_equal_int64 = paddle::lite::kernels::host::CompareCompute<
PRECISION(kInt64),
paddle::lite::kernels::host::_GreaterEqualFunctor<int64_t>>;
REGISTER_LITE_KERNEL(
greater_equal, kHost, kInt64, kAny, greater_equal_float, def)
.BindInput("X",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kInt64), DATALAYOUT(kAny), -1)})
.BindInput("Y",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kInt64), DATALAYOUT(kAny), -1)})
.BindOutput("Out",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)})
.Finalize();
20 changes: 10 additions & 10 deletions lite/tests/kernels/gather_compute_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ class GatherComputeTest : public arena::TestCase {
out_dims[0] = batch_size;
out->Resize(out_dims);

auto x_data = x->data<float>();
auto index_data = index->data<int>();
auto out_data = out->mutable_data<float>();
auto x_data = x->data<int64_t>();
auto index_data = index->data<int64_t>();
auto out_data = out->mutable_data<int64_t>();

auto slice_num = x_dims[0];
auto slice_size = x_dims.Slice(1, x_dims.size()).production();
Expand All @@ -66,7 +66,7 @@ class GatherComputeTest : public arena::TestCase {
CHECK_GE(index, 0) << "gather ids[i] expected >= 0 but got " << index;
memcpy(out_data + i * slice_size,
x_data + index * slice_size,
slice_size * sizeof(float));
slice_size * sizeof(int64_t));
}
}

Expand All @@ -78,11 +78,11 @@ class GatherComputeTest : public arena::TestCase {
}

void PrepareData() override {
std::vector<float> x(x_dims_.production());
fill_data_rand(x.data(), -1.f, 1.f, x_dims_.production());
std::vector<int64_t> x(x_dims_.production());
fill_data_rand(x.data(), int64_t(-1), int64_t(1), x_dims_.production());

std::vector<int32_t> index(index_dims_.production());
fill_data_rand<int32_t>(
std::vector<int64_t> index(index_dims_.production());
fill_data_rand<int64_t>(
index.data(), 0, x_dims_[0] - 1, index_dims_.production());

SetCommonTensor(x_, x_dims_, x.data());
Expand Down Expand Up @@ -110,8 +110,8 @@ TEST(Gather, precision) {
for (auto x_dims :
std::vector<std::vector<int64_t>>{{5, 2, 3, 4}, {8, 3, 5}, {12, 3}}) {
for (auto index_dims : std::vector<std::vector<int64_t>>{{3}, {7}, {10}}) {
std::unique_ptr<arena::TestCase> tester(
new GatherComputeTest(place, "def", DDim(x_dims), DDim(index_dims)));
std::unique_ptr<arena::TestCase> tester(new GatherComputeTest(
place, "int64", DDim(x_dims), DDim(index_dims)));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
Expand Down

0 comments on commit 4b5b540

Please sign in to comment.