From 71ac4f362d19702d2e4ab2301e9aaac77e0b261a Mon Sep 17 00:00:00 2001 From: Timo Kaufmann Date: Thu, 14 Mar 2024 10:37:45 +0100 Subject: [PATCH] Avoid keeping references to full trajectories Previously we stored a view into the trajectory in the preference comparison dataset. This view is a reference to the original trajectory, and therefore keeps it from getting garbage collected for as long as the view exists (i.e., however long the comparison is stored in the dataset). This is problematic when trajectories are large and long, e.g., in the case of atari (images) with SEALS (long episodes). It can cause the RAM to fill up quite quickly in that setting. We can fix it by copying the fragments we want to store. That avoids keeping a reference to the original trajectory alive, we only need to store the fragment. The tradeoff is that copying adds some overhead and overlapping fragments are no longer deduplicated. --- src/imitation/algorithms/preference_comparisons.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 14a8fad5b..5a7411b61 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -650,11 +650,12 @@ def __call__( start = self.rng.integers(0, n - fragment_length, endpoint=True) end = start + fragment_length terminal = (end == n) and traj.terminal + # Copy the slices to enable garbage collection of full trajectory. fragment = TrajectoryWithRew( - obs=traj.obs[start : end + 1], - acts=traj.acts[start:end], - infos=traj.infos[start:end] if traj.infos is not None else None, - rews=traj.rews[start:end], + obs=traj.obs[start : end + 1].copy(), + acts=traj.acts[start:end].copy(), + infos=traj.infos[start:end].copy() if traj.infos is not None else None, + rews=traj.rews[start:end].copy(), terminal=terminal, ) fragments.append(fragment)