Skip to content

Commit

Permalink
handle deferred factory generation by registration
Browse files Browse the repository at this point in the history
Co-authored-by: Marcel Koch <marcel.koch@kit.edu>
  • Loading branch information
upsj and MarcelKoch committed Sep 25, 2023
1 parent 2d356be commit 4f8e21f
Show file tree
Hide file tree
Showing 11 changed files with 169 additions and 182 deletions.
30 changes: 16 additions & 14 deletions core/test/solver/multigrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,7 @@ TYPED_TEST(Multigrid, ApplyUsesInitialGuessReturnsFalseWhenZeroGuess)
using Solver = typename TestFixture::Solver;
auto multigrid_factory =
Solver::build()
.with_criteria(
gko::stop::Iteration::build().with_max_iters(3u))
.with_criteria(gko::stop::Iteration::build().with_max_iters(3u))
.with_max_levels(2u)
.with_coarsest_solver(this->lo_factory)
.with_pre_smoother(this->lo_factory)
Expand Down Expand Up @@ -426,25 +425,28 @@ TYPED_TEST(Multigrid, ThrowWhenNullMgLevel)
TYPED_TEST(Multigrid, ThrowWhenMgLevelContainsNullptr)
{
using Solver = typename TestFixture::Solver;
auto factory_parameters = Solver::build()
.with_max_levels(1u)
.with_min_coarse_rows(2u)
.with_criteria(this->criterion)
.with_mg_level(this->rp_factory, nullptr);
auto factory = Solver::build()
.with_max_levels(1u)
.with_min_coarse_rows(2u)
.with_criteria(this->criterion)
.with_mg_level(this->rp_factory, nullptr)
.on(this->exec);

ASSERT_THROW(factory_parameters.on(this->exec), gko::NotSupported);
ASSERT_THROW(factory->generate(this->mtx), gko::NotSupported);
}


TYPED_TEST(Multigrid, ThrowWhenEmptyMgLevelList)
{
using Solver = typename TestFixture::Solver;
auto factory = Solver::build()
.with_max_levels(1u)
.with_min_coarse_rows(2u)
.with_mg_level()
.with_criteria(this->criterion)
.on(this->exec);
auto factory =
Solver::build()
.with_max_levels(1u)
.with_min_coarse_rows(2u)
.with_mg_level(
std::vector<std::shared_ptr<const gko::LinOpFactory>>{})
.with_criteria(this->criterion)
.on(this->exec);

ASSERT_THROW(factory->generate(this->mtx), gko::NotSupported);
}
Expand Down
86 changes: 77 additions & 9 deletions include/ginkgo/core/base/abstract_factory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#define GKO_PUBLIC_CORE_BASE_ABSTRACT_FACTORY_HPP_


#include <unordered_map>


#include <ginkgo/core/base/polymorphic_object.hpp>


Expand Down Expand Up @@ -257,7 +260,11 @@ class enable_parameters_type {
*/
std::unique_ptr<Factory> on(std::shared_ptr<const Executor> exec) const
{
auto factory = std::unique_ptr<Factory>(new Factory(exec, *self()));
ConcreteParametersType copy = *self();
for (const auto& item : deferred_factories) {
item.second(exec, copy);
}
auto factory = std::unique_ptr<Factory>(new Factory(exec, copy));
for (auto& logger : loggers) {
factory->add_logger(logger);
};
Expand All @@ -271,9 +278,35 @@ class enable_parameters_type {
* Loggers to be attached to the factory and generated object.
*/
std::vector<std::shared_ptr<const log::Logger>> loggers{};

std::unordered_map<std::string,
std::function<void(std::shared_ptr<const Executor> exec,
ConcreteParametersType&)>>
deferred_factories;
};


/**
* This Macro will generate a new type containing the parameters for the factory
* `_factory_name`. For more details, see #GKO_ENABLE_LIN_OP_FACTORY().
* It is required to use this macro **before** calling the
* macro #GKO_ENABLE_LIN_OP_FACTORY().
* It is also required to use the same names for all parameters between both
* macros.
*
* @param _parameters_name name of the parameters member in the class
* @param _factory_name name of the generated factory type
*
* @ingroup LinOp
*/
#define GKO_CREATE_FACTORY_PARAMETERS(_parameters_name, _factory_name) \
public: \
class _factory_name; \
struct _parameters_name##_type \
: public ::gko::enable_parameters_type<_parameters_name##_type, \
_factory_name>


/**
* Represents a factory parameter of factory type that can either initialized by
* a pre-existing factory or by passing in a factory_parameters object whose
Expand All @@ -288,7 +321,7 @@ class deferred_factory_parameter {
deferred_factory_parameter() = default;

/** Creates an empty deferred factory parameter. */
explicit deferred_factory_parameter(std::nullptr_t)
deferred_factory_parameter(std::nullptr_t)
{
generator_ = [](std::shared_ptr<const Executor>) { return nullptr; };
}
Expand All @@ -301,8 +334,7 @@ class deferred_factory_parameter {
std::enable_if_t<std::is_base_of<
FactoryType,
std::remove_const_t<ConcreteFactoryType>>::value>* = nullptr>
explicit deferred_factory_parameter(
std::shared_ptr<ConcreteFactoryType> factory)
deferred_factory_parameter(std::shared_ptr<ConcreteFactoryType> factory)
{
generator_ =
[factory = std::shared_ptr<const FactoryType>(std::move(factory))](
Expand All @@ -317,7 +349,7 @@ class deferred_factory_parameter {
std::enable_if_t<std::is_base_of<
FactoryType,
std::remove_const_t<ConcreteFactoryType>>::value>* = nullptr>
explicit deferred_factory_parameter(
deferred_factory_parameter(
std::unique_ptr<ConcreteFactoryType, Deleter> factory)
{
generator_ =
Expand All @@ -333,7 +365,7 @@ class deferred_factory_parameter {
template <typename ParametersType,
typename = decltype(std::declval<ParametersType>().on(
std::shared_ptr<const Executor>{}))>
explicit deferred_factory_parameter(ParametersType parameters)
deferred_factory_parameter(ParametersType parameters)
{
generator_ = [parameters](std::shared_ptr<const Executor> exec)
-> std::shared_ptr<const FactoryType> {
Expand All @@ -351,8 +383,8 @@ class deferred_factory_parameter {
return generator_(exec);
}

/** Returns true iff the parameter contains a factory. */
bool is_empty() const { return bool(generator_); }
/** Returns true iff the parameter is empty. */
bool is_empty() const { return !bool(generator_); }

private:
std::function<std::shared_ptr<const FactoryType>(
Expand Down Expand Up @@ -499,6 +531,12 @@ public: \
parameters_type& with_##_name(deferred_factory_parameter<_type> factory) \
{ \
this->_name##_generator_ = std::move(factory); \
this->deferred_factories[#_name] = [](const auto& exec, \
auto& params) { \
if (!params._name##_generator_.is_empty()) { \
params._name = params._name##_generator_.on(exec); \
} \
}; \
return *this; \
} \
\
Expand All @@ -523,11 +561,41 @@ public: \
#define GKO_DEFERRED_FACTORY_VECTOR_PARAMETER(_name, _type) \
public: \
std::vector<std::shared_ptr<const _type>> _name{}; \
template <typename... Args> \
template <typename... Args, \
typename = \
std::enable_if_t<xstd::conjunction<std::is_convertible< \
Args, deferred_factory_parameter<_type>>...>::value>> \
parameters_type& with_##_name(Args&&... factories) \
{ \
this->_name##_generator_ = {deferred_factory_parameter<_type>{ \
std::forward<Args>(factories)}...}; \
this->deferred_factories[#_name] = [](const auto& exec, \
auto& params) { \
if (!params._name##_generator_.empty()) { \
params._name.clear(); \
for (auto& generator : params._name##_generator_) { \
params._name.push_back(generator.on(exec)); \
} \
} \
}; \
return *this; \
} \
template <typename FactoryType> \
parameters_type& with_##_name(const std::vector<FactoryType>& factories) \
{ \
this->_name##_generator_.clear(); \
for (const auto& factory : factories) { \
this->_name##_generator_.push_back(factory); \
} \
this->deferred_factories[#_name] = [](const auto& exec, \
auto& params) { \
if (!params._name##_generator_.empty()) { \
params._name.clear(); \
for (auto& generator : params._name##_generator_) { \
params._name.push_back(generator.on(exec)); \
} \
} \
}; \
return *this; \
} \
\
Expand Down
20 changes: 0 additions & 20 deletions include/ginkgo/core/base/lin_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -949,26 +949,6 @@ using EnableDefaultLinOpFactory =
EnableDefaultFactory<ConcreteFactory, ConcreteLinOp, ParametersType,
PolymorphicBase>;

/**
* This Macro will generate a new type containing the parameters for the factory
* `_factory_name`. For more details, see #GKO_ENABLE_LIN_OP_FACTORY().
* It is required to use this macro **before** calling the
* macro #GKO_ENABLE_LIN_OP_FACTORY().
* It is also required to use the same names for all parameters between both
* macros.
*
* @param _parameters_name name of the parameters member in the class
* @param _factory_name name of the generated factory type
*
* @ingroup LinOp
*/
#define GKO_CREATE_FACTORY_PARAMETERS(_parameters_name, _factory_name) \
public: \
class _factory_name; \
struct _parameters_name##_type \
: public ::gko::enable_parameters_type<_parameters_name##_type, \
_factory_name>


/**
* This macro will generate a default implementation of a LinOpFactory for the
Expand Down
10 changes: 10 additions & 0 deletions include/ginkgo/core/base/std_extensions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,16 @@ constexpr bool less_equal(const T&& lhs, const T&& rhs)
}


// available in <type_traits> with C++17
template <class...>
struct conjunction : std::true_type {};
template <class B1>
struct conjunction<B1> : B1 {};
template <class B1, class... Bn>
struct conjunction<B1, Bn...>
: std::conditional_t<bool(B1::value), conjunction<Bn...>, B1> {};


} // namespace xstd
} // namespace gko

Expand Down
8 changes: 0 additions & 8 deletions include/ginkgo/core/distributed/preconditioner/schwarz.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,6 @@ class Schwarz
* Local solver factory.
*/
GKO_DEFERRED_FACTORY_PARAMETER(local_solver, LinOpFactory);

std::unique_ptr<Factory> on(std::shared_ptr<const Executor> exec) const
{
auto copy = *this;
copy.local_solver = local_solver_generator_.on(exec);
return copy.enable_parameters_type<parameters_type, Factory>::on(
exec);
}
};
GKO_ENABLE_LIN_OP_FACTORY(Schwarz, parameters, Factory);
GKO_ENABLE_BUILD_METHOD(Factory);
Expand Down
31 changes: 14 additions & 17 deletions include/ginkgo/core/preconditioner/ic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,13 @@ class Ic : public EnableLinOp<Ic<LSolverType, IndexType>>, public Transposable {
deferred_factory_parameter<typename l_solver_type::Factory> solver)
{
this->l_solver_generator = std::move(solver);
this->deferred_factories["l_solver"] = [](const auto& exec,
auto& params) {
if (!params.l_solver_generator.is_empty()) {
params.l_solver_factory =
params.l_solver_generator.on(exec);
}
};
return *this;
}

Expand All @@ -159,26 +166,16 @@ class Ic : public EnableLinOp<Ic<LSolverType, IndexType>>, public Transposable {
deferred_factory_parameter<LinOpFactory> factorization)
{
this->factorization_generator = std::move(factorization);
this->deferred_factories["factorization"] = [](const auto& exec,
auto& params) {
if (!params.factorization_generator.is_empty()) {
params.factorization_factory =
params.factorization_generator.on(exec);
}
};
return *this;
}

/**
*
*/
std::unique_ptr<Factory> on(std::shared_ptr<const Executor> exec) const
{
auto parameters_copy = *this;
if (l_solver_generator) {
parameters_copy.l_solver_factory = l_solver_generator.on(exec);
}
if (factorization_generator) {
parameters_copy.factorization_factory =
factorization_generator.on(exec);
}
return parameters_copy
.enable_parameters_type<parameters_type, Factory>::on(exec);
}

private:
deferred_factory_parameter<typename l_solver_type::Factory>
l_solver_generator;
Expand Down
41 changes: 21 additions & 20 deletions include/ginkgo/core/preconditioner/ilu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,13 @@ class Ilu : public EnableLinOp<
deferred_factory_parameter<typename l_solver_type::Factory> solver)
{
this->l_solver_generator = std::move(solver);
this->deferred_factories["l_solver"] = [](const auto& exec,
auto& params) {
if (!params.l_solver_generator.is_empty()) {
params.l_solver_factory =
params.l_solver_generator.on(exec);
}
};
return *this;
}

Expand All @@ -177,6 +184,13 @@ class Ilu : public EnableLinOp<
deferred_factory_parameter<typename u_solver_type::Factory> solver)
{
this->u_solver_generator = std::move(solver);
this->deferred_factories["u_solver"] = [](const auto& exec,
auto& params) {
if (!params.u_solver_generator.is_empty()) {
params.u_solver_factory =
params.u_solver_generator.on(exec);
}
};
return *this;
}

Expand All @@ -191,29 +205,16 @@ class Ilu : public EnableLinOp<
deferred_factory_parameter<LinOpFactory> factorization)
{
this->factorization_generator = std::move(factorization);
this->deferred_factories["factorization"] = [](const auto& exec,
auto& params) {
if (!params.factorization_generator.is_empty()) {
params.factorization_factory =
params.factorization_generator.on(exec);
}
};
return *this;
}

/**
*
*/
std::unique_ptr<Factory> on(std::shared_ptr<const Executor> exec) const
{
auto parameters_copy = *this;
if (l_solver_generator) {
parameters_copy.l_solver_factory = l_solver_generator.on(exec);
}
if (u_solver_generator) {
parameters_copy.u_solver_factory = u_solver_generator.on(exec);
}
if (factorization_generator) {
parameters_copy.factorization_factory =
factorization_generator.on(exec);
}
return parameters_copy
.enable_parameters_type<parameters_type, Factory>::on(exec);
}

private:
deferred_factory_parameter<typename l_solver_type::Factory>
l_solver_generator;
Expand Down
Loading

0 comments on commit 4f8e21f

Please sign in to comment.