-
Notifications
You must be signed in to change notification settings - Fork 9
/
once.py
72 lines (51 loc) · 1.94 KB
/
once.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# command line arguments
# [preset names, ...] checkpoint file
import numpy as np
def test_model(env, model, record=False):
from presets import preset
enable_vae = preset.vae.enable
env.set_mode(1)
s_norm = model["s_norm"]
actor = model["policy"]
import torch
T = lambda x: torch.FloatTensor(x)
if enable_vae:
from utils.vae import VAE
vae = VAE(preset.vae.latent_dim, use_gpu=False)
vae_path = 'results/vae/models/%s.pth' % preset.vae.model
vae.load_state_dict(torch.load(vae_path))
with open(vae_path + '.norm.npy', 'rb') as f:
vae_mean = np.load(f)
vae_std = np.load(f)
ob = env.reset()
while True:
with torch.no_grad():
obt = T(ob)
obt_normed = s_norm(obt)
ac = actor.act_deterministic(obt_normed)
if enable_vae:
decoded = vae.decode(T(ac[0:preset.vae.latent_dim]*vae_std + vae_mean)).cpu().numpy()
ac = np.concatenate((decoded, ac[preset.vae.latent_dim:]))
ob, rwd, done, info = env.step(ac)
if done:
break
if __name__=="__main__":
import sys
from presets import preset
preset.load_default()
exp_settings = preset.experiment
env_settings = preset.env
env_settings.enable_rendering = True
for i in range(1, len(sys.argv) - 1):
preset.load_custom(sys.argv[i])
preset.load_env_override()
checkpoint = sys.argv[-1]
from env import get_env
test_env = get_env(exp_settings.env)(seed=0, checkpoint=checkpoint, evaluate=True)
from algorithm import algorithm_bundle_dict
algorithm_name = preset.experiment.algorithm
model_class = algorithm_bundle_dict[algorithm_name].model_class
from model import model_dict
load_model = model_dict[model_class].loader
model = load_model(checkpoint)
test_model(test_env, model, exp_settings.record_motion)