Skip to content

Commit

Permalink
Simplify Form constructor by adding a struct for integral data (#…
Browse files Browse the repository at this point in the history
…3045)

* Add struct

* Docs

* Add default ctor

* Docs

* Simplify

* Docs

* Docs

* Update demo

* Simplify Form

* Update docs

* Use emplace_back

* Store as vector

* Update demo

* Revert "Update demo"

This reverts commit d3b0ac2.

* Revert "Store as vector"

This reverts commit 2a707a7.

* Attept to avoid copy

* Try templating

* Tidy

* Remove sort

* Sort

* Code review suggestions

* Change template name

* Shorten docs

* Add braces

* consts

* Docs

* const

* Muck about with traits

* Update demo

* Simplifications

* Simplify

* Simplification

---------

Co-authored-by: Garth N. Wells <gnw20@cam.ac.uk>
Co-authored-by: Matthew Scroggs <matthew.w.scroggs@gmail.com>
  • Loading branch information
3 people committed Feb 13, 2024
1 parent 35e66b5 commit 4f57596
Show file tree
Hide file tree
Showing 9 changed files with 130 additions and 71 deletions.
7 changes: 2 additions & 5 deletions cpp/demo/custom_kernel/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@ template <typename T, std::size_t n0, std::size_t n1>
using mdspan2_t
= MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<T,
std::extents<std::size_t, n0, n1>>;
template <typename T>
using kernel_t = std::function<void(T*, const T*, const T*, const T*,
const int*, const uint8_t*)>;

// .. code-block:: cpp

Expand Down Expand Up @@ -77,7 +74,7 @@ double assemble_matrix0(std::shared_ptr<fem::FunctionSpace<T>> V, auto kernel,
std::span<const std::int32_t> cells)
{
// Kernel data (ID, kernel function, cell indices to execute over)
std::vector kernel_data{std::tuple{-1, kernel_t<T>(kernel), cells}};
std::vector kernel_data{fem::integral_data<T>(-1, kernel, cells)};

// Associate kernel with cells (as opposed to facets, etc)
std::map integrals{std::pair{fem::IntegralType::cell, kernel_data}};
Expand Down Expand Up @@ -107,7 +104,7 @@ double assemble_vector0(std::shared_ptr<fem::FunctionSpace<T>> V, auto kernel,
std::span<const std::int32_t> cells)
{
auto mesh = V->mesh();
std::vector kernal_data{std::tuple{-1, kernel_t<T>(kernel), cells}};
std::vector kernal_data{fem::integral_data<T>(-1, kernel, cells)};
std::map integrals{std::pair{fem::IntegralType::cell, kernal_data}};
fem::Form<T> L({V}, integrals, {}, {}, false, mesh);
auto dofmap = V->dofmap();
Expand Down
1 change: 1 addition & 0 deletions cpp/dolfinx/fem/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ set(HEADERS_fem
${CMAKE_CURRENT_SOURCE_DIR}/interpolate.h
${CMAKE_CURRENT_SOURCE_DIR}/petsc.h
${CMAKE_CURRENT_SOURCE_DIR}/sparsitybuild.h
${CMAKE_CURRENT_SOURCE_DIR}/traits.h
${CMAKE_CURRENT_SOURCE_DIR}/utils.h
PARENT_SCOPE
)
Expand Down
128 changes: 88 additions & 40 deletions cpp/dolfinx/fem/Form.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#pragma once

#include "FunctionSpace.h"
#include "traits.h"
#include <algorithm>
#include <array>
#include <concepts>
Expand Down Expand Up @@ -37,6 +38,46 @@ enum class IntegralType : std::int8_t
vertex = 3 ///< Vertex
};

/// @brief Represents integral data, containing the integral ID, the
/// kernel, and a list of entities to integrate over.
template <dolfinx::scalar T,
FEkernel<T> Kern = std::function<void(
T*, const T*, const T*, const T*, const int*, const uint8_t*)>>
struct integral_data
{
/// @brief Kernel type
using kern_t = Kern;

/// @brief Create a structure to hold integral data.
/// @tparam U `std::vector<std::int32_t>` holding entity indices.
/// @param id Domain ID.
/// @param kernel Integration kernel.
/// @param entities Entities to integrate over.
template <typename U>
integral_data(int id, kern_t kernel, U&& entities)
: id(id), kernel(kernel), entities(std::forward<U>(entities))
{
}

/// @brief Create a structure to hold integral data.
/// @param id Domain ID
/// @param kernel Integration kernel.
/// @param e Entities to integrate over.
integral_data(int id, kern_t kernel, std::span<const std::int32_t> e)
: id(id), kernel(kernel), entities(e.begin(), e.end())
{
}

/// Integral ID
int id;

/// The integration kernel
kern_t kernel;

/// The entities to integrate over
std::vector<std::int32_t> entities;
};

/// @brief A representation of finite element variational forms.
///
/// A note on the order of trial and test spaces: FEniCS numbers
Expand All @@ -60,10 +101,15 @@ enum class IntegralType : std::int8_t
/// (the variable `function_spaces` in the constructors below), the list
/// of spaces should start with space number 0 (the test space) and then
/// space number 1 (the trial space).
template <dolfinx::scalar T,
std::floating_point U = dolfinx::scalar_value_type_t<T>>
template <
dolfinx::scalar T, std::floating_point U = dolfinx::scalar_value_type_t<T>,
FEkernel<T> Kern
= std::function<void(T*, const T*, const T*, const scalar_value_type_t<T>*,
const int*, const std::uint8_t*)>>
class Form
{
using kern_t = Kern;

public:
/// Scalar type
using scalar_type = T;
Expand All @@ -74,26 +120,24 @@ class Form
/// rather using this interface directly.
///
/// @param[in] V Function spaces for the form arguments
/// @param[in] integrals Integrals in the form. The first key is the
/// domain type. For each key there is a list of tuples (domain id,
/// integration kernel, entities).
/// @param[in] integrals The integrals in the form. For each
/// integral type, there is a list of integral data
/// @param[in] coefficients
/// @param[in] constants Constants in the Form
/// @param[in] needs_facet_permutations Set to true is any of the
/// integration kernels require cell permutation data
/// @param[in] mesh Mesh of the domain. This is required when there
/// are no argument functions from which the mesh can be extracted,
/// e.g. for functionals.
///
/// @pre The integral data in integrals must be sorted by domain
template <typename X>
Form(const std::vector<std::shared_ptr<const FunctionSpace<U>>>& V,
const std::map<IntegralType,
std::vector<std::tuple<
int,
std::function<void(T*, const T*, const T*,
const scalar_value_type_t<T>*,
const int*, const std::uint8_t*)>,
std::span<const std::int32_t>>>>& integrals,
const std::vector<std::shared_ptr<const Function<T, U>>>& coefficients,
const std::vector<std::shared_ptr<const Constant<T>>>& constants,
X&& integrals,
const std::vector<std::shared_ptr<const Function<scalar_type, U>>>&
coefficients,
const std::vector<std::shared_ptr<const Constant<scalar_type>>>&
constants,
bool needs_facet_permutations,
std::shared_ptr<const mesh::Mesh<U>> mesh = nullptr)
: _function_spaces(V), _coefficients(coefficients), _constants(constants),
Expand All @@ -103,19 +147,24 @@ class Form
if (!_mesh and !V.empty())
_mesh = V[0]->mesh();
for (auto& space : V)
{
if (_mesh != space->mesh())
throw std::runtime_error("Incompatible mesh");
}
if (!_mesh)
throw std::runtime_error("No mesh could be associated with the Form.");

// Store kernels, looping over integrals by domain type (dimension)
for (auto& [type, kernels] : integrals)
for (auto&& [domain_type, data] : integrals)
{
auto& itg = _integrals[static_cast<std::size_t>(type)];
for (auto& [id, kern, e] : kernels)
itg.insert({id, {kern, std::vector(e.begin(), e.end())}});
if (!std::is_sorted(data.begin(), data.end(),
[](auto& a, auto& b) { return a.id < b.id; }))
{
throw std::runtime_error("Integral IDs not sorted");
}

std::vector<integral_data<T, kern_t>>& itg
= _integrals[static_cast<std::size_t>(domain_type)];
for (auto&& [id, kern, e] : data)
itg.emplace_back(id, kern, std::move(e));
}
}

Expand Down Expand Up @@ -150,13 +199,14 @@ class Form
/// @param[in] type Integral type
/// @param[in] i Domain identifier (index)
/// @return Function to call for tabulate_tensor
std::function<void(T*, const T*, const T*, const scalar_value_type_t<T>*,
const int*, const std::uint8_t*)>
kernel(IntegralType type, int i) const
kern_t kernel(IntegralType type, int i) const
{
auto integrals = _integrals[static_cast<std::size_t>(type)];
if (auto it = integrals.find(i); it != integrals.end())
return it->second.first;
const auto& integrals = _integrals[static_cast<std::size_t>(type)];
auto it = std::lower_bound(integrals.begin(), integrals.end(), i,
[](auto& itg_data, int i)
{ return itg_data.id < i; });
if (it != integrals.end() and it->id == i)
return it->kernel;
else
throw std::runtime_error("No kernel for requested domain index.");
}
Expand Down Expand Up @@ -192,9 +242,9 @@ class Form
std::vector<int> integral_ids(IntegralType type) const
{
std::vector<int> ids;
auto& integrals = _integrals[static_cast<std::size_t>(type)];
const auto& integrals = _integrals[static_cast<std::size_t>(type)];
std::transform(integrals.begin(), integrals.end(), std::back_inserter(ids),
[](auto& integral) { return integral.first; });
[](auto& integral) { return integral.id; });
return ids;
}

Expand All @@ -217,9 +267,12 @@ class Form
/// @return List of active cell entities for the given integral (kernel)
std::span<const std::int32_t> domain(IntegralType type, int i) const
{
auto& integral = _integrals[static_cast<std::size_t>(type)];
if (auto it = integral.find(i); it != integral.end())
return it->second.second;
const auto& integrals = _integrals[static_cast<std::size_t>(type)];
auto it = std::lower_bound(integrals.begin(), integrals.end(), i,
[](auto& itg_data, int i)
{ return itg_data.id < i; });
if (it != integrals.end() and it->id == i)
return it->entities;
else
throw std::runtime_error("No mesh entities for requested domain index.");
}
Expand Down Expand Up @@ -257,13 +310,13 @@ class Form
}

private:
using kern_t = std::function<void(T*, const T*, const T*,
const scalar_value_type_t<T>*, const int*,
const std::uint8_t*)>;

// Function spaces (one for each argument)
std::vector<std::shared_ptr<const FunctionSpace<U>>> _function_spaces;

// Integrals. Array index is
// static_cast<std::size_t(IntegralType::foo)
std::array<std::vector<integral_data<T, kern_t>>, 4> _integrals;

// Form coefficients
std::vector<std::shared_ptr<const Function<T, U>>> _coefficients;

Expand All @@ -273,11 +326,6 @@ class Form
// The mesh
std::shared_ptr<const mesh::Mesh<U>> _mesh;

// Integrals. Array index is
// static_cast<std::size_t(IntegralType::foo)
std::array<std::map<int, std::pair<kern_t, std::vector<std::int32_t>>>, 4>
_integrals;

// True if permutation data needs to be passed into these integrals
bool _needs_facet_permutations;
};
Expand Down
2 changes: 1 addition & 1 deletion cpp/dolfinx/fem/assemble_matrix_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
#include "DofMap.h"
#include "Form.h"
#include "FunctionSpace.h"
#include "traits.h"
#include "utils.h"
#include <algorithm>
#include <concepts>
#include <dolfinx/la/utils.h>
#include <dolfinx/mesh/Geometry.h>
#include <dolfinx/mesh/Mesh.h>
Expand Down
4 changes: 2 additions & 2 deletions cpp/dolfinx/fem/assembler.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
#include "assemble_matrix_impl.h"
#include "assemble_scalar_impl.h"
#include "assemble_vector_impl.h"
#include "traits.h"
#include "utils.h"
#include <concepts>
#include <cstdint>
#include <dolfinx/common/types.h>
#include <memory>
Expand All @@ -21,7 +21,7 @@ namespace dolfinx::fem
{
template <dolfinx::scalar T, std::floating_point U>
class DirichletBC;
template <dolfinx::scalar T, std::floating_point U>
template <dolfinx::scalar T, std::floating_point U, FEkernel<T> Kern>
class Form;
template <std::floating_point T>
class FunctionSpace;
Expand Down
25 changes: 25 additions & 0 deletions cpp/dolfinx/fem/traits.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (C) 2024 Joseph P. Dean and Garth N. Wells
// This file is part of DOLFINx (https://www.fenicsproject.org)
//
// SPDX-License-Identifier: LGPL-3.0-or-later

#pragma once

#include <concepts>
#include <cstdint>
#include <dolfinx/common/types.h>
#include <type_traits>

namespace dolfinx::fem
{

/// @brief Finite element cell kernel concept.
///
/// Kernel functions that can be passed to an assembler for execution
/// must satisfy this concept.
template <class U, class T>
concept FEkernel = std::is_invocable_v<U, T*, const T*, const T*,
const scalar_value_type_t<T>*,
const int*, const std::uint8_t*>;

} // namespace dolfinx::fem
21 changes: 5 additions & 16 deletions cpp/dolfinx/fem/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,6 @@ compute_integration_domains(IntegralType integral_type,
std::span<const std::int32_t> entities, int dim,
std::span<const int> values);

/// @brief Finite element cell kernel concept.
///
/// Kernel functions that can be passed to an assembler for execution
/// must satisfy this concept.
template <class U, class T>
concept FEkernel = std::is_invocable_v<U, T*, const T*, const T*,
const scalar_value_type_t<T>*,
const int*, const std::uint8_t*>;

/// @brief Extract test (0) and trial (1) function spaces pairs for each
/// bilinear form for a rectangular array of forms.
///
Expand Down Expand Up @@ -356,9 +347,7 @@ Form<T, U> create_form_factory(
using kern = std::function<void(
T*, const T*, const T*, const typename scalar_value_type<T>::value_type*,
const int*, const std::uint8_t*)>;
std::map<IntegralType,
std::vector<std::tuple<int, kern, std::span<const std::int32_t>>>>
integral_data;
std::map<IntegralType, std::vector<integral_data<T, kern>>> integrals;

bool needs_facet_permutations = false;

Expand All @@ -368,7 +357,7 @@ Form<T, U> create_form_factory(
std::span<const int> ids(ufcx_form.form_integral_ids
+ integral_offsets[cell],
num_integrals_type[cell]);
auto itg = integral_data.insert({IntegralType::cell, {}});
auto itg = integrals.insert({IntegralType::cell, {}});
auto sd = subdomains.find(IntegralType::cell);
for (int i = 0; i < num_integrals_type[cell]; ++i)
{
Expand Down Expand Up @@ -428,7 +417,7 @@ Form<T, U> create_form_factory(
std::span<const int> ids(ufcx_form.form_integral_ids
+ integral_offsets[exterior_facet],
num_integrals_type[exterior_facet]);
auto itg = integral_data.insert({IntegralType::exterior_facet, {}});
auto itg = integrals.insert({IntegralType::exterior_facet, {}});
auto sd = subdomains.find(IntegralType::exterior_facet);
for (int i = 0; i < num_integrals_type[exterior_facet]; ++i)
{
Expand Down Expand Up @@ -499,7 +488,7 @@ Form<T, U> create_form_factory(
std::span<const int> ids(ufcx_form.form_integral_ids
+ integral_offsets[interior_facet],
num_integrals_type[interior_facet]);
auto itg = integral_data.insert({IntegralType::interior_facet, {}});
auto itg = integrals.insert({IntegralType::interior_facet, {}});
auto sd = subdomains.find(IntegralType::interior_facet);
for (int i = 0; i < num_integrals_type[interior_facet]; ++i)
{
Expand Down Expand Up @@ -577,7 +566,7 @@ Form<T, U> create_form_factory(
sd.insert({itg, std::move(x)});
}

return Form<T, U>(spaces, integral_data, coefficients, constants,
return Form<T, U>(spaces, integrals, coefficients, constants,
needs_facet_permutations, mesh);
}

Expand Down
Loading

0 comments on commit 4f57596

Please sign in to comment.