diff --git a/diffmpm/materials/_base.py b/diffmpm/materials/_base.py index d30b15b..896b206 100644 --- a/diffmpm/materials/_base.py +++ b/diffmpm/materials/_base.py @@ -6,6 +6,7 @@ class _Material(abc.ABC): """Base material class.""" _props: Tuple[str, ...] + properties: dict def __init__(self, material_properties): """Initialize material properties. @@ -35,7 +36,7 @@ def __repr__(self): ... @abc.abstractmethod - def compute_stress(self): + def compute_stress(self, particles): """Compute stress for the material.""" ... diff --git a/diffmpm/materials/linear_elastic.py b/diffmpm/materials/linear_elastic.py index 098c10e..5a008d4 100644 --- a/diffmpm/materials/linear_elastic.py +++ b/diffmpm/materials/linear_elastic.py @@ -9,6 +9,7 @@ class LinearElastic(_Material): """Linear Elastic Material.""" _props = ("density", "youngs_modulus", "poisson_ratio") + state_vars = () def __init__(self, material_properties): """Create a Linear Elastic material. @@ -63,7 +64,7 @@ def _compute_elastic_tensor(self): ] ) - def compute_stress(self, dstrain): + def compute_stress(self, particles): """Compute material stress.""" - dstress = self.de @ dstrain + dstress = self.de @ particles.dstrain return dstress diff --git a/diffmpm/materials/simple.py b/diffmpm/materials/simple.py index d9cf15d..77b57ca 100644 --- a/diffmpm/materials/simple.py +++ b/diffmpm/materials/simple.py @@ -6,6 +6,7 @@ @register_pytree_node_class class SimpleMaterial(_Material): _props = ("E", "density") + state_vars = () def __init__(self, material_properties): self.validate_props(material_properties) @@ -14,5 +15,5 @@ def __init__(self, material_properties): def __repr__(self): return f"SimpleMaterial(props={self.properties})" - def compute_stress(self, dstrain): - return dstrain * self.properties["E"] + def compute_stress(self, particles): + return particles.dstrain * self.properties["E"] diff --git a/diffmpm/particle.py b/diffmpm/particle.py index 586fec0..04f2581 100644 --- a/diffmpm/particle.py +++ b/diffmpm/particle.py @@ -69,6 +69,11 @@ def __init__( self.reference_loc = jnp.zeros_like(self.loc) self.dvolumetric_strain = jnp.zeros((self.loc.shape[0], 1)) self.volumetric_strain_centroid = jnp.zeros((self.loc.shape[0], 1)) + self.state_vars = {} + if self.material.state_vars: + self.state_vars = self.material.initialize_state_variables( + self.loc.shape[0] + ) else: ( self.mass, @@ -87,6 +92,7 @@ def __init__( self.reference_loc, self.dvolumetric_strain, self.volumetric_strain_centroid, + self.state_vars, ) = data # type: ignore self.initialized = True @@ -112,6 +118,7 @@ def tree_flatten(self): self.reference_loc, self.dvolumetric_strain, self.volumetric_strain_centroid, + self.state_vars, ) aux_data = (self.material,) return (children, aux_data) @@ -319,7 +326,7 @@ def compute_stress(self, *args): particles. The stress calculated by the material is then added to the particles current stress values. """ - self.stress = self.stress.at[:].add(self.material.compute_stress(self.dstrain)) + self.stress = self.stress.at[:].add(self.material.compute_stress(self)) def update_volume(self, *args): """Update volume based on central strain rate.""" diff --git a/tests/test_material.py b/tests/test_material.py index 66cb4dc..f81cfb0 100644 --- a/tests/test_material.py +++ b/tests/test_material.py @@ -1,27 +1,48 @@ import jax.numpy as jnp import pytest from diffmpm.materials import LinearElastic, SimpleMaterial +from diffmpm.particle import Particles -material_dstrain_stress_targets = [ +particles_dstrain_stress_targets = [ ( - SimpleMaterial({"E": 10, "density": 1}), + Particles( + jnp.array([[0.5, 0.5]]).reshape(1, 1, 2), + SimpleMaterial({"E": 10, "density": 1}), + jnp.array([0]), + ), jnp.ones((1, 6, 1)), jnp.ones((1, 6, 1)) * 10, ), ( - LinearElastic({"density": 1, "youngs_modulus": 10, "poisson_ratio": 1}), + Particles( + jnp.array([[0.5, 0.5]]).reshape(1, 1, 2), + LinearElastic({"density": 1, "youngs_modulus": 10, "poisson_ratio": 1}), + jnp.array([0]), + ), jnp.ones((1, 6, 1)), jnp.array([-10, -10, -10, 2.5, 2.5, 2.5]).reshape(1, 6, 1), ), ( - LinearElastic({"density": 1000, "youngs_modulus": 1e7, "poisson_ratio": 0.3}), + Particles( + jnp.array([[0.5, 0.5]]).reshape(1, 1, 2), + LinearElastic( + {"density": 1000, "youngs_modulus": 1e7, "poisson_ratio": 0.3} + ), + jnp.array([0]), + ), jnp.array([0.001, 0.0005, 0, 0, 0, 0]).reshape(1, 6, 1), jnp.array([1.63461538461538e4, 12500, 0.86538461538462e4, 0, 0, 0]).reshape( 1, 6, 1 ), ), ( - LinearElastic({"density": 1000, "youngs_modulus": 1e7, "poisson_ratio": 0.3}), + Particles( + jnp.array([[0.5, 0.5]]).reshape(1, 1, 2), + LinearElastic( + {"density": 1000, "youngs_modulus": 1e7, "poisson_ratio": 0.3} + ), + jnp.array([0]), + ), jnp.array([0.001, 0.0005, 0, 0.00001, 0, 0]).reshape(1, 6, 1), jnp.array( [1.63461538461538e4, 12500, 0.86538461538462e4, 3.84615384615385e01, 0, 0] @@ -30,7 +51,8 @@ ] -@pytest.mark.parametrize("material, dstrain, target", material_dstrain_stress_targets) -def test_compute_stress(material, dstrain, target): - stress = material.compute_stress(dstrain) +@pytest.mark.parametrize("particles, dstrain, target", particles_dstrain_stress_targets) +def test_compute_stress(particles, dstrain, target): + particles.dstrain = dstrain + stress = particles.material.compute_stress(particles) assert jnp.allclose(stress, target)