From cc835bc3fca84ab8e7c7385ff7cc4f87e38418e7 Mon Sep 17 00:00:00 2001 From: Tobias Ribizel Date: Mon, 14 Aug 2023 17:17:31 +0200 Subject: [PATCH] review updates - move parameter macros to abstract_factory.hpp - use macros for defining deferred parameters Co-authored-by: Yuhsiang M. Tsai --- include/ginkgo/core/base/abstract_factory.hpp | 203 ++++++++++++++++++ include/ginkgo/core/base/lin_op.hpp | 124 ----------- .../distributed/preconditioner/schwarz.hpp | 14 +- include/ginkgo/core/solver/direct.hpp | 38 +--- include/ginkgo/core/solver/ir.hpp | 32 +-- include/ginkgo/core/solver/multigrid.hpp | 89 ++------ 6 files changed, 227 insertions(+), 273 deletions(-) diff --git a/include/ginkgo/core/base/abstract_factory.hpp b/include/ginkgo/core/base/abstract_factory.hpp index e8ec803b480..e644bcdcd76 100644 --- a/include/ginkgo/core/base/abstract_factory.hpp +++ b/include/ginkgo/core/base/abstract_factory.hpp @@ -274,13 +274,26 @@ class enable_parameters_type { }; +/** + * Represents a factory parameter of factory type that can either initialized by + * a pre-existing factory or by passing in a factory_parameters object whose + * `.on(exec)` will be called to instantiate a factory. + * + * @tparam FactoryType the type of factory that can be instantiated from this + * object. + */ template class deferred_factory_parameter { public: deferred_factory_parameter() = default; + /** Creates an empty deferred factory parameter. */ deferred_factory_parameter(std::nullptr_t) {} + /** + * Creates a deferred factory parameter from a preexisting factory with + * shared ownership. + */ template ) { return factory; }; } + /** + * Creates a deferred factory parameter by taking ownership of a + * preexisting factory with unique ownership. + */ template ) { return factory; }; } + /** + * Creates a deferred factory parameter object from a + * factory_parameters-like object. To instantiate the actual factory, the + * parameter's `.on(exec)` function will be called. + */ template ().on( std::shared_ptr{}))> @@ -315,6 +337,7 @@ class deferred_factory_parameter { }; } + /** Instantiates the deferred parameter into an actual factory. */ std::shared_ptr on( std::shared_ptr exec) const { @@ -324,6 +347,7 @@ class deferred_factory_parameter { return generator_(exec); } + /** Returns true iff the parameter contains a factory. */ explicit operator bool() const { return bool(generator_); } private: @@ -333,6 +357,185 @@ class deferred_factory_parameter { }; +/** + * Defines a build method for the factory, simplifying its construction by + * removing the repetitive typing of factory's name. + * + * @param _factory_name the factory for which to define the method + * + * @ingroup LinOp + */ +#define GKO_ENABLE_BUILD_METHOD(_factory_name) \ + static auto build()->decltype(_factory_name::create()) \ + { \ + return _factory_name::create(); \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + + +#if !(defined(__CUDACC__) || defined(__HIPCC__)) +/** + * Creates a factory parameter in the factory parameters structure. + * + * @param _name name of the parameter + * @param __VA_ARGS__ default value of the parameter + * + * @see GKO_ENABLE_LIN_OP_FACTORY for more details, and usage example + * + * @deprecated Use GKO_FACTORY_PARAMETER_SCALAR or GKO_FACTORY_PARAMETER_VECTOR + * + * @ingroup LinOp + */ +#define GKO_FACTORY_PARAMETER(_name, ...) \ + mutable _name{__VA_ARGS__}; \ + \ + template \ + auto with_##_name(Args&&... _value)->std::decay_t& \ + { \ + using type = decltype(this->_name); \ + this->_name = type{std::forward(_value)...}; \ + return *this; \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +/** + * Creates a scalar factory parameter in the factory parameters structure. + * + * Scalar in this context means that the constructor for this type only takes + * a single parameter. + * + * @param _name name of the parameter + * @param _default default value of the parameter + * + * @see GKO_ENABLE_LIN_OP_FACTORY for more details, and usage example + * + * @ingroup LinOp + */ +#define GKO_FACTORY_PARAMETER_SCALAR(_name, _default) \ + GKO_FACTORY_PARAMETER(_name, _default) + +/** + * Creates a vector factory parameter in the factory parameters structure. + * + * Vector in this context means that the constructor for this type takes + * multiple parameters. + * + * @param _name name of the parameter + * @param _default default value of the parameter + * + * @see GKO_ENABLE_LIN_OP_FACTORY for more details, and usage example + * + * @ingroup LinOp + */ +#define GKO_FACTORY_PARAMETER_VECTOR(_name, ...) \ + GKO_FACTORY_PARAMETER(_name, __VA_ARGS__) +#else // defined(__CUDACC__) || defined(__HIPCC__) +// A workaround for the NVCC compiler - parameter pack expansion does not work +// properly, because while the assignment to a scalar value is translated by +// cudafe into a C-style cast, the parameter pack expansion is not removed and +// `Args&&... args` is still kept as a parameter pack. +#define GKO_FACTORY_PARAMETER(_name, ...) \ + mutable _name{__VA_ARGS__}; \ + \ + template \ + auto with_##_name(Args&&... _value)->std::decay_t& \ + { \ + GKO_NOT_IMPLEMENTED; \ + return *this; \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +#define GKO_FACTORY_PARAMETER_SCALAR(_name, _default) \ + mutable _name{_default}; \ + \ + template \ + auto with_##_name(Arg&& _value)->std::decay_t& \ + { \ + using type = decltype(this->_name); \ + this->_name = type{std::forward(_value)}; \ + return *this; \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +#define GKO_FACTORY_PARAMETER_VECTOR(_name, ...) \ + mutable _name{__VA_ARGS__}; \ + \ + template \ + auto with_##_name(Args&&... _value)->std::decay_t& \ + { \ + using type = decltype(this->_name); \ + this->_name = type{std::forward(_value)...}; \ + return *this; \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") +#endif // defined(__CUDACC__) || defined(__HIPCC__) + +/** + * Creates a factory parameter of factory type. The parameter can either be set + * directly, or its creation can be deferred until the executor is set in the + * `.on(exec)` function call, by using a deferred_factory_parameter. + * + * @param _name name of the parameter + * @param _type pointee type of the parameter, e.g. LinOpFactory + * + */ +#define GKO_DEFERRED_FACTORY_PARAMETER(_name, _type) \ +public: \ + std::shared_ptr _name{}; \ + parameters_type& with_##_name(deferred_factory_parameter<_type> factory) \ + { \ + this->_name##_generator_ = std::move(factory); \ + return *this; \ + } \ + \ +private: \ + deferred_factory_parameter<_type> _name##_generator_; \ + \ +public: \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + +/** + * Creates a factory parameter representing a vector of factories type. The + * parameter can either be set directly, or its creation can be deferred until + * the executor is set in the + * `.on(exec)` function call, by using a vector of deferred_factory_parameters. + * + * @param _name name of the parameter + * @param _type pointee type of the vector entries, e.g. LinOpFactory + * + */ +#define GKO_DEFERRED_FACTORY_VECTOR_PARAMETER(_name, _type) \ +public: \ + std::vector> _name{}; \ + template \ + parameters_type& with_##_name(Args&&... factories) \ + { \ + this->_name##_generator_ = {deferred_factory_parameter<_type>{ \ + std::forward(factories)}...}; \ + return *this; \ + } \ + \ +private: \ + std::vector> _name##_generator_; \ + \ +public: \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + + } // namespace gko diff --git a/include/ginkgo/core/base/lin_op.hpp b/include/ginkgo/core/base/lin_op.hpp index 20d7771822f..e2660baff2e 100644 --- a/include/ginkgo/core/base/lin_op.hpp +++ b/include/ginkgo/core/base/lin_op.hpp @@ -1084,130 +1084,6 @@ public: \ "semi-colon warnings") -/** - * Defines a build method for the factory, simplifying its construction by - * removing the repetitive typing of factory's name. - * - * @param _factory_name the factory for which to define the method - * - * @ingroup LinOp - */ -#define GKO_ENABLE_BUILD_METHOD(_factory_name) \ - static auto build()->decltype(_factory_name::create()) \ - { \ - return _factory_name::create(); \ - } \ - static_assert(true, \ - "This assert is used to counter the false positive extra " \ - "semi-colon warnings") - - -#if !(defined(__CUDACC__) || defined(__HIPCC__)) -/** - * Creates a factory parameter in the factory parameters structure. - * - * @param _name name of the parameter - * @param __VA_ARGS__ default value of the parameter - * - * @see GKO_ENABLE_LIN_OP_FACTORY for more details, and usage example - * - * @deprecated Use GKO_FACTORY_PARAMETER_SCALAR or GKO_FACTORY_PARAMETER_VECTOR - * - * @ingroup LinOp - */ -#define GKO_FACTORY_PARAMETER(_name, ...) \ - mutable _name{__VA_ARGS__}; \ - \ - template \ - auto with_##_name(Args&&... _value)->std::decay_t& \ - { \ - using type = decltype(this->_name); \ - this->_name = type{std::forward(_value)...}; \ - return *this; \ - } \ - static_assert(true, \ - "This assert is used to counter the false positive extra " \ - "semi-colon warnings") - -/** - * Creates a scalar factory parameter in the factory parameters structure. - * - * Scalar in this context means that the constructor for this type only takes - * a single parameter. - * - * @param _name name of the parameter - * @param _default default value of the parameter - * - * @see GKO_ENABLE_LIN_OP_FACTORY for more details, and usage example - * - * @ingroup LinOp - */ -#define GKO_FACTORY_PARAMETER_SCALAR(_name, _default) \ - GKO_FACTORY_PARAMETER(_name, _default) - -/** - * Creates a vector factory parameter in the factory parameters structure. - * - * Vector in this context means that the constructor for this type takes - * multiple parameters. - * - * @param _name name of the parameter - * @param _default default value of the parameter - * - * @see GKO_ENABLE_LIN_OP_FACTORY for more details, and usage example - * - * @ingroup LinOp - */ -#define GKO_FACTORY_PARAMETER_VECTOR(_name, ...) \ - GKO_FACTORY_PARAMETER(_name, __VA_ARGS__) -#else // defined(__CUDACC__) || defined(__HIPCC__) -// A workaround for the NVCC compiler - parameter pack expansion does not work -// properly, because while the assignment to a scalar value is translated by -// cudafe into a C-style cast, the parameter pack expansion is not removed and -// `Args&&... args` is still kept as a parameter pack. -#define GKO_FACTORY_PARAMETER(_name, ...) \ - mutable _name{__VA_ARGS__}; \ - \ - template \ - auto with_##_name(Args&&... _value)->std::decay_t& \ - { \ - GKO_NOT_IMPLEMENTED; \ - return *this; \ - } \ - static_assert(true, \ - "This assert is used to counter the false positive extra " \ - "semi-colon warnings") - -#define GKO_FACTORY_PARAMETER_SCALAR(_name, _default) \ - mutable _name{_default}; \ - \ - template \ - auto with_##_name(Arg&& _value)->std::decay_t& \ - { \ - using type = decltype(this->_name); \ - this->_name = type{std::forward(_value)}; \ - return *this; \ - } \ - static_assert(true, \ - "This assert is used to counter the false positive extra " \ - "semi-colon warnings") - -#define GKO_FACTORY_PARAMETER_VECTOR(_name, ...) \ - mutable _name{__VA_ARGS__}; \ - \ - template \ - auto with_##_name(Args&&... _value)->std::decay_t& \ - { \ - using type = decltype(this->_name); \ - this->_name = type{std::forward(_value)...}; \ - return *this; \ - } \ - static_assert(true, \ - "This assert is used to counter the false positive extra " \ - "semi-colon warnings") -#endif // defined(__CUDACC__) || defined(__HIPCC__) - - } // namespace gko diff --git a/include/ginkgo/core/distributed/preconditioner/schwarz.hpp b/include/ginkgo/core/distributed/preconditioner/schwarz.hpp index 3347828a55d..fe0539570ee 100644 --- a/include/ginkgo/core/distributed/preconditioner/schwarz.hpp +++ b/include/ginkgo/core/distributed/preconditioner/schwarz.hpp @@ -94,25 +94,15 @@ class Schwarz /** * Local solver factory. */ - std::shared_ptr local_solver{}; - - parameters_type& with_local_solver( - deferred_factory_parameter solver) - { - this->local_solver_generator = std::move(solver); - return *this; - } + GKO_DEFERRED_FACTORY_PARAMETER(local_solver, LinOpFactory); std::unique_ptr on(std::shared_ptr exec) const { auto copy = *this; - copy.local_solver = local_solver_generator.on(exec); + copy.local_solver = local_solver_generator_.on(exec); return copy.enable_parameters_type::on( exec); } - - private: - deferred_factory_parameter local_solver_generator; }; GKO_ENABLE_LIN_OP_FACTORY(Schwarz, parameters, Factory); GKO_ENABLE_BUILD_METHOD(Factory); diff --git a/include/ginkgo/core/solver/direct.hpp b/include/ginkgo/core/solver/direct.hpp index f66546cd2ec..dcd6fd189a6 100644 --- a/include/ginkgo/core/solver/direct.hpp +++ b/include/ginkgo/core/solver/direct.hpp @@ -87,36 +87,7 @@ class Direct : public EnableLinOp>, gko::size_type GKO_FACTORY_PARAMETER_SCALAR(num_rhs, 1u); /** The factorization factory to use for generating the factors. */ - std::shared_ptr factorization; - - /** - * - */ - parameters_type& with_factorization( - std::shared_ptr factorization) - { - this->factorization_generator = - [factorization](std::shared_ptr) - -> std::shared_ptr { - return factorization; - }; - return *this; - } - - template < - typename FactorizationParameters, - typename = decltype(std::declval().on( - std::shared_ptr{}))> - parameters_type& with_factorization( - FactorizationParameters factorization_parameters) - { - this->factorization_generator = - [factorization_parameters](std::shared_ptr exec) - -> std::shared_ptr { - return factorization_parameters.on(exec); - }; - return *this; - } + GKO_DEFERRED_FACTORY_PARAMETER(factorization, LinOpFactory); /** * @@ -124,15 +95,10 @@ class Direct : public EnableLinOp>, std::unique_ptr on(std::shared_ptr exec) const { auto parameters_copy = *this; - parameters_copy.factorization = factorization_generator(exec); + parameters_copy.factorization = factorization_generator_.on(exec); return parameters_copy .enable_parameters_type::on(exec); } - - private: - std::function( - std::shared_ptr)> - factorization_generator; }; GKO_ENABLE_LIN_OP_FACTORY(Direct, parameters, Factory); GKO_ENABLE_BUILD_METHOD(Factory); diff --git a/include/ginkgo/core/solver/ir.hpp b/include/ginkgo/core/solver/ir.hpp index d30fd9d69bc..1f04c8b75d2 100644 --- a/include/ginkgo/core/solver/ir.hpp +++ b/include/ginkgo/core/solver/ir.hpp @@ -184,13 +184,14 @@ class Ir : public EnableLinOp>, /** * Inner solver factory. */ - std::shared_ptr solver{}; + GKO_DEFERRED_FACTORY_PARAMETER(solver, LinOpFactory); /** * Already generated solver. If one is provided, the factory `solver` * will be ignored. */ - std::shared_ptr generated_solver{}; + std::shared_ptr GKO_FACTORY_PARAMETER_SCALAR( + generated_solver, nullptr); /** * Relaxation factor for Richardson iteration @@ -205,41 +206,18 @@ class Ir : public EnableLinOp>, initial_guess_mode GKO_FACTORY_PARAMETER_SCALAR( default_initial_guess, initial_guess_mode::provided); - /** - * - */ - parameters_type& with_solver( - deferred_factory_parameter solver) - { - this->solver_generator = std::move(solver); - return *this; - } - - /** - * - */ - parameters_type& with_generated_solver( - std::shared_ptr generated_solver) - { - this->generated_solver = std::move(generated_solver); - return *this; - } - /** * */ std::unique_ptr on(std::shared_ptr exec) const { auto parameters_copy = *this; - if (solver_generator) { - parameters_copy.solver = solver_generator.on(exec); + if (solver_generator_) { + parameters_copy.solver = solver_generator_.on(exec); } return parameters_copy.enable_iterative_solver_factory_parameters< parameters_type, Factory>::on(exec); } - - private: - deferred_factory_parameter solver_generator; }; GKO_ENABLE_LIN_OP_FACTORY(Ir, parameters, Factory); GKO_ENABLE_BUILD_METHOD(Factory); diff --git a/include/ginkgo/core/solver/multigrid.hpp b/include/ginkgo/core/solver/multigrid.hpp index 5aab788f71f..0a0a6fdd191 100644 --- a/include/ginkgo/core/solver/multigrid.hpp +++ b/include/ginkgo/core/solver/multigrid.hpp @@ -225,16 +225,7 @@ class Multigrid : public EnableLinOp, /** * MultigridLevel Factory list */ - std::vector> mg_level{nullptr}; - - template - parameters_type& with_mg_level(Args&&... level) - { - this->mg_level_generator = { - deferred_factory_parameter{ - std::forward(level)}...}; - return *this; - } + GKO_DEFERRED_FACTORY_VECTOR_PARAMETER(mg_level, LinOpFactory); /** * Custom selector size_type (size_type level, const LinOp* fine_matrix) @@ -265,7 +256,6 @@ class Multigrid : public EnableLinOp, std::function GKO_FACTORY_PARAMETER_SCALAR(level_selector, nullptr); - using smoother_list = std::vector>; /** * Pre-smooth Factory list. * Its size must be 0, 1 or be the same as mg_level's. @@ -280,14 +270,14 @@ class Multigrid : public EnableLinOp, * If any element in the vector is a `nullptr` then the smoother * application at the corresponding level is skipped. */ - smoother_list pre_smoother{}; + GKO_DEFERRED_FACTORY_VECTOR_PARAMETER(pre_smoother, LinOpFactory); /** * Post-smooth Factory list. * It is similar to Pre-smooth Factory list. It is ignored if * the factory parameter post_uses_pre is set to true. */ - smoother_list post_smoother{}; + GKO_DEFERRED_FACTORY_VECTOR_PARAMETER(post_smoother, LinOpFactory); /** * Mid-smooth Factory list. If it contains available elements, multigrid @@ -296,34 +286,7 @@ class Multigrid : public EnableLinOp, * Pre-smooth Factory list. It is ignored if the factory parameter * mid_case is not mid. */ - smoother_list mid_smoother{}; - - template - parameters_type& with_pre_smoother(Args&&... smoother) - { - this->pre_smoother_generator = { - deferred_factory_parameter{ - std::forward(smoother)}...}; - return *this; - } - - template - parameters_type& with_post_smoother(Args&&... smoother) - { - this->post_smoother_generator = { - deferred_factory_parameter{ - std::forward(smoother)}...}; - return *this; - } - - template - parameters_type& with_mid_smoother(Args&&... smoother) - { - this->mid_smoother_generator = { - deferred_factory_parameter{ - std::forward(smoother)}...}; - return *this; - } + GKO_DEFERRED_FACTORY_VECTOR_PARAMETER(mid_smoother, LinOpFactory); /** * Whether post-smoothing-related calls use corresponding @@ -363,17 +326,7 @@ class Multigrid : public EnableLinOp, * If not set, then a direct LU solver will be used as solver on the * coarsest level. */ - std::vector> coarsest_solver{ - nullptr}; - - template - parameters_type& with_coarsest_solver(Args&&... solver) - { - this->coarsest_solver_generator = { - deferred_factory_parameter{ - std::forward(solver)}...}; - return *this; - } + GKO_DEFERRED_FACTORY_VECTOR_PARAMETER(coarsest_solver, LinOpFactory); /** * Custom coarsest_solver selector @@ -449,36 +402,36 @@ class Multigrid : public EnableLinOp, std::unique_ptr on(std::shared_ptr exec) const { auto copy = *this; - if (!copy.mg_level_generator.empty()) { + if (!copy.mg_level_generator_.empty()) { copy.mg_level.clear(); - for (auto& generator : copy.mg_level_generator) { + for (auto& generator : copy.mg_level_generator_) { copy.mg_level.push_back(generator.on(exec)); } } - if (!copy.pre_smoother_generator.empty()) { + if (!copy.pre_smoother_generator_.empty()) { copy.pre_smoother.clear(); - for (auto& generator : copy.pre_smoother_generator) { + for (auto& generator : copy.pre_smoother_generator_) { copy.pre_smoother.push_back(generator ? generator.on(exec) : nullptr); } } - if (!copy.mid_smoother_generator.empty()) { + if (!copy.mid_smoother_generator_.empty()) { copy.mid_smoother.clear(); - for (auto& generator : copy.mid_smoother_generator) { + for (auto& generator : copy.mid_smoother_generator_) { copy.mid_smoother.push_back(generator ? generator.on(exec) : nullptr); } } - if (!copy.post_smoother_generator.empty()) { + if (!copy.post_smoother_generator_.empty()) { copy.post_smoother.clear(); - for (auto& generator : copy.post_smoother_generator) { + for (auto& generator : copy.post_smoother_generator_) { copy.post_smoother.push_back(generator ? generator.on(exec) : nullptr); } } - if (!copy.coarsest_solver_generator.empty()) { + if (!copy.coarsest_solver_generator_.empty()) { copy.coarsest_solver.clear(); - for (auto& generator : copy.coarsest_solver_generator) { + for (auto& generator : copy.coarsest_solver_generator_) { copy.coarsest_solver.push_back( generator ? generator.on(exec) : nullptr); } @@ -486,18 +439,6 @@ class Multigrid : public EnableLinOp, return copy.enable_iterative_solver_factory_parameters< parameters_type, Factory>::on(exec); } - - private: - std::vector> - mg_level_generator; - std::vector> - pre_smoother_generator; - std::vector> - mid_smoother_generator; - std::vector> - post_smoother_generator; - std::vector> - coarsest_solver_generator; }; GKO_ENABLE_LIN_OP_FACTORY(Multigrid, parameters, Factory); GKO_ENABLE_BUILD_METHOD(Factory);