diff --git a/benchmarks/2d/uniaxial_nodal_forces/test_benchmark.py b/benchmarks/2d/uniaxial_nodal_forces/test_benchmark.py index 248034a..ae72923 100644 --- a/benchmarks/2d/uniaxial_nodal_forces/test_benchmark.py +++ b/benchmarks/2d/uniaxial_nodal_forces/test_benchmark.py @@ -1,6 +1,8 @@ import os from pathlib import Path + import jax.numpy as jnp + from diffmpm import MPM diff --git a/benchmarks/2d/uniaxial_particle_traction/test_benchmark.py b/benchmarks/2d/uniaxial_particle_traction/test_benchmark.py index b880ffa..356d0a3 100644 --- a/benchmarks/2d/uniaxial_particle_traction/test_benchmark.py +++ b/benchmarks/2d/uniaxial_particle_traction/test_benchmark.py @@ -1,6 +1,8 @@ import os from pathlib import Path + import jax.numpy as jnp + from diffmpm import MPM diff --git a/benchmarks/2d/uniaxial_stress/test_benchmark.py b/benchmarks/2d/uniaxial_stress/test_benchmark.py index 0a6d10c..f04e820 100644 --- a/benchmarks/2d/uniaxial_stress/test_benchmark.py +++ b/benchmarks/2d/uniaxial_stress/test_benchmark.py @@ -1,6 +1,8 @@ import os from pathlib import Path + import jax.numpy as jnp + from diffmpm import MPM diff --git a/diffmpm/__init__.py b/diffmpm/__init__.py index bf2f251..faa8316 100644 --- a/diffmpm/__init__.py +++ b/diffmpm/__init__.py @@ -40,7 +40,7 @@ def __init__(self, filepath): raise ValueError("Wrong type of solver specified.") def solve(self): - """Solve the MPM simulation.""" + """Solve the MPM simulation using JIT solver.""" arrays = self.solver.solve_jit( self._config.parsed_config["external_loading"]["gravity"], ) diff --git a/diffmpm/cli/mpm.py b/diffmpm/cli/mpm.py index 3e621cf..063140d 100644 --- a/diffmpm/cli/mpm.py +++ b/diffmpm/cli/mpm.py @@ -3,11 +3,12 @@ from diffmpm import MPM -@click.command() +@click.command() # type: ignore @click.option( "-f", "--file", "filepath", required=True, type=str, help="Input TOML file" ) @click.version_option(package_name="diffmpm") def mpm(filepath): + """CLI utility for MPM.""" solver = MPM(filepath) solver.solve() diff --git a/diffmpm/constraint.py b/diffmpm/constraint.py index cba836a..93f75bd 100644 --- a/diffmpm/constraint.py +++ b/diffmpm/constraint.py @@ -3,7 +3,18 @@ @register_pytree_node_class class Constraint: - def __init__(self, dir, velocity): + """Generic velocity constraints to apply on nodes or particles.""" + + def __init__(self, dir: int, velocity: float): + """Contains 2 govering parameters. + + Attributes + ---------- + dir : int + Direction in which constraint is applied. + velocity : float + Constrained velocity to be applied. + """ self.dir = dir self.velocity = velocity @@ -16,16 +27,15 @@ def tree_unflatten(cls, aux_data, children): return cls(*aux_data) def apply(self, obj, ids): - """ - Apply constraint values to the passed object. + """Apply constraint values to the passed object. - Arguments - --------- + Parameters + ---------- obj : diffmpm.node.Nodes, diffmpm.particle.Particles Object on which the constraint is applied ids : array_like The indices of the container `obj` on which the constraint - will be applied. + will be applied. """ obj.velocity = obj.velocity.at[ids, :, self.dir].set(self.velocity) obj.momentum = obj.momentum.at[ids, :, self.dir].set( diff --git a/diffmpm/element.py b/diffmpm/element.py index f0b9d4f..3eeff67 100644 --- a/diffmpm/element.py +++ b/diffmpm/element.py @@ -1,52 +1,79 @@ +from __future__ import annotations + import abc import itertools -from typing import Sequence, Tuple, List +from typing import TYPE_CHECKING, Optional, Sequence, Tuple + +if TYPE_CHECKING: + from diffmpm.particle import Particles import jax.numpy as jnp -from jax import jacobian, jit, lax, vmap +from jax import Array, jacobian, jit, lax, vmap from jax.tree_util import register_pytree_node_class +from jax.typing import ArrayLike -from diffmpm.node import Nodes from diffmpm.constraint import Constraint +from diffmpm.node import Nodes + +__all__ = ["_Element", "Linear1D", "Quadrilateral4Node"] class _Element(abc.ABC): + """Base element class that is inherited by all types of Elements.""" + + nodes: Nodes + total_elements: int + concentrated_nodal_forces: Sequence + volume: Array + @abc.abstractmethod - def id_to_node_ids(self): - ... + def id_to_node_ids(self, id: ArrayLike) -> Array: + """Node IDs corresponding to element `id`. + + This method is implemented by each of the subclass. - def id_to_node_loc(self, id: int): + Parameters + ---------- + id : int + Element ID. + + Returns + ------- + ArrayLike + Nodal IDs of the element. """ - Node locations corresponding to element `id`. + ... + + def id_to_node_loc(self, id: ArrayLike) -> Array: + """Node locations corresponding to element `id`. - Arguments - --------- + Parameters + ---------- id : int Element ID. Returns ------- - jax.numpy.ndarray + ArrayLike Nodal locations for the element. Shape of returned - array is (nodes_in_element, 1, ndim) + array is `(nodes_in_element, 1, ndim)` """ node_ids = self.id_to_node_ids(id).squeeze() return self.nodes.loc[node_ids] - def id_to_node_vel(self, id: int): - """ - Node velocities corresponding to element `id`. + def id_to_node_vel(self, id: ArrayLike) -> Array: + """Node velocities corresponding to element `id`. - Arguments - --------- + Parameters + ---------- id : int Element ID. Returns ------- - jax.numpy.ndarray + ArrayLike Nodal velocities for the element. Shape of returned - array is (nodes_in_element, 1, ndim) + array is `(nodes_in_element, 1, ndim)` """ node_ids = self.id_to_node_ids(id).squeeze() return self.nodes.velocity[node_ids] @@ -77,29 +104,33 @@ def tree_unflatten(cls, aux_data, children): ) @abc.abstractmethod - def shapefn(self): + def shapefn(self, xi: ArrayLike): + """Evaluate Shape function for element type.""" ... @abc.abstractmethod - def shapefn_grad(self): + def shapefn_grad(self, xi: ArrayLike, coords: ArrayLike): + """Evaluate gradient of shape function for element type.""" ... @abc.abstractmethod - def set_particle_element_ids(self): + def set_particle_element_ids(self, particles: Particles): + """Set the element IDs that particles are present in.""" ... # Mapping from particles to nodes (P2G) - def compute_nodal_mass(self, particles): - r""" - Compute the nodal mass based on particle mass. + def compute_nodal_mass(self, particles: Particles): + r"""Compute the nodal mass based on particle mass. The nodal mass is updated as a sum of particle mass for all particles mapped to the node. - :math:`(m)_i = \sum_p N_i(x_p) m_p` + \[ + (m)_i = \sum_p N_i(x_p) m_p + \] - Arguments - --------- + Parameters + ---------- particles: diffmpm.particle.Particles Particles to map to the nodal values. """ @@ -120,17 +151,18 @@ def _step(pid, args): ) _, self.nodes.mass, _, _ = lax.fori_loop(0, len(particles), _step, args) - def compute_nodal_momentum(self, particles): - r""" - Compute the nodal mass based on particle mass. + def compute_nodal_momentum(self, particles: Particles): + r"""Compute the nodal mass based on particle mass. The nodal mass is updated as a sum of particle mass for all particles mapped to the node. - :math:`(mv)_i = \sum_p N_i(x_p) (mv)_p` + \[ + (mv)_i = \sum_p N_i(x_p) (mv)_p + \] - Arguments - --------- + Parameters + ---------- particles: diffmpm.particle.Particles Particles to map to the nodal values. """ @@ -156,7 +188,8 @@ def _step(pid, args): self.nodes.momentum, ) - def compute_velocity(self, particles): + def compute_velocity(self, particles: Particles): + """Compute velocity using momentum.""" self.nodes.velocity = jnp.where( self.nodes.mass == 0, self.nodes.velocity, @@ -168,17 +201,18 @@ def compute_velocity(self, particles): self.nodes.velocity, ) - def compute_external_force(self, particles): - r""" - Update the nodal external force based on particle f_ext. + def compute_external_force(self, particles: Particles): + r"""Update the nodal external force based on particle f_ext. The nodal force is updated as a sum of particle external force for all particles mapped to the node. - :math:`(f_{ext})_i = \sum_p N_i(x_p) f_{ext}` + \[ + f_{ext})_i = \sum_p N_i(x_p) f_{ext} + \] - Arguments - --------- + Parameters + ---------- particles: diffmpm.particle.Particles Particles to map to the nodal values. """ @@ -199,17 +233,18 @@ def _step(pid, args): ) self.nodes.f_ext, _, _, _ = lax.fori_loop(0, len(particles), _step, args) - def compute_body_force(self, particles, gravity: float | jnp.ndarray): - r""" - Update the nodal external force based on particle mass. + def compute_body_force(self, particles: Particles, gravity: ArrayLike): + r"""Update the nodal external force based on particle mass. The nodal force is updated as a sum of particle body force for all particles mapped to th - :math:`(f_{ext})_i += \sum_p N_i(x_p) m_p g` + \[ + (f_{ext})_i = (f_{ext})_i + \sum_p N_i(x_p) m_p g + \] - Arguments - --------- + Parameters + ---------- particles: diffmpm.particle.Particles Particles to map to the nodal values. """ @@ -232,14 +267,31 @@ def _step(pid, args): ) self.nodes.f_ext, _, _, _, _ = lax.fori_loop(0, len(particles), _step, args) - def apply_concentrated_nodal_forces(self, particles, curr_time): + def apply_concentrated_nodal_forces(self, particles: Particles, curr_time: float): + """Apply concentrated nodal forces. + + Parameters + ---------- + particles: Particles + Particles in the simulation. + curr_time: float + Current time in the simulation. + """ for cnf in self.concentrated_nodal_forces: factor = cnf.function.value(curr_time) self.nodes.f_ext = self.nodes.f_ext.at[cnf.node_ids, 0, cnf.dir].add( factor * cnf.force ) - def apply_particle_traction_forces(self, particles): + def apply_particle_traction_forces(self, particles: Particles): + """Apply concentrated nodal forces. + + Parameters + ---------- + particles: Particles + Particles in the simulation. + """ + def _step(pid, args): f_ext, ptraction, mapped_pos, el_nodes = args f_ext = f_ext.at[el_nodes[pid]].add(mapped_pos[pid] @ ptraction[pid]) @@ -250,7 +302,9 @@ def _step(pid, args): args = (self.nodes.f_ext, particles.traction, mapped_positions, mapped_nodes) self.nodes.f_ext, _, _, _ = lax.fori_loop(0, len(particles), _step, args) - def update_nodal_acceleration_velocity(self, particles, dt: float, *args): + def update_nodal_acceleration_velocity( + self, particles: Particles, dt: float, *args + ): """Update the nodal momentum based on total force on nodes.""" total_force = self.nodes.get_total_force() self.nodes.acceleration = self.nodes.acceleration.at[:].set( @@ -288,38 +342,54 @@ def apply_force_boundary_constraints(self, *args): @register_pytree_node_class class Linear1D(_Element): - """ - Container for 1D line elements (and nodes). + """Container for 1D line elements (and nodes). + + Element ID: 0 1 2 3 + Mesh: +-----+-----+-----+-----+ + Node IDs: 0 1 2 3 4 - Element ID: 0 1 2 3 - Mesh: +-----+-----+-----+-----+ - Node IDs: 0 1 2 3 4 + where + + + : Nodes + +-----+ : An element - + : Nodes - +-----+ : An element """ def __init__( self, nelements: int, - total_elements, + total_elements: int, el_len: float, - constraints: List[Tuple[jnp.ndarray, Constraint]], - nodes: Nodes = None, - concentrated_nodal_forces=[], - initialized=None, - volume=None, + constraints: Sequence[Tuple[ArrayLike, Constraint]], + nodes: Optional[Nodes] = None, + concentrated_nodal_forces: Sequence = [], + initialized: Optional[bool] = None, + volume: Optional[ArrayLike] = None, ): """Initialize Linear1D. - Arguments - --------- + Parameters + ---------- nelements : int Number of elements. + total_elements : int + Total number of elements (same as `nelements` for 1D) el_len : float Length of each element. - boundary_nodes : Sequence - IDs of nodes that are supposed to be fixed (boundary). + constraints: list + A list of constraints where each element is a tuple of type + `(node_ids, diffmpm.Constraint)`. Here, `node_ids` correspond to + the node IDs where `diffmpm.Constraint` should be applied. + nodes : Nodes, Optional + Nodes in the element object. + concentrated_nodal_forces: list + A list of `diffmpm.forces.NodalForce`s that are to be + applied. + initialized: bool, None + `True` if the class has been initialized, `None` if not. + This is required like this for using JAX flattening. + volume: ArrayLike + Volume of the elements. """ self.nelements = nelements self.total_elements = nelements @@ -338,72 +408,71 @@ def __init__( if initialized is None: self.volume = jnp.ones((self.total_elements, 1, 1)) else: - self.volume = volume + self.volume = jnp.asarray(volume) self.initialized = True - def id_to_node_ids(self, id: int): - """ - Node IDs corresponding to element `id`. + def id_to_node_ids(self, id: ArrayLike): + """Node IDs corresponding to element `id`. - Arguments - --------- + Parameters + ---------- id : int Element ID. Returns ------- - jax.numpy.ndarray + ArrayLike Nodal IDs of the element. Shape of returned - array is (2, 1) + array is `(2, 1)` """ return jnp.array([id, id + 1]).reshape(2, 1) - def shapefn(self, xi: float | jnp.ndarray): - """ - Evaluate linear shape function. + def shapefn(self, xi: ArrayLike): + """Evaluate linear shape function. - Arguments - --------- + Parameters + ---------- xi : float, array_like Locations of particles in natural coordinates to evaluate - the function at. Expected shape is (npoints, 1, ndim) + the function at. Expected shape is `(npoints, 1, ndim)` Returns ------- array_like Evaluated shape function values. The shape of the returned - array will depend on the input shape. For example, in the linear - case, if the input is a scalar, the returned array will be of - the shape (1, 2, 1) but if the input is a vector then the output will - be of the shape (len(x), 2, 1). + array will depend on the input shape. For example, in the linear + case, if the input is a scalar, the returned array will be of + the shape `(1, 2, 1)` but if the input is a vector then the output will + be of the shape `(len(x), 2, 1)`. """ - if len(xi.shape) != 3: + xi = jnp.asarray(xi) + if xi.ndim != 3: raise ValueError( f"`xi` should be of size (npoints, 1, ndim); found {xi.shape}" ) result = jnp.array([0.5 * (1 - xi), 0.5 * (1 + xi)]).transpose(1, 0, 2, 3) return result - def _shapefn_natural_grad(self, xi: float | jnp.ndarray): - """ - Calculate the gradient of shape function. + def _shapefn_natural_grad(self, xi: ArrayLike): + """Calculate the gradient of shape function. This calculation is done in the natural coordinates. - Arguments - --------- + Parameters + ---------- x : float, array_like Locations of particles in natural coordinates to evaluate - the function at. + the function at. Returns ------- array_like Evaluated gradient values of the shape function. The shape of - the returned array will depend on the input shape. For example, - in the linear case, if the input is a scalar, the returned array - will be of the shape (2, 1). + the returned array will depend on the input shape. For example, + in the linear case, if the input is a scalar, the returned array + will be of the shape `(2, 1)`. """ + xi = jnp.asarray(xi) result = vmap(jacobian(self.shapefn))(xi[..., jnp.newaxis]).squeeze() # TODO: The following code tries to evaluate vmap even if @@ -416,25 +485,26 @@ def _shapefn_natural_grad(self, xi: float | jnp.ndarray): # ) return result.reshape(2, 1) - def shapefn_grad(self, xi: float | jnp.ndarray, coords: jnp.ndarray): - """ - Gradient of shape function in physical coordinates. + def shapefn_grad(self, xi: ArrayLike, coords: ArrayLike): + """Gradient of shape function in physical coordinates. - Arguments - --------- + Parameters + ---------- xi : float, array_like Locations of particles to evaluate in natural coordinates. - Expected shape (npoints, 1, ndim). + Expected shape `(npoints, 1, ndim)`. coords : array_like Nodal coordinates to transform by. Expected shape - (npoints, 1, ndim) + `(npoints, 1, ndim)` Returns ------- array_like Gradient of the shape function in physical coordinates at `xi` """ - if len(xi.shape) != 3: + xi = jnp.asarray(xi) + coords = jnp.asarray(coords) + if xi.ndim != 3: raise ValueError( f"`x` should be of size (npoints, 1, ndim); found {xi.shape}" ) @@ -445,8 +515,7 @@ def shapefn_grad(self, xi: float | jnp.ndarray, coords: jnp.ndarray): return result def set_particle_element_ids(self, particles): - """ - Set the element IDs for the particles. + """Set the element IDs for the particles. If the particle doesn't lie between the boundaries of any element, it sets the element index to -1. @@ -472,20 +541,24 @@ def f(x): ) def compute_volume(self, *args): + """Compute volume of all elements.""" vol = jnp.ediff1d(self.nodes.loc) self.volume = jnp.ones((self.total_elements, 1, 1)) * vol def compute_internal_force(self, particles): - r""" - Update the nodal internal force based on particle mass. + r"""Update the nodal internal force based on particle mass. The nodal force is updated as a sum of internal forces for all particles mapped to the node. - :math:`(f_{int})_i = -\sum_p V_p * stress_p * \nabla N_i(x_p)` + \[ + (f_{int})_i = -\sum_p V_p \sigma_p \nabla N_i(x_p) + \] - Arguments - --------- + where \(\sigma_p\) is the stress at particle \(p\). + + Parameters + ---------- particles: diffmpm.particle.Particles Particles to map to the nodal values. """ @@ -529,45 +602,63 @@ def _step(pid, args): @register_pytree_node_class class Quadrilateral4Node(_Element): - """ - Container for 2D quadrilateral elements with 4 nodes. + r"""Container for 2D quadrilateral elements with 4 nodes. Nodes and elements are numbered as - 15 0---0---0---0---0 19 + 15 +---+---+---+---+ 19 | 8 | 9 | 10| 11| - 10 0---0---0---0---0 14 + 10 +---+---+---+---+ 14 | 4 | 5 | 6 | 7 | - 5 0---0---0---0---0 9 + 5 +---+---+---+---+ 9 | 0 | 1 | 2 | 3 | - 0---0---0---0---0 + +---+---+---+---+ 0 1 2 3 4 - + : Nodes - +---+ - | | : An element - +---+ + where + + + : Nodes + +---+ + | | : An element + +---+ """ def __init__( self, - nelements: Tuple[int, int], + nelements: int, total_elements: int, - el_len: Tuple[float, float], - constraints: List[Tuple[jnp.ndarray, Constraint]], - nodes: Nodes = None, - concentrated_nodal_forces=[], - initialized: bool = None, - volume: jnp.ndarray = None, - ): - """Initialize Quadrilateral4Node. - - Arguments - --------- - nelements : (int, int) - Number of elements in X and Y direction. - el_len : (float, float) - Length of each element in X and Y direction. + el_len: float, + constraints: Sequence[Tuple[ArrayLike, Constraint]], + nodes: Optional[Nodes] = None, + concentrated_nodal_forces: Sequence = [], + initialized: Optional[bool] = None, + volume: Optional[ArrayLike] = None, + ) -> None: + """Initialize Linear1D. + + Parameters + ---------- + nelements : int + Number of elements. + total_elements : int + Total number of elements (product of all elements of `nelements`) + el_len : float + Length of each element. + constraints: list + A list of constraints where each element is a tuple of + type `(node_ids, diffmpm.Constraint)`. Here, `node_ids` + correspond to the node IDs where `diffmpm.Constraint` + should be applied. + nodes : Nodes, Optional + Nodes in the element object. + concentrated_nodal_forces: list + A list of `diffmpm.forces.NodalForce`s that are to be + applied. + initialized: bool, None + `True` if the class has been initialized, `None` if not. + This is required like this for using JAX flattening. + volume: ArrayLike + Volume of the elements. """ self.nelements = jnp.asarray(nelements) self.el_len = jnp.asarray(el_len) @@ -578,15 +669,15 @@ def __init__( coords = jnp.asarray( list( itertools.product( - jnp.arange(nelements[1] + 1), - jnp.arange(nelements[0] + 1), + jnp.arange(self.nelements[1] + 1), + jnp.arange(self.nelements[0] + 1), ) ) ) node_locations = ( jnp.asarray([coords[:, 1], coords[:, 0]]).T * self.el_len ).reshape(-1, 1, 2) - self.nodes = Nodes(total_nodes, node_locations) + self.nodes = Nodes(int(total_nodes), node_locations) else: self.nodes = nodes @@ -595,12 +686,11 @@ def __init__( if initialized is None: self.volume = jnp.ones((self.total_elements, 1, 1)) else: - self.volume = volume + self.volume = jnp.asarray(volume) self.initialized = True - def id_to_node_ids(self, id: int): - """ - Node IDs corresponding to element `id`. + def id_to_node_ids(self, id: ArrayLike): + """Node IDs corresponding to element `id`. 3----2 | | @@ -608,16 +698,16 @@ def id_to_node_ids(self, id: int): Node ids are returned in the order as shown in the figure. - Arguments - --------- + Parameters + ---------- id : int Element ID. Returns ------- - jax.numpy.ndarray + ArrayLike Nodal IDs of the element. Shape of returned - array is (4, 1) + array is (4, 1) """ lower_left = (id // self.nelements[0]) * ( self.nelements[0] + 1 @@ -632,26 +722,26 @@ def id_to_node_ids(self, id: int): ) return result.reshape(4, 1) - def shapefn(self, xi: Sequence[float]): - """ - Evaluate linear shape function. + def shapefn(self, xi: ArrayLike): + """Evaluate linear shape function. - Arguments - --------- + Parameters + ---------- xi : float, array_like Locations of particles in natural coordinates to evaluate - the function at. Expected shape is (npoints, 1, ndim) + the function at. Expected shape is (npoints, 1, ndim) Returns ------- array_like Evaluated shape function values. The shape of the returned - array will depend on the input shape. For example, in the linear - case, if the input is a scalar, the returned array will be of - the shape (1, 4, 1) but if the input is a vector then the output will - be of the shape (len(x), 4, 1). + array will depend on the input shape. For example, in the linear + case, if the input is a scalar, the returned array will be of + the shape `(1, 4, 1)` but if the input is a vector then the output will + be of the shape `(len(x), 4, 1)`. """ - if len(xi.shape) != 3: + xi = jnp.asarray(xi) + if xi.ndim != 3: raise ValueError( f"`xi` should be of size (npoints, 1, ndim); found {xi.shape}" ) @@ -666,27 +756,27 @@ def shapefn(self, xi: Sequence[float]): result = result.transpose(1, 0, 2)[..., jnp.newaxis] return result - def _shapefn_natural_grad(self, xi: float | jnp.ndarray): - """ - Calculate the gradient of shape function. + def _shapefn_natural_grad(self, xi: ArrayLike): + """Calculate the gradient of shape function. This calculation is done in the natural coordinates. - Arguments - --------- + Parameters + ---------- x : float, array_like Locations of particles in natural coordinates to evaluate - the function at. + the function at. Returns ------- array_like Evaluated gradient values of the shape function. The shape of - the returned array will depend on the input shape. For example, - in the linear case, if the input is a scalar, the returned array - will be of the shape (4, 2). + the returned array will depend on the input shape. For example, + in the linear case, if the input is a scalar, the returned array + will be of the shape `(4, 2)`. """ # result = vmap(jacobian(self.shapefn))(xi[..., jnp.newaxis]).squeeze() + xi = jnp.asarray(xi) xi = xi.squeeze() result = jnp.array( [ @@ -698,25 +788,26 @@ def _shapefn_natural_grad(self, xi: float | jnp.ndarray): ) return result - def shapefn_grad(self, xi: float | jnp.ndarray, coords: jnp.ndarray): - """ - Gradient of shape function in physical coordinates. + def shapefn_grad(self, xi: ArrayLike, coords: ArrayLike): + """Gradient of shape function in physical coordinates. - Arguments - --------- + Parameters + ---------- xi : float, array_like Locations of particles to evaluate in natural coordinates. - Expected shape (npoints, 1, ndim). + Expected shape `(npoints, 1, ndim)`. coords : array_like Nodal coordinates to transform by. Expected shape - (npoints, 1, ndim) + `(npoints, 1, ndim)` Returns ------- array_like Gradient of the shape function in physical coordinates at `xi` """ - if len(xi.shape) != 3: + xi = jnp.asarray(xi) + coords = jnp.asarray(coords) + if xi.ndim != 3: raise ValueError( f"`x` should be of size (npoints, 1, ndim); found {xi.shape}" ) @@ -726,9 +817,8 @@ def shapefn_grad(self, xi: float | jnp.ndarray, coords: jnp.ndarray): result = grad_sf @ jnp.linalg.inv(_jacobian).T return result - def set_particle_element_ids(self, particles): - """ - Set the element IDs for the particles. + def set_particle_element_ids(self, particles: Particles): + """Set the element IDs for the particles. If the particle doesn't lie between the boundaries of any element, it sets the element index to -1. @@ -749,17 +839,20 @@ def f(x): ids = vmap(f)(particles.loc) particles.element_ids = ids - def compute_internal_force(self, particles): - r""" - Update the nodal internal force based on particle mass. + def compute_internal_force(self, particles: Particles): + r"""Update the nodal internal force based on particle mass. The nodal force is updated as a sum of internal forces for all particles mapped to the node. - :math:`(f_{int})_i = -\sum_p V_p * stress_p * \nabla N_i(x_p)` + \[ + (f_{int})_i = -\sum_p V_p \sigma_p \nabla N_i(x_p) + \] + + where \(\sigma_p\) is the stress at particle \(p\). - Arguments - --------- + Parameters + ---------- particles: diffmpm.particle.Particles Particles to map to the nodal values. """ @@ -808,14 +901,9 @@ def _step(pid, args): self.nodes.f_int, _, _, _, _ = lax.fori_loop(0, len(particles), _step, args) def compute_volume(self, *args): + """Compute volume of all elements.""" a = c = self.el_len[1] b = d = self.el_len[0] p = q = jnp.sqrt(a**2 + b**2) vol = 0.25 * jnp.sqrt(4 * p * p * q * q - (a * a + c * c - b * b - d * d) ** 2) self.volume = self.volume.at[:].set(vol) - - -if __name__ == "__main__": - from diffmpm.utils import _show_example - - _show_example(Linear1D(2, 1, jnp.array([0]))) diff --git a/diffmpm/forces.py b/diffmpm/forces.py index eb6d27f..6740462 100644 --- a/diffmpm/forces.py +++ b/diffmpm/forces.py @@ -1,17 +1,61 @@ -from collections import namedtuple +from typing import Annotated, NamedTuple, get_type_hints + +from jax import Array from jax.tree_util import register_pytree_node -NodalForce = namedtuple("NodalForce", ("node_ids", "function", "dir", "force")) -ParticleTraction = namedtuple( - "ParticleTraction", ("pset", "pids", "function", "dir", "traction") -) +from diffmpm.functions import Function + + +class NodalForce(NamedTuple): + """Nodal Force being applied constantly on a set of nodes.""" + + node_ids: Annotated[Array, "Array of Node IDs to which force is applied."] + function: Annotated[ + Function, + "Mathematical function that governs time-varying changes in the force.", + ] + dir: Annotated[int, "Direction in which force is applied."] + force: Annotated[float, "Amount of force to be applied."] + + +nfhints = get_type_hints(NodalForce, include_extras=True) +for attr in nfhints: + getattr(NodalForce, attr).__doc__ = "".join(nfhints[attr].__metadata__) + + +class ParticleTraction(NamedTuple): + """Traction being applied on a set of particles.""" + + pset: Annotated[ + int, "The particle set in which traction is applied to the particles." + ] + pids: Annotated[ + Array, + "Array of Particle IDs to which traction is applied inside the particle set.", + ] + function: Annotated[ + Function, + "Mathematical function that governs time-varying changes in the traction.", + ] + dir: Annotated[int, "Direction in which traction is applied."] + traction: Annotated[float, "Amount of traction to be applied."] + + +pthints = get_type_hints(ParticleTraction, include_extras=True) +for attr in pthints: + getattr(ParticleTraction, attr).__doc__ = "".join(pthints[attr].__metadata__) + register_pytree_node( NodalForce, - lambda xs: (tuple(xs), None), # tell JAX how to unpack to an iterable - lambda _, xs: NodalForce(*xs), # tell JAX how to pack back into a NodalForce + # tell JAX how to unpack to an iterable + lambda xs: (tuple(xs), None), # type: ignore + # tell JAX how to pack back into a NodalForce + lambda _, xs: NodalForce(*xs), # type: ignore ) register_pytree_node( ParticleTraction, - lambda xs: (tuple(xs), None), # tell JAX how to unpack to an iterable - lambda _, xs: ParticleTraction(*xs), # tell JAX how to pack back + # tell JAX how to unpack to an iterable + lambda xs: (tuple(xs), None), # type: ignore + # tell JAX how to pack back + lambda _, xs: ParticleTraction(*xs), # type: ignore ) diff --git a/diffmpm/functions.py b/diffmpm/functions.py index 44880dc..90b55c4 100644 --- a/diffmpm/functions.py +++ b/diffmpm/functions.py @@ -1,4 +1,5 @@ import abc + import jax.numpy as jnp from jax.tree_util import register_pytree_node_class diff --git a/diffmpm/io.py b/diffmpm/io.py index b9e08d3..bd10930 100644 --- a/diffmpm/io.py +++ b/diffmpm/io.py @@ -9,7 +9,7 @@ from diffmpm import mesh as mpmesh from diffmpm.constraint import Constraint from diffmpm.forces import NodalForce, ParticleTraction -from diffmpm.functions import Unit, Linear +from diffmpm.functions import Linear, Unit from diffmpm.particle import Particles diff --git a/diffmpm/material.py b/diffmpm/material.py index 2dd8487..09230d4 100644 --- a/diffmpm/material.py +++ b/diffmpm/material.py @@ -1,19 +1,20 @@ -from jax.tree_util import register_pytree_node_class import abc +from typing import Tuple + import jax.numpy as jnp +from jax.tree_util import register_pytree_node_class class Material(abc.ABC): """Base material class.""" - _props = () + _props: Tuple[str, ...] def __init__(self, material_properties): - """ - Initialize material properties. + """Initialize material properties. - Arguments - --------- + Parameters + ---------- material_properties: dict A key-value map for various material properties. """ @@ -57,11 +58,10 @@ class LinearElastic(Material): _props = ("density", "youngs_modulus", "poisson_ratio") def __init__(self, material_properties): - """ - Create a Linear Elastic material. + """Create a Linear Elastic material. - Arguments - --------- + Parameters + ---------- material_properties: dict Dictionary with material properties. For linear elastic materials, 'density' and 'youngs_modulus' are required keys. @@ -111,9 +111,7 @@ def _compute_elastic_tensor(self): ) def compute_stress(self, dstrain): - """ - Compute material stress. - """ + """Compute material stress.""" dstress = self.de @ dstrain return dstress @@ -131,9 +129,3 @@ def __repr__(self): def compute_stress(self, dstrain): return dstrain * self.properties["E"] - - -if __name__ == "__main__": - from diffmpm.utils import _show_example - - _show_example(SimpleMaterial({"E": 2, "density": 1})) diff --git a/diffmpm/mesh.py b/diffmpm/mesh.py index 1a31c60..ed8b27b 100644 --- a/diffmpm/mesh.py +++ b/diffmpm/mesh.py @@ -1,5 +1,5 @@ import abc -from typing import Iterable +from typing import Callable, Sequence, Tuple import jax.numpy as jnp from jax.tree_util import register_pytree_node_class @@ -7,40 +7,64 @@ from diffmpm.element import _Element from diffmpm.particle import Particles +__all__ = ["_MeshBase", "Mesh1D", "Mesh2D"] + class _MeshBase(abc.ABC): - """ - Base class for Meshes. + """Base class for Meshes. - Note: If attributes other than elements and particles are added - then the child class should also implement `tree_flatten` and - `tree_unflatten` correctly or that information will get lost. + .. note:: + If attributes other than elements and particles are added + then the child class should also implement `tree_flatten` and + `tree_unflatten` correctly or that information will get lost. """ + ndim: int + def __init__(self, config: dict): """Initialize mesh using configuration.""" - self.particles: Iterable[Particles, ...] = config["particles"] + self.particles: Sequence[Particles] = config["particles"] self.elements: _Element = config["elements"] self.particle_tractions = config["particle_surface_traction"] - @property - @abc.abstractmethod - def ndim(self): - ... - # TODO: Convert to using jax directives for loop - def apply_on_elements(self, function, args=()): + def apply_on_elements(self, function: str, args: Tuple = ()): + """Apply a given function to elements. + + Parameters + ---------- + function: str + A string corresponding to a function name in `_Element`. + args: tuple + Parameters to be passed to the function. + """ f = getattr(self.elements, function) for particle_set in self.particles: f(particle_set, *args) # TODO: Convert to using jax directives for loop - def apply_on_particles(self, function, args=()): + def apply_on_particles(self, function: str, args: Tuple = ()): + """Apply a given function to particles. + + Parameters + ---------- + function: str + A string corresponding to a function name in `Particles`. + args: tuple + Parameters to be passed to the function. + """ for particle_set in self.particles: f = getattr(particle_set, function) f(self.elements, *args) - def apply_traction_on_particles(self, curr_time): + def apply_traction_on_particles(self, curr_time: float): + """Apply tractions on particles. + + Parameters + ---------- + curr_time: float + Current time in the simulation. + """ self.apply_on_particles("zero_traction") for ptraction in self.particle_tractions: factor = ptraction.function.value(curr_time) @@ -50,7 +74,6 @@ def apply_traction_on_particles(self, curr_time): ptraction.pids[i], ptraction.dir, traction_val ) - # breakpoint() self.apply_on_elements("apply_particle_traction_forces") def tree_flatten(self): @@ -74,52 +97,30 @@ class Mesh1D(_MeshBase): """1D Mesh class with nodes, elements, and particles.""" def __init__(self, config: dict): - """ - Initialize a 1D Mesh. + """Initialize a 1D Mesh. - Arguments - --------- + Parameters + ---------- config: dict Configuration to be used for initialization. It _should_ - contain `elements` and `particles` keys. + contain `elements` and `particles` keys. """ + self.ndim = 1 super().__init__(config) - @property - def ndim(self): - return 1 - @register_pytree_node_class class Mesh2D(_MeshBase): """1D Mesh class with nodes, elements, and particles.""" def __init__(self, config: dict): - """ - Initialize a 2D Mesh. + """Initialize a 2D Mesh. - Arguments - --------- + Parameters + ---------- config: dict Configuration to be used for initialization. It _should_ - contain `elements` and `particles` keys. + contain `elements` and `particles` keys. """ + self.ndim = 2 super().__init__(config) - - @property - def ndim(self): - return 2 - - -if __name__ == "__main__": - from diffmpm.element import Linear1D - from diffmpm.material import SimpleMaterial - from diffmpm.utils import _show_example - - particles = Particles( - jnp.array([[[1]]]), - SimpleMaterial({"E": 2, "density": 1}), - jnp.array([0]), - ) - elements = Linear1D(2, 1, jnp.array([0])) - _show_example(Mesh1D({"particles": [particles], "elements": elements})) diff --git a/diffmpm/node.py b/diffmpm/node.py index 14396a4..46e2a60 100644 --- a/diffmpm/node.py +++ b/diffmpm/node.py @@ -1,13 +1,13 @@ -from typing import Tuple +from typing import Optional, Sized, Tuple import jax.numpy as jnp from jax.tree_util import register_pytree_node_class +from jax.typing import ArrayLike @register_pytree_node_class -class Nodes: - """ - Nodes container class. +class Nodes(Sized): + """Nodes container class. Keeps track of all values required for nodal points. @@ -15,50 +15,51 @@ class Nodes: ---------- nnodes : int Number of nodes stored. - loc : array_like + loc : ArrayLike Location of all the nodes. velocity : array_like Velocity of all the nodes. - mass : array_like + mass : ArrayLike Mass of all the nodes. momentum : array_like Momentum of all the nodes. - f_int : array_like + f_int : ArrayLike Internal forces on all the nodes. - f_ext : array_like + f_ext : ArrayLike External forces present on all the nodes. - f_damp : array_like + f_damp : ArrayLike Damping forces on the nodes. """ def __init__( self, nnodes: int, - loc: jnp.ndarray, - initialized: bool = None, - data: Tuple[jnp.ndarray, ...] = tuple(), + loc: ArrayLike, + initialized: Optional[bool] = None, + data: Tuple[ArrayLike, ...] = tuple(), ): - """ - Initialize container for Nodes. + """Initialize container for Nodes. Parameters ---------- nnodes : int Number of nodes stored. - loc : array_like + loc : ArrayLike Locations of all the nodes. Expected shape (nnodes, 1, ndim) initialized: bool - False if node property arrays like mass need to be initialized. - If True, they are set to values from `data`. + `False` if node property arrays like mass need to be initialized. + If `True`, they are set to values from `data`. data: tuple Tuple of length 7 that sets arrays for mass, density, volume, + and forces. Mainly used by JAX while unflattening. """ self.nnodes = nnodes - if len(loc.shape) != 3: + loc = jnp.asarray(loc, dtype=jnp.float32) + if loc.ndim != 3: raise ValueError( f"`loc` should be of size (nnodes, 1, ndim); found {loc.shape}" ) - self.loc = jnp.asarray(loc, dtype=jnp.float32) + self.loc = loc if initialized is None: self.velocity = jnp.zeros_like(self.loc, dtype=jnp.float32) @@ -77,11 +78,11 @@ def __init__( self.f_int, self.f_ext, self.f_damp, - ) = data + ) = data # type: ignore self.initialized = True def tree_flatten(self): - """Helper method for registering class as Pytree type.""" + """Flatten class as Pytree type.""" children = ( self.loc, self.initialized, @@ -98,9 +99,8 @@ def tree_flatten(self): @classmethod def tree_unflatten(cls, aux_data, children): - return cls( - aux_data[0], children[0], initialized=children[1], data=children[2:] - ) + """Unflatten class from Pytree type.""" + return cls(aux_data[0], children[0], initialized=children[1], data=children[2:]) def reset_values(self): """Reset nodal parameter values except location.""" @@ -123,9 +123,3 @@ def __repr__(self): def get_total_force(self): """Calculate total force on the nodes.""" return self.f_int + self.f_ext + self.f_damp - - -if __name__ == "__main__": - from diffmpm.utils import _show_example - - _show_example(Nodes(2, jnp.array([1, 2]).reshape(2, 1, 1))) diff --git a/diffmpm/particle.py b/diffmpm/particle.py index 0df598c..1bb3d70 100644 --- a/diffmpm/particle.py +++ b/diffmpm/particle.py @@ -1,49 +1,50 @@ -from typing import Tuple +from typing import Optional, Sized, Tuple import jax.numpy as jnp -from jax import vmap, lax +from jax import lax, vmap from jax.tree_util import register_pytree_node_class +from jax.typing import ArrayLike from diffmpm.element import _Element from diffmpm.material import Material @register_pytree_node_class -class Particles: +class Particles(Sized): """Container class for a set of particles.""" def __init__( self, - loc: jnp.ndarray, + loc: ArrayLike, material: Material, - element_ids: jnp.ndarray, - initialized: bool = None, - data: Tuple[jnp.ndarray, ...] = None, + element_ids: ArrayLike, + initialized: Optional[bool] = None, + data: Optional[Tuple[ArrayLike, ...]] = None, ): - """ - Initialize a container of particles. + """Initialize a container of particles. - Arguments - --------- - loc: jax.numpy.ndarray + Parameters + ---------- + loc: ArrayLike Location of the particles. Expected shape (nparticles, 1, ndim) material: diffmpm.material.Material Type of material for the set of particles. - element_ids: jax.numpy.ndarray + element_ids: ArrayLike The element ids that the particles belong to. This contains - information that will make sense only with the information of - the mesh that is being considered. + information that will make sense only with the information of + the mesh that is being considered. initialized: bool - False if particle property arrays like mass need to be initialized. - If True, they are set to values from `data`. + `False` if particle property arrays like mass need to be initialized. + If `True`, they are set to values from `data`. data: tuple Tuple of length 13 that sets arrays for mass, density, volume, - velocity, acceleration, momentum, strain, stress, strain_rate, - dstrain, f_ext, reference_loc and volumetric_strain_centroid. + velocity, acceleration, momentum, strain, stress, strain_rate, + dstrain, f_ext, reference_loc and volumetric_strain_centroid. """ self.material = material self.element_ids = element_ids - if len(loc.shape) != 3: + loc = jnp.asarray(loc, dtype=jnp.float32) + if loc.ndim != 3: raise ValueError( f"`loc` should be of size (nparticles, 1, ndim); " f"found {loc.shape}" ) @@ -86,10 +87,11 @@ def __init__( self.reference_loc, self.dvolumetric_strain, self.volumetric_strain_centroid, - ) = data + ) = data # type: ignore self.initialized = True def tree_flatten(self): + """Flatten class as Pytree type.""" children = ( self.loc, self.element_ids, @@ -116,6 +118,7 @@ def tree_flatten(self): @classmethod def tree_unflatten(cls, aux_data, children): + """Unflatten class from Pytree type.""" return cls( children[0], aux_data[0], @@ -124,24 +127,24 @@ def tree_unflatten(cls, aux_data, children): data=children[3:], ) - def __len__(self): + def __len__(self) -> int: """Set length of the class as number of particles.""" return self.loc.shape[0] - def __repr__(self): + def __repr__(self) -> str: """Informative repr showing number of particles.""" return f"Particles(nparticles={len(self)})" - def set_mass_volume(self, m: float | jnp.ndarray): - """ - Set particle mass. + def set_mass_volume(self, m: ArrayLike): + """Set particle mass. - Arguments - --------- + Parameters + ---------- m: float, array_like Mass to be set for particles. If scalar, mass for all - particles is set to this value. + particles is set to this value. """ + m = jnp.asarray(m) if jnp.isscalar(m): self.mass = jnp.ones_like(self.loc) * m elif m.shape == self.mass.shape: @@ -152,12 +155,22 @@ def set_mass_volume(self, m: float | jnp.ndarray): ) self.volume = jnp.divide(self.mass, self.material.properties["density"]) - def compute_volume(self, elements, total_elements): + def compute_volume(self, elements: _Element, total_elements: int): + """Compute volume of all particles. + + Parameters + ---------- + elements: diffmpm._Element + Elements that the particles are present in, and are used to + compute the particles' volumes. + total_elements: int + Total elements present in `elements`. + """ particles_per_element = jnp.bincount( self.element_ids, length=elements.total_elements ) vol = ( - elements.volume.squeeze((1, 2))[self.element_ids] + elements.volume.squeeze((1, 2))[self.element_ids] # type: ignore / particles_per_element[self.element_ids] ) self.volume = self.volume.at[:, 0, 0].set(vol) @@ -165,24 +178,26 @@ def compute_volume(self, elements, total_elements): self.mass = self.mass.at[:, 0, 0].set(vol * self.density.squeeze()) def update_natural_coords(self, elements: _Element): - """ - Update natural coordinates for the particles. + r"""Update natural coordinates for the particles. Whenever the particles' physical coordinates change, their natural coordinates need to be updated. This function updates the natural coordinates of the particles based on the element a particle is a part of. The update formula is - :math:`xi = (2x - (x_1^e + x_2^e)) / (x_2^e - x_1^e)` + \[ + \xi = (2x - (x_1^e + x_2^e)) / (x_2^e - x_1^e) + \] - If a particle is not in any element (element_id = -1), its - natural coordinate is set to 0. + where \(x_i^e\) are the nodal coordinates of the element the + particle is in. If a particle is not in any element + (element_id = -1), its natural coordinate is set to 0. - Arguments - --------- + Parameters + ---------- elements: diffmpm.element._Element Elements based on which to update the natural coordinates - of the particles. + of the particles. """ t = vmap(elements.id_to_node_loc)(self.element_ids) xi_coords = (self.loc - (t[:, 0, ...] + t[:, 2, ...]) / 2) * ( @@ -193,21 +208,20 @@ def update_natural_coords(self, elements: _Element): def update_position_velocity( self, elements: _Element, dt: float, velocity_update: bool ): - """ - Transfer nodal velocity to particles and update particle position. + """Transfer nodal velocity to particles and update particle position. The velocity is calculated based on the total force at nodes. - Arguments - --------- + Parameters + ---------- elements: diffmpm.element._Element Elements whose nodes are used to transfer the velocity. dt: float Timestep. velocity_update: bool If True, velocity is directly used as nodal velocity, else - velocity is calculated is interpolated nodal acceleration - multiplied by dt. Default is False. + velocity is calculated is interpolated nodal acceleration + multiplied by dt. Default is False. """ mapped_positions = elements.shapefn(self.reference_loc) mapped_ids = vmap(elements.id_to_node_ids)(self.element_ids).squeeze(-1) @@ -233,14 +247,13 @@ def update_position_velocity( self.momentum = self.momentum.at[:].set(self.mass * self.velocity) def compute_strain(self, elements: _Element, dt: float): - """ - Compute the strain on all particles. + """Compute the strain on all particles. This is done by first calculating the strain rate for the particles - and then calculating strain as strain += strain rate * dt. + and then calculating strain as `strain += strain rate * dt`. - Arguments - --------- + Parameters + ---------- elements: diffmpm.element._Element Elements whose nodes are used to calculate the strain. dt : float @@ -265,18 +278,18 @@ def compute_strain(self, elements: _Element, dt: float): self.dvolumetric_strain ) - def _compute_strain_rate(self, dn_dx: jnp.ndarray, elements: _Element): - """ - Compute the strain rate for particles. + def _compute_strain_rate(self, dn_dx: ArrayLike, elements: _Element): + """Compute the strain rate for particles. - Arguments - --------- - dn_dx: jnp.ndarray - The gradient of the shape function. - Expected shape (nparticles, 1, ndim) + Parameters + ---------- + dn_dx: ArrayLike + The gradient of the shape function. Expected shape + `(nparticles, 1, ndim)` elements: diffmpm.element._Element Elements whose nodes are used to calculate the strain rate. """ + dn_dx = jnp.asarray(dn_dx) strain_rate = jnp.zeros((dn_dx.shape[0], 6, 1)) # (nparticles, 6, 1) mapped_vel = vmap(elements.id_to_node_vel)( self.element_ids @@ -300,8 +313,7 @@ def _step(pid, args): return strain_rate def compute_stress(self, *args): - """ - Compute the strain on all particles. + """Compute the strain on all particles. This calculation is governed by the material of the particles. The stress calculated by the material is then @@ -314,23 +326,22 @@ def update_volume(self, *args): self.volume = self.volume.at[:, 0, :].multiply(1 + self.dvolumetric_strain) self.density = self.density.at[:, 0, :].divide(1 + self.dvolumetric_strain) - def assign_traction(self, pids, dir, traction_): + def assign_traction(self, pids: ArrayLike, dir: int, traction_: float): + """Assign traction to particles. + + Parameters + ---------- + pids: ArrayLike + IDs of the particles to which traction should be applied. + dir: int + The direction in which traction should be applied. + traction_: float + Traction value to be applied in the direction. + """ self.traction = self.traction.at[pids, 0, dir].add( traction_ * self.volume[pids, 0, 0] / self.size[pids, 0, dir] ) def zero_traction(self, *args): + """Set all traction values to 0.""" self.traction = self.traction.at[:].set(0) - - -if __name__ == "__main__": - from diffmpm.material import SimpleMaterial - from diffmpm.utils import _show_example - - _show_example( - Particles( - jnp.array([[[1]]]), - SimpleMaterial({"E": 2, "density": 1}), - jnp.array([0]), - ) - ) diff --git a/diffmpm/scheme.py b/diffmpm/scheme.py index 83e35ca..61a062e 100644 --- a/diffmpm/scheme.py +++ b/diffmpm/scheme.py @@ -1,3 +1,13 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from jax.typing import ArrayLike + +if TYPE_CHECKING: + import jax.numpy as jnp + from diffmpm.mesh import _MeshBase + import abc _schemes = ("usf", "usl") @@ -10,6 +20,7 @@ def __init__(self, mesh, dt, velocity_update): self.dt = dt def compute_nodal_kinematics(self): + """Compute nodal kinematics - map mass and momentum to mesh nodes.""" self.mesh.apply_on_elements("set_particle_element_ids") self.mesh.apply_on_particles("update_natural_coords") self.mesh.apply_on_elements("compute_nodal_mass") @@ -18,11 +29,23 @@ def compute_nodal_kinematics(self): self.mesh.apply_on_elements("apply_boundary_constraints") def compute_stress_strain(self): + """Compute stress and strain on the particles.""" self.mesh.apply_on_particles("compute_strain", args=(self.dt,)) self.mesh.apply_on_particles("update_volume") self.mesh.apply_on_particles("compute_stress") - def compute_forces(self, gravity, step): + def compute_forces(self, gravity: ArrayLike, step: int): + """Compute the forces acting in the system. + + Parameters + ---------- + gravity: ArrayLike + Gravity present in the system. This should be an array equal + with shape `(1, ndim)` where `ndim` is the dimension of the + simulation. + step: int + Current step being simulated. + """ self.mesh.apply_on_elements("compute_external_force") self.mesh.apply_on_elements("compute_body_force", args=(gravity,)) self.mesh.apply_traction_on_particles(step * self.dt) @@ -33,6 +56,7 @@ def compute_forces(self, gravity, step): # self.mesh.apply_on_elements("apply_force_boundary_constraints") def compute_particle_kinematics(self): + """Compute particle location, acceleration and velocity.""" self.mesh.apply_on_elements( "update_nodal_acceleration_velocity", args=(self.dt,) ) @@ -43,24 +67,40 @@ def compute_particle_kinematics(self): # TODO: Apply particle velocity constraints. @abc.abstractmethod - def precompute_stress_strain(): + def precompute_stress_strain(self): ... @abc.abstractmethod - def postcompute_stress_strain(): + def postcompute_stress_strain(self): ... class USF(_MPMScheme): """USF Scheme solver.""" - def __init__(self, mesh, dt, velocity_update): + def __init__(self, mesh: _MeshBase, dt: float, velocity_update: bool): + """Initialize USF Scheme solver. + + Parameters + ---------- + mesh: _MeshBase + A `diffmpm.Mesh` object that contains the elements that form + the underlying mesh used to solve the simulation. + dt: float + Timestep used in the simulation. + velocity_update: bool + Flag to control if velocity should be updated using nodal + velocity or interpolated nodal acceleration. If `True`, nodal + velocity is used, else nodal acceleration. Default `False`. + """ super().__init__(mesh, dt, velocity_update) def precompute_stress_strain(self): + """Compute stress and strain on particles.""" self.compute_stress_strain() def postcompute_stress_strain(self): + """Compute stress and strain on particles. (Empty call for USF).""" pass @@ -68,10 +108,26 @@ class USL(_MPMScheme): """USL Scheme solver.""" def __init__(self, mesh, dt, velocity_update): + """Initialize USL Scheme solver. + + Parameters + ---------- + mesh: _MeshBase + A `diffmpm.Mesh` object that contains the elements that form + the underlying mesh used to solve the simulation. + dt: float + Timestep used in the simulation. + velocity_update: bool + Flag to control if velocity should be updated using nodal + velocity or interpolated nodal acceleration. If `True`, nodal + velocity is used, else nodal acceleration. Default `False`. + """ super().__init__(mesh, dt, velocity_update) def precompute_stress_strain(self): + """Compute stress and strain on particles. (Empty call for USL).""" pass def postcompute_stress_strain(self): + """Compute stress and strain on particles.""" self.compute_stress_strain() diff --git a/diffmpm/solver.py b/diffmpm/solver.py index de4624e..3b1ae01 100644 --- a/diffmpm/solver.py +++ b/diffmpm/solver.py @@ -1,33 +1,73 @@ +from __future__ import annotations + import functools -from pathlib import Path +from typing import TYPE_CHECKING, Callable, Optional import jax.numpy as jnp from jax import lax from jax.experimental.host_callback import id_tap from jax.tree_util import register_pytree_node_class +from jax.typing import ArrayLike + +from diffmpm.scheme import USF, USL, _MPMScheme, _schemes -from diffmpm.scheme import USF, USL, _schemes +if TYPE_CHECKING: + from diffmpm.mesh import _MeshBase @register_pytree_node_class class MPMExplicit: + """A class to implement the fully explicit MPM.""" + __particle_props = ("loc", "velocity", "stress", "strain") def __init__( self, - mesh, - dt, - scheme="usf", - velocity_update=False, - sim_steps=1, - out_steps=1, - out_dir="results/", - writer_func=None, - ): + mesh: _MeshBase, + dt: float, + scheme: str = "usf", + velocity_update: bool = False, + sim_steps: int = 1, + out_steps: int = 1, + out_dir: str = "results/", + writer_func: Optional[Callable] = None, + ) -> None: + """Create an `MPMExplicit` object. + + This can be used to solve a given configuration of an MPM + problem. + + Parameters + ---------- + mesh: _MeshBase + A `diffmpm.Mesh` object that contains the elements that form + the underlying mesh used to solve the simulation. + dt: float + Timestep used in the simulation. + scheme: str + The MPM Scheme type used for the simulation. Can be one of + `"usl"` or `"usf"`. Default set to `"usf"`. + velocity_update: bool + Flag to control if velocity should be updated using nodal + velocity or interpolated nodal acceleration. If `True`, nodal + velocity is used, else nodal acceleration. Default `False`. + sim_steps: int + Number of steps to run the simulation for. Default set to 1. + out_steps: int + Frequency with which to store the results. For example, if + set to 5, the result at every 5th step will be stored. Default + set to 1. + out_dir: str + Path to the output directory where results are stored. + writer_func: Callable, None + Function that is used to write the state in the output + directory. + """ + if scheme == "usf": - self.mpm_scheme = USF(mesh, dt, velocity_update) + self.mpm_scheme: _MPMScheme = USF(mesh, dt, velocity_update) # type: ignore elif scheme == "usl": - self.mpm_scheme = USL(mesh, dt, velocity_update) + self.mpm_scheme: _MPMScheme = USL(mesh, dt, velocity_update) # type: ignore else: raise ValueError(f"Please select scheme from {_schemes}. Found {scheme}") self.mesh = mesh @@ -70,13 +110,35 @@ def tree_unflatten(cls, aux_data, children): writer_func=aux_data["writer_func"], ) - def jax_writer(self, func, args): + def _jax_writer(self, func, args): id_tap(func, args) - def solve(self, gravity: float | jnp.ndarray): + def solve(self, gravity: ArrayLike): + """Non-JIT solve method. + + This method runs the entire simulation for the defined number + of steps. + + .. note:: + This is mainly used for debugging and might be removed in + future versions or moved to the JIT solver. + + Parameters + ---------- + gravity: ArrayLike + Gravity present in the system. This should be an array equal + with shape `(1, ndim)` where `ndim` is the dimension of the + simulation. + + Returns + ------- + dict + A dictionary of `ArrayLike` arrays corresponding to the + all states of the simulation after completing all steps. + """ from collections import defaultdict - from tqdm import tqdm + from tqdm import tqdm # type: ignore result = defaultdict(list) for step in tqdm(range(self.sim_steps)): @@ -91,10 +153,29 @@ def solve(self, gravity: float | jnp.ndarray): result["stress"].append(pset.stress[:, :2, 0]) result["strain"].append(pset.strain[:, :2, 0]) - result = {k: jnp.asarray(v) for k, v in result.items()} - return result + result_arr = {k: jnp.asarray(v) for k, v in result.items()} + return result_arr + + def solve_jit(self, gravity: ArrayLike) -> dict: + """Solver method that runs the simulation. + + This method runs the entire simulation for the defined number + of steps. + + Parameters + ---------- + gravity: ArrayLike + Gravity present in the system. This should be an array equal + with shape `(1, ndim)` where `ndim` is the dimension of the + simulation. + + Returns + ------- + dict + A dictionary of `jax.numpy` arrays corresponding to the + final state of the simulation after completing all steps. + """ - def solve_jit(self, gravity: float | jnp.ndarray): def _step(i, data): self = data self.mpm_scheme.compute_nodal_kinematics() @@ -112,7 +193,7 @@ def _write(self, i): for j in range(len(self.mesh.particles)) ] ) - self.jax_writer( + self._jax_writer( functools.partial( self.writer_func, out_dir=self.out_dir, max_steps=self.sim_steps ), diff --git a/diffmpm/utils.py b/diffmpm/utils.py deleted file mode 100644 index 6559036..0000000 --- a/diffmpm/utils.py +++ /dev/null @@ -1,7 +0,0 @@ -from jax.tree_util import tree_flatten, tree_unflatten - - -def _show_example(structured): - flat, tree = tree_flatten(structured) - unflattened = tree_unflatten(tree, flat) - print(f"{structured=}\n {flat=}\n {tree=}\n {unflattened=}") diff --git a/diffmpm/writers.py b/diffmpm/writers.py index 4038b52..fdc5cd2 100644 --- a/diffmpm/writers.py +++ b/diffmpm/writers.py @@ -1,24 +1,45 @@ import abc import logging -import numpy as np from pathlib import Path +from typing import Tuple, Annotated, Any +from jax.typing import ArrayLike +import numpy as np + logger = logging.getLogger(__file__) +__all__ = ["_Writer", "EmptyWriter", "NPZWriter"] + + +class _Writer(abc.ABC): + """Base writer class.""" -class Writer(abc.ABC): @abc.abstractmethod def write(self): ... -class EmptyWriter(Writer): +class EmptyWriter(_Writer): + """Empty writer used when output is not to be written.""" + def write(self, args, transforms, **kwargs): + """Empty function.""" pass -class NPZWriter(Writer): - def write(self, args, transforms, **kwargs): +class NPZWriter(_Writer): + """Writer to write output in `.npz` format.""" + + def write( + self, + args: Tuple[ + Annotated[ArrayLike, "JAX arrays to be written"], + Annotated[int, "step number of the simulation"], + ], + transforms: Any, + **kwargs, + ): + """Writes the output arrays as `.npz` files.""" arrays, step = args max_digits = int(np.log10(kwargs["max_steps"])) + 1 if step == 0: diff --git a/examples/optim_1d.py b/examples/optim_1d.py index ceb9b7b..7c3a366 100644 --- a/examples/optim_1d.py +++ b/examples/optim_1d.py @@ -1,14 +1,15 @@ import jax.numpy as jnp import matplotlib.pyplot as plt import optax +from jax import grad, jit, value_and_grad +from tqdm import tqdm + +from diffmpm.constraint import Constraint from diffmpm.element import Linear1D from diffmpm.material import SimpleMaterial from diffmpm.mesh import Mesh1D from diffmpm.particle import Particles -from diffmpm.constraint import Constraint from diffmpm.solver import MPMExplicit -from jax import value_and_grad, grad, jit -from tqdm import tqdm E_true = 100 material = SimpleMaterial({"E": E_true, "density": 1}) diff --git a/examples/simple_1d.py b/examples/simple_1d.py index 739ac16..63929d4 100644 --- a/examples/simple_1d.py +++ b/examples/simple_1d.py @@ -1,10 +1,11 @@ import jax.numpy as jnp import matplotlib.pyplot as plt + +from diffmpm.constraint import Constraint from diffmpm.element import Linear1D from diffmpm.material import SimpleMaterial from diffmpm.mesh import Mesh1D from diffmpm.particle import Particles -from diffmpm.constraint import Constraint from diffmpm.solver import MPMExplicit E = 100 diff --git a/examples/simple_1d_file.py b/examples/simple_1d_file.py index 9048da6..2edd181 100644 --- a/examples/simple_1d_file.py +++ b/examples/simple_1d_file.py @@ -1,6 +1,8 @@ import sys + import jax.numpy as jnp import matplotlib.pyplot as plt + from diffmpm.solver import MPM mpm = MPM(sys.argv[1]) diff --git a/examples/simple_2d.py b/examples/simple_2d.py index a9206a2..2fbb1c3 100644 --- a/examples/simple_2d.py +++ b/examples/simple_2d.py @@ -1,14 +1,15 @@ from collections import namedtuple import jax.numpy as jnp + from diffmpm.constraint import Constraint from diffmpm.element import Quadrilateral4Node +from diffmpm.forces import NodalForce from diffmpm.functions import Linear from diffmpm.material import LinearElastic, SimpleMaterial from diffmpm.mesh import Mesh2D from diffmpm.particle import Particles from diffmpm.solver import MPMExplicit -from diffmpm.forces import NodalForce particles = Particles( jnp.array([[0.25, 0.25], [0.75, 0.25], [0.75, 0.75], [0.25, 0.75]]).reshape( diff --git a/examples/simple_2d_file.py b/examples/simple_2d_file.py index e6908f7..6cf41f9 100644 --- a/examples/simple_2d_file.py +++ b/examples/simple_2d_file.py @@ -1,6 +1,8 @@ import sys + import jax.numpy as jnp import matplotlib.pyplot as plt + from diffmpm.solver import MPM mpm = MPM(sys.argv[1]) diff --git a/tests/test_element.py b/tests/test_element.py index 2b72460..50881d9 100644 --- a/tests/test_element.py +++ b/tests/test_element.py @@ -1,11 +1,12 @@ -import pytest import jax.numpy as jnp +import pytest + +from diffmpm.constraint import Constraint from diffmpm.element import Quadrilateral4Node -from diffmpm.particle import Particles -from diffmpm.material import SimpleMaterial from diffmpm.forces import NodalForce from diffmpm.functions import Unit -from diffmpm.constraint import Constraint +from diffmpm.material import SimpleMaterial +from diffmpm.particle import Particles class TestLinear1D: diff --git a/tests/test_material.py b/tests/test_material.py index 8fa78f3..2e041d7 100644 --- a/tests/test_material.py +++ b/tests/test_material.py @@ -1,7 +1,7 @@ -import pytest import jax.numpy as jnp +import pytest -from diffmpm.material import SimpleMaterial, LinearElastic +from diffmpm.material import LinearElastic, SimpleMaterial material_dstrain_stress_targets = [ ( diff --git a/tests/test_particle.py b/tests/test_particle.py index 9dd9a5e..d7dedaa 100644 --- a/tests/test_particle.py +++ b/tests/test_particle.py @@ -1,5 +1,6 @@ import jax.numpy as jnp import pytest + from diffmpm.element import Quadrilateral4Node from diffmpm.material import SimpleMaterial from diffmpm.particle import Particles