Skip to content

Commit

Permalink
3D rl rpm example
Browse files Browse the repository at this point in the history
  • Loading branch information
JacopoPan committed Dec 8, 2023
1 parent dc71ddc commit 4843943
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 4 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,15 @@ cd gym_pybullet_drones/examples/
python3 downwash.py
```

### Reinforcement learning 3'-training examples (SB3's PPO)
### Reinforcement learning 10'-training example (SB3's PPO)

```sh
cd gym_pybullet_drones/examples/
python learn.py # task: single drone hover at z == 1.0
python learn.py --multiagent true # task: 2-drone hover at z == 1.2 and 0.7
```

<img src="gym_pybullet_drones/assets/rl.gif" alt="rl example" width="350">

### Betaflight SITL example (Ubuntu only)

```sh
Expand Down
Binary file added gym_pybullet_drones/assets/rl.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 5 additions & 0 deletions gym_pybullet_drones/envs/HoverAviary.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ def _computeTruncated(self):
Whether the current episode timed out.
"""
state = self._getDroneStateVector(0)
if (abs(state[0]) > 2.0 or abs(state[1]) > 2.0 or state[2] > 2.0 # Truncate when the drone is too far away
or abs(state[7]) > .5 or abs(state[8]) > .5 # Truncate when the drone is too tilted
):
return True
if self.step_counter/self.PYB_FREQ > self.EPISODE_LEN_SEC:
return True
else:
Expand Down
4 changes: 2 additions & 2 deletions gym_pybullet_drones/examples/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
DEFAULT_COLAB = False

DEFAULT_OBS = ObservationType('kin') # 'kin' or 'rgb'
DEFAULT_ACT = ActionType('one_d_rpm') # 'rpm' or 'pid' or 'vel' or 'one_d_rpm' or 'one_d_pid'
DEFAULT_ACT = ActionType('rpm') # 'rpm' or 'pid' or 'vel' or 'one_d_rpm' or 'one_d_pid'
DEFAULT_AGENTS = 2
DEFAULT_MA = False

Expand Down Expand Up @@ -85,7 +85,7 @@ def run(multiagent=DEFAULT_MA, output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_
eval_freq=int(2000),
deterministic=True,
render=False)
model.learn(total_timesteps=3*int(1e5) if local else int(1e2), # shorter training in GitHub Actions pytest
model.learn(total_timesteps=int(1e6) if local else int(1e2), # shorter training in GitHub Actions pytest
callback=eval_callback,
log_interval=100)

Expand Down

0 comments on commit 4843943

Please sign in to comment.