zl程序教程

您现在的位置是:首页 >  后端

当前栏目

baselines算法库common/wrapper.py模块分析

算法模块 分析 py Common Wrapper
2023-09-11 14:19:19 时间

 

common/wrapper.py模块:

 

import gym

class TimeLimit(gym.Wrapper):
    def __init__(self, env, max_episode_steps=None):
        super(TimeLimit, self).__init__(env)
        self._max_episode_steps = max_episode_steps
        self._elapsed_steps = 0

    def step(self, ac):
        observation, reward, done, info = self.env.step(ac)
        self._elapsed_steps += 1
        if self._elapsed_steps >= self._max_episode_steps:
            done = True
            info['TimeLimit.truncated'] = True
        return observation, reward, done, info

    def reset(self, **kwargs):
        self._elapsed_steps = 0
        return self.env.reset(**kwargs)

class ClipActionsWrapper(gym.Wrapper):
    def step(self, action):
        import numpy as np
        action = np.nan_to_num(action)
        action = np.clip(action, self.action_space.low, self.action_space.high)
        return self.env.step(action)

    def reset(self, **kwargs):
        return self.env.reset(**kwargs)

 

两个对gym环境类的包装类,TimeLimit限制环境类env的最大step数,如果到达最大step数后游戏还没有终止则强制返回终止状态done=True,并设置返回信息:info['TimeLimit.truncated'] = True

 

类ClipActionsWrapper对输入给gym环境的动作进行包装,如果输入的action(action为numpy向量)中含有np.nan则置为0,

如果action中的数值大小超过action_space.low和action_space.high则进行clip操作。

 

 

 

 

 

============================================