Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Increase backward compatibility of checkpoints. #47708

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

simonsays1980
Copy link
Collaborator

@simonsays1980 simonsays1980 commented Sep 17, 2024

Why are these changes needed?

The AlgorithmConfig is still a pain point when trying to load older checkpoints (of the new stack specifically). The reason for this are usually attributes that were added between storing the checkpoint and loading it again (lately e.g. the _torch_grad_scaler_class attribute). This PR suggests a logic that enables loading older checkpoints with a newer version (of the new stack).

  • The AlgorithmConfig class is provided by set_state and from_state to receive a dictionary state and initiate a config from it.
  • The Algorithm does always create its config by calling the from_state method and thereby adds all attributes of the present AlgorithmConfig class version to it.

Related issue number

Closes #47426

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

…nitialization of 'Algorithm' to always initialize the config from state. Furthermore, added getter and setter to 'PolicySpec'.

Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
@simonsays1980 simonsays1980 added rllib RLlib related issues rllib-checkpointing-or-recovery An issue related to checkpointing/recovering RLlib Trainers. labels Sep 17, 2024
@@ -455,7 +455,7 @@ def __init__(
object. If unspecified, a default logger is created.
**kwargs: Arguments passed to the Trainable base class.
"""
config = config or self.get_default_config()
config = config # or self.get_default_config()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why remove the default config path?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Yes, this is impüortant when a user does not pass in any config. I will change this. Good catch!

config = AlgorithmConfig.from_dict(
config_dict=self.merge_algorithm_configs(
default_config, config, True
if "class" in config:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this entire if-block: There are no more algos left that return a dict from their get_default_config() method.

# Default config is an AlgorithmConfig -> update its properties
# from the given config dict.
else:
config = default_config.update_from_dict(config)
if isinstance(config, dict) and "class" in config:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, let's do the following with this PR:

  • Make AlgorithmConfig a Checkpointable and override the get/set_state and get_ctor_args_and_kwargs methods first:
def get_state(self, *, components=...):
    return [basically the AlgoConfig as a dict]

def set_state(self, state):
    # Here, we can conveniently check, whether keys in `state` exist as properties and override the correct properties as wanted.

def get_ctor_args_and_kwargs(self):
    return ()  # <- empty tuple AlgorithmConfigs are always constructed w/o any args/kwargs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then we can do this here:

if isinstance(config, dict):
    config = default_config
    config.set_state(config)
# else:
#     keep config as is (it's already a proper AlgorithmConfig object)

So, basically, we from here on treat user-provided config dicts as "AlgorithmConfig state dicts".

@@ -781,6 +776,53 @@ def update_from_dict(

return self

def get_state(self) -> Dict[str, Any]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add @override(Checkpointable) and subclass AlgorithmConfig from Checkpointable.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. This looks sound now. I am so glad when the old stack is gone. Everything will become so clean then.


return state

@classmethod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add @override(Checkpointable) and subclass AlgorithmConfig from Checkpointable.

Rename this method to set_state() and make it not a class method.

@@ -124,6 +124,23 @@ def __eq__(self, other: "PolicySpec"):
and self.config == other.config
)

def get_state(self) -> Dict[str, Any]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is ok. Leave as-is :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Puh! :D

if isinstance(config, dict) and "class" in config:
config = default_config.from_state(config)
else:
config = default_config.update_from_dict(config)
else:
default_config = self.get_default_config()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. This answers my question above: We already do this get_default_dict() call here. Ok.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha, now I haven't overviewed the whole logic myself. So, yes, here we set the default, if nothing else is provided.

Comment on lines 486 to +493
default_config = self.get_default_config()
# Given AlgorithmConfig is not of the same type as the default config:
# This could be the case e.g. if the user is building an algo from a
# generic AlgorithmConfig() object.
if not isinstance(config, type(default_config)):
config = default_config.update_from_dict(config.to_dict())
else:
config = default_config.from_state(config.get_state())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
default_config = self.get_default_config()
# Given AlgorithmConfig is not of the same type as the default config:
# This could be the case e.g. if the user is building an algo from a
# generic AlgorithmConfig() object.
if not isinstance(config, type(default_config)):
config = default_config.update_from_dict(config.to_dict())
else:
config = default_config.from_state(config.get_state())
default_config = self.get_default_config()
config_state = config.get_state()
config = default_config
config.set_state(config_state)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can simplify here to the above suggestion ^

@@ -2899,7 +2908,7 @@ def get_checkpointable_components(self) -> List[Tuple[str, "Checkpointable"]]:
@override(Checkpointable)
def get_ctor_args_and_kwargs(self) -> Tuple[Tuple, Dict[str, Any]]:
return (
(self.config,), # *args,
(self.config.get_state(),), # *args,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, we should probably do the same everywhere else config is part of the c'tor args/kwargs: Learner, LearnerGroup, and EnvRunner.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then, we also have to make sure their c'tors also accept config states(!), not just AlgorithmConfig objects (the same as how Algorithm does it now).

Copy link
Contributor

@sven1977 sven1977 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like this PR. It solves so many problems at once!

A few design change requests and nits, but overall already in very good shape! Thanks @simonsays1980 .

@sven1977
Copy link
Contributor

Oh, sorry, another thing. Can we also get rid of (rename):

  • AlgorithmConfig.to_dict() -> get_state()
  • AlgorithmConfig.update_from_dict() -> set_state()

Does this make sense? ^

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
rllib RLlib related issues rllib-checkpointing-or-recovery An issue related to checkpointing/recovering RLlib Trainers.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[RLlib] Rename of SingleAgentRLModuleSpec to RLModuleSpec breaks restoring old checkpoints
2 participants