Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add zip_iterator implementation #966

Merged
merged 7 commits into from
Feb 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
635 changes: 307 additions & 328 deletions core/base/iterator_factory.hpp

Large diffs are not rendered by default.

181 changes: 128 additions & 53 deletions core/test/base/iterator_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


#include <algorithm>
#include <complex>
#include <numeric>
#include <vector>


Expand All @@ -43,6 +45,24 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "core/test/utils.hpp"


namespace std {


// add a comparison for std::complex to allow comparisons without custom
// comparator
// HERE BE DRAGONS! This is technically UB, since we are adding things to
// namespace std, but since it is only inside a test, it should be fine :)
template <typename ValueType>
bool operator<(const std::complex<ValueType>& a,
const std::complex<ValueType>& b)
upsj marked this conversation as resolved.
Show resolved Hide resolved
{
return a.real() < b.real();
}


} // namespace std


namespace {


Expand All @@ -62,13 +82,6 @@ class IteratorFactory : public ::testing::Test {
9., 10., 11., 12., 13., 14., 15.}
{}

template <typename T1, typename T2>
void check_vector_equal(const std::vector<T1>& v1,
const std::vector<T2>& v2)
{
ASSERT_TRUE(std::equal(v1.begin(), v1.end(), v2.begin()));
}

// Require that Iterator has a `value_type` specified
template <typename Iterator, typename = typename Iterator::value_type>
bool is_sorted_iterator(Iterator begin, Iterator end)
Expand Down Expand Up @@ -104,11 +117,11 @@ TYPED_TEST(IteratorFactory, EmptyIterator)
{
using index_type = typename TestFixture::index_type;
using value_type = typename TestFixture::value_type;
auto test_iter = gko::detail::IteratorFactory<index_type, value_type>(
nullptr, nullptr, 0);

ASSERT_TRUE(test_iter.begin() == test_iter.end());
ASSERT_NO_THROW(std::sort(test_iter.begin(), test_iter.end()));
auto test_iter = gko::detail::make_zip_iterator<index_type*, value_type*>(
nullptr, nullptr);

ASSERT_NO_THROW(std::sort(test_iter, test_iter));
}


Expand All @@ -119,12 +132,11 @@ TYPED_TEST(IteratorFactory, SortingReversedWithIterator)
std::vector<index_type> vec1{this->reversed_index};
std::vector<value_type> vec2{this->ordered_value};

auto test_iter = gko::detail::IteratorFactory<index_type, value_type>(
vec1.data(), vec2.data(), vec1.size());
std::sort(test_iter.begin(), test_iter.end());
auto test_iter = gko::detail::make_zip_iterator(vec1.data(), vec2.data());
std::sort(test_iter, test_iter + vec1.size());

this->check_vector_equal(vec1, this->ordered_index);
this->check_vector_equal(vec2, this->reversed_value);
ASSERT_EQ(vec1, this->ordered_index);
ASSERT_EQ(vec2, this->reversed_value);
}


Expand All @@ -135,12 +147,11 @@ TYPED_TEST(IteratorFactory, SortingAlreadySortedWithIterator)
std::vector<index_type> vec1{this->ordered_index};
std::vector<value_type> vec2{this->ordered_value};

auto test_iter = gko::detail::IteratorFactory<index_type, value_type>(
vec1.data(), vec2.data(), vec1.size());
std::sort(test_iter.begin(), test_iter.end());
auto test_iter = gko::detail::make_zip_iterator(vec1.data(), vec2.data());
std::sort(test_iter, test_iter + vec1.size());

this->check_vector_equal(vec1, this->ordered_index);
this->check_vector_equal(vec2, this->ordered_value);
ASSERT_EQ(vec1, this->ordered_index);
ASSERT_EQ(vec2, this->ordered_value);
}


Expand All @@ -151,10 +162,9 @@ TYPED_TEST(IteratorFactory, IteratorReferenceOperatorSmaller)
std::vector<index_type> vec1{this->reversed_index};
std::vector<value_type> vec2{this->ordered_value};

auto test_iter = gko::detail::IteratorFactory<index_type, value_type>(
vec1.data(), vec2.data(), vec1.size());
auto test_iter = gko::detail::make_zip_iterator(vec1.data(), vec2.data());
bool is_sorted =
this->is_sorted_iterator(test_iter.begin(), test_iter.end());
this->is_sorted_iterator(test_iter, test_iter + vec1.size());

ASSERT_FALSE(is_sorted);
}
Expand All @@ -167,10 +177,9 @@ TYPED_TEST(IteratorFactory, IteratorReferenceOperatorSmaller2)
std::vector<index_type> vec1{this->ordered_index};
std::vector<value_type> vec2{this->ordered_value};

auto test_iter = gko::detail::IteratorFactory<index_type, value_type>(
vec1.data(), vec2.data(), vec1.size());
auto test_iter = gko::detail::make_zip_iterator(vec1.data(), vec2.data());
bool is_sorted =
this->is_sorted_iterator(test_iter.begin(), test_iter.end());
this->is_sorted_iterator(test_iter, test_iter + vec1.size());

ASSERT_TRUE(is_sorted);
}
Expand All @@ -183,10 +192,10 @@ TYPED_TEST(IteratorFactory, IncreasingIterator)
std::vector<index_type> vec1{this->reversed_index};
std::vector<value_type> vec2{this->ordered_value};

auto test_iter = gko::detail::IteratorFactory<index_type, value_type>(
vec1.data(), vec2.data(), vec1.size());
auto begin = test_iter.begin();
auto test_iter = gko::detail::make_zip_iterator(vec1.data(), vec2.data());
auto begin = test_iter;
auto plus_2 = begin + 2;
auto plus_2_rev = 2 + begin;
auto plus_minus_2 = plus_2 - 2;
auto increment_pre_2 = begin;
++increment_pre_2;
Expand All @@ -197,26 +206,95 @@ TYPED_TEST(IteratorFactory, IncreasingIterator)
auto increment_pre_test = begin;
auto increment_post_test = begin;

// check results for equality
ASSERT_TRUE(begin == plus_minus_2);
ASSERT_TRUE(plus_2 == increment_pre_2);
ASSERT_TRUE(plus_2_rev == increment_pre_2);
ASSERT_TRUE(increment_pre_2 == increment_post_2);
ASSERT_TRUE(begin == increment_post_test++);
ASSERT_TRUE(begin + 1 == ++increment_pre_test);
ASSERT_TRUE((*plus_2).dominant() == vec1[2]);
ASSERT_TRUE((*plus_2).secondary() == vec2[2]);
ASSERT_TRUE(std::get<0>(*plus_2) == vec1[2]);
ASSERT_TRUE(std::get<1>(*plus_2) == vec2[2]);
// check other comparison operators and difference
std::vector<gko::detail::zip_iterator<index_type*, value_type*>> its{
begin,
plus_2,
plus_2_rev,
plus_minus_2,
increment_pre_2,
increment_post_2,
increment_pre_test,
increment_post_test,
begin + 5,
begin + 9};
std::sort(its.begin(), its.end());
std::vector<int> dists;
std::vector<int> ref_dists{0, 1, 0, 1, 0, 0, 0, 3, 4};
for (int i = 0; i < its.size() - 1; i++) {
SCOPED_TRACE(i);
dists.push_back(its[i + 1] - its[i]);
auto equal = dists.back() > 0;
ASSERT_EQ(its[i + 1] > its[i], equal);
ASSERT_EQ(its[i] < its[i + 1], equal);
ASSERT_EQ(its[i] != its[i + 1], equal);
ASSERT_EQ(its[i] == its[i + 1], !equal);
ASSERT_EQ(its[i] >= its[i + 1], !equal);
ASSERT_EQ(its[i + 1] <= its[i], !equal);
ASSERT_TRUE(its[i + 1] >= its[i]);
ASSERT_TRUE(its[i] <= its[i + 1]);
}
ASSERT_EQ(dists, ref_dists);
}


#ifndef NDEBUG


bool check_assertion_exit_code(int exit_code)
{
#ifdef _MSC_VER
// MSVC picks up the exit code incorrectly,
// so we can only check that it exits
return true;
#else
return exit_code != 0;
#endif
}


TYPED_TEST(IteratorFactory, IncompatibleIteratorDeathTest)
{
using index_type = typename TestFixture::index_type;
using value_type = typename TestFixture::value_type;
auto it1 = gko::detail::make_zip_iterator(this->ordered_index.data(),
this->ordered_value.data());
auto it2 = gko::detail::make_zip_iterator(this->ordered_index.data() + 1,
this->ordered_value.data());

// a set of operations that return inconsistent results for the two
// different iterators
EXPECT_EXIT(it2 - it1, check_assertion_exit_code, "");
EXPECT_EXIT(it2 == it1, check_assertion_exit_code, "");
EXPECT_EXIT(it2 != it1, check_assertion_exit_code, "");
EXPECT_EXIT(it1 < it2, check_assertion_exit_code, "");
EXPECT_EXIT(it2 <= it1, check_assertion_exit_code, "");
EXPECT_EXIT(it2 > it1, check_assertion_exit_code, "");
EXPECT_EXIT(it1 >= it2, check_assertion_exit_code, "");
}


#endif


TYPED_TEST(IteratorFactory, DecreasingIterator)
{
using index_type = typename TestFixture::index_type;
using value_type = typename TestFixture::value_type;
std::vector<index_type> vec1{this->reversed_index};
std::vector<value_type> vec2{this->ordered_value};

auto test_iter = gko::detail::IteratorFactory<index_type, value_type>(
vec1.data(), vec2.data(), vec1.size());
auto iter = test_iter.begin() + 5;
auto test_iter = gko::detail::make_zip_iterator(vec1.data(), vec2.data());
auto iter = test_iter + 5;
auto minus_2 = iter - 2;
auto minus_plus_2 = minus_2 + 2;
auto decrement_pre_2 = iter;
Expand All @@ -233,8 +311,8 @@ TYPED_TEST(IteratorFactory, DecreasingIterator)
ASSERT_TRUE(decrement_pre_2 == decrement_post_2);
ASSERT_TRUE(iter == decrement_post_test--);
ASSERT_TRUE(iter - 1 == --decrement_pre_test);
ASSERT_TRUE((*minus_2).dominant() == vec1[3]);
ASSERT_TRUE((*minus_2).secondary() == vec2[3]);
ASSERT_TRUE(std::get<0>(*minus_2) == vec1[3]);
ASSERT_TRUE(std::get<1>(*minus_2) == vec2[3]);
}


Expand All @@ -246,17 +324,16 @@ TYPED_TEST(IteratorFactory, CorrectDereferencing)
std::vector<value_type_it> vec2{this->ordered_value};
constexpr int element_to_test = 3;

auto test_iter = gko::detail::IteratorFactory<index_type_it, value_type_it>(
vec1.data(), vec2.data(), vec1.size());
auto begin = test_iter.begin();
auto test_iter = gko::detail::make_zip_iterator(vec1.data(), vec2.data());
auto begin = test_iter;
using value_type = typename decltype(begin)::value_type;
auto to_test_ref = *(begin + element_to_test);
value_type to_test_pair = to_test_ref; // Testing implicit conversion

ASSERT_TRUE(to_test_pair.dominant == vec1[element_to_test]);
ASSERT_TRUE(to_test_pair.dominant == to_test_ref.dominant());
ASSERT_TRUE(to_test_pair.secondary == vec2[element_to_test]);
ASSERT_TRUE(to_test_pair.secondary == to_test_ref.secondary());
ASSERT_TRUE(std::get<0>(to_test_pair) == vec1[element_to_test]);
ASSERT_TRUE(std::get<0>(to_test_pair) == std::get<0>(to_test_ref));
ASSERT_TRUE(std::get<1>(to_test_pair) == vec2[element_to_test]);
ASSERT_TRUE(std::get<1>(to_test_pair) == std::get<1>(to_test_ref));
}


Expand All @@ -267,10 +344,9 @@ TYPED_TEST(IteratorFactory, CorrectSwapping)
std::vector<index_type> vec1{this->reversed_index};
std::vector<value_type> vec2{this->ordered_value};

auto test_iter = gko::detail::IteratorFactory<index_type, value_type>(
vec1.data(), vec2.data(), vec1.size());
auto first_el_reference = *test_iter.begin();
auto second_el_reference = *(test_iter.begin() + 1);
auto test_iter = gko::detail::make_zip_iterator(vec1.data(), vec2.data());
auto first_el_reference = *test_iter;
auto second_el_reference = *(test_iter + 1);
swap(first_el_reference, second_el_reference);

ASSERT_TRUE(vec1[0] == this->reversed_index[1]);
Expand All @@ -292,11 +368,10 @@ TYPED_TEST(IteratorFactory, CorrectHandWrittenSwapping)
std::vector<index_type> vec1{this->reversed_index};
std::vector<value_type> vec2{this->ordered_value};

auto test_iter = gko::detail::IteratorFactory<index_type, value_type>(
vec1.data(), vec2.data(), vec1.size());
auto first_el_reference = *test_iter.begin();
auto second_el_reference = *(test_iter.begin() + 1);
auto temp = static_cast<typename decltype(test_iter.begin())::value_type>(
auto test_iter = gko::detail::make_zip_iterator(vec1.data(), vec2.data());
auto first_el_reference = *test_iter;
auto second_el_reference = *(test_iter + 1);
auto temp = static_cast<typename decltype(test_iter)::value_type>(
first_el_reference);
first_el_reference = second_el_reference;
second_el_reference = temp;
Expand Down
5 changes: 2 additions & 3 deletions core/test/utils/unsort_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,8 @@ void unsort_matrix(matrix::Csr<ValueType, IndexType>* mtx,
for (index_type row = 0; row < size[0]; ++row) {
auto start = row_ptrs[row];
auto end = row_ptrs[row + 1];
auto sort_wrapper = gko::detail::IteratorFactory<IndexType, ValueType>(
cols + start, vals + start, end - start);
std::shuffle(sort_wrapper.begin(), sort_wrapper.end(), engine);
auto it = gko::detail::make_zip_iterator(cols + start, vals + start);
std::shuffle(it, it + (end - start), engine);
}
}

Expand Down
8 changes: 5 additions & 3 deletions omp/matrix/csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -903,9 +903,11 @@ void sort_by_column_index(std::shared_ptr<const OmpExecutor> exec,
for (size_type i = 0; i < number_rows; ++i) {
auto start_row_idx = row_ptrs[i];
auto row_nnz = row_ptrs[i + 1] - start_row_idx;
auto helper = detail::IteratorFactory<IndexType, ValueType>(
col_idxs + start_row_idx, values + start_row_idx, row_nnz);
std::sort(helper.begin(), helper.end());
auto it = detail::make_zip_iterator(col_idxs + start_row_idx,
values + start_row_idx);
std::sort(it, it + row_nnz, [](auto t1, auto t2) {
return std::get<0>(t1) < std::get<0>(t2);
});
}
}

Expand Down
7 changes: 4 additions & 3 deletions omp/matrix/fbcsr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -391,9 +391,10 @@ void sort_by_column_index_impl(

std::vector<IndexType> col_permute(nbnz_brow);
std::iota(col_permute.begin(), col_permute.end(), 0);
auto helper = detail::IteratorFactory<IndexType, IndexType>(
brow_col_idxs, col_permute.data(), nbnz_brow);
std::sort(helper.begin(), helper.end());
auto it = detail::make_zip_iterator(brow_col_idxs, col_permute.data());
std::sort(it, it + nbnz_brow, [](auto a, auto b) {
return std::get<0>(a) < std::get<0>(b);
});

std::vector<ValueType> oldvalues(nbnz_brow * bs2);
std::copy(brow_vals, brow_vals + nbnz_brow * bs2, oldvalues.begin());
Expand Down
1 change: 0 additions & 1 deletion omp/matrix/sparsity_csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <ginkgo/core/matrix/dense.hpp>


#include "core/base/iterator_factory.hpp"
#include "core/components/fill_array_kernels.hpp"
#include "core/components/prefix_sum_kernels.hpp"

Expand Down
8 changes: 5 additions & 3 deletions reference/matrix/csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -874,9 +874,11 @@ void sort_by_column_index(std::shared_ptr<const ReferenceExecutor> exec,
for (size_type i = 0; i < number_rows; ++i) {
auto start_row_idx = row_ptrs[i];
auto row_nnz = row_ptrs[i + 1] - start_row_idx;
auto helper = detail::IteratorFactory<IndexType, ValueType>(
col_idxs + start_row_idx, values + start_row_idx, row_nnz);
std::sort(helper.begin(), helper.end());
auto it = detail::make_zip_iterator(col_idxs + start_row_idx,
values + start_row_idx);
std::sort(it, it + row_nnz, [](auto t1, auto t2) {
return std::get<0>(t1) < std::get<0>(t2);
});
}
}

Expand Down
7 changes: 4 additions & 3 deletions reference/matrix/fbcsr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -445,9 +445,10 @@ void sort_by_column_index_impl(

std::vector<IndexType> col_permute(nbnz_brow);
std::iota(col_permute.begin(), col_permute.end(), 0);
auto helper = detail::IteratorFactory<IndexType, IndexType>(
brow_col_idxs, col_permute.data(), nbnz_brow);
std::sort(helper.begin(), helper.end());
auto it = detail::make_zip_iterator(brow_col_idxs, col_permute.data());
std::sort(it, it + nbnz_brow, [](auto a, auto b) {
return std::get<0>(a) < std::get<0>(b);
});

std::vector<ValueType> oldvalues(nbnz_brow * bs2);
std::copy(brow_vals, brow_vals + nbnz_brow * bs2, oldvalues.begin());
Expand Down
1 change: 0 additions & 1 deletion reference/matrix/sparsity_csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <ginkgo/core/matrix/dense.hpp>


#include "core/base/iterator_factory.hpp"
#include "core/components/fill_array_kernels.hpp"
#include "core/components/prefix_sum_kernels.hpp"

Expand Down