-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Increase backward compatibility of checkpoints. #47708
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -455,7 +455,7 @@ def __init__( | |||||||||||||||||||||||||
object. If unspecified, a default logger is created. | ||||||||||||||||||||||||||
**kwargs: Arguments passed to the Trainable base class. | ||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||
config = config or self.get_default_config() | ||||||||||||||||||||||||||
config = config # or self.get_default_config() | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Translate possible dict into an AlgorithmConfig object, as well as, | ||||||||||||||||||||||||||
# resolving generic config objects into specific ones (e.g. passing | ||||||||||||||||||||||||||
|
@@ -466,22 +466,31 @@ def __init__( | |||||||||||||||||||||||||
# `self.get_default_config()` also returned a dict -> | ||||||||||||||||||||||||||
# Last resort: Create core AlgorithmConfig from merged dicts. | ||||||||||||||||||||||||||
if isinstance(default_config, dict): | ||||||||||||||||||||||||||
config = AlgorithmConfig.from_dict( | ||||||||||||||||||||||||||
config_dict=self.merge_algorithm_configs( | ||||||||||||||||||||||||||
default_config, config, True | ||||||||||||||||||||||||||
if "class" in config: | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove this entire if-block: There are no more algos left that return a dict from their |
||||||||||||||||||||||||||
AlgorithmConfig.from_state(config) | ||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||
config = AlgorithmConfig.from_dict( | ||||||||||||||||||||||||||
config_dict=self.merge_algorithm_configs( | ||||||||||||||||||||||||||
default_config, config, True | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Default config is an AlgorithmConfig -> update its properties | ||||||||||||||||||||||||||
# from the given config dict. | ||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||
config = default_config.update_from_dict(config) | ||||||||||||||||||||||||||
if isinstance(config, dict) and "class" in config: | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, let's do the following with this PR:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Then we can do this here:
So, basically, we from here on treat user-provided config dicts as "AlgorithmConfig state dicts". |
||||||||||||||||||||||||||
config = default_config.from_state(config) | ||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||
config = default_config.update_from_dict(config) | ||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||
default_config = self.get_default_config() | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, I see. This answers my question above: We already do this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Haha, now I haven't overviewed the whole logic myself. So, yes, here we set the default, if nothing else is provided. |
||||||||||||||||||||||||||
# Given AlgorithmConfig is not of the same type as the default config: | ||||||||||||||||||||||||||
# This could be the case e.g. if the user is building an algo from a | ||||||||||||||||||||||||||
# generic AlgorithmConfig() object. | ||||||||||||||||||||||||||
if not isinstance(config, type(default_config)): | ||||||||||||||||||||||||||
config = default_config.update_from_dict(config.to_dict()) | ||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||
config = default_config.from_state(config.get_state()) | ||||||||||||||||||||||||||
Comment on lines
486
to
+493
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can simplify here to the above suggestion ^ |
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# In case this algo is using a generic config (with no algo_class set), set it | ||||||||||||||||||||||||||
# here. | ||||||||||||||||||||||||||
|
@@ -2899,7 +2908,7 @@ def get_checkpointable_components(self) -> List[Tuple[str, "Checkpointable"]]: | |||||||||||||||||||||||||
@override(Checkpointable) | ||||||||||||||||||||||||||
def get_ctor_args_and_kwargs(self) -> Tuple[Tuple, Dict[str, Any]]: | ||||||||||||||||||||||||||
return ( | ||||||||||||||||||||||||||
(self.config,), # *args, | ||||||||||||||||||||||||||
(self.config.get_state(),), # *args, | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perfect! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Btw, we should probably do the same everywhere else There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Then, we also have to make sure their c'tors also accept config states(!), not just AlgorithmConfig objects (the same as how Algorithm does it now). |
||||||||||||||||||||||||||
{}, # **kwargs | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -633,12 +633,7 @@ def to_dict(self) -> AlgorithmConfigDict: | |
policies_dict = {} | ||
for policy_id, policy_spec in config.pop("policies").items(): | ||
if isinstance(policy_spec, PolicySpec): | ||
policies_dict[policy_id] = ( | ||
policy_spec.policy_class, | ||
policy_spec.observation_space, | ||
policy_spec.action_space, | ||
policy_spec.config, | ||
) | ||
policies_dict[policy_id] = policy_spec.get_state() | ||
else: | ||
policies_dict[policy_id] = policy_spec | ||
config["policies"] = policies_dict | ||
|
@@ -781,6 +776,53 @@ def update_from_dict( | |
|
||
return self | ||
|
||
def get_state(self) -> Dict[str, Any]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah. This looks sound now. I am so glad when the old stack is gone. Everything will become so clean then. |
||
"""Returns a dict state that can be pickled. | ||
|
||
Returns: | ||
A dictionary containing all attributes of the instance. | ||
""" | ||
|
||
state = self.__dict__.copy() | ||
state["class"] = type(self) | ||
state.pop("algo_class") | ||
state.pop("_is_frozen") | ||
|
||
# Convert `policies` (PolicySpecs?) into dict. | ||
# Convert policies dict such that each policy ID maps to a old-style. | ||
# 4-tuple: class, obs-, and action space, config. | ||
# TODO (simon, sven): Remove when deprecating old stack. | ||
if "policies" in state and isinstance(state["policies"], dict): | ||
policies_dict = {} | ||
for policy_id, policy_spec in state.pop("policies").items(): | ||
if isinstance(policy_spec, PolicySpec): | ||
policies_dict[policy_id] = policy_spec.get_state() | ||
else: | ||
policies_dict[policy_id] = policy_spec | ||
state["policies"] = policies_dict | ||
|
||
return state | ||
|
||
@classmethod | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add Rename this method to |
||
def from_state(cls, state: Dict[str, Any]) -> "AlgorithmConfig": | ||
"""Returns an instance constructed from the state. | ||
|
||
Args: | ||
cls: An `AlgorithmConfig` class. | ||
state: A dictionary containing the state of an `AlgorithmConfig`. | ||
See `AlgorithmConfig.get_state` for creating a state. | ||
|
||
Returns: | ||
An `AlgorithmConfig` instance with attributes from the `state`. | ||
""" | ||
|
||
ctor = state["class"] | ||
config = ctor() | ||
|
||
config.__dict__.update(state) | ||
|
||
return config | ||
|
||
# TODO(sven): We might want to have a `deserialize` method as well. Right now, | ||
# simply using the from_dict() API works in this same (deserializing) manner, | ||
# whether the dict used is actually code-free (already serialized) or not | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -124,6 +124,23 @@ def __eq__(self, other: "PolicySpec"): | |
and self.config == other.config | ||
) | ||
|
||
def get_state(self) -> Dict[str, Any]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is ok. Leave as-is :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Puh! :D |
||
"""Returns the state of a `PolicyDict` as a dict.""" | ||
return ( | ||
self.policy_class, | ||
self.observation_space, | ||
self.action_space, | ||
self.config, | ||
) | ||
|
||
@classmethod | ||
def from_state(cls, state: Dict[str, Any]) -> "PolicySpec": | ||
"""Builds a `PolicySpec` from a state.""" | ||
policy_spec = PolicySpec() | ||
policy_spec.__dict__.update(state) | ||
|
||
return policy_spec | ||
|
||
def serialize(self) -> Dict: | ||
from ray.rllib.algorithms.registry import get_policy_class_name | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: why remove the default config path?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Yes, this is impüortant when a user does not pass in any
config
. I will change this. Good catch!