zl程序教程

您现在的位置是:首页 >  后端

当前栏目

baselines算法库common/vec_env/util.py模块分析

算法模块 分析 py Common util ENV
2023-09-11 14:19:19 时间

util.py模块代码:

"""
Helpers for dealing with vectorized environments.
"""

from collections import OrderedDict

import gym
import numpy as np


def copy_obs_dict(obs):
    """
    Deep-copy an observation dict.
    """
    return {k: np.copy(v) for k, v in obs.items()}


def dict_to_obs(obs_dict):
    """
    Convert an observation dict into a raw array if the
    original observation space was not a Dict space.
    """
    if set(obs_dict.keys()) == {None}:
        return obs_dict[None]
    return obs_dict


def obs_space_info(obs_space):
    """
    Get dict-structured information about a gym.Space.

    Returns:
      A tuple (keys, shapes, dtypes):
        keys: a list of dict keys.
        shapes: a dict mapping keys to shapes.
        dtypes: a dict mapping keys to dtypes.
    """
    if isinstance(obs_space, gym.spaces.Dict):
        assert isinstance(obs_space.spaces, OrderedDict)
        subspaces = obs_space.spaces
    elif isinstance(obs_space, gym.spaces.Tuple):
        assert isinstance(obs_space.spaces, tuple)
        subspaces = {i: obs_space.spaces[i] for i in range(len(obs_space.spaces))}
    else:
        subspaces = {None: obs_space}
    keys = []
    shapes = {}
    dtypes = {}
    for key, box in subspaces.items():
        keys.append(key)
        shapes[key] = box.shape
        dtypes[key] = box.dtype
    return keys, shapes, dtypes


def obs_to_dict(obs):
    """
    Convert an observation into a dict.
    """
    if isinstance(obs, dict):
        return obs
    return {None: obs}

 

 

 

函数:

def copy_obs_dict(obs):
def obs_to_dict(obs_dict):

假设传入的observation都是dict类型的。

在函数

obs_to_dict

中,如果传入的observation不是dict类型的则将其转为dict类型,此时的key值设置为None。

 

 

 

函数

def dict_to_obs(obs_dict)

假设输入的是key为None的dict类型的observation,将其dict类型转为np.array类型的observation。

如果输入的不是key为None的dict类型的observation则直接将其返回。

 

 

 

 

函数

def obs_space_info(obs_space):

输入参数为observation的spaces变量。

    if isinstance(obs_space, gym.spaces.Dict):
        assert isinstance(obs_space.spaces, OrderedDict)
        subspaces = obs_space.spaces
    elif isinstance(obs_space, gym.spaces.Tuple):
        assert isinstance(obs_space.spaces, tuple)
        subspaces = {i: obs_space.spaces[i] for i in range(len(obs_space.spaces))}
    else:
        subspaces = {None: obs_space}

首先将env.observation_sapce.spaces变量进行判断,将其转为dict类型。

 

 

对env.observation_space.spaces进行信息提取,得到:

    Returns:
      A tuple (keys, shapes, dtypes):
        keys: a list of dict keys.
        shapes: a dict mapping keys to shapes.
        dtypes: a dict mapping keys to dtypes.

最后返回信息的形式为tuple类型。

 

 

 

 

====================================================