Skip to content

Commit

Permalink
Pass particles as arg to compute stress
Browse files Browse the repository at this point in the history
  • Loading branch information
chahak13 committed Aug 7, 2023
1 parent efd7843 commit 1ec1156
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 14 deletions.
3 changes: 2 additions & 1 deletion diffmpm/materials/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ class _Material(abc.ABC):
"""Base material class."""

_props: Tuple[str, ...]
properties: dict

def __init__(self, material_properties):
"""Initialize material properties.
Expand Down Expand Up @@ -35,7 +36,7 @@ def __repr__(self):
...

@abc.abstractmethod
def compute_stress(self):
def compute_stress(self, particles):
"""Compute stress for the material."""
...

Expand Down
5 changes: 3 additions & 2 deletions diffmpm/materials/linear_elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class LinearElastic(_Material):
"""Linear Elastic Material."""

_props = ("density", "youngs_modulus", "poisson_ratio")
state_vars = ()

def __init__(self, material_properties):
"""Create a Linear Elastic material.
Expand Down Expand Up @@ -63,7 +64,7 @@ def _compute_elastic_tensor(self):
]
)

def compute_stress(self, dstrain):
def compute_stress(self, particles):
"""Compute material stress."""
dstress = self.de @ dstrain
dstress = self.de @ particles.dstrain
return dstress
5 changes: 3 additions & 2 deletions diffmpm/materials/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
@register_pytree_node_class
class SimpleMaterial(_Material):
_props = ("E", "density")
state_vars = ()

def __init__(self, material_properties):
self.validate_props(material_properties)
Expand All @@ -14,5 +15,5 @@ def __init__(self, material_properties):
def __repr__(self):
return f"SimpleMaterial(props={self.properties})"

def compute_stress(self, dstrain):
return dstrain * self.properties["E"]
def compute_stress(self, particles):
return particles.dstrain * self.properties["E"]
9 changes: 8 additions & 1 deletion diffmpm/particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ def __init__(
self.reference_loc = jnp.zeros_like(self.loc)
self.dvolumetric_strain = jnp.zeros((self.loc.shape[0], 1))
self.volumetric_strain_centroid = jnp.zeros((self.loc.shape[0], 1))
self.state_vars = {}
if self.material.state_vars:
self.state_vars = self.material.initialize_state_variables(
self.loc.shape[0]
)
else:
(
self.mass,
Expand All @@ -87,6 +92,7 @@ def __init__(
self.reference_loc,
self.dvolumetric_strain,
self.volumetric_strain_centroid,
self.state_vars,
) = data # type: ignore
self.initialized = True

Expand All @@ -112,6 +118,7 @@ def tree_flatten(self):
self.reference_loc,
self.dvolumetric_strain,
self.volumetric_strain_centroid,
self.state_vars,
)
aux_data = (self.material,)
return (children, aux_data)
Expand Down Expand Up @@ -319,7 +326,7 @@ def compute_stress(self, *args):
particles. The stress calculated by the material is then
added to the particles current stress values.
"""
self.stress = self.stress.at[:].add(self.material.compute_stress(self.dstrain))
self.stress = self.stress.at[:].add(self.material.compute_stress(self))

def update_volume(self, *args):
"""Update volume based on central strain rate."""
Expand Down
38 changes: 30 additions & 8 deletions tests/test_material.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,48 @@
import jax.numpy as jnp
import pytest
from diffmpm.materials import LinearElastic, SimpleMaterial
from diffmpm.particle import Particles

material_dstrain_stress_targets = [
particles_dstrain_stress_targets = [
(
SimpleMaterial({"E": 10, "density": 1}),
Particles(
jnp.array([[0.5, 0.5]]).reshape(1, 1, 2),
SimpleMaterial({"E": 10, "density": 1}),
jnp.array([0]),
),
jnp.ones((1, 6, 1)),
jnp.ones((1, 6, 1)) * 10,
),
(
LinearElastic({"density": 1, "youngs_modulus": 10, "poisson_ratio": 1}),
Particles(
jnp.array([[0.5, 0.5]]).reshape(1, 1, 2),
LinearElastic({"density": 1, "youngs_modulus": 10, "poisson_ratio": 1}),
jnp.array([0]),
),
jnp.ones((1, 6, 1)),
jnp.array([-10, -10, -10, 2.5, 2.5, 2.5]).reshape(1, 6, 1),
),
(
LinearElastic({"density": 1000, "youngs_modulus": 1e7, "poisson_ratio": 0.3}),
Particles(
jnp.array([[0.5, 0.5]]).reshape(1, 1, 2),
LinearElastic(
{"density": 1000, "youngs_modulus": 1e7, "poisson_ratio": 0.3}
),
jnp.array([0]),
),
jnp.array([0.001, 0.0005, 0, 0, 0, 0]).reshape(1, 6, 1),
jnp.array([1.63461538461538e4, 12500, 0.86538461538462e4, 0, 0, 0]).reshape(
1, 6, 1
),
),
(
LinearElastic({"density": 1000, "youngs_modulus": 1e7, "poisson_ratio": 0.3}),
Particles(
jnp.array([[0.5, 0.5]]).reshape(1, 1, 2),
LinearElastic(
{"density": 1000, "youngs_modulus": 1e7, "poisson_ratio": 0.3}
),
jnp.array([0]),
),
jnp.array([0.001, 0.0005, 0, 0.00001, 0, 0]).reshape(1, 6, 1),
jnp.array(
[1.63461538461538e4, 12500, 0.86538461538462e4, 3.84615384615385e01, 0, 0]
Expand All @@ -30,7 +51,8 @@
]


@pytest.mark.parametrize("material, dstrain, target", material_dstrain_stress_targets)
def test_compute_stress(material, dstrain, target):
stress = material.compute_stress(dstrain)
@pytest.mark.parametrize("particles, dstrain, target", particles_dstrain_stress_targets)
def test_compute_stress(particles, dstrain, target):
particles.dstrain = dstrain
stress = particles.material.compute_stress(particles)
assert jnp.allclose(stress, target)

0 comments on commit 1ec1156

Please sign in to comment.