forked from anuragk1/procgen
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_ebigfish.py
More file actions
31 lines (23 loc) · 879 Bytes
/
train_ebigfish.py
File metadata and controls
31 lines (23 loc) · 879 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from statistics import mode
from procgen import ProcgenEnv, ProcgenGym3Env
import numpy as np
import os
from stable_baselines3 import PPO, A2C
from stable_baselines3.common.vec_env import VecExtractDictObs, VecMonitor
models_dir = "models/PPO_MultipleFish_0.5/"
logdir = "logs"
if not os.path.exists(models_dir):
os.makedirs(models_dir)
if not os.path.exists(logdir):
os.makedirs(logdir)
env_name = "ebigfishs"
env = ProcgenEnv(num_envs=32, env_name=env_name)
env = VecMonitor(venv=env)
model = PPO("MultiInputPolicy", env, verbose=1, tensorboard_log=logdir, device='cuda')
TIMESTEPS = 500000
model.learn(total_timesteps=TIMESTEPS, tb_log_name="PPO_MultipleFish_0.5")
model.save(models_dir)
# iters = 0
# for i in range(50):
# model.learn(total_timesteps=TIMESTEPS, reset_num_timesteps=False, tb_log_name="PPO")
# model.save(f"{models_dir}/{TIMESTEPS*i}")