Spaces:
Runtime error
Runtime error
File size: 2,134 Bytes
164ef55 d812b5e 3210277 164ef55 aa21d8d d812b5e 5d7151b d812b5e a53beac d812b5e 4a80994 d812b5e 5871174 d812b5e c044350 d812b5e 164ef55 90f838f aa21d8d 90f838f d38fa4c 90f838f c044350 164ef55 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import gradio as gr
from huggingface_sb3 import load_from_hub
import gym
import os
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecNormalize
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv
def replay(model_id, filename, environment, evaluate):
# Load the model
checkpoint = load_from_hub(model_id, filename)
# Because we using 3.7 on Colab and this agent was trained with 3.8 to avoid Pickle errors:
custom_objects = {
"learning_rate": 0.0,
"lr_schedule": lambda _: 0.0,
"clip_range": lambda _: 0.0,}
model= PPO.load(checkpoint, custom_objects=custom_objects)
eval_env = make_atari_env(environment, n_envs=1)
eval_env = VecFrameStack(eval_env, n_stack=4)
video_folder = 'logs/videos/'
video_length = 100
# Record the video starting at the first step
env = VecVideoRecorder(eval_env, video_folder,
record_video_trigger=lambda x: x == 0, video_length=video_length,
name_prefix=f"test")
obs = env.reset()
for _ in range(video_length + 1):
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
# Save the video
env.close()
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10)
print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")
return f"mean_reward={mean_reward:.2f} +/- {std_reward}", "video"
iface = gr.Interface(fn=replay, inputs=[
gr.inputs.Textbox(lines=1, placeholder=None, default="", label="Model Id: "),
gr.inputs.Textbox(lines=1, placeholder=None, default="", label="Filename: "),
gr.inputs.Textbox(lines=1, placeholder=None, default="", label="Environment: "),
gr.inputs.Checkbox(default=False, label="Evaluate?: ")
]
, outputs=["text", "video"])
iface.launch() |