Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mypy #28

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

Mypy #28

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/dolfinx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
del _cpp, sys


def get_include(user=False):
def get_include(user: bool = False) -> str:
import os

d = os.path.dirname(__file__)
Expand Down
50 changes: 32 additions & 18 deletions python/dolfinx/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,13 @@

import functools
import typing
from types import TracebackType

from dolfinx import cpp as _cpp
from mpi4py import MPI as _MPI

from dolfinx.cpp import Reduction as _Reduction
from dolfinx.cpp import Timer as _Timer
from dolfinx.cpp import TimingTyp as _TimingType
from dolfinx.cpp.common import (
IndexMap,
git_commit_hash,
Expand All @@ -23,6 +28,8 @@
has_slepc,
ufcx_signature,
)
from dolfinx.cpp.common import list_timings as _list_timings
from dolfinx.cpp.common import timing as _timing

__all__ = [
"IndexMap",
Expand All @@ -41,20 +48,22 @@
"ufcx_signature",
]

TimingType = _cpp.common.TimingType
Reduction = _cpp.common.Reduction
TimingType = _TimingType
Reduction = _Reduction


def timing(task: str):
return _cpp.common.timing(task)
def timing(task: str) -> tuple[int, float, float, float]:
return _timing(task)


def list_timings(comm, timing_types: list, reduction=Reduction.max):
def list_timings(
comm: _MPI.Comm, timing_types: list, reduction: _Reduction = Reduction.max
) -> None:
"""Print out a summary of all Timer measurements, with a choice of
wall time, system time or user time. When used in parallel, a
reduction is applied across all processes. By default, the maximum
time is shown."""
_cpp.common.list_timings(comm, timing_types, reduction)
_list_timings(comm, timing_types, reduction)


class Timer:
Expand Down Expand Up @@ -91,37 +100,42 @@ class Timer:
list_timings(comm, [TimingType.wall, TimingType.user])
"""

_cpp_object: _cpp.common.Timer
_cpp_object: _Timer

def __init__(self, name: typing.Optional[str] = None):
def __init__(self, name: typing.Optional[str] = None) -> None:
self._cpp_object = _cpp.common.Timer(name)

def __enter__(self):
def __enter__(self) -> typing.Self:
self._cpp_object.start()
return self

def __exit__(self, *args):
def __exit__(
self,
exception_type: typing.Optional[BaseException],
exception_value: typing.Optional[BaseException],
traceback: typing.Optional[TracebackType],
) -> None:
self._cpp_object.stop()

def start(self):
def start(self) -> None:
self._cpp_object.start()

def stop(self):
def stop(self) -> float:
return self._cpp_object.stop()

def resume(self):
def resume(self) -> None:
self._cpp_object.resume()

def elapsed(self):
def elapsed(self) -> float:
return self._cpp_object.elapsed()


def timed(task: str):
def timed(task: str) -> typing.Callable[..., typing.Any]:
"""Decorator for timing functions."""

def decorator(func):
def decorator(func: typing.Callable[..., typing.Any]) -> typing.Callable[..., typing.Any]:
@functools.wraps(func)
def wrapper(*args, **kwargs):
def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
with Timer(task):
return func(*args, **kwargs)

Expand Down
20 changes: 15 additions & 5 deletions python/dolfinx/fem/bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,15 @@ class DirichletBC:
_cpp.fem.DirichletBC_float64,
]

def __init__(self, bc):
def __init__(
self,
bc: typing.Union[
_cpp.fem.DirichletBC_complex64,
_cpp.fem.DirichletBC_complex128,
_cpp.fem.DirichletBC_float32,
_cpp.fem.DirichletBC_float64,
],
):
"""Representation of Dirichlet boundary condition which is imposed on
a linear system.

Expand Down Expand Up @@ -210,8 +218,8 @@ def dirichletbc(


def bcs_by_block(
spaces: typing.Iterable[typing.Union[dolfinx.fem.FunctionSpace, None]],
bcs: typing.Iterable[DirichletBC],
spaces: list[typing.Union[dolfinx.fem.FunctionSpace, None]],
bcs: list[DirichletBC],
) -> list[list[DirichletBC]]:
"""Arrange Dirichlet boundary conditions by the function space that
they constrain.
Expand All @@ -222,8 +230,10 @@ def bcs_by_block(
``space[i]``.
"""

def _bc_space(V, bcs):
def _bc_space(
V: list[typing.Union[dolfinx.fem.FunctionSpace, None]], bcs: list[DirichletBC]
) -> list[DirichletBC]:
"""Return list of bcs that have the same space as V"""
return [bc for bc in bcs if V.contains(bc.function_space)]
return [bc for bc in bcs if bc.function_space in V]

return [_bc_space(V, bcs) if V is not None else [] for V in spaces]
17 changes: 11 additions & 6 deletions python/dolfinx/fem/dofmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
#
# SPDX-License-Identifier: LGPL-3.0-or-later

import numpy as np
import numpy.typing as npt

from dolfinx import cpp as _cpp
from dolfinx.common import IndexMap
from dolfinx.fem import ElementDofLayout


class DofMap:
Expand All @@ -19,7 +24,7 @@ class DofMap:
def __init__(self, dofmap: _cpp.fem.DofMap):
self._cpp_object = dofmap

def cell_dofs(self, cell_index: int):
def cell_dofs(self, cell_index: int) -> npt.NDArray[np.int32]:
"""Cell local-global dof map

Args:
Expand All @@ -31,26 +36,26 @@ def cell_dofs(self, cell_index: int):
return self._cpp_object.cell_dofs(cell_index)

@property
def bs(self):
def bs(self) -> int:
"""Returns the block size of the dofmap"""
return self._cpp_object.bs

@property
def dof_layout(self):
def dof_layout(self) -> ElementDofLayout:
"""Layout of dofs on an element."""
return self._cpp_object.dof_layout

@property
def index_map(self):
def index_map(self) -> IndexMap:
"""Index map that described the parallel distribution of the dofmap."""
return self._cpp_object.index_map

@property
def index_map_bs(self):
def index_map_bs(self) -> int:
"""Block size of the index map."""
return self._cpp_object.index_map_bs

@property
def list(self):
def list(self) -> npt.NDArray[np.int32]:
"""Adjacency list with dof indices for each cell."""
return self._cpp_object.map()
6 changes: 3 additions & 3 deletions python/dolfinx/fem/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ def dtype(self) -> np.dtype:
def coordinate_element(
celltype: _cpp.mesh.CellType,
degree: int,
variant=int(basix.LagrangeVariant.unset),
variant: int = int(basix.LagrangeVariant.unset),
dtype: npt.DTypeLike = np.float64,
):
) -> CoordinateElement:
"""Create a Lagrange CoordinateElement from element metadata.

Coordinate elements are typically used to create meshes.
Expand All @@ -76,7 +76,7 @@ def coordinate_element(


@coordinate_element.register(basix.finite_element.FiniteElement)
def _(e: basix.finite_element.FiniteElement):
def _(e: basix.finite_element.FiniteElement) -> CoordinateElement:
"""Create a Lagrange CoordinateElement from a Basix finite element.

Coordinate elements are typically used when creating meshes.
Expand Down
21 changes: 9 additions & 12 deletions python/dolfinx/fem/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,25 +254,22 @@ def _form(form):
constants = [c._cpp_object for c in form.constants()]

# Make map from integral_type to subdomain id
subdomain_ids = {type: [] for type in sd.get(domain).keys()}
subdomain_ids: dict[str, list[int]] = {type: [] for type in sd.get(domain).keys()}
for integral in form.integrals():
if integral.subdomain_data() is not None:
# Subdomain ids can be strings, its or tuples with strings and ints
if integral.subdomain_id() != "everywhere":
try:
ids = [sid for sid in integral.subdomain_id() if sid != "everywhere"]
subdomain_ids[integral.integral_type()] += [
sid for sid in integral.subdomain_id() if sid != "everywhere"
]
except TypeError:
# If not tuple, but single integer id
ids = [integral.subdomain_id()]
else:
ids = []
subdomain_ids[integral.integral_type()].append(ids)

# Chain and sort subdomain ids
for itg_type, marker_ids in subdomain_ids.items():
flattened_ids = list(chain.from_iterable(marker_ids))
flattened_ids.sort()
subdomain_ids[itg_type] = flattened_ids
subdomain_ids[integral.integral_type()].append(integral.subdomain_id())

# Sort subdomain ids
for val in subdomain_ids.values():
val.sort()

# Subdomain markers (possibly empty list for some integral types)
subdomains = {
Expand Down
Loading
Loading