diff --git a/configs/commands/rlbench/_train.yaml b/configs/commands/rlbench/_train.yaml index d2fe858..42dfa1d 100644 --- a/configs/commands/rlbench/_train.yaml +++ b/configs/commands/rlbench/_train.yaml @@ -12,7 +12,7 @@ pretraining: mode: none training: - max_epochs: 500 + max_epochs: 2000 batch_size: 8 check_val_every_n_epoch: 10 log_every_n_steps: 100 diff --git a/configs/dataset/rlbench/_default.yaml b/configs/dataset/rlbench/_default.yaml index da53382..e0d3cc5 100644 --- a/configs/dataset/rlbench/_default.yaml +++ b/configs/dataset/rlbench/_default.yaml @@ -15,11 +15,13 @@ train_dset: cached: true phase: ${task.phase.name} teleport_initial_to_final: true - with_symmetry: True + with_symmetry: False occlusion_cfg: ??? num_points: 512 action_mode: "gripper_and_object" anchor_mode: "background_robot_removed" + include_wrist_cam: True + gripper_in_first_phase: False val_dset: demo_dset: @@ -32,8 +34,10 @@ val_dset: cached: true phase: ${task.phase.name} teleport_initial_to_final: true - with_symmetry: True + with_symmetry: False occlusion_cfg: ??? num_points: ${...train_dset.demo_dset.num_points} action_mode: ${...train_dset.demo_dset.action_mode} anchor_mode: ${...train_dset.demo_dset.anchor_mode} + include_wrist_cam: ${...train_dset.demo_dset.include_wrist_cam} + gripper_in_first_phase: ${...train_dset.demo_dset.gripper_in_first_phase} diff --git a/configs/eval_rlbench.yaml b/configs/eval_rlbench.yaml index 6e2c49d..0c334b8 100644 --- a/configs/eval_rlbench.yaml +++ b/configs/eval_rlbench.yaml @@ -23,6 +23,8 @@ policy_spec: model: ${model} include_rgb_features: False add_random_jitter: True + include_wrist_cam: True + gripper_in_first_phase: False # Usually only a single checkpoint, but we could have multiple for each phase. diff --git a/pyproject.toml b/pyproject.toml index cfa6b66..99ff721 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,7 @@ rlbench = [ # These should be installed manually... # "pyrep", # "rlbench @ git+https://github.com/stepjam/RLBench.git", - "rpad-rlbench-utils @ git+https://github.com/r-pad/rlbench_utils.git@11722e4d803581cc79097f333c1f2c2eb536e15f", + "rpad-rlbench-utils @ git+https://github.com/r-pad/rlbench_utils.git@5d6167cbb07c25242e3a87a5d5a896c05a7370a7", ] diff --git a/scripts/README.md b/scripts/README.md index 97d34ff..ff98640 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -167,25 +167,25 @@ RLBENCH_10_TASKS = [ ## Training -./launch.sh autobot 5 python scripts/train_residual_flow.py --config-name commands/rlbench/pick_and_lift/train_taxpose_tc.yaml dm.train_dset.include_rgb_features=True model.feature_channels=3 +./launch.sh autobot 0 python scripts/train_residual_flow.py --config-name commands/rlbench/pick_and_lift/train_taxpose_tc.yaml dm.train_dset.include_rgb_features=True model.feature_channels=3 benchmark.dataset_root=/data/rlbench10_collisions -./launch.sh autobot 1 python scripts/train_residual_flow.py --config-name commands/rlbench/pick_up_cup/train_taxpose_tc.yaml dm.train_dset.include_rgb_features=True model.feature_channels=3 +./launch.sh autobot 1 python scripts/train_residual_flow.py --config-name commands/rlbench/pick_up_cup/train_taxpose_tc.yaml dm.train_dset.include_rgb_features=True model.feature_channels=3 benchmark.dataset_root=/data/rlbench10_collisions -./launch.sh autobot 2 python scripts/train_residual_flow.py --config-name commands/rlbench/put_knife_on_chopping_board/train_taxpose_tc.yaml dm.train_dset.include_rgb_features=True model.feature_channels=3 +./launch.sh autobot 2 python scripts/train_residual_flow.py --config-name commands/rlbench/put_knife_on_chopping_board/train_taxpose_tc.yaml dm.train_dset.include_rgb_features=True model.feature_channels=3 benchmark.dataset_root=/data/rlbench10_collisions -./launch.sh autobot 3 python scripts/train_residual_flow.py --config-name commands/rlbench/put_money_in_safe/train_taxpose_tc.yaml dm.train_dset.include_rgb_features=True model.feature_channels=3 +./launch.sh autobot 3 python scripts/train_residual_flow.py --config-name commands/rlbench/put_money_in_safe/train_taxpose_tc.yaml dm.train_dset.include_rgb_features=True model.feature_channels=3 benchmark.dataset_root=/data/rlbench10_collisions -./launch.sh autobot 4 python scripts/train_residual_flow.py --config-name commands/rlbench/push_button/train_taxpose_tc.yaml dm.train_dset.include_rgb_features=True model.feature_channels=3 +./launch.sh autobot 4 python scripts/train_residual_flow.py --config-name commands/rlbench/push_button/train_taxpose_tc.yaml dm.train_dset.include_rgb_features=True model.feature_channels=3 benchmark.dataset_root=/data/rlbench10_collisions -./launch.sh autobot 0 python scripts/train_residual_flow.py --config-name commands/rlbench/reach_target/train_taxpose_tc.yaml dm.train_dset.include_rgb_features=True model.feature_channels=3 +./launch.sh autobot 5 python scripts/train_residual_flow.py --config-name commands/rlbench/reach_target/train_taxpose_tc.yaml dm.train_dset.include_rgb_features=True model.feature_channels=3 benchmark.dataset_root=/data/rlbench10_collisions -./launch.sh autobot 6 python scripts/train_residual_flow.py --config-name commands/rlbench/slide_block_to_target/train_taxpose_tc.yaml dm.train_dset.include_rgb_features=True model.feature_channels=3 +./launch.sh autobot 6 python scripts/train_residual_flow.py --config-name commands/rlbench/slide_block_to_target/train_taxpose_tc.yaml dm.train_dset.include_rgb_features=True model.feature_channels=3 benchmark.dataset_root=/data/rlbench10_collisions -./launch.sh autobot 7 python scripts/train_residual_flow.py --config-name commands/rlbench/stack_wine/train_taxpose_tc.yaml dm.train_dset.include_rgb_features=True model.feature_channels=3 +./launch.sh autobot 7 python scripts/train_residual_flow.py --config-name commands/rlbench/stack_wine/train_taxpose_tc.yaml dm.train_dset.include_rgb_features=True model.feature_channels=3 benchmark.dataset_root=/data/rlbench10_collisions -./launch.sh autobot 8 python scripts/train_residual_flow.py --config-name commands/rlbench/take_money_out_safe/train_taxpose_tc.yaml dm.train_dset.include_rgb_features=True model.feature_channels=3 +./launch.sh autobot 8 python scripts/train_residual_flow.py --config-name commands/rlbench/take_money_out_safe/train_taxpose_tc.yaml dm.train_dset.include_rgb_features=True model.feature_channels=3 benchmark.dataset_root=/data/rlbench10_collisions -./launch.sh autobot 9 python scripts/train_residual_flow.py --config-name commands/rlbench/take_umbrella_out_of_umbrella_stand/train_taxpose_tc.yaml dm.train_dset.include_rgb_features=True model.feature_channels=3 +./launch.sh autobot 9 python scripts/train_residual_flow.py --config-name commands/rlbench/take_umbrella_out_of_umbrella_stand/train_taxpose_tc.yaml dm.train_dset.include_rgb_features=True model.feature_channels=3 benchmark.dataset_root=/data/rlbench10_collisions ## Checkpoints from this training run @@ -293,3 +293,67 @@ take_umbrella_out_of_umbrella_stand: r-pad/taxpose/model-txvpna0v:v0 ./launch.sh local-docker 0 python scripts/eval_rlbench.py --config-name commands/rlbench/take_umbrella_out_of_umbrella_stand/taxpose_tc/eval_rlbench.yaml num_trials=100 policy_spec.include_rgb_features=True model.feature_channels=3 checkpoints.ckpt_file=r-pad/taxpose/model-txvpna0v:v0 wandb.group=rlbench_take_umbrella_out_of_umbrella_stand ``` + +# Wrist Cam + Gripper Out + +## Training + +See the branch. + +## Precision Eval + +Trained models: +pick_and_lift: hbdcvydg +pick_up_cup: gwfjqpfk +push_button: mmc5fhzu +put_knife_on_chopping_board: rjscih24 +put_money_in_safe: gj3h3o3c +reach_target: v8vv53tx +slide_block_to_target: sejd7pz0 +stack_wine: 3hyo3r7q +take_money_out_safe: 69gka1ew +take_umbrella_out_of_umbrella_stand: b48mz8e1 + +``` + +# pick_and_lift + +./configs/commands/rlbench/pick_and_lift/taxpose_tc/precision_eval/precision_eval.sh dm.train_dset.include_rgb_features=True model.feature_channels=3 checkpoint=r-pad/taxpose/model-hbdcvydg:v0 benchmark.dataset_root=/data/rlbench10_collisions + +# pick_up_cup + +./configs/commands/rlbench/pick_up_cup/taxpose_tc/precision_eval/precision_eval.sh dm.train_dset.include_rgb_features=True model.feature_channels=3 checkpoint=r-pad/taxpose/model-gwfjqpfk:v0 benchmark.dataset_root=/data/rlbench10_collisions + +# push_button + +./configs/commands/rlbench/push_button/taxpose_tc/precision_eval/precision_eval.sh dm.train_dset.include_rgb_features=True model.feature_channels=3 checkpoint=r-pad/taxpose/model-mmc5fhzu:v0 benchmark.dataset_root=/data/rlbench10_collisions + +# put_knife_on_chopping_board + +./configs/commands/rlbench/put_knife_on_chopping_board/taxpose_tc/precision_eval/precision_eval.sh dm.train_dset.include_rgb_features=True model.feature_channels=3 checkpoint=r-pad/taxpose/model-rjscih24:v0 benchmark.dataset_root=/data/rlbench10_collisions + +# put_money_in_safe + +./configs/commands/rlbench/put_money_in_safe/taxpose_tc/precision_eval/precision_eval.sh dm.train_dset.include_rgb_features=True model.feature_channels=3 checkpoint=r-pad/taxpose/model-gj3h3o3c:v0 benchmark.dataset_root=/data/rlbench10_collisions + +# reach_target + +./configs/commands/rlbench/reach_target/taxpose_tc/precision_eval/precision_eval.sh dm.train_dset.include_rgb_features=True model.feature_channels=3 checkpoint=r-pad/taxpose/model-v8vv53tx:v0 benchmark.dataset_root=/data/rlbench10_collisions + +# slide_block_to_target + +./configs/commands/rlbench/slide_block_to_target/taxpose_tc/precision_eval/precision_eval.sh dm.train_dset.include_rgb_features=True model.feature_channels=3 checkpoint=r-pad/taxpose/model-sejd7pz0:v0 benchmark.dataset_root=/data/rlbench10_collisions + +# stack_wine + +./configs/commands/rlbench/stack_wine/taxpose_tc/precision_eval/precision_eval.sh dm.train_dset.include_rgb_features=True model.feature_channels=3 checkpoint=r-pad/taxpose/model-3hyo3r7q:v0 benchmark.dataset_root=/data/rlbench10_collisions + +# take_money_out_safe + +./configs/commands/rlbench/take_money_out_safe/taxpose_tc/precision_eval/precision_eval.sh dm.train_dset.include_rgb_features=True model.feature_channels=3 checkpoint=r-pad/taxpose/model-69gka1ew:v0 benchmark.dataset_root=/data/rlbench10_collisions + +# take_umbrella_out_of_umbrella_stand + +./configs/commands/rlbench/take_umbrella_out_of_umbrella_stand/taxpose_tc/precision_eval/precision_eval.sh dm.train_dset.include_rgb_features=True model.feature_channels=3 checkpoint=r-pad/taxpose/model-b48mz8e1:v0 benchmark.dataset_root=/data/rlbench10_collisions + +``` diff --git a/scripts/eval_rlbench.py b/scripts/eval_rlbench.py index 6aaddbc..aa41227 100644 --- a/scripts/eval_rlbench.py +++ b/scripts/eval_rlbench.py @@ -347,6 +347,8 @@ def predict(self, obs, phase: str) -> Tuple[np.ndarray, Dict[str, Any]]: self.action_mode, self.anchor_mode, n_pts=self.policy_spec.num_points, + include_wrist_cam=self.policy_spec.include_wrist_cam, + gripper_in_first_phase=self.policy_spec.gripper_in_first_phase, ) model = self.models[phase] @@ -551,8 +553,16 @@ def obs_to_input( action_mode: ActionMode, anchor_mode: AnchorMode, n_pts: Optional[int] = 1024, + include_wrist_cam=False, + gripper_in_first_phase=False, ): - rgb, pc, mask = obs_to_rgb_point_cloud(obs) + rgb, pc, mask = obs_to_rgb_point_cloud(obs, include_wrist_cam) + + # Filter out any points with the mask == 16777215 + mask_ixs = (mask != 16777215).squeeze() + mask = mask[mask_ixs] + pc = pc[mask_ixs] + rgb = rgb[mask_ixs] ######################################## # Separate the action and anchor points. @@ -576,6 +586,7 @@ def obs_to_input( task_name, phase, use_from_simulator=True, + gripper_in_first_phase=gripper_in_first_phase, ) ############################## diff --git a/scripts/run_rlbench10_rollouts.sh b/scripts/run_rlbench10_rollouts.sh index be02579..471a06c 100755 --- a/scripts/run_rlbench10_rollouts.sh +++ b/scripts/run_rlbench10_rollouts.sh @@ -2,42 +2,97 @@ ########################### RGB experiments ############################# + +# # pick_and_lift + +# ./launch.sh local-docker 0 python scripts/eval_rlbench.py --config-name commands/rlbench/pick_and_lift/taxpose_tc/eval_rlbench.yaml num_trials=100 policy_spec.include_rgb_features=True model.feature_channels=3 checkpoints.ckpt_file=r-pad/taxpose/model-9tx1uje9:v0 wandb.group=rlbench_pick_and_lift + +# # pick_up_cup + +# ./launch.sh local-docker 0 python scripts/eval_rlbench.py --config-name commands/rlbench/pick_up_cup/taxpose_tc/eval_rlbench.yaml num_trials=100 policy_spec.include_rgb_features=True model.feature_channels=3 checkpoints.ckpt_file=r-pad/taxpose/model-9pfjeq0j:v0 wandb.group=rlbench_pick_up_cup + +# # put_knife_on_chopping_board + +# ./launch.sh local-docker 0 python scripts/eval_rlbench.py --config-name commands/rlbench/put_knife_on_chopping_board/taxpose_tc/eval_rlbench.yaml num_trials=100 policy_spec.include_rgb_features=True model.feature_channels=3 checkpoints.ckpt_file=r-pad/taxpose/model-c9u0u4np:v0 wandb.group=rlbench_put_knife_on_chopping_board + +# # put_money_in_safe + +# ./launch.sh local-docker 0 python scripts/eval_rlbench.py --config-name commands/rlbench/put_money_in_safe/taxpose_tc/eval_rlbench.yaml num_trials=100 policy_spec.include_rgb_features=True model.feature_channels=3 checkpoints.ckpt_file=r-pad/taxpose/model-lq4b953m:v0 wandb.group=rlbench_put_money_in_safe + +# # push_button + +# ./launch.sh local-docker 0 python scripts/eval_rlbench.py --config-name commands/rlbench/push_button/taxpose_tc/eval_rlbench.yaml num_trials=100 policy_spec.include_rgb_features=True model.feature_channels=3 checkpoints.ckpt_file=r-pad/taxpose/model-0kxpww1x:v0 wandb.group=rlbench_push_button + +# # reach_target + +# ./launch.sh local-docker 0 python scripts/eval_rlbench.py --config-name commands/rlbench/reach_target/taxpose_tc/eval_rlbench.yaml num_trials=100 policy_spec.include_rgb_features=True model.feature_channels=3 checkpoints.ckpt_file=r-pad/taxpose/model-w5kjqoph:v0 wandb.group=rlbench_reach_target + +# # slide_block_to_target + +# ./launch.sh local-docker 0 python scripts/eval_rlbench.py --config-name commands/rlbench/slide_block_to_target/taxpose_tc/eval_rlbench.yaml num_trials=100 policy_spec.include_rgb_features=True model.feature_channels=3 checkpoints.ckpt_file=r-pad/taxpose/model-fct8vbrq:v0 wandb.group=rlbench_slide_block_to_target + +# # stack_wine + +# ./launch.sh local-docker 0 python scripts/eval_rlbench.py --config-name commands/rlbench/stack_wine/taxpose_tc/eval_rlbench.yaml num_trials=100 policy_spec.include_rgb_features=True model.feature_channels=3 checkpoints.ckpt_file=r-pad/taxpose/model-cbe1hgx4:v0 wandb.group=rlbench_stack_wine + +# # take_money_out_safe + +# ./launch.sh local-docker 0 python scripts/eval_rlbench.py --config-name commands/rlbench/take_money_out_safe/taxpose_tc/eval_rlbench.yaml num_trials=100 policy_spec.include_rgb_features=True model.feature_channels=3 checkpoints.ckpt_file=r-pad/taxpose/model-j3swo5k7:v0 wandb.group=rlbench_take_money_out_safe + +# # take_umbrella_out_of_umbrella_stand + +# ./launch.sh local-docker 0 python scripts/eval_rlbench.py --config-name commands/rlbench/take_umbrella_out_of_umbrella_stand/taxpose_tc/eval_rlbench.yaml num_trials=100 policy_spec.include_rgb_features=True model.feature_channels=3 checkpoints.ckpt_file=r-pad/taxpose/model-txvpna0v:v0 wandb.group=rlbench_take_umbrella_out_of_umbrella_stand + + +########################### No gripper, wrist cam ############################# + +# pick_and_lift: hbdcvydg +# pick_up_cup: gwfjqpfk +# push_button: mmc5fhzu +# put_knife_on_chopping_board: rjscih24 +# put_money_in_safe: gj3h3o3c +# reach_target: v8vv53tx +# slide_block_to_target: sejd7pz0 +# stack_wine: 3hyo3r7q +# take_money_out_safe: 69gka1ew +# take_umbrella_out_of_umbrella_stand: b48mz8e1 + # pick_and_lift -./launch.sh local-docker 0 python scripts/eval_rlbench.py --config-name commands/rlbench/pick_and_lift/taxpose_tc/eval_rlbench.yaml num_trials=100 policy_spec.include_rgb_features=True model.feature_channels=3 checkpoints.ckpt_file=r-pad/taxpose/model-9tx1uje9:v0 wandb.group=rlbench_pick_and_lift +./launch.sh local-docker 0 python scripts/eval_rlbench.py --config-name commands/rlbench/pick_and_lift/taxpose_tc/eval_rlbench.yaml num_trials=100 policy_spec.include_rgb_features=True model.feature_channels=3 checkpoints.ckpt_file=r-pad/taxpose/model-hbdcvydg:v0 wandb.group=rlbench_pick_and_lift # pick_up_cup -./launch.sh local-docker 0 python scripts/eval_rlbench.py --config-name commands/rlbench/pick_up_cup/taxpose_tc/eval_rlbench.yaml num_trials=100 policy_spec.include_rgb_features=True model.feature_channels=3 checkpoints.ckpt_file=r-pad/taxpose/model-9pfjeq0j:v0 wandb.group=rlbench_pick_up_cup +./launch.sh local-docker 0 python scripts/eval_rlbench.py --config-name commands/rlbench/pick_up_cup/taxpose_tc/eval_rlbench.yaml num_trials=100 policy_spec.include_rgb_features=True model.feature_channels=3 checkpoints.ckpt_file=r-pad/taxpose/model-gwfjqpfk:v0 wandb.group=rlbench_pick_up_cup # put_knife_on_chopping_board -./launch.sh local-docker 0 python scripts/eval_rlbench.py --config-name commands/rlbench/put_knife_on_chopping_board/taxpose_tc/eval_rlbench.yaml num_trials=100 policy_spec.include_rgb_features=True model.feature_channels=3 checkpoints.ckpt_file=r-pad/taxpose/model-c9u0u4np:v0 wandb.group=rlbench_put_knife_on_chopping_board +./launch.sh local-docker 0 python scripts/eval_rlbench.py --config-name commands/rlbench/put_knife_on_chopping_board/taxpose_tc/eval_rlbench.yaml num_trials=100 policy_spec.include_rgb_features=True model.feature_channels=3 checkpoints.ckpt_file=r-pad/taxpose/model-rjscih24:v0 wandb.group=rlbench_put_knife_on_chopping_board # put_money_in_safe -./launch.sh local-docker 0 python scripts/eval_rlbench.py --config-name commands/rlbench/put_money_in_safe/taxpose_tc/eval_rlbench.yaml num_trials=100 policy_spec.include_rgb_features=True model.feature_channels=3 checkpoints.ckpt_file=r-pad/taxpose/model-lq4b953m:v0 wandb.group=rlbench_put_money_in_safe +./launch.sh local-docker 0 python scripts/eval_rlbench.py --config-name commands/rlbench/put_money_in_safe/taxpose_tc/eval_rlbench.yaml num_trials=100 policy_spec.include_rgb_features=True model.feature_channels=3 checkpoints.ckpt_file=r-pad/taxpose/model-gj3h3o3c:v0 wandb.group=rlbench_put_money_in_safe # push_button -./launch.sh local-docker 0 python scripts/eval_rlbench.py --config-name commands/rlbench/push_button/taxpose_tc/eval_rlbench.yaml num_trials=100 policy_spec.include_rgb_features=True model.feature_channels=3 checkpoints.ckpt_file=r-pad/taxpose/model-0kxpww1x:v0 wandb.group=rlbench_push_button +./launch.sh local-docker 0 python scripts/eval_rlbench.py --config-name commands/rlbench/push_button/taxpose_tc/eval_rlbench.yaml num_trials=100 policy_spec.include_rgb_features=True model.feature_channels=3 checkpoints.ckpt_file=r-pad/taxpose/model-mmc5fhzu:v0 wandb.group=rlbench_push_button # reach_target -./launch.sh local-docker 0 python scripts/eval_rlbench.py --config-name commands/rlbench/reach_target/taxpose_tc/eval_rlbench.yaml num_trials=100 policy_spec.include_rgb_features=True model.feature_channels=3 checkpoints.ckpt_file=r-pad/taxpose/model-w5kjqoph:v0 wandb.group=rlbench_reach_target +./launch.sh local-docker 0 python scripts/eval_rlbench.py --config-name commands/rlbench/reach_target/taxpose_tc/eval_rlbench.yaml num_trials=100 policy_spec.include_rgb_features=True model.feature_channels=3 checkpoints.ckpt_file=r-pad/taxpose/model-v8vv53tx:v0 wandb.group=rlbench_reach_target # slide_block_to_target -./launch.sh local-docker 0 python scripts/eval_rlbench.py --config-name commands/rlbench/slide_block_to_target/taxpose_tc/eval_rlbench.yaml num_trials=100 policy_spec.include_rgb_features=True model.feature_channels=3 checkpoints.ckpt_file=r-pad/taxpose/model-fct8vbrq:v0 wandb.group=rlbench_slide_block_to_target +./launch.sh local-docker 0 python scripts/eval_rlbench.py --config-name commands/rlbench/slide_block_to_target/taxpose_tc/eval_rlbench.yaml num_trials=100 policy_spec.include_rgb_features=True model.feature_channels=3 checkpoints.ckpt_file=r-pad/taxpose/model-sejd7pz0:v0 wandb.group=rlbench_slide_block_to_target # stack_wine -./launch.sh local-docker 0 python scripts/eval_rlbench.py --config-name commands/rlbench/stack_wine/taxpose_tc/eval_rlbench.yaml num_trials=100 policy_spec.include_rgb_features=True model.feature_channels=3 checkpoints.ckpt_file=r-pad/taxpose/model-cbe1hgx4:v0 wandb.group=rlbench_stack_wine +./launch.sh local-docker 0 python scripts/eval_rlbench.py --config-name commands/rlbench/stack_wine/taxpose_tc/eval_rlbench.yaml num_trials=100 policy_spec.include_rgb_features=True model.feature_channels=3 checkpoints.ckpt_file=r-pad/taxpose/model-3hyo3r7q:v0 wandb.group=rlbench_stack_wine # take_money_out_safe -./launch.sh local-docker 0 python scripts/eval_rlbench.py --config-name commands/rlbench/take_money_out_safe/taxpose_tc/eval_rlbench.yaml num_trials=100 policy_spec.include_rgb_features=True model.feature_channels=3 checkpoints.ckpt_file=r-pad/taxpose/model-j3swo5k7:v0 wandb.group=rlbench_take_money_out_safe +./launch.sh local-docker 0 python scripts/eval_rlbench.py --config-name commands/rlbench/take_money_out_safe/taxpose_tc/eval_rlbench.yaml num_trials=100 policy_spec.include_rgb_features=True model.feature_channels=3 checkpoints.ckpt_file=r-pad/taxpose/model-69gka1ew:v0 wandb.group=rlbench_take_money_out_safe # take_umbrella_out_of_umbrella_stand -./launch.sh local-docker 0 python scripts/eval_rlbench.py --config-name commands/rlbench/take_umbrella_out_of_umbrella_stand/taxpose_tc/eval_rlbench.yaml num_trials=100 policy_spec.include_rgb_features=True model.feature_channels=3 checkpoints.ckpt_file=r-pad/taxpose/model-txvpna0v:v0 wandb.group=rlbench_take_umbrella_out_of_umbrella_stand +./launch.sh local-docker 0 python scripts/eval_rlbench.py --config-name commands/rlbench/take_umbrella_out_of_umbrella_stand/taxpose_tc/eval_rlbench.yaml num_trials=100 policy_spec.include_rgb_features=True model.feature_channels=3 checkpoints.ckpt_file=r-pad/taxpose/model-b48mz8e1:v0 wandb.group=rlbench_take_umbrella_out_of_umbrella_stand diff --git a/taxpose/datasets/rlbench.py b/taxpose/datasets/rlbench.py index f80da66..0048df0 100644 --- a/taxpose/datasets/rlbench.py +++ b/taxpose/datasets/rlbench.py @@ -261,6 +261,10 @@ class RLBenchPointCloudDatasetConfig: anchor_mode: AnchorMode = AnchorMode.SINGLE_OBJECT action_mode: ActionMode = ActionMode.OBJECT + # Whether to include the wrist camera. + include_wrist_cam: bool = False + gripper_in_first_phase: bool = False + class RLBenchPointCloudDataset(Dataset[PlacementPointCloudData]): def __init__(self, cfg: RLBenchPointCloudDatasetConfig): @@ -273,6 +277,8 @@ def __init__(self, cfg: RLBenchPointCloudDatasetConfig): use_first_as_init_keyframe=cfg.use_first_as_init_keyframe, anchor_mode=cfg.anchor_mode, action_mode=cfg.action_mode, + include_wrist_cam=cfg.include_wrist_cam, + gripper_in_first_phase=cfg.gripper_in_first_phase, ) self.cfg = cfg