0%

检查强化学习自定义环境是否符合Gym的API

虽然Openai Gym给我们提供了很多有趣的环境,但是,针对实际问题,我们一般都需要建立自己的强化学习环境,也就是说,建立一个函数,输入是action,输出是next_state, reward, done, info,建立这样一个环境的时候最好按照Gym的API来设计,这样方便代码的编写与调试,很多网上实现的算法都是按照Gym的环境进行设计的,同时,Gym的API本身设计的就比较合理的。

本文总结了Gym的环境的格式,同时给出了一个能够检查强化学习自定义环境的工具。

按照Gym的API设计自定义环境

按照Gym的接口进行设计,你的环境需要继承Gym类,同时必须包含以下方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import gym
from gym import spaces

class CustomEnv(gym.Env):
"""Custom Environment that follows gym interface"""
metadata = {'render.modes': ['human']}

def __init__(self, arg1, arg2, ...):
super(CustomEnv, self).__init__()
# Define action and observation space
# They must be gym.spaces objects
# Example when using discrete actions:
self.action_space = spaces.Discrete(N_DISCRETE_ACTIONS)
# Example for using image as input:
self.observation_space = spaces.Box(low=0, high=255,
shape=(HEIGHT, WIDTH, N_CHANNELS), dtype=np.uint8)

def step(self, action):
...
return observation, reward, done, info

def reset(self):
...
return observation # reward, done, info can't be included

def render(self, mode='human'):
...

def close(self):
...

然后,你可以定义一个RL agent:

1
2
3
4
# Instantiate the env
env = CustomEnv(arg1, ...)
# Define and Train the agent
model = A2C('CnnPolicy', env).learn(total_timesteps=1000)

通过stable baselines3中的check_env工具来进行环境API是否符合Gym的API。

1
2
3
4
5
from stable_baselines3.common.env_checker import check_env

env = CustomEnv(arg1, ...)
# It will check your custom environment and output additional warnings if needed
check_env(env)

参考来源:https://stable-baselines3.readthedocs.io/en/master/guide/custom_env.html

check_env工具

stable_baselines3.common.env_checker.check_env(env: gym.core.Env, warn: bool = True, skip_render_check: bool = True) → None

参考来源:https://stable-baselines3.readthedocs.io/en/master/guide/custom_env.html

示例

Y. Zhang, B. Zhao和D. Liu, 《Deterministic policy gradient adaptive dynamic programming for model-free optimal control》, *Neurocomputing*, 卷 387, 页 40–50, 4月 2020, doi: [10.1016/j.neucom.2019.11.032](https://doi.org/10.1016/j.neucom.2019.11.032).中所仿真的算例:

\[\mathbf{x}_{t+1}=\left[\begin{array}{c} \left(\mathbf{x}_{1, t}+\mathbf{x}_{2, t}^{2}+\mathbf{u}_{t}\right) \cos \left(\mathbf{x}_{2, t}\right) \\ 0.5\left(\mathbf{x}_{1, t}^{2}+\mathbf{x}_{2, t}+\mathbf{u}_{t}\right) \sin \left(\mathbf{x}_{2, t}\right) \end{array}\right]\]

对于这样一个非仿射非线性状态空间描述的系统建立强化学习环境,Reward定义为:

\[\mathcal{J}\left(\mathbf{x}_{0}\right)=\sum_{t=0}^{\infty}\left(\mathbf{x}_{t}^{\top} Q \mathbf{x}_{t}+\mathbf{u}_{t}^{\top} R \mathbf{u}_{t}\right)\]中的\(\mathbf{x}_{t}^{\top} Q \mathbf{x}_{t}+\mathbf{u}_{t}^{\top} R \mathbf{u}_{t}\)

其中,\(Q=\left[\begin{array}{cc} 1.0 & 0 \\ 0 & 1.0 \end{array}\right]\), \(R=1.0\)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import numpy as np
import pandas as pd
# 定义非线性系统环境,按照GYM的格式
import gym
from gym import spaces, logger
from gym.utils import seeding

class NonLinearEnv(gym.Env):
"""
描述:
一个离散时间非线性非仿射系统

来源:
论文《Policy Gradient Adaptive Dynamic Programming for Data-Based Optimal Control》

状态:
State1,State2

动作:
单输入系统,u

回报:


初始状态:
x0=[0.2, 0.7]'

episode结束条件:

"""
def __init__(self, Q: np.array, R: np.array):
self.Q = Q
self.R = R
self.state = np.array([0.2, 0.7])
self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(1, ), dtype=np.float64)
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(2, ), dtype=np.float64)

def reset(self):
self.state = np.array([0.2, 0.7])
return self.state

def step(self, action):
next_state_x1 = (self.state[0]+self.state[1]**2+action)*np.cos(self.state[1])
next_state_x2 = 0.5*(self.state[0]**2+self.state[1]+action)*np.sin(self.state[1])
next_state = [next_state_x1, next_state_x2]
reward = np.matrix(self.state)*self.Q*np.matrix(self.state).T + action**2*self.R
self.state = np.array(next_state).reshape(2, )
done = False
info = {}
return self.state, -float(reward[0][0]), done, info

def render(self):
pass
1
2
3
4
5
6
7
8
9
from NLEnv import NonLinearEnv
from stable_baselines3.common.env_checker import check_env
import numpy as np

Q = np.matrix([[1, 0],
[0, 1]])
R = np.matrix([[1]])
env = NonLinearEnv(Q, R)
check_env(env)

这样,就按照Gym的标准定义好了强化学习环境,同时,能够通过API一致性检测,然后,只需要考虑强化学习算法了,无论是自己设计算法,还是使用现成的算法包,都比较方便了。

If you like my blog, please donate for me.