Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Saveload custom func #934

Merged
merged 4 commits into from
Dec 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lmfit/jsonutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ def decode4js(obj):
return obj
out = obj
classname = obj.pop('__class__', None)
if classname is None and isinstance(obj, dict):
classname = 'dict'
if classname is None:
return obj
if classname == 'Complex':
Expand Down
4 changes: 1 addition & 3 deletions lmfit/minimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,11 +906,9 @@ def scalar_minimize(self, method='Nelder-Mead', params=None, max_nfev=None,
fmin_kws['options']['maxfun'] = 2*self.max_nfev
elif method == 'COBYLA':
# for this method, we explicitly let the solver reach
# the users max nfev, and do not abort in _residual.
# the users max nfev, and do not abort in _residual
fmin_kws['options']['maxiter'] = self.max_nfev
self.max_nfev = 5*self.max_nfev

# fmin_kws = dict(method=method, options={'maxfun': 2*self.max_nfev})
fmin_kws.update(self.kws)

if 'maxiter' in kws:
Expand Down
48 changes: 29 additions & 19 deletions lmfit/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,8 @@ def setpar(par, val):
# val is expected to be float-like or a dict: must have 'value' or 'expr' key
if isinstance(val, dict):
dval = val
elif np.iscomplex(val) or isinstance(val, complex):
dval = {'value': val.real}
else:
dval = {'value': float(val)}
if len(dval) < 1 or not ('value' in dval or 'expr' in dval):
Expand Down Expand Up @@ -1808,12 +1810,11 @@ def dumps(self, **kws):
loads, json.dumps

"""
out = {'__class__': 'lmfit.ModelResult', '__version__': '1',
out = {'__class__': 'lmfit.ModelResult', '__version__': '2',
'model': encode4js(self.model._get_state())}
pasteval = self.params._asteval
out['params'] = [p.__getstate__() for p in self.params.values()]
out['unique_symbols'] = {key: encode4js(pasteval.symtable[key])
for key in pasteval.user_defined_symbols()}

for attr in ('params', 'init_params'):
out[attr] = getattr(self, attr).dumps()

for attr in ('aborted', 'aic', 'best_values', 'bic', 'chisqr',
'ci_out', 'col_deriv', 'covar', 'errorbars', 'flatchain',
Expand All @@ -1828,7 +1829,6 @@ def dumps(self, **kws):
continue
if isinstance(val, np.bool_):
val = bool(val)

out[attr] = encode4js(val)

val = out.get('message', '')
Expand Down Expand Up @@ -1896,19 +1896,28 @@ def loads(self, s, funcdefs=None, **kws):
if funcdefs:
# Remove model function so as not pass it into the _asteval.symtable
funcdefs.pop(self.model.func.__name__, None)
# params
if funcdefs:
# Remove model function so as not pass it into the _asteval.symtable
funcdefs.pop(self.model.func.__name__, None)
for target in ('params', 'init_params'):
state = {'unique_symbols': modres['unique_symbols'], 'params': []}
for parstate in modres['params']:
_par = Parameter(name='')
_par.__setstate__(parstate)
state['params'].append(_par)
_params = Parameters(usersyms=funcdefs)
_params.__setstate__(state)
setattr(self, target, _params)

# how params are saved was changed with version 2:
modres_vers = modres.get('__version__', '1')
if modres_vers == '1':
for target in ('params', 'init_params'):
state = {'unique_symbols': modres['unique_symbols'], 'params': []}
for parstate in modres['params']:
_par = Parameter(name='')
_par.__setstate__(parstate)
state['params'].append(_par)
_params = Parameters(usersyms=funcdefs)
_params.__setstate__(state)
setattr(self, target, _params)

elif modres_vers == '2':
for target in ('params', 'init_params'):
_pars = Parameters()
_pars.loads(modres[target])
if funcdefs:
for key, val in funcdefs.items():
_pars._asteval.symtable[key] = val
setattr(self, target, _pars)

for attr in ('aborted', 'aic', 'best_fit', 'best_values', 'bic',
'chisqr', 'ci_out', 'col_deriv', 'covar', 'data',
Expand All @@ -1930,6 +1939,7 @@ def loads(self, s, funcdefs=None, **kws):
if par is not None:
par.correl = par.stderr = None
par.value = par.init_value = self.init_values[parname]

self.init_fit = self.model.eval(self.init_params, **self.userkws)
self.result = MinimizerResult()
self.result.params = self.params
Expand Down
4 changes: 4 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,6 +1079,10 @@ def test_composite_has_bestvalues(self):
def test_composite_plotting(self):
# test that a composite model has non-empty best_values
import matplotlib
try:
matplotlib.pyplot.close('all')
except ValueError:
pass
matplotlib.use('Agg')

model1 = models.GaussianModel(prefix='g1_')
Expand Down
Loading