From a62f5f22fd41fe9097f66a56d79013fe3dd38cf5 Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Mon, 12 Sep 2022 17:31:59 -0700 Subject: [PATCH 01/42] Add post_wrappers entry in common config --- src/imitation/scripts/common/common.py | 10 ++++++++-- src/imitation/scripts/train_rl.py | 13 +++++++++---- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/src/imitation/scripts/common/common.py b/src/imitation/scripts/common/common.py index c215460e2..938572d55 100644 --- a/src/imitation/scripts/common/common.py +++ b/src/imitation/scripts/common/common.py @@ -3,9 +3,10 @@ import contextlib import logging import os -from typing import Any, Mapping, Sequence, Tuple, Union +from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Union import sacred +from gym import Env from stable_baselines3.common import vec_env from imitation.scripts.common import wb @@ -33,6 +34,7 @@ def config(): num_vec = 8 # number of environments in VecEnv parallel = True # Use SubprocVecEnv rather than DummyVecEnv max_episode_steps = None # Set to positive int to limit episode horizons + post_wrappers = [] # Sequence of wrappers to apply to each env in the VecEnv env_make_kwargs = {} # The kwargs passed to `spec.make`. locals() # quieten flake8 @@ -136,6 +138,7 @@ def make_venv( parallel: bool, log_dir: str, max_episode_steps: int, + post_wrappers: Optional[Sequence[Callable[[Env, int], Env]]], env_make_kwargs: Mapping[str, Any], **kwargs, ) -> vec_env.VecEnv: @@ -149,6 +152,8 @@ def make_venv( max_episode_steps: If not None, then a TimeLimit wrapper is applied to each environment to artificially limit the maximum number of timesteps in an episode. + post_wrappers: If specified, iteratively wraps each environment with each + of the wrappers specified in the sequence. log_dir: Logs episode return statistics to a subdirectory 'monitor`. env_make_kwargs: The kwargs passed to `spec.make` of a gym environment. kwargs: Passed through to `util.make_vec_env`. @@ -163,8 +168,9 @@ def make_venv( num_vec, seed=_seed, parallel=parallel, - max_episode_steps=max_episode_steps, log_dir=log_dir, + max_episode_steps=max_episode_steps, + post_wrappers=post_wrappers, env_make_kwargs=env_make_kwargs, **kwargs, ) diff --git a/src/imitation/scripts/train_rl.py b/src/imitation/scripts/train_rl.py index 7122cd701..89c05efc3 100644 --- a/src/imitation/scripts/train_rl.py +++ b/src/imitation/scripts/train_rl.py @@ -14,6 +14,7 @@ import warnings from typing import Any, Mapping, Optional +from sacred.config.custom_containers import ReadOnlyDict from sacred.observers import FileStorageObserver from stable_baselines3.common import callbacks from stable_baselines3.common.vec_env import VecNormalize @@ -22,7 +23,8 @@ from imitation.policies import serialize from imitation.rewards.reward_wrapper import RewardVecEnvWrapper from imitation.rewards.serialize import load_reward -from imitation.scripts.common import common, rl, train +from imitation.scripts.common import common as scripts_common +from imitation.scripts.common import rl, train from imitation.scripts.config.train_rl import train_rl_ex @@ -41,6 +43,7 @@ def train_rl( policy_save_interval: int, policy_save_final: bool, agent_path: Optional[str], + common: ReadOnlyDict, ) -> Mapping[str, float]: """Trains an expert policy from scratch and saves the rollouts and policy. @@ -87,14 +90,16 @@ def train_rl( Returns: The return value of `rollout_stats()` using the final policy. """ - custom_logger, log_dir = common.setup_logging() + custom_logger, log_dir = scripts_common.setup_logging() rollout_dir = osp.join(log_dir, "rollouts") policy_dir = osp.join(log_dir, "policies") os.makedirs(rollout_dir, exist_ok=True) os.makedirs(policy_dir, exist_ok=True) - post_wrappers = [lambda env, idx: wrappers.RolloutInfoWrapper(env)] - with common.make_venv(post_wrappers=post_wrappers) as venv: + all_post_wrappers = common["post_wrappers"] + [ + lambda env, idx: wrappers.RolloutInfoWrapper(env) + ] + with scripts_common.make_venv(post_wrappers=all_post_wrappers) as venv: callback_objs = [] if reward_type is not None: reward_fn = load_reward( From ec3755b27cf2a29e17f5ef44a1c9324fa7778dab Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Mon, 12 Sep 2022 17:32:22 -0700 Subject: [PATCH 02/42] Add config to train on atari, use post_wrappers --- src/imitation/scripts/config/train_rl.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/imitation/scripts/config/train_rl.py b/src/imitation/scripts/config/train_rl.py index ae3add76f..4ab37e3d3 100644 --- a/src/imitation/scripts/config/train_rl.py +++ b/src/imitation/scripts/config/train_rl.py @@ -1,6 +1,9 @@ """Configuration settings for train_rl, training a policy with RL.""" import sacred +from gym.wrappers import TimeLimit +from seals.util import AutoResetWrapper +from stable_baselines3.common.atari_wrappers import AtariWrapper from imitation.scripts.common import common, rl, train @@ -48,6 +51,18 @@ def acrobot(): common = dict(env_name="Acrobot-v1") +@train_rl_ex.named_config +def asteroids(): + common = dict( + env_name="AsteroidsNoFrameskip-v4", + post_wrappers=[ + lambda env, _: AutoResetWrapper(env), + lambda env, _: AtariWrapper(env, terminal_on_life_loss=False), + lambda env, _: TimeLimit(env, max_episode_steps=100_000), + ], + ) + + @train_rl_ex.named_config def ant(): common = dict(env_name="Ant-v2") From 2ad0864b0dc3d52c54f4e194ce533846e8a288fe Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Mon, 12 Sep 2022 17:55:03 -0700 Subject: [PATCH 03/42] Add config option for CNN reward net --- src/imitation/scripts/common/reward.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/imitation/scripts/common/reward.py b/src/imitation/scripts/common/reward.py index 548bca855..1cfd5b1e1 100644 --- a/src/imitation/scripts/common/reward.py +++ b/src/imitation/scripts/common/reward.py @@ -61,6 +61,11 @@ def reward_ensemble(): locals() +@reward_ingredient.named_config +def cnn_reward(): + net_cls = reward_nets.CnnRewardNet # noqa: F841 + + @reward_ingredient.config_hook def config_hook(config, command_name, logger): """Sets default values for `net_cls` and `net_kwargs`.""" @@ -72,7 +77,10 @@ def config_hook(config, command_name, logger): default_net = reward_nets.BasicShapedRewardNet res["net_cls"] = default_net - if "normalize_input_layer" not in config["reward"]["net_kwargs"]: + if ( + "normalize_input_layer" not in config["reward"]["net_kwargs"] + and config["reward"]["net_cls"] != reward_nets.CnnRewardNet + ): res["net_kwargs"] = {"normalize_input_layer": networks.RunningNorm} if "net_cls" in res and issubclass(res["net_cls"], reward_nets.RewardEnsemble): From 70f17e5c6fdcc55dd7c52c790293b05b1d1cedad Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Mon, 12 Sep 2022 17:55:53 -0700 Subject: [PATCH 04/42] Add config to train preference comparisons on Asteroids --- .../config/train_preference_comparisons.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/src/imitation/scripts/config/train_preference_comparisons.py b/src/imitation/scripts/config/train_preference_comparisons.py index ba4e9483c..c11084645 100644 --- a/src/imitation/scripts/config/train_preference_comparisons.py +++ b/src/imitation/scripts/config/train_preference_comparisons.py @@ -1,6 +1,9 @@ """Configuration for imitation.scripts.train_preference_comparisons.""" import sacred +from gym.wrappers import TimeLimit +from seals.util import AutoResetWrapper +from stable_baselines3.common.atari_wrappers import AtariWrapper from imitation.algorithms import preference_comparisons from imitation.scripts.common import common, reward, rl, train @@ -115,6 +118,30 @@ def seals_mountain_car(): common = dict(env_name="seals/MountainCar-v0") +@train_preference_comparisons_ex.named_config +def asteroids_fast(): + common = dict( + env_name="AsteroidsNoFrameskip-v4", + post_wrappers=[ + lambda env, _: AutoResetWrapper(env), + lambda env, _: AtariWrapper(env, terminal_on_life_loss=False), + lambda env, _: TimeLimit(env, max_episode_steps=100), + ], + ) + + +@train_preference_comparisons_ex.named_config +def asteroids(): + common = dict( + env_name="AsteroidsNoFrameskip-v4", + post_wrappers=[ + lambda env, _: AutoResetWrapper(env), + lambda env, _: AtariWrapper(env, terminal_on_life_loss=False), + lambda env, _: TimeLimit(env, max_episode_steps=100_000), + ], + ) + + @train_preference_comparisons_ex.named_config def fast(): # Minimize the amount of computation. Useful for test cases. From 262a35afda35c881a030edb78f0b8d9307767f9c Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Mon, 12 Sep 2022 17:58:30 -0700 Subject: [PATCH 05/42] Quell flake8 complaint --- src/imitation/scripts/train_rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imitation/scripts/train_rl.py b/src/imitation/scripts/train_rl.py index 89c05efc3..9ab00b56d 100644 --- a/src/imitation/scripts/train_rl.py +++ b/src/imitation/scripts/train_rl.py @@ -97,7 +97,7 @@ def train_rl( os.makedirs(policy_dir, exist_ok=True) all_post_wrappers = common["post_wrappers"] + [ - lambda env, idx: wrappers.RolloutInfoWrapper(env) + lambda env, idx: wrappers.RolloutInfoWrapper(env), ] with scripts_common.make_venv(post_wrappers=all_post_wrappers) as venv: callback_objs = [] From 74fb856c4bda58e00883e858baca6a229725df7d Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Tue, 13 Sep 2022 11:17:51 -0700 Subject: [PATCH 06/42] Add short-episode asteroids config --- src/imitation/scripts/config/train_rl.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/imitation/scripts/config/train_rl.py b/src/imitation/scripts/config/train_rl.py index 4ab37e3d3..d19b4458d 100644 --- a/src/imitation/scripts/config/train_rl.py +++ b/src/imitation/scripts/config/train_rl.py @@ -63,6 +63,18 @@ def asteroids(): ) +@train_rl_ex.named_config +def asteroids_short_episodes(): + common = dict( + env_name="AsteroidsNoFrameskip-v4", + post_wrappers=[ + lambda env, _: AutoResetWrapper(env), + lambda env, _: AtariWrapper(env, terminal_on_life_loss=False), + lambda env, _: TimeLimit(env, max_episode_steps=100), + ], + ) + + @train_rl_ex.named_config def ant(): common = dict(env_name="Ant-v2") From ea508bae3f6ea2fa5524dc23b41ff488c45b76b0 Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Tue, 13 Sep 2022 11:33:06 -0700 Subject: [PATCH 07/42] Add config to train CNN policies --- src/imitation/policies/base.py | 8 ++++++++ src/imitation/scripts/common/train.py | 9 +++++++++ 2 files changed, 17 insertions(+) diff --git a/src/imitation/policies/base.py b/src/imitation/policies/base.py index 3f9b0d919..83ce16786 100644 --- a/src/imitation/policies/base.py +++ b/src/imitation/policies/base.py @@ -88,6 +88,14 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs, net_arch=[1024, 1024]) +class CnnPolicy(policies.ActorCriticCnnPolicy): + """A CNN Actor-Critic policy.""" + + def __init__(self, *args, **kwargs): + """Builds CnnPolicy; arguments passed to `CnnActorCriticPolicy`.""" + super().__init__(*args, **kwargs) + + class NormalizeFeaturesExtractor(torch_layers.FlattenExtractor): """Feature extractor that flattens then normalizes input.""" diff --git a/src/imitation/scripts/common/train.py b/src/imitation/scripts/common/train.py index f5aa3c1bb..e4ffb2b2d 100644 --- a/src/imitation/scripts/common/train.py +++ b/src/imitation/scripts/common/train.py @@ -36,6 +36,15 @@ def sac(): policy_cls = base.SAC1024Policy # noqa: F841 +@train_ingredient.named_config +def cnn(): + policy_cls = base.CnnPolicy # noqa: F841 + # If features_extractor_class is not set, it will be set to a + # NormalizeFeaturesExtractor by default via the config hook, which implements an MLP. + # Therefore, to actually get this to implement a CNN, we need to set it here. + policy_kwargs = {"features_extractor_class": torch_layers.NatureCNN} # noqa: F841 + + @train_ingredient.named_config def normalize_disable(): policy_kwargs = { # noqa: F841 From 8ac9ae5925d3e607fcc9fd7e17f67e370b833ebd Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Tue, 13 Sep 2022 11:36:21 -0700 Subject: [PATCH 08/42] Fix line length --- src/imitation/scripts/common/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/imitation/scripts/common/train.py b/src/imitation/scripts/common/train.py index e4ffb2b2d..572f95e22 100644 --- a/src/imitation/scripts/common/train.py +++ b/src/imitation/scripts/common/train.py @@ -40,8 +40,8 @@ def sac(): def cnn(): policy_cls = base.CnnPolicy # noqa: F841 # If features_extractor_class is not set, it will be set to a - # NormalizeFeaturesExtractor by default via the config hook, which implements an MLP. - # Therefore, to actually get this to implement a CNN, we need to set it here. + # NormalizeFeaturesExtractor by default via the config hook, which implements an + # MLP. Therefore, to actually get this to implement a CNN, we need to set it here. policy_kwargs = {"features_extractor_class": torch_layers.NatureCNN} # noqa: F841 From a9d24aeba7ec0b3d4a4e8efdd4d3c600f5a0cd78 Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Tue, 13 Sep 2022 11:55:50 -0700 Subject: [PATCH 09/42] Make names of short-episode asteroids configs consistent --- src/imitation/scripts/config/train_preference_comparisons.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imitation/scripts/config/train_preference_comparisons.py b/src/imitation/scripts/config/train_preference_comparisons.py index c11084645..d655dfcde 100644 --- a/src/imitation/scripts/config/train_preference_comparisons.py +++ b/src/imitation/scripts/config/train_preference_comparisons.py @@ -119,7 +119,7 @@ def seals_mountain_car(): @train_preference_comparisons_ex.named_config -def asteroids_fast(): +def asteroids_short_episodes(): common = dict( env_name="AsteroidsNoFrameskip-v4", post_wrappers=[ From ad61105f491eab7e893985832039c783a5a3f9da Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Tue, 13 Sep 2022 12:09:05 -0700 Subject: [PATCH 10/42] Add docstring for common config --- src/imitation/scripts/train_rl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/imitation/scripts/train_rl.py b/src/imitation/scripts/train_rl.py index 9ab00b56d..36d604456 100644 --- a/src/imitation/scripts/train_rl.py +++ b/src/imitation/scripts/train_rl.py @@ -86,6 +86,7 @@ def train_rl( policy_save_final: If True, then save the policy right after training is finished. agent_path: Path to load warm-started agent. + common: Dummy argument for the `common` ingredient configuration. Returns: The return value of `rollout_stats()` using the final policy. From d90d7349971438d7058b541444beb7d213208ffd Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Tue, 13 Sep 2022 12:35:01 -0700 Subject: [PATCH 11/42] Add tests for training RL + pref comp on image envs --- tests/scripts/test_scripts.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index f461dc889..27e0c60bf 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -104,6 +104,8 @@ def test_main_console(script_mod): RL_SAC_NAMED_CONFIGS = ["rl.sac", "train.sac"] +ASTEROIDS_CNN_POLICY_CONFIG = ["asteroids_short_episodes", "train.cnn"] + @pytest.fixture( params=[ @@ -247,6 +249,20 @@ def test_train_preference_comparisons_reward_named_config(tmpdir, named_configs) assert isinstance(run.result, dict) +def test_train_preference_comparisons_image_env(tmpdir): + config_updates = dict(common=dict(log_root=tmpdir)) + run = train_preference_comparisons.train_preference_comparisons_ex.run( + named_configs=( + ["reward.cnn_reward"] + + ASTEROIDS_CNN_POLICY_CONFIG + + ALGO_FAST_CONFIGS["preference_comparison"] + ), + config_updates=config_updates, + ) + assert run.status == "COMPLETED" + assert isinstance(run.result, dict) + + def test_train_dagger_main(tmpdir): with pytest.warns(None) as record: run = train_imitation.train_imitation_ex.run( @@ -420,6 +436,16 @@ def test_train_rl_sac(tmpdir): assert isinstance(run.result, dict) +def test_train_rl_image_env(tmpdir): + config_updates = dict(common=dict(log_root=tmpdir)) + run = train_rl.train_rl_ex.run( + named_configs=ASTEROIDS_CNN_POLICY_CONFIG + ALGO_FAST_CONFIGS["rl"], + config_updates=config_updates, + ) + assert run.status == "COMPLETED" + assert isinstance(run.result, dict) + + # check if platform is macos EVAL_POLICY_CONFIGS: List[Dict] = [ From d92581cf374e00f1e6ce38008f28f2bed9ed4d29 Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Thu, 15 Sep 2022 18:30:53 -0700 Subject: [PATCH 12/42] First draft of training CNNs with DAgger - with backup CNN implementation visible in comments --- src/imitation/policies/base.py | 107 +++++++++++++++++- src/imitation/rewards/reward_nets.py | 18 +-- .../scripts/config/train_imitation.py | 45 ++++++++ src/imitation/scripts/train_imitation.py | 14 ++- src/imitation/util/networks.py | 10 ++ tests/rewards/test_reward_nets.py | 4 +- tests/scripts/test_scripts.py | 19 ++++ 7 files changed, 197 insertions(+), 20 deletions(-) diff --git a/src/imitation/policies/base.py b/src/imitation/policies/base.py index 83ce16786..f667e98a2 100644 --- a/src/imitation/policies/base.py +++ b/src/imitation/policies/base.py @@ -1,12 +1,13 @@ """Custom policy classes and convenience methods.""" import abc -from typing import Type +from typing import Tuple, Type import gym import numpy as np import torch as th from stable_baselines3.common import policies, torch_layers +from stable_baselines3.common.distributions import Distribution from stable_baselines3.sac import policies as sac_policies from torch import nn @@ -89,11 +90,111 @@ def __init__(self, *args, **kwargs): class CnnPolicy(policies.ActorCriticCnnPolicy): - """A CNN Actor-Critic policy.""" + """A CNN Actor-Critic policy. - def __init__(self, *args, **kwargs): + This policy optionally transposes its observation inputs. Note that if this is done, + the policy expects the observation space to be a Box with values ranging from 0 to + 255. Methods are copy-pasted from StableBaselines 3's ActorCriticPolicy, with an + initial check whether or not to transpose an observation input. + """ + + def __init__(self, *args, transpose_input: bool = False, **kwargs): """Builds CnnPolicy; arguments passed to `CnnActorCriticPolicy`.""" + self.transpose_input = transpose_input + if self.transpose_input: + kwargs.update( + { + "observation_space": self.transpose_space( + kwargs["observation_space"], + ), + }, + ) super().__init__(*args, **kwargs) + # self.base_policy = policies.ActorCriticCnnPolicy(*args, **kwargs) + + def transpose_space(self, observation_space: gym.spaces.Box) -> gym.spaces.Box: + if not isinstance(observation_space, gym.spaces.Box): + raise TypeError("This code assumes that observation spaces are gym Boxes.") + if not (observation_space.low == 0 and observation_space.high == 255): + error_msg = ( + "This code assumes the observation space values range from " + + "0 to 255." + ) + raise ValueError(error_msg) + h, w, c = observation_space.shape + new_shape = (c, h, w) + return gym.spaces.Box( + low=0, + high=255, + shape=new_shape, + dtype=observation_space.dtype, + ) + + def forward( + self, + obs: th.Tensor, + deterministic: bool = False, + ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + if self.transpose_input: + obs_ = networks.cnn_transpose(obs) + else: + obs_ = obs + # return self.base_policy.forward(obs_, deterministic) + # Preprocess the observation if needed + features = self.extract_features(obs_) + latent_pi, latent_vf = self.mlp_extractor(features) + # Evaluate the values for the given observations + values = self.value_net(latent_vf) + distribution = self._get_action_dist_from_latent(latent_pi) + actions = distribution.get_actions(deterministic=deterministic) + log_prob = distribution.log_prob(actions) + return actions, values, log_prob + + # def _predict(self, observation: th.Tensor, deterministic: bool = False + # ) -> th.Tensor: + # if self.transpose_input: + # obs_ = networks.cnn_transpose(observation) + # else: + # obs_ = observation + # return self.base_policy._predict(obs_, deterministic) + + def evaluate_actions( + self, + obs: th.Tensor, + actions: th.Tensor, + ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + if self.transpose_input: + obs_ = networks.cnn_transpose(obs) + else: + obs_ = obs + # return self.base_policy.evaluate_actions(obs_, actions) + # Preprocess the observation if needed + features = self.extract_features(obs_) + latent_pi, latent_vf = self.mlp_extractor(features) + distribution = self._get_action_dist_from_latent(latent_pi) + log_prob = distribution.log_prob(actions) + values = self.value_net(latent_vf) + return values, log_prob, distribution.entropy() + + def get_distribution(self, obs: th.Tensor) -> Distribution: + if self.transpose_input: + obs_ = networks.cnn_transpose(obs) + else: + obs_ = obs + # return self.base_policy.get_distribution(obs_) + features = self.extract_features(obs_) + latent_pi = self.mlp_extractor.forward_actor(features) + return self._get_action_dist_from_latent(latent_pi) + + def predict_values(self, obs: th.Tensor) -> th.Tensor: + if self.transpose_input: + obs_ = networks.cnn_transpose(obs) + else: + obs_ = obs + # return self.base_policy.predict_values(obs_) + features = self.extract_features(obs_) + latent_vf = self.mlp_extractor.forward_critic(features) + return self.value_net(latent_vf) class NormalizeFeaturesExtractor(torch_layers.FlattenExtractor): diff --git a/src/imitation/rewards/reward_nets.py b/src/imitation/rewards/reward_nets.py index 41a7e6f82..4a983d8ba 100644 --- a/src/imitation/rewards/reward_nets.py +++ b/src/imitation/rewards/reward_nets.py @@ -557,10 +557,12 @@ def forward( """ inputs = [] if self.use_state: - state_ = cnn_transpose(state) if self.hwc_format else state + state_ = networks.cnn_transpose(state) if self.hwc_format else state inputs.append(state_) if self.use_next_state: - next_state_ = cnn_transpose(next_state) if self.hwc_format else next_state + next_state_ = ( + networks.cnn_transpose(next_state) if self.hwc_format else next_state + ) inputs.append(next_state_) inputs_concat = th.cat(inputs, dim=1) @@ -586,16 +588,6 @@ def forward( return rewards -def cnn_transpose(tens: th.Tensor) -> th.Tensor: - """Transpose a (b,h,w,c)-formatted tensor to (b,c,h,w) format.""" - if len(tens.shape) == 4: - return th.permute(tens, (0, 3, 1, 2)) - else: - raise ValueError( - f"Invalid input: len(tens.shape) = {len(tens.shape)} != 4.", - ) - - class NormalizedRewardNet(PredictProcessedWrapper): """A reward net that normalizes the output of its base network.""" @@ -861,7 +853,7 @@ def __init__( ) def forward(self, state: th.Tensor) -> th.Tensor: - state_ = cnn_transpose(state) if self.hwc_format else state + state_ = networks.cnn_transpose(state) if self.hwc_format else state return self._potential_net(state_) diff --git a/src/imitation/scripts/config/train_imitation.py b/src/imitation/scripts/config/train_imitation.py index c2466a936..cdb00fa8d 100644 --- a/src/imitation/scripts/config/train_imitation.py +++ b/src/imitation/scripts/config/train_imitation.py @@ -2,7 +2,12 @@ import sacred import torch as th +from gym.wrappers import TimeLimit +from seals.util import AutoResetWrapper +from stable_baselines3.common import torch_layers +from stable_baselines3.common.atari_wrappers import AtariWrapper +from imitation.policies import base from imitation.scripts.common import common from imitation.scripts.common import demonstrations as demos_common from imitation.scripts.common import expert, train @@ -105,6 +110,46 @@ def seals_humanoid(): common = dict(env_name="seals/Humanoid-v0") +@train_imitation_ex.named_config +def asteroids_short_episodes(): + common = dict( + env_name="AsteroidsNoFrameskip-v4", + post_wrappers=[ + lambda env, _: AutoResetWrapper(env), + lambda env, _: AtariWrapper(env, terminal_on_life_loss=False), + lambda env, _: TimeLimit(env, max_episode_steps=100), + ], + ) + train = dict( + policy_kwargs=dict(transpose_input=True), + ) + transpose_obs = True + + +@train_imitation_ex.named_config +def asteroids(): + common = dict( + env_name="AsteroidsNoFrameskip-v4", + post_wrappers=[ + lambda env, _: AutoResetWrapper(env), + lambda env, _: AtariWrapper(env, terminal_on_life_loss=False), + lambda env, _: TimeLimit(env, max_episode_steps=100_000), + ], + ) + train = dict( + policy_kwargs=dict(transpose_input=True), + ) + transpose_obs = True + + +@train_imitation_ex.named_config +def cnn(): + train = dict( + policy_cls=base.CnnPolicy, + policy_kwargs=dict(features_extractor_class=torch_layers.NatureCNN), + ) + + @train_imitation_ex.named_config def fast(): dagger = dict(total_timesteps=50) diff --git a/src/imitation/scripts/train_imitation.py b/src/imitation/scripts/train_imitation.py index 8d7085577..4a007c18d 100644 --- a/src/imitation/scripts/train_imitation.py +++ b/src/imitation/scripts/train_imitation.py @@ -69,6 +69,7 @@ def train_imitation( dagger: Mapping[str, Any], use_dagger: bool, agent_path: Optional[str], + transpose_obs: bool, ) -> Mapping[str, Mapping[str, float]]: """Runs DAgger (if `use_dagger`) or BC (otherwise) training. @@ -80,6 +81,9 @@ def train_imitation( agent_path: Path to serialized policy. If provided, then load the policy from this path. Otherwise, make a new policy. Specify only if policy_cls and policy_kwargs are not specified. + transpose_obs: Whether observations will need to be transposed to be fed into + the policy. Should usually be True for image environments, and usually be + False otherwise. Returns: Statistics for rollouts from the trained policy and demonstration data. @@ -93,9 +97,15 @@ def train_imitation( if not use_dagger or dagger["use_offline_rollouts"]: expert_trajs = demonstrations.get_expert_trajectories() + if transpose_obs: + # this modification only affects the observation space the BC trainer + # expects to deal with + bc_trainer_venv = vec_env.vec_transpose.VecTransposeImage(venv) + else: + bc_trainer_venv = venv bc_trainer = bc_algorithm.BC( - observation_space=venv.observation_space, - action_space=venv.action_space, + observation_space=bc_trainer_venv.observation_space, + action_space=bc_trainer_venv.action_space, policy=imit_policy, demonstrations=expert_trajs, custom_logger=custom_logger, diff --git a/src/imitation/util/networks.py b/src/imitation/util/networks.py index 664f5f081..521d8d589 100644 --- a/src/imitation/util/networks.py +++ b/src/imitation/util/networks.py @@ -196,6 +196,16 @@ def update_stats(self, batch: th.Tensor) -> None: self.num_batches += 1 +def cnn_transpose(tens: th.Tensor) -> th.Tensor: + """Transpose a (b,h,w,c)-formatted tensor to (b,c,h,w) format.""" + if len(tens.shape) == 4: + return th.permute(tens, (0, 3, 1, 2)) + else: + raise ValueError( + f"Invalid input: len(tens.shape) = {len(tens.shape)} != 4.", + ) + + def build_mlp( in_size: int, hid_sizes: Iterable[int], diff --git a/tests/rewards/test_reward_nets.py b/tests/rewards/test_reward_nets.py index c0e3212e0..e00a47cdc 100644 --- a/tests/rewards/test_reward_nets.py +++ b/tests/rewards/test_reward_nets.py @@ -214,10 +214,10 @@ def test_cnn_transpose_input_validation(dimensions: int): tens = th.zeros(shape) if dimensions == 4: # should succeed - reward_nets.cnn_transpose(tens) + networks.cnn_transpose(tens) else: # should fail with pytest.raises(ValueError, match="Invalid input: "): - reward_nets.cnn_transpose(tens) + networks.cnn_transpose(tens) def _sample(space, n): diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index 27e0c60bf..a8c975ed2 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -61,6 +61,10 @@ PENDULUM_TEST_DATA_PATH = TEST_DATA_PATH / "expert_models/pendulum_0/" PENDULUM_TEST_ROLLOUT_PATH = PENDULUM_TEST_DATA_PATH / "rollouts/final.pkl" +ASTEROIDS_TEST_DATA_PATH = TEST_DATA_PATH / "expert_models/asteroids_short_episodes_0/" +ASTEROIDS_TEST_ROLLOUT_PATH = TEST_DATA_PATH / "rollouts/final.pkl" +ASTEROIDS_TEST_POLICY_PATH = TEST_DATA_PATH / "policies/final" + OLD_FMT_ROLLOUT_TEST_DATA_PATH = TEST_DATA_PATH / "old_format_rollout.pkl" @@ -310,6 +314,21 @@ def test_train_dagger_warmstart(tmpdir): assert isinstance(run_warmstart.result, dict) +def test_train_dagger_with_image_env(tmpdir): + run = train_imitation.train_imitation_ex.run( + command_name="dagger", + named_configs=( + ["asteroids_short_episodes", "cnn"] + ALGO_FAST_CONFIGS["imitation"] + ), + config_updates=dict( + common=dict(log_root=tmpdir), + demonstrations=dict(rollout_path=ASTEROIDS_TEST_ROLLOUT_PATH), + ), + ) + assert run.status == "COMPLETED" + assert isinstance(run.result, dict) + + def test_train_bc_main_with_none_demonstrations_raises_value_error(tmpdir): with pytest.raises(ValueError, match=".*n_expert_demos.*rollout_path.*"): train_imitation.train_imitation_ex.run( From 8f0c7c00e7ddd9b74c48640d4898837df36a8e43 Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Thu, 15 Sep 2022 18:32:22 -0700 Subject: [PATCH 13/42] Delete commented-out alternate implementation of CnnPolicy --- src/imitation/policies/base.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/imitation/policies/base.py b/src/imitation/policies/base.py index f667e98a2..36f7f1009 100644 --- a/src/imitation/policies/base.py +++ b/src/imitation/policies/base.py @@ -110,7 +110,6 @@ def __init__(self, *args, transpose_input: bool = False, **kwargs): }, ) super().__init__(*args, **kwargs) - # self.base_policy = policies.ActorCriticCnnPolicy(*args, **kwargs) def transpose_space(self, observation_space: gym.spaces.Box) -> gym.spaces.Box: if not isinstance(observation_space, gym.spaces.Box): @@ -139,7 +138,6 @@ def forward( obs_ = networks.cnn_transpose(obs) else: obs_ = obs - # return self.base_policy.forward(obs_, deterministic) # Preprocess the observation if needed features = self.extract_features(obs_) latent_pi, latent_vf = self.mlp_extractor(features) @@ -150,14 +148,6 @@ def forward( log_prob = distribution.log_prob(actions) return actions, values, log_prob - # def _predict(self, observation: th.Tensor, deterministic: bool = False - # ) -> th.Tensor: - # if self.transpose_input: - # obs_ = networks.cnn_transpose(observation) - # else: - # obs_ = observation - # return self.base_policy._predict(obs_, deterministic) - def evaluate_actions( self, obs: th.Tensor, @@ -167,7 +157,6 @@ def evaluate_actions( obs_ = networks.cnn_transpose(obs) else: obs_ = obs - # return self.base_policy.evaluate_actions(obs_, actions) # Preprocess the observation if needed features = self.extract_features(obs_) latent_pi, latent_vf = self.mlp_extractor(features) @@ -181,7 +170,6 @@ def get_distribution(self, obs: th.Tensor) -> Distribution: obs_ = networks.cnn_transpose(obs) else: obs_ = obs - # return self.base_policy.get_distribution(obs_) features = self.extract_features(obs_) latent_pi = self.mlp_extractor.forward_actor(features) return self._get_action_dist_from_latent(latent_pi) @@ -191,7 +179,6 @@ def predict_values(self, obs: th.Tensor) -> th.Tensor: obs_ = networks.cnn_transpose(obs) else: obs_ = obs - # return self.base_policy.predict_values(obs_) features = self.extract_features(obs_) latent_vf = self.mlp_extractor.forward_critic(features) return self.value_net(latent_vf) From e5be4ca20c28addc50b77a5df372a4affc3168d5 Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Fri, 16 Sep 2022 13:17:06 -0700 Subject: [PATCH 14/42] Fix notebook bug I introduced --- examples/5a_train_preference_comparisons_with_cnn.ipynb | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/5a_train_preference_comparisons_with_cnn.ipynb b/examples/5a_train_preference_comparisons_with_cnn.ipynb index 594de3745..293c05b27 100644 --- a/examples/5a_train_preference_comparisons_with_cnn.ipynb +++ b/examples/5a_train_preference_comparisons_with_cnn.ipynb @@ -147,7 +147,8 @@ "metadata": {}, "outputs": [], "source": [ - "from imitation.rewards.reward_nets import ShapedRewardNet, cnn_transpose\n", + "from imitation.util.networks import cnn_transpose\n", + "from imitation.rewards.reward_nets import ShapedRewardNet\n", "from imitation.rewards.reward_wrapper import RewardVecEnvWrapper\n", "\n", "\n", From 39c94ea896ad440cecbfe34576c8eaf91a75c7e7 Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Fri, 16 Sep 2022 13:25:20 -0700 Subject: [PATCH 15/42] Fix problem where transpose_obs had to be specified --- src/imitation/scripts/train_imitation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imitation/scripts/train_imitation.py b/src/imitation/scripts/train_imitation.py index 4a007c18d..a251c9a3a 100644 --- a/src/imitation/scripts/train_imitation.py +++ b/src/imitation/scripts/train_imitation.py @@ -69,7 +69,7 @@ def train_imitation( dagger: Mapping[str, Any], use_dagger: bool, agent_path: Optional[str], - transpose_obs: bool, + transpose_obs: bool = False, ) -> Mapping[str, Mapping[str, float]]: """Runs DAgger (if `use_dagger`) or BC (otherwise) training. From 209c61ac2fdc4645933eab94b475e43dfd4ad585 Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Fri, 16 Sep 2022 14:16:17 -0700 Subject: [PATCH 16/42] Fix check for transposition in CnnPolicy --- src/imitation/policies/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/imitation/policies/base.py b/src/imitation/policies/base.py index 36f7f1009..3c58fa1e0 100644 --- a/src/imitation/policies/base.py +++ b/src/imitation/policies/base.py @@ -114,7 +114,9 @@ def __init__(self, *args, transpose_input: bool = False, **kwargs): def transpose_space(self, observation_space: gym.spaces.Box) -> gym.spaces.Box: if not isinstance(observation_space, gym.spaces.Box): raise TypeError("This code assumes that observation spaces are gym Boxes.") - if not (observation_space.low == 0 and observation_space.high == 255): + if not ( + np.all(observation_space.low == 0) and np.all(observation_space.high == 255) + ): error_msg = ( "This code assumes the observation space values range from " + "0 to 255." From e51eaa967b7b021bf35509b118124c01acd78714 Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Fri, 16 Sep 2022 16:22:47 -0700 Subject: [PATCH 17/42] Reorder asteroids envs --- src/imitation/scripts/config/train_imitation.py | 8 ++++---- .../scripts/config/train_preference_comparisons.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/imitation/scripts/config/train_imitation.py b/src/imitation/scripts/config/train_imitation.py index cdb00fa8d..43a9c099e 100644 --- a/src/imitation/scripts/config/train_imitation.py +++ b/src/imitation/scripts/config/train_imitation.py @@ -111,13 +111,13 @@ def seals_humanoid(): @train_imitation_ex.named_config -def asteroids_short_episodes(): +def asteroids(): common = dict( env_name="AsteroidsNoFrameskip-v4", post_wrappers=[ lambda env, _: AutoResetWrapper(env), lambda env, _: AtariWrapper(env, terminal_on_life_loss=False), - lambda env, _: TimeLimit(env, max_episode_steps=100), + lambda env, _: TimeLimit(env, max_episode_steps=100_000), ], ) train = dict( @@ -127,13 +127,13 @@ def asteroids_short_episodes(): @train_imitation_ex.named_config -def asteroids(): +def asteroids_short_episodes(): common = dict( env_name="AsteroidsNoFrameskip-v4", post_wrappers=[ lambda env, _: AutoResetWrapper(env), lambda env, _: AtariWrapper(env, terminal_on_life_loss=False), - lambda env, _: TimeLimit(env, max_episode_steps=100_000), + lambda env, _: TimeLimit(env, max_episode_steps=100), ], ) train = dict( diff --git a/src/imitation/scripts/config/train_preference_comparisons.py b/src/imitation/scripts/config/train_preference_comparisons.py index d655dfcde..735541fc3 100644 --- a/src/imitation/scripts/config/train_preference_comparisons.py +++ b/src/imitation/scripts/config/train_preference_comparisons.py @@ -119,25 +119,25 @@ def seals_mountain_car(): @train_preference_comparisons_ex.named_config -def asteroids_short_episodes(): +def asteroids(): common = dict( env_name="AsteroidsNoFrameskip-v4", post_wrappers=[ lambda env, _: AutoResetWrapper(env), lambda env, _: AtariWrapper(env, terminal_on_life_loss=False), - lambda env, _: TimeLimit(env, max_episode_steps=100), + lambda env, _: TimeLimit(env, max_episode_steps=100_000), ], ) @train_preference_comparisons_ex.named_config -def asteroids(): +def asteroids_short_episodes(): common = dict( env_name="AsteroidsNoFrameskip-v4", post_wrappers=[ lambda env, _: AutoResetWrapper(env), lambda env, _: AtariWrapper(env, terminal_on_life_loss=False), - lambda env, _: TimeLimit(env, max_episode_steps=100_000), + lambda env, _: TimeLimit(env, max_episode_steps=100), ], ) From c91caf2f5d46766157b3db438617f4d94cbf0768 Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Fri, 16 Sep 2022 17:04:50 -0700 Subject: [PATCH 18/42] Try to get AIRL/GAIL working for image envs --- .../algorithms/adversarial/common.py | 10 ++++-- .../scripts/config/train_adversarial.py | 32 +++++++++++++++++++ tests/scripts/test_scripts.py | 24 ++++++++++++-- 3 files changed, 62 insertions(+), 4 deletions(-) diff --git a/src/imitation/algorithms/adversarial/common.py b/src/imitation/algorithms/adversarial/common.py index eee6937e4..fafbd047c 100644 --- a/src/imitation/algorithms/adversarial/common.py +++ b/src/imitation/algorithms/adversarial/common.py @@ -117,6 +117,7 @@ def __init__( disc_opt_kwargs: Optional[Mapping] = None, gen_train_timesteps: Optional[int] = None, gen_replay_buffer_capacity: Optional[int] = None, + transpose_obs: bool = False, custom_logger: Optional[logger.HierarchicalLogger] = None, init_tensorboard: bool = False, init_tensorboard_graph: bool = False, @@ -152,6 +153,9 @@ def __init__( the generator that can be stored). By default this is equal to `gen_train_timesteps`, meaning that we sample only from the most recent batch of generator samples. + transpose_obs: Whether observations will need to be transposed from (h,w,c) + format to be manually fed into the policy. Should usually be True for + image environments, and usually be False otherwise. custom_logger: Where to log to; if None (default), creates a new logger. init_tensorboard: If True, makes various discriminator TensorBoard summaries. @@ -206,6 +210,7 @@ def __init__( self._summary_writer = thboard.SummaryWriter(summary_dir) venv = self.venv_buffering = wrappers.BufferingWrapper(self.venv) + self.transpose_obs = transpose_obs if debug_use_ground_truth: # Would use an identity reward fn here, but RewardFns can't see rewards. @@ -447,18 +452,19 @@ def _get_log_policy_act_prob( Returns: A batch of log policy action probabilities. """ + obs_th_ = networks.cnn_transpose(obs_th) if self.transpose_obs else obs_th if isinstance(self.policy, policies.ActorCriticPolicy): # policies.ActorCriticPolicy has a concrete implementation of # evaluate_actions to generate log_policy_act_prob given obs and actions. _, log_policy_act_prob_th, _ = self.policy.evaluate_actions( - obs_th, + obs_th_, acts_th, ) elif isinstance(self.policy, sac_policies.SACPolicy): gen_algo_actor = self.policy.actor assert gen_algo_actor is not None # generate log_policy_act_prob from SAC actor. - mean_actions, log_std, _ = gen_algo_actor.get_action_dist_params(obs_th) + mean_actions, log_std, _ = gen_algo_actor.get_action_dist_params(obs_th_) distribution = gen_algo_actor.action_dist.proba_distribution( mean_actions, log_std, diff --git a/src/imitation/scripts/config/train_adversarial.py b/src/imitation/scripts/config/train_adversarial.py index 3183ac9f6..596b8b041 100644 --- a/src/imitation/scripts/config/train_adversarial.py +++ b/src/imitation/scripts/config/train_adversarial.py @@ -1,6 +1,9 @@ """Configuration for imitation.scripts.train_adversarial.""" import sacred +from gym.wrappers import TimeLimit +from seals.util import AutoResetWrapper +from stable_baselines3.common.atari_wrappers import AtariWrapper from imitation.rewards import reward_nets from imitation.scripts.common import common, demonstrations, expert, reward, rl, train @@ -172,6 +175,35 @@ def seals_walker(): common = dict(env_name="seals/Walker2d-v0") +# Atari configs + + +@train_adversarial_ex.named_config +def asteroids(): + common = dict( + env_name="AsteroidsNoFrameskip-v4", + post_wrappers=[ + lambda env, _: AutoResetWrapper(env), + lambda env, _: AtariWrapper(env, terminal_on_life_loss=False), + lambda env, _: TimeLimit(env, max_episode_steps=100_000), + ], + ) + algorithm_kwargs = dict(transpose_obs=True) + + +@train_adversarial_ex.named_config +def asteroids_short_episodes(): + common = dict( + env_name="AsteroidsNoFrameskip-v4", + post_wrappers=[ + lambda env, _: AutoResetWrapper(env), + lambda env, _: AtariWrapper(env, terminal_on_life_loss=False), + lambda env, _: TimeLimit(env, max_episode_steps=100), + ], + ) + algorithm_kwargs = dict(transpose_obs=True) + + # Debug configs diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index a8c975ed2..b0743ad2e 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -62,8 +62,8 @@ PENDULUM_TEST_ROLLOUT_PATH = PENDULUM_TEST_DATA_PATH / "rollouts/final.pkl" ASTEROIDS_TEST_DATA_PATH = TEST_DATA_PATH / "expert_models/asteroids_short_episodes_0/" -ASTEROIDS_TEST_ROLLOUT_PATH = TEST_DATA_PATH / "rollouts/final.pkl" -ASTEROIDS_TEST_POLICY_PATH = TEST_DATA_PATH / "policies/final" +ASTEROIDS_TEST_ROLLOUT_PATH = ASTEROIDS_TEST_DATA_PATH / "rollouts/final.pkl" +ASTEROIDS_TEST_POLICY_PATH = ASTEROIDS_TEST_DATA_PATH / "policies/final" OLD_FMT_ROLLOUT_TEST_DATA_PATH = TEST_DATA_PATH / "old_format_rollout.pkl" @@ -637,6 +637,26 @@ def test_train_adversarial_algorithm_value_error(tmpdir): ) +@pytest.mark.parametrize("command", ("airl", "gail")) +def test_train_adversarial_image_env(tmpdir, command): + """Smoke test for imitation.scripts.train_adversarial on atari.""" + named_configs = ( + ASTEROIDS_CNN_POLICY_CONFIG + ALGO_FAST_CONFIGS["adversarial"] + + ["reward.cnn_reward"] + ) + config_updates = { + "common": dict(log_root=tmpdir), + "demonstrations": dict(rollout_path=ASTEROIDS_TEST_ROLLOUT_PATH), + } + run = train_adversarial.train_adversarial_ex.run( + command_name=command, + named_configs=named_configs, + config_updates=config_updates, + ) + assert run.status == "COMPLETED" + _check_train_ex_result(run.result) + + def test_transfer_learning(tmpdir: str) -> None: """Transfer learning smoke test. From aae2711f4cae5a85b3d6151fc60d8ae37fb41235 Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Fri, 16 Sep 2022 17:05:08 -0700 Subject: [PATCH 19/42] Improve documentation of transpose_obs arg --- src/imitation/scripts/train_imitation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/imitation/scripts/train_imitation.py b/src/imitation/scripts/train_imitation.py index a251c9a3a..6ed93d212 100644 --- a/src/imitation/scripts/train_imitation.py +++ b/src/imitation/scripts/train_imitation.py @@ -81,9 +81,9 @@ def train_imitation( agent_path: Path to serialized policy. If provided, then load the policy from this path. Otherwise, make a new policy. Specify only if policy_cls and policy_kwargs are not specified. - transpose_obs: Whether observations will need to be transposed to be fed into - the policy. Should usually be True for image environments, and usually be - False otherwise. + transpose_obs: Whether observations will need to be transposed from (h,w,c) + format to be fed into the policy. Should usually be True for image + environments, and usually be False otherwise. Returns: Statistics for rollouts from the trained policy and demonstration data. From b3c9e3e8ce310dfeb19bcb6c21181ab75e31470d Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Wed, 5 Oct 2022 14:36:50 +0200 Subject: [PATCH 20/42] Get procgen RL training working --- setup.py | 7 ++++--- src/imitation/scripts/config/train_rl.py | 13 +++++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index f28dd5f7d..bbf8d602a 100644 --- a/setup.py +++ b/setup.py @@ -14,11 +14,12 @@ IS_NOT_WINDOWS = os.name != "nt" PARALLEL_REQUIRE = ["ray[debug,tune]~=2.0.0"] -ATARI_REQUIRE = [ +IMAGE_ENV_REQUIRE = [ "opencv-python", "ale-py==0.7.4", "pillow", "autorom[accept-rom-license]~=0.4.2", + "procgen==0.10.4", ] PYTYPE = ["pytype==2022.7.26"] if IS_NOT_WINDOWS else [] if IS_NOT_WINDOWS: @@ -68,7 +69,7 @@ "setuptools_scm~=7.0.5", ] + PARALLEL_REQUIRE - + ATARI_REQUIRE + + IMAGE_ENV_REQUIRE + PYTYPE ) DOCS_REQUIRE = [ @@ -233,7 +234,7 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str: "mujoco": [ "gym[classic_control,mujoco]" + GYM_VERSION_SPECIFIER, ], - "atari": ATARI_REQUIRE, + "image_envs": IMAGE_ENV_REQUIRE, }, entry_points={ "console_scripts": [ diff --git a/src/imitation/scripts/config/train_rl.py b/src/imitation/scripts/config/train_rl.py index d19b4458d..547a85d96 100644 --- a/src/imitation/scripts/config/train_rl.py +++ b/src/imitation/scripts/config/train_rl.py @@ -155,6 +155,19 @@ def seals_walker(): common = dict(env_name="seals/Walker2d-v0") +# Procgen configs + + +@train_rl_ex.named_config +def coinrun(): + common = dict(env_name="procgen:procgen-coinrun-v0") + + +@train_rl_ex.named_config +def maze(): + common = dict(env_name="procgen:procgen-maze-v0") + + # Debug configs From 62148050ce1aabc2c4a7e2c213a167640d3cb89b Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Wed, 5 Oct 2022 14:40:24 +0200 Subject: [PATCH 21/42] Add right version of gym3 to get procgen working --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index bbf8d602a..094de75b6 100644 --- a/setup.py +++ b/setup.py @@ -20,6 +20,7 @@ "pillow", "autorom[accept-rom-license]~=0.4.2", "procgen==0.10.4", + "gym3@git+https://github.com/openai/gym3.git#4c3824680eaf9dd04dce224ee3d4856429878226", ] PYTYPE = ["pytype==2022.7.26"] if IS_NOT_WINDOWS else [] if IS_NOT_WINDOWS: From ff1741211bbe845a00070e9d3f853b4540588067 Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Thu, 6 Oct 2022 16:46:19 +0200 Subject: [PATCH 22/42] Add all procgen environments --- src/imitation/scripts/config/train_rl.py | 69 ++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/src/imitation/scripts/config/train_rl.py b/src/imitation/scripts/config/train_rl.py index 547a85d96..a3821cfb4 100644 --- a/src/imitation/scripts/config/train_rl.py +++ b/src/imitation/scripts/config/train_rl.py @@ -168,6 +168,75 @@ def maze(): common = dict(env_name="procgen:procgen-maze-v0") +@train_rl_ex.named_config +def bigfish(): + common = dict(env_name="procgen:procgen-bigfish-v0") + + +@train_rl_ex.named_config +def bossfight(): + common = dict(env_name="procgen:procgen-bossfight-v0") + + +@train_rl_ex.named_config +def caveflyer(): + common = dict(env_name="procgen:procgen-caveflyer-v0") + + +@train_rl_ex.named_config +def chaser(): + common = dict(env_name="procgen:procgen-chaser-v0") + + +@train_rl_ex.named_config +def climber(): + common = dict(env_name="procgen:procgen-climber-v0") + + +@train_rl_ex.named_config +def dodgeball(): + common = dict(env_name="procgen:procgen-dodgeball-v0") + + +@train_rl_ex.named_config +def fruitbot(): + common = dict(env_name="procgen:procgen-fruitbot-v0") + + +@train_rl_ex.named_config +def heist(): + common = dict(env_name="procgen:procgen-heist-v0") + + +@train_rl_ex.named_config +def jumper(): + common = dict(env_name="procgen:procgen-jumper-v0") + + +@train_rl_ex.named_config +def leaper(): + common = dict(env_name="procgen:procgen-leaper-v0") + + +@train_rl_ex.named_config +def miner(): + common = dict(env_name="procgen:procgen-miner-v0") + + +@train_rl_ex.named_config +def ninja(): + common = dict(env_name="procgen:procgen-ninja-v0") + + +@train_rl_ex.named_config +def plunder(): + common = dict(env_name="procgen:procgen-plunder-v0") + + +@train_rl_ex.named_config +def starpilot(): + common = dict(env_name="procgen:procgen-starpilot-v0") + # Debug configs From 97f16de4bc36e189424ad1bbf76bc1f56930e29a Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Thu, 6 Oct 2022 08:35:58 -0700 Subject: [PATCH 23/42] Remove awscli dependency --- setup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.py b/setup.py index f15479dad..44eeb12bb 100644 --- a/setup.py +++ b/setup.py @@ -211,7 +211,6 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str: # recommended packages for development "dev": [ "autopep8", - "awscli", "ipdb", "isort~=5.0", "codespell", From 6db67fe68de3e22adb67442a562a04321c98b2d5 Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Mon, 10 Oct 2022 11:45:50 +0200 Subject: [PATCH 24/42] Modify scripts to not deal with awscli --- experiments/README.md | 27 +++--- experiments/download_models.sh | 53 ------------ experiments/imit_benchmark.sh | 4 - experiments/imit_benchmark_config.csv | 4 - experiments/rollouts_from_policies_config.csv | 14 ++- experiments/train_experts.sh | 85 ------------------- 6 files changed, 16 insertions(+), 171 deletions(-) delete mode 100755 experiments/download_models.sh delete mode 100755 experiments/train_experts.sh diff --git a/experiments/README.md b/experiments/README.md index 893268054..5ba92e9bc 100644 --- a/experiments/README.md +++ b/experiments/README.md @@ -8,20 +8,19 @@ macOS to install some GNU-compatible binaries before all experiments scripts wil brew install coreutils gnu-getopt parallel ``` -## Setup +## Scripts -Phase 1 requires [AWS CLI](https://aws.amazon.com/cli/) because it downloads data from AWS S3. +### Phase 1: Download RL (PPO2) expert policies. -### Phase 1: Download RL (PPO2) expert policies and AIRL/GAIL reward models. - -Use `experiments/download_models.sh`. (Downloads to `data/{expert,reward}_models/`). -Expert policies are used in Phase 2 to generate demonstrations. -Reward models are used in Phase 4 for transfer learning. - -Want to use other policies or reward models for Phase 2 or 4? - * New policies can be trained using `experiments/train_experts.sh`. - * New reward models are generated in Phase 3 `experiments/imit_benchmark.sh`. - * For both scripts, you can enter the optional commands suggested at the end of the script to upload new files to S3 (will need write access to our S3 bucket). Or, you can manually patch `data/{expert,reward}_models` using the script's output files. +Models are saved in HuggingFace, and to work with these scripts should be downloaded to `data/expert_models/`. Environments to train with: +- [CartPole](https://huggingface.co/HumanCompatibleAI/ppo-seals-CartPole-v0) +- [MountainCar](https://huggingface.co/HumanCompatibleAI/ppo-seals-MountainCar-v0) +- [HalfCheetah](https://huggingface.co/HumanCompatibleAI/ppo-seals-HalfCheetah-v0) +- [Hopper](https://huggingface.co/HumanCompatibleAI/ppo-seals-Hopper-v0) +- [Walker](https://huggingface.co/HumanCompatibleAI/ppo-seals-Walker2d-v0) +- [Swimmer](https://huggingface.co/HumanCompatibleAI/ppo-seals-Swimmer-v0) +- [Ant](https://huggingface.co/HumanCompatibleAI/ppo-seals-Ant-v0) +- [Humanoid](https://huggingface.co/HumanCompatibleAI/ppo-seals-Humanoid-v0) ### Phase 2: Generate expert demonstrations from models. @@ -35,10 +34,6 @@ Run `experiments/imit_benchmark.sh --run_name RUN_NAME`. To choose AIRL or GAIL, To analyze these results, run `python -m imitation.scripts.analyze with run_name=RUN_NAME`. Analysis can be run even while training is midway (will only show completed imitation learner's results). [Example output.](https://gist.github.com/shwang/4049cd4fb5cab72f2eeb7f3d15a7ab47) -### Phase 4: Transfer learning. - -Run `experiments/transfer_benchmark.sh`. To choose AIRL or GAIL, add the `--airl` and `--gail` flags (default is GAIL). Transfer rewards are loaded from `data/reward_models`. - ## Hyperparameter tuning Add a named config containing the hyperparameter search space and other settings to `src/imitation/scripts/config/parallel.py`. (`def example_cartpole_rl():` is an example). diff --git a/experiments/download_models.sh b/experiments/download_models.sh deleted file mode 100755 index 4313cf418..000000000 --- a/experiments/download_models.sh +++ /dev/null @@ -1,53 +0,0 @@ -#!/bin/bash - -# Always sync to ../data, relative to this script. -SCRIPT_DIR="$(cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd)" -PROJECT_DIR="$(dirname "${SCRIPT_DIR}")" -DATA_DIR="${PROJECT_DIR}/data" - -DRY_RUN_MODE=false -ALL_MODE=false - -if ! TEMP=$(getopt -o '' -l all,dryrun -- "$@"); then - exit 1 -fi -eval set -- "$TEMP" - -while true; do - case "$1" in - --dryrun) - DRY_RUN_MODE=true - shift - ;; - --all) - # Download all data (by default, skips large meta-data/log files). - # As of 2019.Nov.12, the difference in download size was 77MB vs ~800MB. - ALL_MODE=true - shift - ;; - --) - shift - break - ;; - *) - echo "Unrecognized flag $1" >&2 - exit 1 - ;; - esac -done - -FLAGS=() - -if [[ $ALL_MODE != "true" ]]; then - for excl_pat in '*/monitor/*' '*/parallel/*' '*/sacred/*' '*events.out.tfevents*'; do - FLAGS=("${FLAGS[@]}" "--exclude" "${excl_pat}") - done -fi - -if [[ $DRY_RUN_MODE == "true" ]]; then - FLAGS=("${FLAGS[@]}" "--dryrun") -elif [[ -d "${DATA_DIR}" ]]; then - rm -ir "${DATA_DIR}" -fi - -aws --no-sign-request s3 sync "${FLAGS[@]}" s3://shwang-chai/public/data/ "${DATA_DIR}" diff --git a/experiments/imit_benchmark.sh b/experiments/imit_benchmark.sh index c75126d9b..38b0684c2 100755 --- a/experiments/imit_benchmark.sh +++ b/experiments/imit_benchmark.sh @@ -118,9 +118,5 @@ pushd "${LOG_ROOT}/parallel" find . -name stderr -print0 | sort -z | xargs -0 tail -n 15 | grep -E '==|Result' popd -echo "[Optional] Upload new reward models to S3 (replacing old ones) using the commands:" -echo "aws s3 rm --recursive s3://shwang-chai/public/data/reward_models/${ALGORITHM}/" -echo "aws s3 sync --exclude '*/rollouts/*' --exclude '*/checkpoints/*' --include '*/checkpoints/final/*' '${LOG_ROOT}' s3://shwang-chai/public/data/reward_models/${ALGORITHM}/" - # shellcheck disable=SC2016 echo 'Generate results table using `python -m imitation.scripts.analyze`' diff --git a/experiments/imit_benchmark_config.csv b/experiments/imit_benchmark_config.csv index c1b3d5911..eb963003e 100644 --- a/experiments/imit_benchmark_config.csv +++ b/experiments/imit_benchmark_config.csv @@ -7,10 +7,6 @@ seals_mountain_car,5000,1 seals_mountain_car,5000,4 seals_mountain_car,5000,7 seals_mountain_car,5000,10 -acrobot,5000,1 -acrobot,5000,4 -acrobot,5000,7 -acrobot,5000,10 seals_half_cheetah,50000,4 seals_half_cheetah,50000,11 seals_half_cheetah,50000,18 diff --git a/experiments/rollouts_from_policies_config.csv b/experiments/rollouts_from_policies_config.csv index a9ae03f97..68887905b 100644 --- a/experiments/rollouts_from_policies_config.csv +++ b/experiments/rollouts_from_policies_config.csv @@ -1,13 +1,9 @@ env_config_name,n_demonstrations -cartpole,40 -mountain_car,40 seals_cartpole,40 seals_mountain_car,40 -acrobot,40 seals_half_cheetah,40 -hopper,40 -walker,40 -swimmer,40 -ant,40 -humanoid,240 -reacher,40 +seals_hopper,40 +seals_walker,40 +seals_swimmer,40 +seals_ant,40 +seals_humanoid,240 diff --git a/experiments/train_experts.sh b/experiments/train_experts.sh deleted file mode 100755 index 086242f66..000000000 --- a/experiments/train_experts.sh +++ /dev/null @@ -1,85 +0,0 @@ -#!/usr/bin/env bash -set -e - -# This script trains experts for experiments/imit_benchmark.sh. -# When training is finished, it reports the mean episode reward of each -# expert. - -source experiments/common.sh - -SEEDS=(0 1 2) -ENVS=(acrobot cartpole mountain_car reacher seals_half_cheetah) -ENVS=("${ENVS[@]}" seals_hopper ant seals_humanoid swimmer seals_walker) -OUTPUT_DIR="output/train_experts/${TIMESTAMP}" -RESULTS_FILE="results.txt" -extra_configs=() - -if ! TEMP=$($GNU_GETOPT -o frw -l fast,regenerate,wandb -- "$@"); then - exit 1 -fi -eval set -- "$TEMP" - -while true; do - case "$1" in - -f | --fast) - # Fast mode (debug) - ENVS=(cartpole pendulum) - SEEDS=(0) - extra_configs=("${extra_configs[@]}" common.fast rl.fast train.fast fast) - shift - ;; - -w | --wandb) - # activate wandb logging by adding 'wandb' format string to common.log_format_strs - extra_configs=("${extra_configs[@]}" "common.wandb_logging") - shift - ;; - -r | --regenerate) - # Regenerate test data (policies and rollouts). - # - # Combine with fast mode flag to generate low-computation versions of - # test data. - # Use `git clean -df tests/testdata` to remove extra log files. - ENVS=(cartpole pendulum) - SEEDS=(0) - OUTPUT_DIR="tests/testdata/expert_models" - extra_configs=("${extra_configs[@]}" "rollout_save_n_episodes=50") - - if [[ -d ${OUTPUT_DIR} ]]; then - rm -r ${OUTPUT_DIR} - fi - - shift - ;; - --) - shift - break - ;; - *) - echo "Unrecognized flag $1" >&2 - exit 1 - ;; - esac -done - -echo "Writing logs in ${OUTPUT_DIR}" -# Train experts. -parallel -j 25% --header : --progress --results ${OUTPUT_DIR}/parallel/ \ - python -m imitation.scripts.train_rl \ - --capture=sys \ - with \ - '{env}' "${extra_configs[@]}" \ - seed='{seed}' \ - common.log_dir="${OUTPUT_DIR}/{env}_{seed}" \ - ::: env "${ENVS[@]}" \ - ::: seed "${SEEDS[@]}" - -pushd "$OUTPUT_DIR" - -# Display and save mean episode reward to ${RESULTS_FILE}. -find . -name stdout -print0 | sort -z | xargs -0 tail -n 15 | grep -E '(==|return_mean)' | tee ${RESULTS_FILE} - -popd - -echo "[Optional] Upload new experts to S3 (replacing old ones) using the commands:" -echo "aws s3 rm --recursive s3://shwang-chai/public/data/expert_models" -echo "aws s3 sync --exclude '*/rollouts/*' --exclude '*/policies/*' --include '*/policies/final/*' '${OUTPUT_DIR}' s3://shwang-chai/public/data/expert_models" From bc422c6e2e276434b951b80e43e637d75bd32a88 Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Mon, 10 Oct 2022 11:55:57 +0200 Subject: [PATCH 25/42] No longer test deleted experiment --- tests/test_experiments.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_experiments.py b/tests/test_experiments.py index ff9842b41..06135b9d0 100644 --- a/tests/test_experiments.py +++ b/tests/test_experiments.py @@ -11,7 +11,6 @@ "benchmark_and_table.sh", "imit_benchmark.sh", "rollouts_from_policies.sh", - "train_experts.sh", "transfer_learn_benchmark.sh", ) From 387783ec6d05415113f513013cb7201ec20a6173 Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Mon, 10 Oct 2022 12:02:20 +0200 Subject: [PATCH 26/42] Improve wording --- experiments/README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/experiments/README.md b/experiments/README.md index 5ba92e9bc..113e80c3d 100644 --- a/experiments/README.md +++ b/experiments/README.md @@ -12,7 +12,7 @@ brew install coreutils gnu-getopt parallel ### Phase 1: Download RL (PPO2) expert policies. -Models are saved in HuggingFace, and to work with these scripts should be downloaded to `data/expert_models/`. Environments to train with: +Expert policies have been saved in HuggingFace, and to work with these scripts should be downloaded to `data/expert_models/`. Environments with pre-trained models: - [CartPole](https://huggingface.co/HumanCompatibleAI/ppo-seals-CartPole-v0) - [MountainCar](https://huggingface.co/HumanCompatibleAI/ppo-seals-MountainCar-v0) - [HalfCheetah](https://huggingface.co/HumanCompatibleAI/ppo-seals-HalfCheetah-v0) @@ -22,6 +22,7 @@ Models are saved in HuggingFace, and to work with these scripts should be downlo - [Ant](https://huggingface.co/HumanCompatibleAI/ppo-seals-Ant-v0) - [Humanoid](https://huggingface.co/HumanCompatibleAI/ppo-seals-Humanoid-v0) +To download, clone the [rl-baselines3-zoo repository](https://github.com/DLR-RM/rl-baselines3-zoo), and run a command like `python rl_zoo3/load_from_hub.py --algo ppo --env seals/Ant-v0 -orga HumanCompatibleAI -f ../imitation/data/expert_models/` (but modifying the path if necessary to ensure the correct download location). ### Phase 2: Generate expert demonstrations from models. From 8654008a011c02af62d3026cdaa79d41116c6330 Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Mon, 10 Oct 2022 14:12:33 +0200 Subject: [PATCH 27/42] Oops huggingface models were auto-loaded --- experiments/README.md | 22 ++++------------------ 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/experiments/README.md b/experiments/README.md index 113e80c3d..01a48a497 100644 --- a/experiments/README.md +++ b/experiments/README.md @@ -10,26 +10,12 @@ brew install coreutils gnu-getopt parallel ## Scripts -### Phase 1: Download RL (PPO2) expert policies. +### Phase 1: Generate expert demonstrations from models. -Expert policies have been saved in HuggingFace, and to work with these scripts should be downloaded to `data/expert_models/`. Environments with pre-trained models: -- [CartPole](https://huggingface.co/HumanCompatibleAI/ppo-seals-CartPole-v0) -- [MountainCar](https://huggingface.co/HumanCompatibleAI/ppo-seals-MountainCar-v0) -- [HalfCheetah](https://huggingface.co/HumanCompatibleAI/ppo-seals-HalfCheetah-v0) -- [Hopper](https://huggingface.co/HumanCompatibleAI/ppo-seals-Hopper-v0) -- [Walker](https://huggingface.co/HumanCompatibleAI/ppo-seals-Walker2d-v0) -- [Swimmer](https://huggingface.co/HumanCompatibleAI/ppo-seals-Swimmer-v0) -- [Ant](https://huggingface.co/HumanCompatibleAI/ppo-seals-Ant-v0) -- [Humanoid](https://huggingface.co/HumanCompatibleAI/ppo-seals-Humanoid-v0) +Run `experiments/rollouts_from_policies.sh`. (Rollouts saved in `output/train_experts/`). +Demonstrations are used in Phase 2 for imitation learning. -To download, clone the [rl-baselines3-zoo repository](https://github.com/DLR-RM/rl-baselines3-zoo), and run a command like `python rl_zoo3/load_from_hub.py --algo ppo --env seals/Ant-v0 -orga HumanCompatibleAI -f ../imitation/data/expert_models/` (but modifying the path if necessary to ensure the correct download location). - -### Phase 2: Generate expert demonstrations from models. - -Run `experiments/rollouts_from_policies.sh`. (Rollouts saved in `data/expert_models/`). -Demonstrations are used in Phase 3 for imitation learning. - -### Phase 3: Train imitation learning. +### Phase 2: Train imitation learning. Run `experiments/imit_benchmark.sh --run_name RUN_NAME`. To choose AIRL or GAIL, add the `--airl` and `--gail` flags (default is GAIL). From dbcb8363cc4d4f8e9d3582b59d55f499575b0df4 Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Mon, 10 Oct 2022 16:12:06 +0200 Subject: [PATCH 28/42] Make imit_benchmark_config match with rollouts_from_policies config --- experiments/imit_benchmark_config.csv | 42 ++++++++++++--------------- 1 file changed, 19 insertions(+), 23 deletions(-) diff --git a/experiments/imit_benchmark_config.csv b/experiments/imit_benchmark_config.csv index eb963003e..60e6bb371 100644 --- a/experiments/imit_benchmark_config.csv +++ b/experiments/imit_benchmark_config.csv @@ -11,26 +11,22 @@ seals_half_cheetah,50000,4 seals_half_cheetah,50000,11 seals_half_cheetah,50000,18 seals_half_cheetah,50000,25 -hopper,50000,4 -hopper,50000,11 -hopper,50000,18 -hopper,50000,25 -walker,50000,4 -walker,50000,11 -walker,50000,18 -walker,50000,25 -swimmer,50000,4 -swimmer,50000,11 -swimmer,50000,18 -swimmer,50000,25 -ant,50000,4 -ant,50000,11 -ant,50000,18 -ant,50000,25 -humanoid,50000,80 -humanoid,50000,160 -humanoid,50000,240 -reacher,50000,4 -reacher,50000,11 -reacher,50000,18 -reacher,50000,25 +seals_hopper,50000,4 +seals_hopper,50000,11 +seals_hopper,50000,18 +seals_hopper,50000,25 +seals_walker,50000,4 +seals_walker,50000,11 +seals_walker,50000,18 +seals_walker,50000,25 +seals_swimmer,50000,4 +seals_swimmer,50000,11 +seals_swimmer,50000,18 +seals_swimmer,50000,25 +seals_ant,50000,4 +seals_ant,50000,11 +seals_ant,50000,18 +seals_ant,50000,25 +seals_humanoid,50000,80 +seals_humanoid,50000,160 +seals_humanoid,50000,240 From 2b0922306c206cfcaf713b336d205632686e762c Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Mon, 10 Oct 2022 16:12:33 +0200 Subject: [PATCH 29/42] Make script save things where it says it will save them --- experiments/rollouts_from_policies.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/experiments/rollouts_from_policies.sh b/experiments/rollouts_from_policies.sh index bc18f98bf..9a1186122 100755 --- a/experiments/rollouts_from_policies.sh +++ b/experiments/rollouts_from_policies.sh @@ -49,7 +49,7 @@ parallel -j 25% --header : --results "${OUTPUT_DIR}/parallel/" --colsep , \ with \ '{env_config_name}' \ common.log_root="${OUTPUT_DIR}" \ - rollout_save_path="${OUTPUT_DIR}/{env_config_name}_0/rollouts/final.pkl" \ + rollout_save_path="${OUTPUT_DIR}/expert_models/{env_config_name}_0/rollouts/final.pkl" \ eval_n_episodes='{n_demonstrations}' \ eval_n_timesteps=None \ :::: ${CONFIG_CSV} From 3600703c3730f775a2ff50d3d2fb588ac3d70d8d Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Mon, 10 Oct 2022 16:13:25 +0200 Subject: [PATCH 30/42] Add seals/HalfCheetah-v0 config to eval_policy --- src/imitation/scripts/config/eval_policy.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/imitation/scripts/config/eval_policy.py b/src/imitation/scripts/config/eval_policy.py index 3f3eab5e4..9a052f403 100644 --- a/src/imitation/scripts/config/eval_policy.py +++ b/src/imitation/scripts/config/eval_policy.py @@ -57,6 +57,11 @@ def half_cheetah(): common = dict(env_name="HalfCheetah-v2") +@eval_policy_ex.named_config +def seals_half_cheetah(): + common = dict(env_name="seals/HalfCheetah-v0") + + @eval_policy_ex.named_config def seals_hopper(): common = dict(env_name="seals/Hopper-v0") From b59e3a30bb9ec20029db322843294026f183119a Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Mon, 10 Oct 2022 07:48:02 -0700 Subject: [PATCH 31/42] Fix install stuff --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index c79d8fcab..ad48dc26c 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ "ale-py==0.7.4", "pillow", "autorom[accept-rom-license]~=0.4.2", - "procgen==0.10.4", + "procgen==0.10.7", "gym3@git+https://github.com/openai/gym3.git#4c3824680eaf9dd04dce224ee3d4856429878226", ] PYTYPE = ["pytype==2022.7.26"] if IS_NOT_WINDOWS else [] @@ -77,7 +77,7 @@ "myst-nb==0.16.0", "ipykernel~=6.15.2", "seals==0.1.2", -] + ATARI_REQUIRE +] + IMAGE_ENV_REQUIRE def get_readme() -> str: From bcd3f6a74c50b77a778f4cd5e023c26d75a17f6c Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Mon, 10 Oct 2022 08:40:46 -0700 Subject: [PATCH 32/42] Fix merge error --- src/imitation/scripts/train_rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imitation/scripts/train_rl.py b/src/imitation/scripts/train_rl.py index 53a005b65..6beedd9a0 100644 --- a/src/imitation/scripts/train_rl.py +++ b/src/imitation/scripts/train_rl.py @@ -91,7 +91,7 @@ def train_rl( Returns: The return value of `rollout_stats()` using the final policy. """ - rng = common.make_rng() + rng = scripts_common.make_rng() custom_logger, log_dir = scripts_common.setup_logging() rollout_dir = osp.join(log_dir, "rollouts") policy_dir = osp.join(log_dir, "policies") From de5207b9f494d15ed35600fd32706b764aa12c04 Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Wed, 12 Oct 2022 16:04:27 +0200 Subject: [PATCH 33/42] Fix merge problem --- src/imitation/scripts/train_rl.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/imitation/scripts/train_rl.py b/src/imitation/scripts/train_rl.py index 3ca2b7935..ab6907592 100644 --- a/src/imitation/scripts/train_rl.py +++ b/src/imitation/scripts/train_rl.py @@ -92,10 +92,10 @@ def train_rl( """ rng = scripts_common.make_rng() custom_logger, log_dir = scripts_common.setup_logging() - rollout_dir = osp.join(log_dir, "rollouts") - policy_dir = osp.join(log_dir, "policies") - os.makedirs(rollout_dir, exist_ok=True) - os.makedirs(policy_dir, exist_ok=True) + rollout_dir = log_dir / "rollouts" + policy_dir = log_dir / "policies") + rollout_dir.mkdir(parents=True, exist_ok=True) + policy_dir.mkdir(parents=True, exist_ok=True) all_post_wrappers = common["post_wrappers"] + [ lambda env, idx: wrappers.RolloutInfoWrapper(env), From c319744cb41a955227cff5a4cb3e9701f3287b4e Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Wed, 12 Oct 2022 16:05:33 +0200 Subject: [PATCH 34/42] Fix merge conflict --- src/imitation/scripts/common/common.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/imitation/scripts/common/common.py b/src/imitation/scripts/common/common.py index 9c699c7fb..e58a3770c 100644 --- a/src/imitation/scripts/common/common.py +++ b/src/imitation/scripts/common/common.py @@ -3,10 +3,8 @@ import contextlib import logging import os -from typing import Any, Callable, Generator, Mapping, Optional, Sequence, Tuple, Union import pathlib from typing import Any, Callable, Generator, Mapping, Optional, Sequence, Tuple, Union ->>>>>>> master import numpy as np import sacred From 2275b629fd4072d2827736859df6954b6d7036ab Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Thu, 13 Oct 2022 12:07:54 +0200 Subject: [PATCH 35/42] Fix formatting, types --- setup.py | 2 +- src/imitation/policies/base.py | 4 ++-- src/imitation/scripts/common/common.py | 1 - src/imitation/scripts/config/train_rl.py | 1 + src/imitation/scripts/train_rl.py | 2 +- tests/scripts/test_scripts.py | 3 ++- 6 files changed, 7 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index ad48dc26c..66086988c 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ "pillow", "autorom[accept-rom-license]~=0.4.2", "procgen==0.10.7", - "gym3@git+https://github.com/openai/gym3.git#4c3824680eaf9dd04dce224ee3d4856429878226", + "gym3@git+https://github.com/openai/gym3.git#4c3824680eaf9dd04dce224ee3d4856429878226", # noqa: E501 ] PYTYPE = ["pytype==2022.7.26"] if IS_NOT_WINDOWS else [] STABLE_BASELINES3 = "stable-baselines3>=1.6.1" diff --git a/src/imitation/policies/base.py b/src/imitation/policies/base.py index 910d34bd6..f58a10a3c 100644 --- a/src/imitation/policies/base.py +++ b/src/imitation/policies/base.py @@ -1,7 +1,7 @@ """Custom policy classes and convenience methods.""" import abc -from typing import Tuple, Type +from typing import Optional, Tuple, Type import gym import numpy as np @@ -154,7 +154,7 @@ def evaluate_actions( self, obs: th.Tensor, actions: th.Tensor, - ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + ) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]: if self.transpose_input: obs_ = networks.cnn_transpose(obs) else: diff --git a/src/imitation/scripts/common/common.py b/src/imitation/scripts/common/common.py index e58a3770c..6cbf0d510 100644 --- a/src/imitation/scripts/common/common.py +++ b/src/imitation/scripts/common/common.py @@ -2,7 +2,6 @@ import contextlib import logging -import os import pathlib from typing import Any, Callable, Generator, Mapping, Optional, Sequence, Tuple, Union diff --git a/src/imitation/scripts/config/train_rl.py b/src/imitation/scripts/config/train_rl.py index 75fac3071..87e32dd1c 100644 --- a/src/imitation/scripts/config/train_rl.py +++ b/src/imitation/scripts/config/train_rl.py @@ -241,6 +241,7 @@ def plunder(): def starpilot(): common = dict(env_name="procgen:procgen-starpilot-v0") + # Debug configs diff --git a/src/imitation/scripts/train_rl.py b/src/imitation/scripts/train_rl.py index ab6907592..f86dd14f8 100644 --- a/src/imitation/scripts/train_rl.py +++ b/src/imitation/scripts/train_rl.py @@ -93,7 +93,7 @@ def train_rl( rng = scripts_common.make_rng() custom_logger, log_dir = scripts_common.setup_logging() rollout_dir = log_dir / "rollouts" - policy_dir = log_dir / "policies") + policy_dir = log_dir / "policies" rollout_dir.mkdir(parents=True, exist_ok=True) policy_dir.mkdir(parents=True, exist_ok=True) diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index 16f9dfc0d..403301c4e 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -649,7 +649,8 @@ def test_train_adversarial_algorithm_value_error(tmpdir): def test_train_adversarial_image_env(tmpdir, command): """Smoke test for imitation.scripts.train_adversarial on atari.""" named_configs = ( - ASTEROIDS_CNN_POLICY_CONFIG + ALGO_FAST_CONFIGS["adversarial"] + ASTEROIDS_CNN_POLICY_CONFIG + + ALGO_FAST_CONFIGS["adversarial"] + ["reward.cnn_reward"] ) config_updates = { From a0682f6a0c1d6336d910839768c590674e9c58fc Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Mon, 17 Oct 2022 05:07:25 -0700 Subject: [PATCH 36/42] Add procgen configs for eval_policy --- src/imitation/scripts/config/eval_policy.py | 83 +++++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/src/imitation/scripts/config/eval_policy.py b/src/imitation/scripts/config/eval_policy.py index 9a052f403..b82330844 100644 --- a/src/imitation/scripts/config/eval_policy.py +++ b/src/imitation/scripts/config/eval_policy.py @@ -107,6 +107,89 @@ def seals_walker(): common = dict(env_name="seals/Walker2d-v0") +# Procgen configs + + +@train_rl_ex.named_config +def coinrun(): + common = dict(env_name="procgen:procgen-coinrun-v0") + + +@train_rl_ex.named_config +def maze(): + common = dict(env_name="procgen:procgen-maze-v0") + + +@train_rl_ex.named_config +def bigfish(): + common = dict(env_name="procgen:procgen-bigfish-v0") + + +@train_rl_ex.named_config +def bossfight(): + common = dict(env_name="procgen:procgen-bossfight-v0") + + +@train_rl_ex.named_config +def caveflyer(): + common = dict(env_name="procgen:procgen-caveflyer-v0") + + +@train_rl_ex.named_config +def chaser(): + common = dict(env_name="procgen:procgen-chaser-v0") + + +@train_rl_ex.named_config +def climber(): + common = dict(env_name="procgen:procgen-climber-v0") + + +@train_rl_ex.named_config +def dodgeball(): + common = dict(env_name="procgen:procgen-dodgeball-v0") + + +@train_rl_ex.named_config +def fruitbot(): + common = dict(env_name="procgen:procgen-fruitbot-v0") + + +@train_rl_ex.named_config +def heist(): + common = dict(env_name="procgen:procgen-heist-v0") + + +@train_rl_ex.named_config +def jumper(): + common = dict(env_name="procgen:procgen-jumper-v0") + + +@train_rl_ex.named_config +def leaper(): + common = dict(env_name="procgen:procgen-leaper-v0") + + +@train_rl_ex.named_config +def miner(): + common = dict(env_name="procgen:procgen-miner-v0") + + +@train_rl_ex.named_config +def ninja(): + common = dict(env_name="procgen:procgen-ninja-v0") + + +@train_rl_ex.named_config +def plunder(): + common = dict(env_name="procgen:procgen-plunder-v0") + + +@train_rl_ex.named_config +def starpilot(): + common = dict(env_name="procgen:procgen-starpilot-v0") + + @eval_policy_ex.named_config def fast(): common = dict(env_name="seals/CartPole-v0", num_vec=1, parallel=False) From 4dfade3109b159e2123e9484cd1ca868797feeeb Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Tue, 18 Oct 2022 07:06:48 -0700 Subject: [PATCH 37/42] Optionally add exploration wrapper before rolling out --- src/imitation/scripts/eval_policy.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/imitation/scripts/eval_policy.py b/src/imitation/scripts/eval_policy.py index 06eab4820..11afafb88 100644 --- a/src/imitation/scripts/eval_policy.py +++ b/src/imitation/scripts/eval_policy.py @@ -10,6 +10,7 @@ from stable_baselines3.common.vec_env import VecEnvWrapper from imitation.data import rollout, types +from imitation.policies.exploration_wrapper import ExplorationWrapper from imitation.rewards import reward_wrapper from imitation.rewards.serialize import load_reward from imitation.scripts.common import common, expert @@ -61,6 +62,7 @@ def eval_policy( reward_type: Optional[str] = None, reward_path: Optional[str] = None, rollout_save_path: Optional[str] = None, + explore_kwargs: Optional[Mapping[str, Any]] = None, ): """Rolls a policy out in an environment, collecting statistics. @@ -79,6 +81,9 @@ def eval_policy( of `reward_type` to override the environment reward with. rollout_save_path: where to save rollouts used for computing stats to disk; if None, then do not save. + explore_kwargs: keyword arguments to an exploration wrapper to apply before + rolling out, not including policy_callable, venv, and rng; if None, then + do not wrap. Returns: Return value of `imitation.util.rollout.rollout_stats()`. @@ -96,12 +101,14 @@ def eval_policy( venv = reward_wrapper.RewardVecEnvWrapper(venv, reward_fn) logging.info(f"Wrapped env in reward {reward_type} from {reward_path}.") - trajs = rollout.generate_trajectories( - expert.get_expert_policy(venv), - venv, - sample_until, - rng=rng, - ) + policy = expert.get_expert_policy(venv) + if explore_kwargs is not None: + policy = ExplorationWrapper(policy, venv, rng=rng, **explore_kwargs) + log_str = ( + f"Wrapped policy in ExplorationWrapper with kwargs {explore_kwargs}" + ) + logging.info(log_str) + trajs = rollout.generate_trajectories(policy, venv, sample_until, rng=rng) if rollout_save_path: types.save(log_dir / rollout_save_path.replace("{log_dir}/", ""), trajs) From 8ed4722f4de5e88ee7783ee6accd3d16894616f5 Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Tue, 18 Oct 2022 07:13:57 -0700 Subject: [PATCH 38/42] Fix eval_policies config --- src/imitation/scripts/config/eval_policy.py | 32 ++++++++++----------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/imitation/scripts/config/eval_policy.py b/src/imitation/scripts/config/eval_policy.py index b82330844..2f2a4c6d9 100644 --- a/src/imitation/scripts/config/eval_policy.py +++ b/src/imitation/scripts/config/eval_policy.py @@ -110,82 +110,82 @@ def seals_walker(): # Procgen configs -@train_rl_ex.named_config +@eval_policy_ex.named_config def coinrun(): common = dict(env_name="procgen:procgen-coinrun-v0") -@train_rl_ex.named_config +@eval_policy_ex.named_config def maze(): common = dict(env_name="procgen:procgen-maze-v0") -@train_rl_ex.named_config +@eval_policy_ex.named_config def bigfish(): common = dict(env_name="procgen:procgen-bigfish-v0") -@train_rl_ex.named_config +@eval_policy_ex.named_config def bossfight(): common = dict(env_name="procgen:procgen-bossfight-v0") -@train_rl_ex.named_config +@eval_policy_ex.named_config def caveflyer(): common = dict(env_name="procgen:procgen-caveflyer-v0") -@train_rl_ex.named_config +@eval_policy_ex.named_config def chaser(): common = dict(env_name="procgen:procgen-chaser-v0") -@train_rl_ex.named_config +@eval_policy_ex.named_config def climber(): common = dict(env_name="procgen:procgen-climber-v0") -@train_rl_ex.named_config +@eval_policy_ex.named_config def dodgeball(): common = dict(env_name="procgen:procgen-dodgeball-v0") -@train_rl_ex.named_config +@eval_policy_ex.named_config def fruitbot(): common = dict(env_name="procgen:procgen-fruitbot-v0") -@train_rl_ex.named_config +@eval_policy_ex.named_config def heist(): common = dict(env_name="procgen:procgen-heist-v0") -@train_rl_ex.named_config +@eval_policy_ex.named_config def jumper(): common = dict(env_name="procgen:procgen-jumper-v0") -@train_rl_ex.named_config +@eval_policy_ex.named_config def leaper(): common = dict(env_name="procgen:procgen-leaper-v0") -@train_rl_ex.named_config +@eval_policy_ex.named_config def miner(): common = dict(env_name="procgen:procgen-miner-v0") -@train_rl_ex.named_config +@eval_policy_ex.named_config def ninja(): common = dict(env_name="procgen:procgen-ninja-v0") -@train_rl_ex.named_config +@eval_policy_ex.named_config def plunder(): common = dict(env_name="procgen:procgen-plunder-v0") -@train_rl_ex.named_config +@eval_policy_ex.named_config def starpilot(): common = dict(env_name="procgen:procgen-starpilot-v0") From 61408793d0986859776f9ac2f1469c4eca9203be Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Tue, 18 Oct 2022 07:23:59 -0700 Subject: [PATCH 39/42] Add epsilon-greedy exploration named config --- src/imitation/scripts/config/eval_policy.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/imitation/scripts/config/eval_policy.py b/src/imitation/scripts/config/eval_policy.py index 9a052f403..2d3ffd3b6 100644 --- a/src/imitation/scripts/config/eval_policy.py +++ b/src/imitation/scripts/config/eval_policy.py @@ -25,6 +25,13 @@ def replay_defaults(): rollout_save_path = None # where to save rollouts to -- if None, do not save + explore_kwargs = None # kwargs to feed to ExplorationWrapper -- if None, do not wrap + + +@eval_policy_ex.named_config +def explore_eps_greedy(): + explore_kwargs = dict(switch_prob=1.0, random_prob=0.1) + @eval_policy_ex.named_config def render(): From 6205c0857e547df39f07554efede366d894010bb Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Tue, 18 Oct 2022 08:26:42 -0700 Subject: [PATCH 40/42] Fix exploration wrapper --- src/imitation/scripts/config/eval_policy.py | 4 +++- src/imitation/scripts/eval_policy.py | 7 ++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/imitation/scripts/config/eval_policy.py b/src/imitation/scripts/config/eval_policy.py index 2d3ffd3b6..9bc8e29a6 100644 --- a/src/imitation/scripts/config/eval_policy.py +++ b/src/imitation/scripts/config/eval_policy.py @@ -25,7 +25,9 @@ def replay_defaults(): rollout_save_path = None # where to save rollouts to -- if None, do not save - explore_kwargs = None # kwargs to feed to ExplorationWrapper -- if None, do not wrap + explore_kwargs = ( + None # kwargs to feed to ExplorationWrapper -- if None, do not wrap + ) @eval_policy_ex.named_config diff --git a/src/imitation/scripts/eval_policy.py b/src/imitation/scripts/eval_policy.py index 11afafb88..70921577a 100644 --- a/src/imitation/scripts/eval_policy.py +++ b/src/imitation/scripts/eval_policy.py @@ -103,7 +103,12 @@ def eval_policy( policy = expert.get_expert_policy(venv) if explore_kwargs is not None: - policy = ExplorationWrapper(policy, venv, rng=rng, **explore_kwargs) + policy = ExplorationWrapper( + rollout._policy_to_callable(policy, venv), + venv, + rng=rng, + **explore_kwargs, + ) log_str = ( f"Wrapped policy in ExplorationWrapper with kwargs {explore_kwargs}" ) From 4fc8d86b710357f99198ab32a1a339f4e01447f8 Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Tue, 18 Oct 2022 09:06:06 -0700 Subject: [PATCH 41/42] [EMPTY] run pre-commit hooks From 8f1d6e56507a6eb20f476a6ddf83de5d0a334cca Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Tue, 29 Nov 2022 14:45:40 -0800 Subject: [PATCH 42/42] Remove extraneous CNN policy config --- src/imitation/scripts/common/train.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/imitation/scripts/common/train.py b/src/imitation/scripts/common/train.py index edc331c2c..099d5404a 100644 --- a/src/imitation/scripts/common/train.py +++ b/src/imitation/scripts/common/train.py @@ -37,15 +37,6 @@ def sac(): policy_cls = base.SAC1024Policy # noqa: F841 -@train_ingredient.named_config -def cnn(): - policy_cls = base.CnnPolicy # noqa: F841 - # If features_extractor_class is not set, it will be set to a - # NormalizeFeaturesExtractor by default via the config hook, which implements an - # MLP. Therefore, to actually get this to implement a CNN, we need to set it here. - policy_kwargs = {"features_extractor_class": torch_layers.NatureCNN} # noqa: F841 - - NORMALIZE_RUNNING_POLICY_KWARGS = { "features_extractor_class": base.NormalizeFeaturesExtractor, "features_extractor_kwargs": {