Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CNN scripts #563

Draft
wants to merge 53 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
a62f5f2
Add post_wrappers entry in common config
dfilan Sep 13, 2022
ec3755b
Add config to train on atari, use post_wrappers
dfilan Sep 13, 2022
2ad0864
Add config option for CNN reward net
dfilan Sep 13, 2022
70f17e5
Add config to train preference comparisons on Asteroids
dfilan Sep 13, 2022
262a35a
Quell flake8 complaint
dfilan Sep 13, 2022
74fb856
Add short-episode asteroids config
dfilan Sep 13, 2022
ea508ba
Add config to train CNN policies
dfilan Sep 13, 2022
8ac9ae5
Fix line length
dfilan Sep 13, 2022
a9d24ae
Make names of short-episode asteroids configs consistent
dfilan Sep 13, 2022
ad61105
Add docstring for common config
dfilan Sep 13, 2022
d90d734
Add tests for training RL + pref comp on image envs
dfilan Sep 13, 2022
d92581c
First draft of training CNNs with DAgger - with backup CNN implementa…
dfilan Sep 16, 2022
8f0c7c0
Delete commented-out alternate implementation of CnnPolicy
dfilan Sep 16, 2022
e5be4ca
Fix notebook bug I introduced
dfilan Sep 16, 2022
39c94ea
Fix problem where transpose_obs had to be specified
dfilan Sep 16, 2022
209c61a
Fix check for transposition in CnnPolicy
dfilan Sep 16, 2022
e51eaa9
Reorder asteroids envs
dfilan Sep 16, 2022
c91caf2
Try to get AIRL/GAIL working for image envs
dfilan Sep 17, 2022
aae2711
Improve documentation of transpose_obs arg
dfilan Sep 17, 2022
b3c9e3e
Get procgen RL training working
dfilan Oct 5, 2022
6214805
Add right version of gym3 to get procgen working
dfilan Oct 5, 2022
ff17412
Add all procgen environments
dfilan Oct 6, 2022
97f16de
Remove awscli dependency
dfilan Oct 6, 2022
6db67fe
Modify scripts to not deal with awscli
dfilan Oct 10, 2022
bc422c6
No longer test deleted experiment
dfilan Oct 10, 2022
387783e
Improve wording
dfilan Oct 10, 2022
8654008
Oops huggingface models were auto-loaded
dfilan Oct 10, 2022
dbcb836
Make imit_benchmark_config match with rollouts_from_policies config
dfilan Oct 10, 2022
2b09223
Make script save things where it says it will save them
dfilan Oct 10, 2022
3600703
Add seals/HalfCheetah-v0 config to eval_policy
dfilan Oct 10, 2022
cc248fd
Merge branch 'master' into remove-awscli
dfilan Oct 10, 2022
593a29b
Fix merge conflict
dfilan Oct 10, 2022
b59e3a3
Fix install stuff
dfilan Oct 10, 2022
bcd3f6a
Fix merge error
dfilan Oct 10, 2022
801c4d9
Merge in master
dfilan Oct 12, 2022
de5207b
Fix merge problem
dfilan Oct 12, 2022
c319744
Fix merge conflict
dfilan Oct 12, 2022
2275b62
Fix formatting, types
dfilan Oct 13, 2022
a0682f6
Add procgen configs for eval_policy
dfilan Oct 17, 2022
051426d
Merge branch 'cnn_scripts' of github.com:HumanCompatibleAI/imitation …
dfilan Oct 17, 2022
be9391d
Merge in master
dfilan Oct 18, 2022
4dfade3
Optionally add exploration wrapper before rolling out
dfilan Oct 18, 2022
1151f48
Merge branch 'wrap_for_rollouts' into cnn_scripts
dfilan Oct 18, 2022
8ed4722
Fix eval_policies config
dfilan Oct 18, 2022
6140879
Add epsilon-greedy exploration named config
dfilan Oct 18, 2022
65bef2c
Merge new explore_kwarg config from wrap_for_rollouts
dfilan Oct 18, 2022
6205c08
Fix exploration wrapper
dfilan Oct 18, 2022
548dc2d
Fix implementation of wrapping rollouts in exploration wrapper
dfilan Oct 18, 2022
4fc8d86
[EMPTY] run pre-commit hooks
dfilan Oct 18, 2022
2bc4355
Merge branch 'wrap_for_rollouts' into cnn_scripts
dfilan Oct 18, 2022
12f2b39
Merge branch 'master' into cnn_scripts
dfilan Oct 27, 2022
daaac0c
Merge branch 'master' into cnn_scripts
dfilan Nov 29, 2022
8f1d6e5
Remove extraneous CNN policy config
dfilan Nov 29, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@
"metadata": {},
"outputs": [],
"source": [
"from imitation.rewards.reward_nets import ShapedRewardNet, cnn_transpose\n",
"from imitation.util.networks import cnn_transpose\n",
"from imitation.rewards.reward_nets import ShapedRewardNet\n",
"from imitation.rewards.reward_wrapper import RewardVecEnvWrapper\n",
"\n",
"\n",
Expand Down
10 changes: 6 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
IS_NOT_WINDOWS = os.name != "nt"

PARALLEL_REQUIRE = ["ray[debug,tune]~=2.0.0"]
ATARI_REQUIRE = [
IMAGE_ENV_REQUIRE = [
"opencv-python",
"ale-py==0.7.4",
"pillow",
"autorom[accept-rom-license]~=0.4.2",
"procgen==0.10.7",
"gym3@git+https://github.com/openai/gym3.git#4c3824680eaf9dd04dce224ee3d4856429878226", # noqa: E501
]
PYTYPE = ["pytype==2022.7.26"] if IS_NOT_WINDOWS else []
STABLE_BASELINES3 = "stable-baselines3>=1.6.1"
Expand Down Expand Up @@ -61,7 +63,7 @@
"pre-commit>=2.20.0",
]
+ PARALLEL_REQUIRE
+ ATARI_REQUIRE
+ IMAGE_ENV_REQUIRE
+ PYTYPE
)
DOCS_REQUIRE = [
Expand All @@ -74,7 +76,7 @@
"sphinx-github-changelog~=1.2.0",
"myst-nb==0.16.0",
"ipykernel~=6.15.2",
] + ATARI_REQUIRE
] + IMAGE_ENV_REQUIRE


def get_readme() -> str:
Expand Down Expand Up @@ -231,7 +233,7 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str:
"mujoco": [
"gym[classic_control,mujoco]" + GYM_VERSION_SPECIFIER,
],
"atari": ATARI_REQUIRE,
"image_envs": IMAGE_ENV_REQUIRE,
},
entry_points={
"console_scripts": [
Expand Down
10 changes: 8 additions & 2 deletions src/imitation/algorithms/adversarial/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(
disc_opt_kwargs: Optional[Mapping] = None,
gen_train_timesteps: Optional[int] = None,
gen_replay_buffer_capacity: Optional[int] = None,
transpose_obs: bool = False,
custom_logger: Optional[logger.HierarchicalLogger] = None,
init_tensorboard: bool = False,
init_tensorboard_graph: bool = False,
Expand Down Expand Up @@ -161,6 +162,9 @@ def __init__(
the generator that can be stored). By default this is equal to
`gen_train_timesteps`, meaning that we sample only from the most
recent batch of generator samples.
transpose_obs: Whether observations will need to be transposed from (h,w,c)
format to be manually fed into the policy. Should usually be True for
image environments, and usually be False otherwise.
custom_logger: Where to log to; if None (default), creates a new logger.
init_tensorboard: If True, makes various discriminator
TensorBoard summaries.
Expand Down Expand Up @@ -221,6 +225,7 @@ def __init__(
self._summary_writer = thboard.SummaryWriter(str(summary_dir))

self.venv_buffering = wrappers.BufferingWrapper(self.venv)
self.transpose_obs = transpose_obs

if debug_use_ground_truth:
# Would use an identity reward fn here, but RewardFns can't see rewards.
Expand Down Expand Up @@ -481,18 +486,19 @@ def _get_log_policy_act_prob(
Returns:
A batch of log policy action probabilities.
"""
obs_th_ = networks.cnn_transpose(obs_th) if self.transpose_obs else obs_th
if isinstance(self.policy, policies.ActorCriticPolicy):
# policies.ActorCriticPolicy has a concrete implementation of
# evaluate_actions to generate log_policy_act_prob given obs and actions.
_, log_policy_act_prob_th, _ = self.policy.evaluate_actions(
obs_th,
obs_th_,
acts_th,
)
elif isinstance(self.policy, sac_policies.SACPolicy):
gen_algo_actor = self.policy.actor
assert gen_algo_actor is not None
# generate log_policy_act_prob from SAC actor.
mean_actions, log_std, _ = gen_algo_actor.get_action_dist_params(obs_th)
mean_actions, log_std, _ = gen_algo_actor.get_action_dist_params(obs_th_)
distribution = gen_algo_actor.action_dist.proba_distribution(
mean_actions,
log_std,
Expand Down
100 changes: 99 additions & 1 deletion src/imitation/policies/base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Custom policy classes and convenience methods."""

import abc
from typing import Type
from typing import Optional, Tuple, Type

import gym
import numpy as np
import torch as th
from stable_baselines3.common import policies, torch_layers
from stable_baselines3.common.distributions import Distribution
from stable_baselines3.sac import policies as sac_policies
from torch import nn

Expand Down Expand Up @@ -88,6 +89,103 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, net_arch=[1024, 1024])


class CnnPolicy(policies.ActorCriticCnnPolicy):
"""A CNN Actor-Critic policy.

This policy optionally transposes its observation inputs. Note that if this is done,
the policy expects the observation space to be a Box with values ranging from 0 to
255. Methods are copy-pasted from StableBaselines 3's ActorCriticPolicy, with an
initial check whether or not to transpose an observation input.
"""

def __init__(self, *args, transpose_input: bool = False, **kwargs):
"""Builds CnnPolicy; arguments passed to `CnnActorCriticPolicy`."""
self.transpose_input = transpose_input
if self.transpose_input:
kwargs.update(
{
"observation_space": self.transpose_space(
kwargs["observation_space"],
),
},
)
super().__init__(*args, **kwargs)

def transpose_space(self, observation_space: gym.spaces.Box) -> gym.spaces.Box:
if not isinstance(observation_space, gym.spaces.Box):
raise TypeError("This code assumes that observation spaces are gym Boxes.")
if not (
np.all(observation_space.low == 0) and np.all(observation_space.high == 255)
):
error_msg = (
"This code assumes the observation space values range from "
+ "0 to 255."
)
raise ValueError(error_msg)
h, w, c = observation_space.shape
new_shape = (c, h, w)
return gym.spaces.Box(
low=0,
high=255,
shape=new_shape,
dtype=observation_space.dtype,
)

def forward(
self,
obs: th.Tensor,
deterministic: bool = False,
) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
if self.transpose_input:
obs_ = networks.cnn_transpose(obs)
else:
obs_ = obs
# Preprocess the observation if needed
features = self.extract_features(obs_)
latent_pi, latent_vf = self.mlp_extractor(features)
# Evaluate the values for the given observations
values = self.value_net(latent_vf)
distribution = self._get_action_dist_from_latent(latent_pi)
actions = distribution.get_actions(deterministic=deterministic)
log_prob = distribution.log_prob(actions)
return actions, values, log_prob

def evaluate_actions(
self,
obs: th.Tensor,
actions: th.Tensor,
) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]:
if self.transpose_input:
obs_ = networks.cnn_transpose(obs)
else:
obs_ = obs
# Preprocess the observation if needed
features = self.extract_features(obs_)
latent_pi, latent_vf = self.mlp_extractor(features)
distribution = self._get_action_dist_from_latent(latent_pi)
log_prob = distribution.log_prob(actions)
values = self.value_net(latent_vf)
return values, log_prob, distribution.entropy()

def get_distribution(self, obs: th.Tensor) -> Distribution:
if self.transpose_input:
obs_ = networks.cnn_transpose(obs)
else:
obs_ = obs
features = self.extract_features(obs_)
latent_pi = self.mlp_extractor.forward_actor(features)
return self._get_action_dist_from_latent(latent_pi)

def predict_values(self, obs: th.Tensor) -> th.Tensor:
if self.transpose_input:
obs_ = networks.cnn_transpose(obs)
else:
obs_ = obs
features = self.extract_features(obs_)
latent_vf = self.mlp_extractor.forward_critic(features)
return self.value_net(latent_vf)


class NormalizeFeaturesExtractor(torch_layers.FlattenExtractor):
"""Feature extractor that flattens then normalizes input."""

Expand Down
18 changes: 5 additions & 13 deletions src/imitation/rewards/reward_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,10 +568,12 @@ def forward(
"""
inputs = []
if self.use_state:
state_ = cnn_transpose(state) if self.hwc_format else state
state_ = networks.cnn_transpose(state) if self.hwc_format else state
inputs.append(state_)
if self.use_next_state:
next_state_ = cnn_transpose(next_state) if self.hwc_format else next_state
next_state_ = (
networks.cnn_transpose(next_state) if self.hwc_format else next_state
)
inputs.append(next_state_)

inputs_concat = th.cat(inputs, dim=1)
Expand All @@ -597,16 +599,6 @@ def forward(
return rewards


def cnn_transpose(tens: th.Tensor) -> th.Tensor:
"""Transpose a (b,h,w,c)-formatted tensor to (b,c,h,w) format."""
if len(tens.shape) == 4:
return th.permute(tens, (0, 3, 1, 2))
else:
raise ValueError(
f"Invalid input: len(tens.shape) = {len(tens.shape)} != 4.",
)


class NormalizedRewardNet(PredictProcessedWrapper):
"""A reward net that normalizes the output of its base network."""

Expand Down Expand Up @@ -872,7 +864,7 @@ def __init__(
)

def forward(self, state: th.Tensor) -> th.Tensor:
state_ = cnn_transpose(state) if self.hwc_format else state
state_ = networks.cnn_transpose(state) if self.hwc_format else state
return self._potential_net(state_)


Expand Down
10 changes: 8 additions & 2 deletions src/imitation/scripts/common/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import contextlib
import logging
import pathlib
from typing import Any, Generator, Mapping, Sequence, Tuple, Union
from typing import Any, Callable, Generator, Mapping, Optional, Sequence, Tuple, Union

import numpy as np
import sacred
from gym import Env
from stable_baselines3.common import vec_env

from imitation.data import types
Expand Down Expand Up @@ -35,6 +36,7 @@ def config():
num_vec = 8 # number of environments in VecEnv
parallel = True # Use SubprocVecEnv rather than DummyVecEnv
max_episode_steps = None # Set to positive int to limit episode horizons
post_wrappers = [] # Sequence of wrappers to apply to each env in the VecEnv
env_make_kwargs = {} # The kwargs passed to `spec.make`.

locals() # quieten flake8
Expand Down Expand Up @@ -141,6 +143,7 @@ def make_venv(
parallel: bool,
log_dir: str,
max_episode_steps: int,
post_wrappers: Optional[Sequence[Callable[[Env, int], Env]]],
env_make_kwargs: Mapping[str, Any],
**kwargs,
) -> Generator[vec_env.VecEnv, None, None]:
Expand All @@ -154,6 +157,8 @@ def make_venv(
max_episode_steps: If not None, then a TimeLimit wrapper is applied to each
environment to artificially limit the maximum number of timesteps in an
episode.
post_wrappers: If specified, iteratively wraps each environment with each
of the wrappers specified in the sequence.
log_dir: Logs episode return statistics to a subdirectory 'monitor`.
env_make_kwargs: The kwargs passed to `spec.make` of a gym environment.
kwargs: Passed through to `util.make_vec_env`.
Expand All @@ -169,8 +174,9 @@ def make_venv(
rng=rng,
n_envs=num_vec,
parallel=parallel,
max_episode_steps=max_episode_steps,
log_dir=log_dir,
max_episode_steps=max_episode_steps,
post_wrappers=post_wrappers,
env_make_kwargs=env_make_kwargs,
**kwargs,
)
Expand Down
10 changes: 9 additions & 1 deletion src/imitation/scripts/common/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ def reward_ensemble():
locals()


@reward_ingredient.named_config
def cnn_reward():
net_cls = reward_nets.CnnRewardNet # noqa: F841


@reward_ingredient.config_hook
def config_hook(config, command_name, logger):
"""Sets default values for `net_cls` and `net_kwargs`."""
Expand All @@ -71,7 +76,10 @@ def config_hook(config, command_name, logger):
default_net = reward_nets.BasicShapedRewardNet
res["net_cls"] = default_net

if "normalize_input_layer" not in config["reward"]["net_kwargs"]:
if (
"normalize_input_layer" not in config["reward"]["net_kwargs"]
and config["reward"]["net_cls"] != reward_nets.CnnRewardNet
):
res["net_kwargs"] = {"normalize_input_layer": networks.RunningNorm}

if "net_cls" in res and issubclass(res["net_cls"], reward_nets.RewardEnsemble):
Expand Down
Loading