Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Increase backward compatibility of checkpoints. #47708

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def __init__(
object. If unspecified, a default logger is created.
**kwargs: Arguments passed to the Trainable base class.
"""
config = config or self.get_default_config()
config = config # or self.get_default_config()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why remove the default config path?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Yes, this is impüortant when a user does not pass in any config. I will change this. Good catch!


# Translate possible dict into an AlgorithmConfig object, as well as,
# resolving generic config objects into specific ones (e.g. passing
Expand All @@ -466,22 +466,31 @@ def __init__(
# `self.get_default_config()` also returned a dict ->
# Last resort: Create core AlgorithmConfig from merged dicts.
if isinstance(default_config, dict):
config = AlgorithmConfig.from_dict(
config_dict=self.merge_algorithm_configs(
default_config, config, True
if "class" in config:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this entire if-block: There are no more algos left that return a dict from their get_default_config() method.

AlgorithmConfig.from_state(config)
else:
config = AlgorithmConfig.from_dict(
config_dict=self.merge_algorithm_configs(
default_config, config, True
)
)
)

# Default config is an AlgorithmConfig -> update its properties
# from the given config dict.
else:
config = default_config.update_from_dict(config)
if isinstance(config, dict) and "class" in config:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, let's do the following with this PR:

  • Make AlgorithmConfig a Checkpointable and override the get/set_state and get_ctor_args_and_kwargs methods first:
def get_state(self, *, components=...):
    return [basically the AlgoConfig as a dict]

def set_state(self, state):
    # Here, we can conveniently check, whether keys in `state` exist as properties and override the correct properties as wanted.

def get_ctor_args_and_kwargs(self):
    return ()  # <- empty tuple AlgorithmConfigs are always constructed w/o any args/kwargs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then we can do this here:

if isinstance(config, dict):
    config = default_config
    config.set_state(config)
# else:
#     keep config as is (it's already a proper AlgorithmConfig object)

So, basically, we from here on treat user-provided config dicts as "AlgorithmConfig state dicts".

config = default_config.from_state(config)
else:
config = default_config.update_from_dict(config)
else:
default_config = self.get_default_config()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. This answers my question above: We already do this get_default_dict() call here. Ok.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha, now I haven't overviewed the whole logic myself. So, yes, here we set the default, if nothing else is provided.

# Given AlgorithmConfig is not of the same type as the default config:
# This could be the case e.g. if the user is building an algo from a
# generic AlgorithmConfig() object.
if not isinstance(config, type(default_config)):
config = default_config.update_from_dict(config.to_dict())
else:
config = default_config.from_state(config.get_state())
Comment on lines 486 to +493
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
default_config = self.get_default_config()
# Given AlgorithmConfig is not of the same type as the default config:
# This could be the case e.g. if the user is building an algo from a
# generic AlgorithmConfig() object.
if not isinstance(config, type(default_config)):
config = default_config.update_from_dict(config.to_dict())
else:
config = default_config.from_state(config.get_state())
default_config = self.get_default_config()
config_state = config.get_state()
config = default_config
config.set_state(config_state)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can simplify here to the above suggestion ^


# In case this algo is using a generic config (with no algo_class set), set it
# here.
Expand Down Expand Up @@ -2899,7 +2908,7 @@ def get_checkpointable_components(self) -> List[Tuple[str, "Checkpointable"]]:
@override(Checkpointable)
def get_ctor_args_and_kwargs(self) -> Tuple[Tuple, Dict[str, Any]]:
return (
(self.config,), # *args,
(self.config.get_state(),), # *args,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, we should probably do the same everywhere else config is part of the c'tor args/kwargs: Learner, LearnerGroup, and EnvRunner.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then, we also have to make sure their c'tors also accept config states(!), not just AlgorithmConfig objects (the same as how Algorithm does it now).

{}, # **kwargs
)

Expand Down
54 changes: 48 additions & 6 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,12 +633,7 @@ def to_dict(self) -> AlgorithmConfigDict:
policies_dict = {}
for policy_id, policy_spec in config.pop("policies").items():
if isinstance(policy_spec, PolicySpec):
policies_dict[policy_id] = (
policy_spec.policy_class,
policy_spec.observation_space,
policy_spec.action_space,
policy_spec.config,
)
policies_dict[policy_id] = policy_spec.get_state()
else:
policies_dict[policy_id] = policy_spec
config["policies"] = policies_dict
Expand Down Expand Up @@ -781,6 +776,53 @@ def update_from_dict(

return self

def get_state(self) -> Dict[str, Any]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add @override(Checkpointable) and subclass AlgorithmConfig from Checkpointable.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. This looks sound now. I am so glad when the old stack is gone. Everything will become so clean then.

"""Returns a dict state that can be pickled.

Returns:
A dictionary containing all attributes of the instance.
"""

state = self.__dict__.copy()
state["class"] = type(self)
state.pop("algo_class")
state.pop("_is_frozen")

# Convert `policies` (PolicySpecs?) into dict.
# Convert policies dict such that each policy ID maps to a old-style.
# 4-tuple: class, obs-, and action space, config.
# TODO (simon, sven): Remove when deprecating old stack.
if "policies" in state and isinstance(state["policies"], dict):
policies_dict = {}
for policy_id, policy_spec in state.pop("policies").items():
if isinstance(policy_spec, PolicySpec):
policies_dict[policy_id] = policy_spec.get_state()
else:
policies_dict[policy_id] = policy_spec
state["policies"] = policies_dict

return state

@classmethod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add @override(Checkpointable) and subclass AlgorithmConfig from Checkpointable.

Rename this method to set_state() and make it not a class method.

def from_state(cls, state: Dict[str, Any]) -> "AlgorithmConfig":
"""Returns an instance constructed from the state.

Args:
cls: An `AlgorithmConfig` class.
state: A dictionary containing the state of an `AlgorithmConfig`.
See `AlgorithmConfig.get_state` for creating a state.

Returns:
An `AlgorithmConfig` instance with attributes from the `state`.
"""

ctor = state["class"]
config = ctor()

config.__dict__.update(state)

return config

# TODO(sven): We might want to have a `deserialize` method as well. Right now,
# simply using the from_dict() API works in this same (deserializing) manner,
# whether the dict used is actually code-free (already serialized) or not
Expand Down
17 changes: 17 additions & 0 deletions rllib/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,23 @@ def __eq__(self, other: "PolicySpec"):
and self.config == other.config
)

def get_state(self) -> Dict[str, Any]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is ok. Leave as-is :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Puh! :D

"""Returns the state of a `PolicyDict` as a dict."""
return (
self.policy_class,
self.observation_space,
self.action_space,
self.config,
)

@classmethod
def from_state(cls, state: Dict[str, Any]) -> "PolicySpec":
"""Builds a `PolicySpec` from a state."""
policy_spec = PolicySpec()
policy_spec.__dict__.update(state)

return policy_spec

def serialize(self) -> Dict:
from ray.rllib.algorithms.registry import get_policy_class_name

Expand Down
1 change: 0 additions & 1 deletion rllib/utils/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,6 @@ def from_checkpoint(
"an implementer of the `Checkpointable` API!"
)

# Construct an initial object.
obj = ctor(
*ctor_info["ctor_args_and_kwargs"][0],
**ctor_info["ctor_args_and_kwargs"][1],
Expand Down