-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Increase backward compatibility of checkpoints. #47708
base: master
Are you sure you want to change the base?
[RLlib] Increase backward compatibility of checkpoints. #47708
Conversation
…nitialization of 'Algorithm' to always initialize the config from state. Furthermore, added getter and setter to 'PolicySpec'. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
@@ -455,7 +455,7 @@ def __init__( | |||
object. If unspecified, a default logger is created. | |||
**kwargs: Arguments passed to the Trainable base class. | |||
""" | |||
config = config or self.get_default_config() | |||
config = config # or self.get_default_config() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: why remove the default config path?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Yes, this is impüortant when a user does not pass in any config
. I will change this. Good catch!
config = AlgorithmConfig.from_dict( | ||
config_dict=self.merge_algorithm_configs( | ||
default_config, config, True | ||
if "class" in config: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove this entire if-block: There are no more algos left that return a dict from their get_default_config()
method.
# Default config is an AlgorithmConfig -> update its properties | ||
# from the given config dict. | ||
else: | ||
config = default_config.update_from_dict(config) | ||
if isinstance(config, dict) and "class" in config: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, let's do the following with this PR:
- Make
AlgorithmConfig
aCheckpointable
and override theget/set_state
andget_ctor_args_and_kwargs
methods first:
def get_state(self, *, components=...):
return [basically the AlgoConfig as a dict]
def set_state(self, state):
# Here, we can conveniently check, whether keys in `state` exist as properties and override the correct properties as wanted.
def get_ctor_args_and_kwargs(self):
return () # <- empty tuple AlgorithmConfigs are always constructed w/o any args/kwargs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then we can do this here:
if isinstance(config, dict):
config = default_config
config.set_state(config)
# else:
# keep config as is (it's already a proper AlgorithmConfig object)
So, basically, we from here on treat user-provided config dicts as "AlgorithmConfig state dicts".
@@ -781,6 +776,53 @@ def update_from_dict( | |||
|
|||
return self | |||
|
|||
def get_state(self) -> Dict[str, Any]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add @override(Checkpointable)
and subclass AlgorithmConfig from Checkpointable
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah. This looks sound now. I am so glad when the old stack is gone. Everything will become so clean then.
|
||
return state | ||
|
||
@classmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add @override(Checkpointable)
and subclass AlgorithmConfig from Checkpointable
.
Rename this method to set_state()
and make it not a class method.
@@ -124,6 +124,23 @@ def __eq__(self, other: "PolicySpec"): | |||
and self.config == other.config | |||
) | |||
|
|||
def get_state(self) -> Dict[str, Any]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is ok. Leave as-is :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Puh! :D
if isinstance(config, dict) and "class" in config: | ||
config = default_config.from_state(config) | ||
else: | ||
config = default_config.update_from_dict(config) | ||
else: | ||
default_config = self.get_default_config() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see. This answers my question above: We already do this get_default_dict()
call here. Ok.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haha, now I haven't overviewed the whole logic myself. So, yes, here we set the default, if nothing else is provided.
default_config = self.get_default_config() | ||
# Given AlgorithmConfig is not of the same type as the default config: | ||
# This could be the case e.g. if the user is building an algo from a | ||
# generic AlgorithmConfig() object. | ||
if not isinstance(config, type(default_config)): | ||
config = default_config.update_from_dict(config.to_dict()) | ||
else: | ||
config = default_config.from_state(config.get_state()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
default_config = self.get_default_config() | |
# Given AlgorithmConfig is not of the same type as the default config: | |
# This could be the case e.g. if the user is building an algo from a | |
# generic AlgorithmConfig() object. | |
if not isinstance(config, type(default_config)): | |
config = default_config.update_from_dict(config.to_dict()) | |
else: | |
config = default_config.from_state(config.get_state()) | |
default_config = self.get_default_config() | |
config_state = config.get_state() | |
config = default_config | |
config.set_state(config_state) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can simplify here to the above suggestion ^
@@ -2899,7 +2908,7 @@ def get_checkpointable_components(self) -> List[Tuple[str, "Checkpointable"]]: | |||
@override(Checkpointable) | |||
def get_ctor_args_and_kwargs(self) -> Tuple[Tuple, Dict[str, Any]]: | |||
return ( | |||
(self.config,), # *args, | |||
(self.config.get_state(),), # *args, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw, we should probably do the same everywhere else config
is part of the c'tor args/kwargs: Learner, LearnerGroup, and EnvRunner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then, we also have to make sure their c'tors also accept config states(!), not just AlgorithmConfig objects (the same as how Algorithm does it now).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I really like this PR. It solves so many problems at once!
A few design change requests and nits, but overall already in very good shape! Thanks @simonsays1980 .
Oh, sorry, another thing. Can we also get rid of (rename):
Does this make sense? ^ |
Why are these changes needed?
The
AlgorithmConfig
is still a pain point when trying to load older checkpoints (of the new stack specifically). The reason for this are usually attributes that were added between storing the checkpoint and loading it again (lately e.g. the_torch_grad_scaler_class
attribute). This PR suggests a logic that enables loading older checkpoints with a newer version (of the new stack).AlgorithmConfig
class is provided byset_state
andfrom_state
to receive a dictionary state and initiate a config from it.Algorithm
does always create itsconfig
by calling thefrom_state
method and thereby adds all attributes of the presentAlgorithmConfig
class version to it.Related issue number
Closes #47426
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.