Skip to content

Commit

Permalink
Merge pull request #791 from dstl/log_weight_particle_state
Browse files Browse the repository at this point in the history
Add log weight property to Particle State
  • Loading branch information
sdhiscocks committed Apr 11, 2023
2 parents 819efb0 + b2c3da1 commit 3df69be
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 78 deletions.
14 changes: 7 additions & 7 deletions stonesoup/predictor/particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ def predict(self, prior, timestamp=None, **kwargs):
time_interval=time_interval,
**kwargs)

return Prediction.from_state(prior, state_vector=new_state_vector, weight=prior.weight,
timestamp=timestamp, particle_list=None,
return Prediction.from_state(prior,
state_vector=new_state_vector,
timestamp=timestamp,
transition_model=self.transition_model)


Expand Down Expand Up @@ -92,10 +93,11 @@ def predict(self, prior, *args, **kwargs):
GaussianState(prior.state_vector, prior.covar, prior.timestamp),
*args, **kwargs)

return Prediction.from_state(prior, state_vector=particle_prediction.state_vector,
weight=particle_prediction.weight,
return Prediction.from_state(prior,
state_vector=particle_prediction.state_vector,
log_weight=particle_prediction.log_weight,
timestamp=particle_prediction.timestamp,
fixed_covar=kalman_prediction.covar, particle_list=None,
fixed_covar=kalman_prediction.covar,
transition_model=self.transition_model)


Expand Down Expand Up @@ -144,7 +146,6 @@ def predict(self, prior, timestamp=None, **kwargs):
prior,
state_vector=copy.copy(prior.state_vector),
parent=prior,
particle_list=None,
dynamic_model=copy.copy(prior.dynamic_model),
timestamp=timestamp)

Expand Down Expand Up @@ -214,7 +215,6 @@ def predict(self, prior, timestamp=None, **kwargs):
prior,
state_vector=copy.copy(prior.state_vector),
parent=prior,
particle_list=None,
timestamp=timestamp)

# Change the value of the dynamic value randomly according to the
Expand Down
11 changes: 4 additions & 7 deletions stonesoup/resampler/particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from .base import Resampler
from ..base import Property
from ..types.numeric import Probability
from ..types.state import ParticleState


Expand All @@ -29,9 +28,7 @@ def resample(self, particles, nparts=None):
if nparts is None:
nparts = len(particles)

weight = Probability(1 / nparts)

log_weights = np.asfarray(np.log(particles.weight))
log_weights = particles.log_weight
weight_order = np.argsort(log_weights, kind='stable')
max_log_value = log_weights[weight_order[-1]]
with np.errstate(divide='ignore'):
Expand All @@ -47,8 +44,7 @@ def resample(self, particles, nparts=None):
index = weight_order[np.searchsorted(cdf, np.log(u_j))]

new_particles = particles[index]
new_particles.weight = np.full((nparts, ), weight)

new_particles.log_weight = np.full((nparts, ), np.log(1/nparts))
return new_particles


Expand Down Expand Up @@ -89,7 +85,8 @@ def resample(self, particles):
particles = ParticleState(None, particle_list=particles)
if self.threshold is None:
self.threshold = len(particles) / 2
if 1 / np.sum(np.square(particles.weight)) < self.threshold: # If ESS too small, resample
# If ESS too small, resample
if 1 / np.sum(np.exp(2*particles.log_weight)) < self.threshold:
return self.resampler.resample(self.resampler, particles)
else:
return particles
134 changes: 91 additions & 43 deletions stonesoup/types/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def from_state(state: 'State', *args: Any, target_type: Optional[typing.Type] =
new_kwargs = {
name: getattr(state, name)
for name in type(state).properties.keys() & target_type.properties.keys()
if name not in args_property_names}
if name not in args_property_names and name not in kwargs}

new_kwargs.update(kwargs)

Expand Down Expand Up @@ -163,7 +163,7 @@ def from_state(
if target_type is None:
target_type = CreatableFromState.class_mapping[cls][state_type]

return State.from_state(state, *args, **kwargs, target_type=target_type)
return target_type.from_state(state, *args, **kwargs, target_type=target_type)


class ASDState(Type):
Expand Down Expand Up @@ -617,6 +617,7 @@ class ParticleState(State):

state_vector: StateVectors = Property(doc='State vectors.')
weight: MutableSequence[Probability] = Property(default=None, doc='Weights of particles')
log_weight: np.ndarray = Property(default=None, doc='Log weights of particles')
parent: 'ParticleState' = Property(default=None, doc='Parent particles')
particle_list: MutableSequence[Particle] = Property(default=None,
doc='List of Particle objects')
Expand All @@ -625,6 +626,22 @@ class ParticleState(State):
'weighted sample covariance is then used.')

def __init__(self, *args, **kwargs):
weight = next(
(val for name, val in zip(type(self).properties, args) if name == 'weight'),
kwargs.get('weight', None))
log_weight, idx = next(
((val, idx) for idx, (name, val) in enumerate(zip(type(self).properties, args))
if name == 'log_weight'),
(kwargs.get('log_weight', None), None))

if weight is not None and log_weight is not None:
raise ValueError("Cannot provide both weight and log weight")
elif log_weight is None and weight is not None:
log_weight = np.log(np.asfarray(weight))
if idx is not None:
args[idx] = log_weight
else:
kwargs['log_weight'] = log_weight
super().__init__(*args, **kwargs)

if (self.particle_list is not None) and \
Expand All @@ -650,34 +667,53 @@ def __init__(self, *args, **kwargs):

if self.state_vector is not None and not isinstance(self.state_vector, StateVectors):
self.state_vector = StateVectors(self.state_vector)
if self.weight is not None and not isinstance(self.weight, np.ndarray):
self.weight = np.array(self.weight)

def __getitem__(self, item):
if self.parent is not None:
parent = self.parent[item]
else:
parent = None

if self.weight is not None:
weight = self.weight[item]
if self.log_weight is not None:
log_weight = self.log_weight[item]
else:
weight = None
log_weight = None

if isinstance(item, int):
result = Particle(state_vector=self.state_vector[:, item],
weight=weight,
weight=self.weight[item] if self.weight is not None else None,
parent=parent)
else:
# Allow for Prediction/Update sub-types
result = type(self).from_state(self,
state_vector=self.state_vector[:, item],
weight=weight,
parent=parent,
particle_list=None)
log_weight=log_weight,
parent=parent)
return result

@clearable_cached_property('state_vector', 'weight')
@classmethod
def from_state(cls, state: 'State', *args: Any, target_type: Optional[typing.Type] = None,
**kwargs: Any) -> 'State':

# Handle default presence of both particle_list and weight once class has been created by
# ignoring particle_list and weight (setting to None) if not provided.
particle_list, particle_list_idx = next(
((val, idx) for idx, (name, val) in enumerate(zip(cls.properties, args))
if name == 'particle_list'),
(kwargs.get('particle_list', None), None))
if particle_list_idx is None:
kwargs['particle_list'] = particle_list

weight, weight_idx = next(
((val, idx) for idx, (name, val) in enumerate(zip(cls.properties, args))
if name == 'weight'),
(kwargs.get('weight', None), None))
if weight_idx is None:
kwargs['weight'] = weight

return super().from_state(state, *args, target_type=target_type, **kwargs)

@clearable_cached_property('state_vector', 'log_weight')
def particles(self):
"""Sequence of individual :class:`~.Particle` objects."""
if self.particle_list is not None:
Expand All @@ -692,22 +728,42 @@ def ndim(self):
"""The number of dimensions represented by the state."""
return self.state_vector.shape[0]

@clearable_cached_property('state_vector', 'weight')
@clearable_cached_property('state_vector', 'log_weight')
def mean(self):
"""Sample mean for particles"""
if len(self) == 1: # No need to calculate mean
return self.state_vector
return np.average(self.state_vector, axis=1, weights=np.asfarray(self.weight))
return np.average(self.state_vector, axis=1, weights=np.exp(self.log_weight))

@clearable_cached_property('state_vector', 'weight', 'fixed_covar')
@clearable_cached_property('state_vector', 'log_weight', 'fixed_covar')
def covar(self):
"""Sample covariance matrix for particles"""
if self.fixed_covar is not None:
return self.fixed_covar
return np.cov(self.state_vector, ddof=0, aweights=np.asfarray(self.weight))
return np.cov(self.state_vector, ddof=0, aweights=np.exp(self.log_weight))

@weight.setter
def weight(self, value):
if value is None:
self.log_weight = None
else:
self.log_weight = np.log(np.asfarray(value))
self.__dict__['weight'] = np.asanyarray(value)

@weight.getter
def weight(self):
try:
return self.__dict__['weight']
except KeyError:
log_weight = self.log_weight
if log_weight is None:
return None
weight = Probability.from_log_ufunc(log_weight)
self.__dict__['weight'] = weight
return weight

State.register(ParticleState) # noqa: E305
ParticleState.log_weight._clear_cached.add('weight')


class MultiModelParticleState(ParticleState):
Expand All @@ -734,28 +790,28 @@ def __getitem__(self, item):
else:
parent = None

if self.weight is not None:
weight = self.weight[item]
if self.log_weight is not None:
log_weight = self.log_weight[item]
else:
weight = None
log_weight = None

if self.dynamic_model is not None:
dynamic_model = self.dynamic_model[item]
else:
dynamic_model = None

if isinstance(item, int):
result = MultiModelParticle(state_vector=self.state_vector[:, item],
weight=weight,
parent=parent,
dynamic_model=dynamic_model)
result = MultiModelParticle(
state_vector=self.state_vector[:, item],
weight=self.weight[item] if self.weight is not None else None,
parent=parent,
dynamic_model=dynamic_model)
else:
# Allow for Prediction/Update sub-types
result = type(self).from_state(self,
state_vector=self.state_vector[:, item],
weight=weight,
log_weight=log_weight,
parent=parent,
particle_list=None,
dynamic_model=dynamic_model)
return result

Expand All @@ -780,33 +836,33 @@ def __getitem__(self, item):
else:
parent = None

if self.weight is not None:
weight = self.weight[item]
if self.log_weight is not None:
log_weight = self.log_weight[item]
else:
weight = None
log_weight = None

if self.model_probabilities is not None:
model_probabilities = self.model_probabilities[:, item]
else:
model_probabilities = None

if isinstance(item, int):
result = RaoBlackwellisedParticle(state_vector=self.state_vector[:, item],
weight=weight,
parent=parent,
model_probabilities=model_probabilities)
result = RaoBlackwellisedParticle(
state_vector=self.state_vector[:, item],
weight=self.weight[item] if self.weight is not None else None,
parent=parent,
model_probabilities=model_probabilities)
else:
# Allow for Prediction/Update sub-types
result = type(self).from_state(self,
state_vector=self.state_vector[:, item],
weight=weight,
log_weight=log_weight,
parent=parent,
particle_list=None,
model_probabilities=model_probabilities)
return result


class EnsembleState(Type):
class EnsembleState(State):
r"""Ensemble State type
This is an Ensemble state object which describes the system state as a
Expand Down Expand Up @@ -889,11 +945,6 @@ def generate_ensemble(mean, covar, num_vectors):

return ensemble

@property
def ndim(self):
"""Number of dimensions in state vectors"""
return np.shape(self.state_vector)[0]

@property
def num_vectors(self):
"""Number of columns in state ensemble"""
Expand All @@ -917,9 +968,6 @@ def sqrt_covar(self):
/ np.sqrt(self.num_vectors - 1))


State.register(EnsembleState) # noqa: E305


class CategoricalState(State):
r"""CategoricalState type.
Expand Down
Loading

0 comments on commit 3df69be

Please sign in to comment.