Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from huggingface_sb3 import load_from_hub | |
| import gym | |
| import os | |
| from stable_baselines3 import PPO | |
| from stable_baselines3.common.vec_env import VecNormalize | |
| from stable_baselines3.common.env_util import make_atari_env | |
| from stable_baselines3.common.vec_env import VecFrameStack | |
| from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv | |
| from stable_baselines3.common.evaluation import evaluate_policy | |
| from moviepy.editor import * | |
| def replay(model_id, filename, environment, evaluate): | |
| # Load the model | |
| checkpoint = load_from_hub(model_id, filename) | |
| # Because we using 3.7 on Colab and this agent was trained with 3.8 to avoid Pickle errors: | |
| custom_objects = { | |
| "learning_rate": 0.0, | |
| "lr_schedule": lambda _: 0.0, | |
| "clip_range": lambda _: 0.0,} | |
| model= PPO.load(checkpoint, custom_objects=custom_objects) | |
| eval_env = make_atari_env(environment, n_envs=1) | |
| eval_env = VecFrameStack(eval_env, n_stack=4) | |
| video_folder = 'logs/videos/' | |
| video_length = 1000 | |
| # Record the video starting at the first step | |
| env = VecVideoRecorder(eval_env, video_folder, | |
| record_video_trigger=lambda x: x == 0, video_length=video_length, | |
| name_prefix=f"test") | |
| obs = env.reset() | |
| for _ in range(video_length + 1): | |
| action, _states = model.predict(obs) | |
| obs, rewards, dones, info = env.step(action) | |
| # Save the video | |
| env.close() | |
| mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10) | |
| print(f"mean_reward={mean_reward:.2f} +/- {std_reward}") | |
| results = f"mean_reward={mean_reward:.2f} +/- {std_reward}" | |
| print(type(results)) | |
| print(env) | |
| print(env.video_recorder.path) | |
| videoclip = VideoFileClip(env.video_recorder.path) | |
| videoclip.write_videofile("new_filename.mp4") | |
| return 'new_filename.mp4', results | |
| examples = [["ThomasSimonini/ppo-QbertNoFrameskip-v4", "ppo-QbertNoFrameskip-v4.zip", "QbertNoFrameskip-v4", True]] | |
| iface = gr.Interface(fn=replay, inputs=[ | |
| gr.inputs.Textbox(lines=1, placeholder=None, default="", label="Model Id: "), | |
| gr.inputs.Textbox(lines=1, placeholder=None, default="", label="Filename: "), | |
| gr.inputs.Textbox(lines=1, placeholder=None, default="", label="Environment: "), | |
| gr.inputs.Checkbox(default=False, label="Evaluate?: ") | |
| ] | |
| , outputs=["video", "text"], enable_queue=True, examples=examples) | |
| iface.launch() |