-
Notifications
You must be signed in to change notification settings - Fork 310
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Dynamics randomization for MuJoCo (#51)
Add support for dynamics randomization for mujoco environment. Includes a simple data structure consisting of a list of variation objects, a wrapped environment of mujoco to perform dynamics randomization. Each variation object is an instance of the Variation class that works as a container for each of the fields used to randomized a dynamic parameter within the simulation environment. The wrapper class of mujoco performs dynamics randomization on each reset(). The data structure and the wrapper class are tested in test_dynamics_rand.py. Refer to: #14
- Loading branch information
1 parent
38f5f97
commit 438f9df
Showing
4 changed files
with
426 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from rllab.envs.mujoco.randomization.randomized_env import randomize | ||
from rllab.envs.mujoco.randomization.variation import Distribution | ||
from rllab.envs.mujoco.randomization.variation import Method | ||
from rllab.envs.mujoco.randomization.variation import Variations |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import os.path as osp | ||
|
||
from mujoco_py import load_model_from_xml | ||
from mujoco_py import MjSim | ||
|
||
from rllab.core import Serializable | ||
from rllab.envs import Env | ||
from rllab.envs.mujoco.mujoco_env import MODEL_DIR | ||
|
||
|
||
class RandomizedEnv(Env, Serializable): | ||
""" | ||
This class is just a wrapper class for the MujocoEnv to perform | ||
the training using Dynamics Randomization. | ||
Only code in the methods reset and terminate has been added. | ||
""" | ||
|
||
def __init__(self, mujoco_env, variations): | ||
""" | ||
Set variations with the node in the XML file at file_path. | ||
""" | ||
Serializable.quick_init(self, locals()) | ||
self._wrapped_env = mujoco_env | ||
self._variations = variations | ||
self._file_path = osp.join(MODEL_DIR, mujoco_env.FILE) | ||
self._variations.initialize_variations(self._file_path) | ||
|
||
def reset(self): | ||
""" | ||
The new model with randomized parameters is requested and the | ||
corresponding parameters in the MuJoCo environment class are | ||
set. | ||
""" | ||
self._wrapped_env.model = load_model_from_xml( | ||
self._variations.get_randomized_xml_model()) | ||
if hasattr(self._wrapped_env, 'action_space'): | ||
del self._wrapped_env.__dict__['action_space'] | ||
self._wrapped_env.sim = MjSim(self._wrapped_env.model) | ||
self._wrapped_env.data = self._wrapped_env.sim.data | ||
self._wrapped_env.init_qpos = self._wrapped_env.sim.data.qpos | ||
self._wrapped_env.init_qvel = self._wrapped_env.sim.data.qvel | ||
self._wrapped_env.init_qacc = self._wrapped_env.sim.data.qacc | ||
self._wrapped_env.init_ctrl = self._wrapped_env.sim.data.ctrl | ||
return self._wrapped_env.reset() | ||
|
||
def step(self, action): | ||
return self._wrapped_env.step(action) | ||
|
||
def render(self, *args, **kwargs): | ||
return self._wrapped_env.render(*args, **kwargs) | ||
|
||
def log_diagnostics(self, paths, *args, **kwargs): | ||
self._wrapped_env.log_diagnostics(paths, *args, **kwargs) | ||
|
||
def get_param_values(self): | ||
return self._wrapped_env.get_param_values() | ||
|
||
def set_param_values(self, params): | ||
self._wrapped_env.set_param_values(params) | ||
|
||
def terminate(self): | ||
self._wrapped_env.terminate() | ||
|
||
@property | ||
def wrapped_env(self): | ||
return self._wrapped_env | ||
|
||
@property | ||
def action_space(self): | ||
return self._wrapped_env.action_space | ||
|
||
@property | ||
def observation_space(self): | ||
return self._wrapped_env.observation_space | ||
|
||
@property | ||
def horizon(self): | ||
return self._wrapped_env.horizon | ||
|
||
|
||
randomize = RandomizedEnv |
Oops, something went wrong.