File size: 2,134 Bytes
164ef55
d812b5e
 
 
 
 
 
 
 
 
 
 
3210277
164ef55
aa21d8d
d812b5e
 
 
 
 
 
 
 
 
 
 
 
5d7151b
d812b5e
 
 
 
 
 
 
 
a53beac
d812b5e
4a80994
d812b5e
 
5871174
d812b5e
 
c044350
 
 
 
 
d812b5e
164ef55
90f838f
aa21d8d
90f838f
 
d38fa4c
90f838f
 
 
c044350
164ef55
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import gradio as gr
from huggingface_sb3 import load_from_hub

import gym
import os 
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecNormalize

from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack

from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv


def replay(model_id, filename, environment, evaluate):
	# Load the model
	checkpoint = load_from_hub(model_id, filename)
    
	# Because we using 3.7 on Colab and this agent was trained with 3.8 to avoid Pickle errors:
	custom_objects = {
            "learning_rate": 0.0,
            "lr_schedule": lambda _: 0.0,
            "clip_range": lambda _: 0.0,}
	
	model= PPO.load(checkpoint, custom_objects=custom_objects)

	eval_env = make_atari_env(environment, n_envs=1)
	eval_env = VecFrameStack(eval_env, n_stack=4)

	video_folder = 'logs/videos/'
	video_length = 100


	# Record the video starting at the first step
	env = VecVideoRecorder(eval_env, video_folder,
                       record_video_trigger=lambda x: x == 0, video_length=video_length,
                       name_prefix=f"test")

	obs = env.reset()
	for _ in range(video_length + 1):
		action, _states = model.predict(obs)
		obs, rewards, dones, info = env.step(action)
	# Save the video
	env.close()
	
	mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10)
	print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")
	
	return f"mean_reward={mean_reward:.2f} +/- {std_reward}", "video"


iface = gr.Interface(fn=replay, inputs=[
                                        gr.inputs.Textbox(lines=1, placeholder=None, default="", label="Model Id: "), 
                                        gr.inputs.Textbox(lines=1, placeholder=None, default="", label="Filename: "),
                                        gr.inputs.Textbox(lines=1, placeholder=None, default="", label="Environment: "),
                                        gr.inputs.Checkbox(default=False, label="Evaluate?: ")
                                        ]


, outputs=["text", "video"])
iface.launch()