Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rgb observation to dagger #802

Open
wants to merge 86 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
86 commits
Select commit Hold shift + click to select a range
5182ecf
first pass of dict obs functionality
NixGD Sep 13, 2023
61d816b
cleanup DictObs
NixGD Sep 13, 2023
c3331f6
add dict space to test_types.py, fix some problems
NixGD Sep 14, 2023
fc9838d
add dict-obs test for rollout
NixGD Sep 14, 2023
fb9498b
add bc.py test
NixGD Sep 14, 2023
e54c36c
cleanup
NixGD Sep 14, 2023
ee04383
small fixes
NixGD Sep 14, 2023
6e2218a
small fixes
NixGD Sep 14, 2023
68fe666
fix type error in interactive.py
NixGD Sep 14, 2023
9ad2aaf
fix introduced error in mce_irl.py
NixGD Sep 14, 2023
67341d5
fix minor ci complaint
NixGD Sep 14, 2023
c497b56
add basic dictobs tests
NixGD Sep 14, 2023
d3f79bf
change default bc policy for dict obs space
NixGD Sep 14, 2023
2de9e49
refine rollout.py typechecks, comments
NixGD Sep 14, 2023
c47cca6
check rollout produces dictobs of correct shape
NixGD Sep 14, 2023
276294b
cleanup types and dictobs helpers
NixGD Sep 14, 2023
071d2a7
clean useless lines
NixGD Sep 14, 2023
a2ccd7e
clean up print statements
NixGD Sep 14, 2023
93baa2d
fix typos
NixGD Sep 15, 2023
54f33af
assert matching keys in from_obs_list
NixGD Sep 15, 2023
c711abf
move maybe_wrap, clean rollout
NixGD Sep 15, 2023
58a0d70
change policy callable to take dict[str, np.ndarray] not dictobs
NixGD Sep 15, 2023
0f080d4
rollout info wrapper supports dictobs
NixGD Sep 15, 2023
c4d3e11
fix from_obs_list key consistency check
NixGD Sep 15, 2023
b93294a
xfail save/load tests with dictobs
NixGD Sep 15, 2023
3f17ff2
doc for dictobs wrapper
NixGD Sep 15, 2023
0212e0e
don't error on int observations
NixGD Sep 15, 2023
070ebf9
lint fixes
NixGD Sep 15, 2023
657e17e
cleanup bc test for dict obs
NixGD Sep 15, 2023
1f8c12a
cleanup bc.py unwrapping
NixGD Sep 15, 2023
bd70ecd
cleanup rollout.py
NixGD Sep 15, 2023
bec464c
cleanup dictobs interface
NixGD Sep 15, 2023
bef19e6
small cleanups
NixGD Sep 15, 2023
9aaf73f
coverage fixes, test fix
NixGD Sep 15, 2023
5d6aa77
adjust error types
NixGD Sep 15, 2023
86fbcf1
docstrings for type helpers
NixGD Sep 15, 2023
8d1e0d6
add dict obs space support for density
NixGD Sep 15, 2023
96978d5
fix typos
NixGD Sep 15, 2023
e95df9d
Adam suggestions from code review
NixGD Sep 16, 2023
161ec95
small changes for code review
NixGD Sep 16, 2023
90bdf57
fix docstring
NixGD Sep 16, 2023
6aa25ff
remove FloatReward
ZiyueWang25 Oct 2, 2023
bf48c76
Merge remote-tracking branch 'origin/master' into support-dict-obs-space
ZiyueWang25 Oct 2, 2023
4ce1b57
Fix test_bc
ZiyueWang25 Oct 2, 2023
de1b1c8
Turn off GPU finding to avoid using gpu device
ZiyueWang25 Oct 2, 2023
1a1a458
Check None to ensure __add__ can work
ZiyueWang25 Oct 2, 2023
f7866f4
fix docstring
ZiyueWang25 Oct 2, 2023
daa838d
bypass pytype and lint test
ZiyueWang25 Oct 2, 2023
803eab0
format with black
ZiyueWang25 Oct 2, 2023
0ac6f54
Test dict space in density algo
ZiyueWang25 Oct 2, 2023
be9798b
black format
ZiyueWang25 Oct 2, 2023
c7e6809
small fix
ZiyueWang25 Oct 2, 2023
82fb558
Add DictObs into test_wrappers
ZiyueWang25 Oct 3, 2023
03714cc
fix format
ZiyueWang25 Oct 3, 2023
187e881
minor fix
ZiyueWang25 Oct 3, 2023
ae96521
type and lint fix
ZiyueWang25 Oct 3, 2023
535a986
Add policy training test
ZiyueWang25 Oct 3, 2023
de027c4
suppress line too long lint check on a line
ZiyueWang25 Oct 3, 2023
4caa151
acts to obs for clarity
ZiyueWang25 Oct 3, 2023
20c6f56
Add HumanReadableWrapper
ZiyueWang25 Oct 3, 2023
aaf94da
adjust wrapper and not set render_mode inside
ZiyueWang25 Oct 3, 2023
ef53690
fix dict env observation space
ZiyueWang25 Oct 3, 2023
df35e3a
add RemoveHumanReadableWrapper and update ob space
ZiyueWang25 Oct 4, 2023
a69e052
Remove some unnecessary helper functions
ZiyueWang25 Oct 4, 2023
c6ed675
include rgb obs to dagger algo
ZiyueWang25 Oct 4, 2023
d7d8db1
add wrappers tests and fix linter and typing
ZiyueWang25 Oct 5, 2023
5dd1699
change ob to obs
ZiyueWang25 Oct 5, 2023
8481890
allow not only dict type obs in dagger
ZiyueWang25 Oct 5, 2023
b0d8d4b
fix lint and test
ZiyueWang25 Oct 5, 2023
12b60b2
fix type and test
ZiyueWang25 Oct 5, 2023
afbbe46
Merge branch 'master' of github.com:HumanCompatibleAI/imitation into …
ZiyueWang25 Oct 5, 2023
5036fcf
fix type
ZiyueWang25 Oct 5, 2023
3f23de1
resolve typing issue
ZiyueWang25 Oct 5, 2023
a44b193
Remove wrong type annotation in test
ZiyueWang25 Oct 5, 2023
1073967
Merge branch 'master' of github.com:HumanCompatibleAI/imitation into …
ZiyueWang25 Oct 5, 2023
ae17588
resolve conflict
ZiyueWang25 Oct 5, 2023
f140035
add policy wrapper
ZiyueWang25 Oct 6, 2023
9a00e68
small fix
ZiyueWang25 Oct 6, 2023
052cf00
fix the data and policy wrappers
ZiyueWang25 Oct 6, 2023
c63761c
Use ObservationWrapper
ZiyueWang25 Oct 6, 2023
468f621
update naming
ZiyueWang25 Oct 6, 2023
3c6def5
update tests
ZiyueWang25 Oct 6, 2023
68d1ac2
update tests
ZiyueWang25 Oct 6, 2023
125d19d
update demo
ZiyueWang25 Oct 6, 2023
f8ebbc4
rgb to hr
ZiyueWang25 Oct 6, 2023
33162b2
small fix
ZiyueWang25 Oct 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions examples/train_dagger_atari_interactive_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,48 @@

import gymnasium as gym
import numpy as np
from stable_baselines3.common import vec_env
import torch as th
from stable_baselines3.common import torch_layers, vec_env

from imitation.algorithms import bc, dagger
from imitation.policies import interactive
from imitation.data import wrappers as data_wrappers
from imitation.policies import base as policy_base
from imitation.policies import interactive, obs_update_wrapper


def lr_schedule(_: float):
# Set lr_schedule to max value to force error if policy.optimizer
# is used by mistake (should use self.optimizer instead).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy-and-pasted comment doesn't make sense out of context (what is self here?)

return th.finfo(th.float32).max


if __name__ == "__main__":
rng = np.random.default_rng(0)

env = vec_env.DummyVecEnv([lambda: gym.wrappers.TimeLimit(gym.make("Pong-v4"), 10)])
env.seed(0)
env = gym.make("PongNoFrameskip-v4", render_mode="rgb_array")
hr_env = data_wrappers.HumanReadableWrapper(env)
venv = vec_env.DummyVecEnv([lambda: hr_env])
venv.seed(0)

expert = interactive.AtariInteractivePolicy(env)
expert = interactive.AtariInteractivePolicy(venv)
policy = policy_base.FeedForward32Policy(
observation_space=env.observation_space,
action_space=env.action_space,
lr_schedule=lr_schedule,
features_extractor_class=torch_layers.FlattenExtractor,
)
wrapped_policy = obs_update_wrapper.RemoveHR(policy, lr_schedule=lr_schedule)

bc_trainer = bc.BC(
observation_space=env.observation_space,
action_space=env.action_space,
policy=wrapped_policy,
rng=rng,
)

with tempfile.TemporaryDirectory(prefix="dagger_example_") as tmpdir:
dagger_trainer = dagger.SimpleDAggerTrainer(
venv=env,
venv=venv,
scratch_dir=tmpdir,
expert_policy=expert,
bc_trainer=bc_trainer,
Expand Down
35 changes: 24 additions & 11 deletions src/imitation/algorithms/dagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import os
import pathlib
import uuid
from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union

import numpy as np
import torch as th
Expand Down Expand Up @@ -161,13 +161,16 @@ class InteractiveTrajectoryCollector(vec_env.VecEnvWrapper):
"""

traj_accum: Optional[rollout.TrajectoryAccumulator]
_last_obs: Optional[np.ndarray]
_last_obs: Optional[Union[Dict[str, np.ndarray], np.ndarray]]
_last_user_actions: Optional[np.ndarray]

def __init__(
self,
venv: vec_env.VecEnv,
get_robot_acts: Callable[[np.ndarray], np.ndarray],
get_robot_acts: Callable[
[Union[Dict[str, np.ndarray], np.ndarray]],
np.ndarray,
],
beta: float,
save_dir: types.AnyPath,
rng: np.random.Generator,
Expand Down Expand Up @@ -213,16 +216,20 @@ def seed(self, seed: Optional[int] = None) -> List[Optional[int]]:
self.rng = np.random.default_rng(seed=seed)
return list(self.venv.seed(seed))

def reset(self) -> np.ndarray:
def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
"""Resets the environment.

Returns:
obs: first observation of a new trajectory.
"""
self.traj_accum = rollout.TrajectoryAccumulator()
obs = self.venv.reset()
assert isinstance(obs, np.ndarray)
for i, ob in enumerate(obs):
assert isinstance(
obs,
(np.ndarray, dict),
), "Tuple observations are not supported."
dictobs = types.maybe_wrap_in_dictobs(obs)
for i, ob in enumerate(dictobs):
self.traj_accum.add_step({"obs": ob}, key=i)
self._last_obs = obs
self._is_reset = True
Expand Down Expand Up @@ -256,7 +263,9 @@ def step_async(self, actions: np.ndarray) -> None:

mask = self.rng.uniform(0, 1, size=(self.num_envs,)) > self.beta
if np.sum(mask) != 0:
actual_acts[mask] = self.get_robot_acts(self._last_obs[mask])
last_obs = types.maybe_wrap_in_dictobs(self._last_obs)
obs_for_robot = types.maybe_unwrap_dictobs(last_obs[mask])
actual_acts[mask] = self.get_robot_acts(obs_for_robot)

self._last_user_actions = actions
self.venv.step_async(actual_acts)
Expand All @@ -270,9 +279,13 @@ def step_wait(self) -> VecEnvStepReturn:
Observation, reward, dones (is terminal?) and info dict.
"""
next_obs, rews, dones, infos = self.venv.step_wait()
assert isinstance(next_obs, np.ndarray)
assert self.traj_accum is not None
assert self._last_user_actions is not None
assert isinstance(
next_obs,
(np.ndarray, dict),
), "Tuple observations are not supported."

self._last_obs = next_obs
fresh_demos = self.traj_accum.add_steps_and_auto_finish(
obs=next_obs,
Expand Down Expand Up @@ -508,7 +521,7 @@ def create_trajectory_collector(self) -> InteractiveTrajectoryCollector:
beta = self.beta_schedule(self.round_num)
collector = InteractiveTrajectoryCollector(
venv=self.venv,
get_robot_acts=lambda acts: self.bc_trainer.policy.predict(acts)[0],
get_robot_acts=lambda obs: self.bc_trainer.policy.predict(obs)[0],
beta=beta,
save_dir=save_dir,
rng=self.rng,
Expand Down Expand Up @@ -550,7 +563,7 @@ def save_trainer(self) -> Tuple[pathlib.Path, pathlib.Path]:


class SimpleDAggerTrainer(DAggerTrainer):
"""Simpler subclass of DAggerTrainer for training with synthetic feedback."""
"""Simpler subclass of DAggerTrainer for training with feedback."""

def __init__(
self,
Expand All @@ -571,7 +584,7 @@ def __init__(
simultaneously for that timestep.
scratch_dir: Directory to use to store intermediate training
information (e.g. for resuming training).
expert_policy: The expert policy used to generate synthetic demonstrations.
expert_policy: The expert policy used to generate demonstrations.
rng: Random state to use for the random number generator.
expert_trajs: Optional starting dataset that is inserted into the round 0
dataset.
Expand Down
2 changes: 1 addition & 1 deletion src/imitation/data/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def dict_len(self):

def __getitem__(
self,
key: Union[int, slice, Tuple[Union[int, slice], ...]],
key: Union[int, slice, Tuple[Union[int, slice], ...], np.ndarray],
) -> "DictObs":
"""Indexes or slices each array.

Expand Down
59 changes: 57 additions & 2 deletions src/imitation/data/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
"""Environment wrappers for collecting rollouts."""

from typing import List, Optional, Sequence, Tuple
from typing import Dict, List, Optional, Sequence, Tuple, Union

import gymnasium as gym
import numpy as np
import numpy.typing as npt
from gymnasium.core import Env
from stable_baselines3.common.vec_env import VecEnv, VecEnvWrapper

from imitation.data import rollout, types

# The key for human readable data in the observation.
HR_OBS_KEY = "HR_OBS"


class BufferingWrapper(VecEnvWrapper):
"""Saves transitions of underlying VecEnv.
Expand Down Expand Up @@ -170,7 +174,7 @@ def pop_transitions(self) -> types.TransitionsWithRew:


class RolloutInfoWrapper(gym.Wrapper):
"""Add the entire episode's rewards and observations to `info` at episode end.
"""Adds the entire episode's rewards and observations to `info` at episode end.

Whenever done=True, `info["rollouts"]` is a dict with keys "obs" and "rews", whose
corresponding values hold the NumPy arrays containing the raw observations and
Expand Down Expand Up @@ -206,3 +210,54 @@ def step(self, action):
"rews": np.stack(self._rews),
}
return obs, rew, terminated, truncated, info


class HumanReadableWrapper(gym.ObservationWrapper):
"""Adds human-readable observation to `obs` at every step."""

def __init__(self, env: Env, original_obs_key: str = "ORI_OBS"):
"""Builds HumanReadableWrapper.

Args:
env: Environment to wrap.
original_obs_key: The key for original observation if the original
observation is not in dict format.

Raises:
ValueError: If `env.render_mode` is not "rgb_array".

"""
if env.render_mode != "rgb_array":
raise ValueError(
"HumanReadableWrapper requires render_mode='rgb_array', "
f"got {env.render_mode!r}",
)
self._original_obs_key = original_obs_key
super().__init__(env)

def observation(
self,
obs: Union[np.ndarray, Dict[str, np.ndarray]],
) -> Dict[str, np.ndarray]:
"""Adds human-readable observation to obs.

Transforms obs into dictionary if it is not already, and adds the human-readable
observation from `env.render()` under the key HR_OBS_KEY.

Args:
obs: Observation from environment.

Returns:
Observation dictionary with the human-readable data.

Raises:
KeyError: When the key HR_OBS_KEY already exists in the observation
dictionary.
"""
if not isinstance(obs, Dict):
obs = {self._original_obs_key: obs}

if HR_OBS_KEY in obs:
raise KeyError(f"{HR_OBS_KEY!r} already exists in observation dict")
obs[HR_OBS_KEY] = self.env.render() # type: ignore[assignment]
return obs
22 changes: 15 additions & 7 deletions src/imitation/policies/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from stable_baselines3.common import vec_env

import imitation.policies.base as base_policies
from imitation.data import wrappers
from imitation.util import util


Expand Down Expand Up @@ -64,9 +65,6 @@ def _choose_action(
if self.clear_screen_on_query:
util.clear_screen()

if isinstance(obs, dict):
raise ValueError("Dictionary observations are not supported here")

context = self._render(obs)
key = self._get_input_key()
self._clean_up(context)
Expand All @@ -87,7 +85,10 @@ def _get_input_key(self) -> str:
return key

@abc.abstractmethod
def _render(self, obs: np.ndarray) -> Optional[object]:
def _render(
self,
obs: Union[np.ndarray, Dict[str, np.ndarray]],
) -> Optional[object]:
"""Renders an observation, optionally returns a context for later cleanup."""

def _clean_up(self, context: object) -> None:
Expand All @@ -97,7 +98,7 @@ def _clean_up(self, context: object) -> None:
class ImageObsDiscreteInteractivePolicy(DiscreteInteractivePolicy):
"""DiscreteInteractivePolicy that renders image observations."""

def _render(self, obs: np.ndarray) -> plt.Figure:
def _render(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> plt.Figure:
img = self._prepare_obs_image(obs)

fig, ax = plt.subplots()
Expand All @@ -110,9 +111,16 @@ def _render(self, obs: np.ndarray) -> plt.Figure:
def _clean_up(self, context: plt.Figure) -> None:
plt.close(context)

def _prepare_obs_image(self, obs: np.ndarray) -> np.ndarray:
def _prepare_obs_image(
self,
obs: Union[np.ndarray, Dict[str, np.ndarray]],
) -> np.ndarray:
"""Applies any required observation processing to get an image to show."""
return obs
if not isinstance(obs, Dict):
return obs
if wrappers.HR_OBS_KEY not in obs:
raise KeyError(f"Observation does not contain {wrappers.HR_OBS_KEY!r}")
return obs[wrappers.HR_OBS_KEY]


ATARI_ACTION_NAMES_TO_KEYS = {
Expand Down
Loading
Loading