Skip to content

Commit

Permalink
Restructure code and add materials subdir
Browse files Browse the repository at this point in the history
To make the imports easier from the materials subdir, also restructured
other files. This moves `MPM` to a separate file so as to remove
circular imports for materials module.
  • Loading branch information
chahak13 committed Aug 7, 2023
1 parent f6946b6 commit efd7843
Show file tree
Hide file tree
Showing 16 changed files with 126 additions and 120 deletions.
2 changes: 1 addition & 1 deletion benchmarks/2d/uniaxial_nodal_forces/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import jax.numpy as jnp

from diffmpm import MPM
from diffmpm.mpm import MPM


def test_benchmarks():
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/2d/uniaxial_particle_traction/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import jax.numpy as jnp

from diffmpm import MPM
from diffmpm.mpm import MPM


def test_benchmarks():
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/2d/uniaxial_stress/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import jax.numpy as jnp

from diffmpm import MPM
from diffmpm.mpm import MPM


def test_benchmarks():
Expand Down
44 changes: 1 addition & 43 deletions diffmpm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,5 @@
from importlib.metadata import version
from pathlib import Path

import diffmpm.writers as writers
from diffmpm.io import Config
from diffmpm.solver import MPMExplicit

__all__ = ["MPM", "__version__"]
__all__ = ["__version__"]

__version__ = version("diffmpm")


class MPM:
def __init__(self, filepath):
self._config = Config(filepath)
mesh = self._config.parse()
out_dir = Path(self._config.parsed_config["output"]["folder"]).joinpath(
self._config.parsed_config["meta"]["title"],
)

write_format = self._config.parsed_config["output"].get("format", None)
if write_format is None or write_format.lower() == "none":
writer_func = None
elif write_format == "npz":
writer_func = writers.NPZWriter().write
else:
raise ValueError(f"Specified output format not supported: {write_format}")

if self._config.parsed_config["meta"]["type"] == "MPMExplicit":
self.solver = MPMExplicit(
mesh,
self._config.parsed_config["meta"]["dt"],
velocity_update=self._config.parsed_config["meta"]["velocity_update"],
sim_steps=self._config.parsed_config["meta"]["nsteps"],
out_steps=self._config.parsed_config["output"]["step_frequency"],
out_dir=out_dir,
writer_func=writer_func,
)
else:
raise ValueError("Wrong type of solver specified.")

def solve(self):
"""Solve the MPM simulation using JIT solver."""
arrays = self.solver.solve_jit(
self._config.parsed_config["external_loading"]["gravity"],
)
return arrays
2 changes: 1 addition & 1 deletion diffmpm/cli/mpm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import click

from diffmpm import MPM
from diffmpm.mpm import MPM


@click.command() # type: ignore
Expand Down
3 changes: 1 addition & 2 deletions diffmpm/io.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import json
import tomllib as tl
from collections import namedtuple

import jax.numpy as jnp

from diffmpm import element as mpel
from diffmpm import material as mpmat
from diffmpm import materials as mpmat
from diffmpm import mesh as mpmesh
from diffmpm.constraint import Constraint
from diffmpm.forces import NodalForce, ParticleTraction
Expand Down
3 changes: 3 additions & 0 deletions diffmpm/materials/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from diffmpm.materials._base import _Material
from diffmpm.materials.simple import SimpleMaterial
from diffmpm.materials.linear_elastic import LinearElastic
48 changes: 48 additions & 0 deletions diffmpm/materials/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import abc
from typing import Tuple


class _Material(abc.ABC):
"""Base material class."""

_props: Tuple[str, ...]

def __init__(self, material_properties):
"""Initialize material properties.
Parameters
----------
material_properties: dict
A key-value map for various material properties.
"""
self.properties = material_properties

# @abc.abstractmethod
def tree_flatten(self):
"""Flatten this class as PyTree Node."""
return (tuple(), self.properties)

# @abc.abstractmethod
@classmethod
def tree_unflatten(cls, aux_data, children):
"""Unflatten this class as PyTree Node."""
del children
return cls(aux_data)

@abc.abstractmethod
def __repr__(self):
"""Repr for Material class."""
...

@abc.abstractmethod
def compute_stress(self):
"""Compute stress for the material."""
...

def validate_props(self, material_properties):
for key in self._props:
if key not in material_properties:
raise KeyError(
f"'{key}' should be present in `material_properties` "
f"for {self.__class__.__name__} materials."
)
66 changes: 2 additions & 64 deletions diffmpm/material.py → diffmpm/materials/linear_elastic.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,11 @@
import abc
from typing import Tuple

import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class


class Material(abc.ABC):
"""Base material class."""

_props: Tuple[str, ...]

def __init__(self, material_properties):
"""Initialize material properties.
Parameters
----------
material_properties: dict
A key-value map for various material properties.
"""
self.properties = material_properties

# @abc.abstractmethod
def tree_flatten(self):
"""Flatten this class as PyTree Node."""
return (tuple(), self.properties)

# @abc.abstractmethod
@classmethod
def tree_unflatten(cls, aux_data, children):
"""Unflatten this class as PyTree Node."""
del children
return cls(aux_data)

@abc.abstractmethod
def __repr__(self):
"""Repr for Material class."""
...

@abc.abstractmethod
def compute_stress(self):
"""Compute stress for the material."""
...

def validate_props(self, material_properties):
for key in self._props:
if key not in material_properties:
raise KeyError(
f"'{key}' should be present in `material_properties` "
f"for {self.__class__.__name__} materials."
)
from ._base import _Material


@register_pytree_node_class
class LinearElastic(Material):
class LinearElastic(_Material):
"""Linear Elastic Material."""

_props = ("density", "youngs_modulus", "poisson_ratio")
Expand Down Expand Up @@ -114,18 +67,3 @@ def compute_stress(self, dstrain):
"""Compute material stress."""
dstress = self.de @ dstrain
return dstress


@register_pytree_node_class
class SimpleMaterial(Material):
_props = ("E", "density")

def __init__(self, material_properties):
self.validate_props(material_properties)
self.properties = material_properties

def __repr__(self):
return f"SimpleMaterial(props={self.properties})"

def compute_stress(self, dstrain):
return dstrain * self.properties["E"]
1 change: 1 addition & 0 deletions diffmpm/materials/newtonian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#!/usr/bin/env python3
18 changes: 18 additions & 0 deletions diffmpm/materials/simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from jax.tree_util import register_pytree_node_class

from ._base import _Material


@register_pytree_node_class
class SimpleMaterial(_Material):
_props = ("E", "density")

def __init__(self, material_properties):
self.validate_props(material_properties)
self.properties = material_properties

def __repr__(self):
return f"SimpleMaterial(props={self.properties})"

def compute_stress(self, dstrain):
return dstrain * self.properties["E"]
42 changes: 42 additions & 0 deletions diffmpm/mpm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from pathlib import Path

import diffmpm.writers as writers
from diffmpm.io import Config
from diffmpm.solver import MPMExplicit


class MPM:
def __init__(self, filepath):
self._config = Config(filepath)
mesh = self._config.parse()
out_dir = Path(self._config.parsed_config["output"]["folder"]).joinpath(
self._config.parsed_config["meta"]["title"],
)

write_format = self._config.parsed_config["output"].get("format", None)
if write_format is None or write_format.lower() == "none":
writer_func = None
elif write_format == "npz":
writer_func = writers.NPZWriter().write
else:
raise ValueError(f"Specified output format not supported: {write_format}")

if self._config.parsed_config["meta"]["type"] == "MPMExplicit":
self.solver = MPMExplicit(
mesh,
self._config.parsed_config["meta"]["dt"],
velocity_update=self._config.parsed_config["meta"]["velocity_update"],
sim_steps=self._config.parsed_config["meta"]["nsteps"],
out_steps=self._config.parsed_config["output"]["step_frequency"],
out_dir=out_dir,
writer_func=writer_func,
)
else:
raise ValueError("Wrong type of solver specified.")

def solve(self):
"""Solve the MPM simulation using JIT solver."""
arrays = self.solver.solve_jit(
self._config.parsed_config["external_loading"]["gravity"],
)
return arrays
6 changes: 3 additions & 3 deletions diffmpm/particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from jax.typing import ArrayLike

from diffmpm.element import _Element
from diffmpm.material import Material
from diffmpm.materials import _Material


@register_pytree_node_class
Expand All @@ -16,7 +16,7 @@ class Particles(Sized):
def __init__(
self,
loc: ArrayLike,
material: Material,
material: _Material,
element_ids: ArrayLike,
initialized: Optional[bool] = None,
data: Optional[Tuple[ArrayLike, ...]] = None,
Expand All @@ -27,7 +27,7 @@ def __init__(
----------
loc: ArrayLike
Location of the particles. Expected shape (nparticles, 1, ndim)
material: diffmpm.material.Material
material: diffmpm.materials._Material
Type of material for the set of particles.
element_ids: ArrayLike
The element ids that the particles belong to. This contains
Expand Down
2 changes: 1 addition & 1 deletion tests/test_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from diffmpm.element import Quadrilateral4Node
from diffmpm.forces import NodalForce
from diffmpm.functions import Unit
from diffmpm.material import SimpleMaterial
from diffmpm.materials import SimpleMaterial
from diffmpm.particle import Particles


Expand Down
3 changes: 1 addition & 2 deletions tests/test_material.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import jax.numpy as jnp
import pytest

from diffmpm.material import LinearElastic, SimpleMaterial
from diffmpm.materials import LinearElastic, SimpleMaterial

material_dstrain_stress_targets = [
(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest

from diffmpm.element import Quadrilateral4Node
from diffmpm.material import SimpleMaterial
from diffmpm.materials import SimpleMaterial
from diffmpm.particle import Particles


Expand Down

0 comments on commit efd7843

Please sign in to comment.