zl程序教程

您现在的位置是:首页 >  工具

当前栏目

Stacked hourglass networks for human pose estimation代码学习

学习代码 for Networks Human estimation
2023-09-11 14:20:11 时间

Stacked hourglass networks for human pose estimation
https://github.com/princeton-vl/pytorch_stacked_hourglass
这是一个用于人体姿态估计的模型,只能检测单个人
作者通过重复的bottom-up(高分辨率->低分辨率)和top-down(低分辨率->高分辨率)以及中间监督(深监督)来提升模型的性能

模型

残差

模型里的残差都是不改变分辨率的
在这里插入图片描述
在这里插入图片描述

class Conv(nn.Module):
    def __init__(self, inp_dim, out_dim, kernel_size=3, stride=1, bn=False, relu=True):
        super(Conv, self).__init__()
        self.inp_dim = inp_dim
        self.conv = nn.Conv2d(inp_dim, out_dim, kernel_size, stride, padding=(kernel_size - 1) // 2, bias=True)
        self.relu = None
        self.bn = None
        if relu:
            self.relu = nn.ReLU()
        if bn:
            self.bn = nn.BatchNorm2d(out_dim)

    def forward(self, x):
        assert x.size()[1] == self.inp_dim, "{} {}".format(x.size()[1], self.inp_dim)
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x


class Residual(nn.Module):
    def __init__(self, inp_dim, out_dim):
        super(Residual, self).__init__()
        self.relu = nn.ReLU()
        self.bn1 = nn.BatchNorm2d(inp_dim)
        self.conv1 = Conv(inp_dim, out_dim // 2, 1, relu=False)
        self.bn2 = nn.BatchNorm2d(out_dim // 2)
        self.conv2 = Conv(out_dim // 2, out_dim // 2, 3, relu=False)
        self.bn3 = nn.BatchNorm2d(out_dim // 2)
        self.conv3 = Conv(out_dim // 2, out_dim, 1, relu=False)
        self.skip_layer = Conv(inp_dim, out_dim, 1, relu=False)
        if inp_dim == out_dim:
            self.need_skip = False
        else:
            self.need_skip = True

    def forward(self, x):  # ([1, inp_dim, H, W])
        if self.need_skip:
            residual = self.skip_layer(x)  # ([1, out_dim, H, W])
        else:
            residual = x  # ([1, out_dim, H, W])
        out = self.bn1(x)
        out = self.relu(out)
        out = self.conv1(out)  # ([1, out_dim / 2, H, W])

        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)  # ([1, out_dim / 2, H, W])

        out = self.bn3(out)
        out = self.relu(out)
        out = self.conv3(out)  # ([1, out_dim, H, W])

        out += residual  # ([1, out_dim, H, W])
        return out  # ([1, out_dim, H, W])

最前面

首先模型使用了一个卷积核为 7 ∗ 7 7*7 77步长为2的卷积,然后使用了一个残差和下采样,将图像从 256 ∗ 256 256*256 256256降到了 64 ∗ 64 64*64 6464
接着接了两个残差

对应论文这一段
在这里插入图片描述

self.pre = nn.Sequential(  # ([B, 3, 256, 256])
            Conv(3, 64, 7, 2, bn=True, relu=True),  # ([B, 64, 128, 128])
            Residual(64, 128),  # ([B, 128, 128, 128])
            Pool(2, 2),  # ([B, 128, 64, 64])
            Residual(128, 128),  # ([B, 128, 64, 64])
            Residual(128, inp_dim)  # ([B, 256, 64, 64])
        )

在这里插入图片描述

单个Hourglass

在每一次最大池化之前,模型会产生一个分支,一条最大池化,另一条会接卷积(残差)
合并之前,走最大池化的的分支会做一次上采样,然后两个分支按元素加
(对应论文这两句)
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述
代码对应这个图
(然而论文的图里最前面的残差不知道怎么算。。。)
在这里插入图片描述

class Hourglass(nn.Module):
    def __init__(self, n, f, bn=None, increase=0):
        super(Hourglass, self).__init__()
        nf = f + increase
        self.up1 = Residual(f, f)
        # Lower branch
        self.pool1 = Pool(2, 2)
        self.low1 = Residual(f, nf)
        self.n = n
        # Recursive hourglass
        if self.n > 1:
            self.low2 = Hourglass(n - 1, nf, bn=bn)
        else:
            self.low2 = Residual(nf, nf)
        self.low3 = Residual(nf, f)
        self.up2 = nn.Upsample(scale_factor=2, mode='nearest')

    def forward(self, x):  # ([1, f, H, W])
        up1 = self.up1(x)  # ([1, f, H, W])
        pool1 = self.pool1(x)  # ([1, f, H/2, W/2])
        low1 = self.low1(pool1)  # ([1, nf, H/2, W/2])
        low2 = self.low2(low1)  # ([1, nf, H/2, W/2])
        low3 = self.low3(low2)  # ([1, f, H/2, W/2])
        up2 = self.up2(low3)  # ([1, f, H, W])
        return up1 + up2  # ([1, f, H, W])

热力图

模型会接两个 1 ∗ 1 1*1 11的卷积来产生热力图(heatmap)
在这里插入图片描述
(虽然不知道为啥代码里还有一个残差)
在这里插入图片描述

中间监督

将前一个Hourglass,heatmap,heatmap之前的特征通过2个 1 ∗ 1 1*1 11的卷积加在一起
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

数据增强

img.py

这个需要一些仿射变换的知识(平移矩阵,旋转矩阵和放缩矩阵)
https://blog.csdn.net/qq_39942341/article/details/129440487?spm=1001.2014.3001.5501
首先是这个get_transform
他的主要作用是得到【在原图上面裁剪出以center为中心,scale*200的bbox,然后放缩到res大小,最后旋转】的矩阵

因为scale在标注里是除以200的,所以这里乘回去,得到原本的bbox大小
https://blog.csdn.net/qq_39942341/article/details/129289591?spm=1001.2014.3001.5502
首先是第一个矩阵t,可以拆成3个矩阵
有了仿射变换的知识,就知道他是先将bbox平移到中心在原点
然后放缩
最后将bbox平移到左上角在原点
( f l o a t ( r e s [ 1 ] ) / h 0 − r e s [ 1 ] ∗ f l o a t ( c e n t e r [ 0 ] ) / h + 0.5 r e s [ 1 ] 0 f l o a t ( r e s [ 0 ] ) / h − r e s [ 0 ] ∗ f l o a t ( c e n t e r [ 1 ] ) / h + 0.5 r e s [ 0 ] 0 0 1 ) = ( 1 0 0.5 r e s [ 1 ] 0 1 0.5 r e s [ 0 ] 0 0 1 ) ( f l o a t ( r e s [ 1 ] ) / h 0 0 0 f l o a t ( r e s [ 0 ] ) / h 0 0 0 1 ) ( 1 0 − c e n t e r [ 0 ] 0 1 − c e n t e r [ 1 ] 0 0 1 ) \begin{aligned} &\begin{pmatrix} float(res[1]) / h & 0 & -res[1] * float(center[0]) / h + 0.5 res[1]\\ 0 & float(res[0]) / h & -res[0] * float(center[1]) / h + 0.5 res[0]\\ 0 & 0 & 1\\ \end{pmatrix} \\ =& \begin{pmatrix} 1&0&0.5 res[1]\\ 0&1&0.5 res[0]\\ 0&0&1 \end{pmatrix}\begin{pmatrix} float(res[1]) / h&0&0\\ 0&float(res[0]) / h&0\\ 0&0&1 \end{pmatrix}\begin{pmatrix} 1&0&-center[0]\\ 0&1&-center[1]\\ 0&0&1 \end{pmatrix} \end{aligned} = float(res[1])/h000float(res[0])/h0res[1]float(center[0])/h+0.5res[1]res[0]float(center[1])/h+0.5res[0]1 1000100.5res[1]0.5res[0]1 float(res[1])/h000float(res[0])/h0001 100010center[0]center[1]1
(灵魂作图)
在这里插入图片描述
接下去那个rot,就是以【图片中心】为中心的旋转

def get_transform(center, scale, res, rot=0):
    # Generate transformation matrix
    h = 200 * scale
    t = np.zeros((3, 3))
    # first shift(so that the image's center is at (0,0))
    # second scale to (res[1], res[0])
    # third shift (so that the image's center is not (0,0) but (0,5res[1], 0.5res[0]))

    # float(res[1]) / h        0                -res[1] * float(center[0]) / h + .5 res[1]
    #            0        float(res[0]) / h     -res[0] * float(center[1]) / h + .5 res[0]
    #            0             0                                 1
    t[0, 0] = float(res[1]) / h
    t[1, 1] = float(res[0]) / h
    t[0, 2] = res[1] * (-float(center[0]) / h + .5)
    t[1, 2] = res[0] * (-float(center[1]) / h + .5)
    t[2, 2] = 1
    if not rot == 0:
        rot = -rot  # To match direction of rotation from cropping
        rot_mat = np.zeros((3, 3))
        rot_rad = rot * np.pi / 180
        sn, cs = np.sin(rot_rad), np.cos(rot_rad)
        # cs -sn  0
        # sn  cs  0
        # 0   0   1
        rot_mat[0, :2] = [cs, -sn]
        rot_mat[1, :2] = [sn, cs]
        rot_mat[2, 2] = 1
        # Need to rotate around center
        # 1 0 -res[1] / 2
        # 0 1 -res[0] / 2
        # 0 0     1
        t_mat = np.eye(3)
        t_mat[0, 2] = -res[1] / 2
        t_mat[1, 2] = -res[0] / 2
        # 1 0 res[1] / 2
        # 0 1 res[0] / 2
        # 0 0     1
        t_inv = t_mat.copy()
        t_inv[:2, 2] *= -1
        t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
    return t

transform
接着这个函数就简单了
将pt变成齐次坐标系,然后变换

def transform(pt, center, scale, res, invert=0, rot=0):
    # Transform pixel location to different reference
    t = get_transform(center, scale, res, rot=rot)
    if invert:
        t = np.linalg.inv(t)
    new_pt = np.array([pt[0], pt[1], 1.]).T
    new_pt = np.dot(t, new_pt)
    return new_pt[:2].astype(int)

crop
接着是裁剪
先找到 ( 0 , 0 ) (0,0) (0,0)在原图中的位置(即bbox的左上角)
然后类似地找右下角
(其实这里完全不用矩阵,可以直接算出来)

h = scale * 200
ul = (center - h / 2).astype(int)
br = (center + h / 2).astype(int)

有了左上角和右上角,就可以确定bbox的大小
接着就是裁剪以及放缩到res大小

def crop(img, center, scale, res, rot=0):
    # Upper left point
    ul = np.array(transform([0, 0], center, scale, res, invert=1))
    # Bottom right point
    br = np.array(transform(res, center, scale, res, invert=1))

    new_shape = [br[1] - ul[1], br[0] - ul[0]]
    if len(img.shape) > 2:
        new_shape += [img.shape[2]]
    new_img = np.zeros(new_shape)

    # Range to fill new array
    new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
    new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
    # Range to sample from original image
    old_x = max(0, ul[0]), min(len(img[0]), br[0])
    old_y = max(0, ul[1]), min(len(img), br[1])
    new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]]

    return cv2.resize(new_img, res)

kpt_affine
这个是关键点的仿射变换

def kpt_affine(kpt, mat):
    kpt = np.array(kpt)
    shape = kpt.shape
    kpt = kpt.reshape(-1, 2)
    # turn to Homogeneous coordinates then np.dot
    return np.dot(np.concatenate((kpt, kpt[:, 0:1] * 0 + 1), axis=1), mat.T).reshape(shape)

dp.py

首先是数据集
其实就是读图,裁剪出人,增强,热力图

class Dataset(torch.utils.data.Dataset):
    def __init__(self, config, ds, index):
        self.input_res = config['train']['input_res']  # 256
        self.output_res = config['train']['output_res']  # 64
        # config['inference']['num_parts'] = 16
        self.generateHeatmap = GenerateHeatmap(self.output_res, config['inference']['num_parts'])
        self.ds = ds
        self.index = index

    def __len__(self):
        return len(self.index)

    def __getitem__(self, idx):
        return self.loadImage(self.index[idx % len(self.index)])

    def loadImage(self, idx):
        ds = self.ds

        ## load + crop
        orig_img = ds.get_img(idx)
        path = ds.get_path(idx)
        orig_keypoints = ds.get_kps(idx)
        kptmp = orig_keypoints.copy()
        c = ds.get_center(idx)
        s = ds.get_scale(idx)
        normalize = ds.get_normalized(idx)
        # to crop
        cropped = utils.img.crop(orig_img, c, s, (self.input_res, self.input_res))
        for i in range(np.shape(orig_keypoints)[1]):
            if orig_keypoints[0, i, 0] > 0:
                orig_keypoints[0, i, :2] = utils.img.transform(orig_keypoints[0, i, :2], c, s,
                                                               (self.input_res, self.input_res))
        keypoints = np.copy(orig_keypoints)

        ## augmentation -- to be done to cropped image
        height, width = cropped.shape[0:2]
        center = np.array((width / 2, height / 2))
        scale = max(height, width) / 200

        aug_rot = 0

        aug_rot = (np.random.random() * 2 - 1) * 30.
        aug_scale = np.random.random() * (1.25 - 0.75) + 0.75
        scale *= aug_scale

        mat_mask = utils.img.get_transform(center, scale, (self.output_res, self.output_res), aug_rot)[:2]

        mat = utils.img.get_transform(center, scale, (self.input_res, self.input_res), aug_rot)[:2]
        inp = cv2.warpAffine(cropped, mat, (self.input_res, self.input_res)).astype(np.float32) / 255
        keypoints[:, :, 0:2] = utils.img.kpt_affine(keypoints[:, :, 0:2], mat_mask)
        if 0 == np.random.randint(2):
            inp = self.preprocess(inp)
            inp = inp[:, ::-1]
            keypoints = keypoints[:, ds.flipped_parts['mpii']]
            keypoints[:, :, 0] = self.output_res - keypoints[:, :, 0]
            orig_keypoints = orig_keypoints[:, ds.flipped_parts['mpii']]
            orig_keypoints[:, :, 0] = self.input_res - orig_keypoints[:, :, 0]

        ## set keypoints to 0 when were not visible initially (so heatmap all 0s)
        for i in range(np.shape(orig_keypoints)[1]):
            if kptmp[0, i, 0] == 0 and kptmp[0, i, 1] == 0:
                keypoints[0, i, 0] = 0
                keypoints[0, i, 1] = 0
                orig_keypoints[0, i, 0] = 0
                orig_keypoints[0, i, 1] = 0

        ## generate heatmaps on outres
        heatmaps = self.generateHeatmap(keypoints)

        return inp.astype(np.float32), heatmaps.astype(np.float32)

    def preprocess(self, data):
        # random hue and saturation
        data = cv2.cvtColor(data, cv2.COLOR_RGB2HSV)
        delta = (np.random.random() * 2 - 1) * 0.2
        data[:, :, 0] = np.mod(data[:, :, 0] + (delta * 360 + 360.), 360.)

        delta_sature = np.random.random() + 0.5
        data[:, :, 1] *= delta_sature
        data[:, :, 1] = np.maximum(np.minimum(data[:, :, 1], 1), 0)
        data = cv2.cvtColor(data, cv2.COLOR_HSV2RGB)

        # adjust brightness
        delta = (np.random.random() * 2 - 1) * 0.3
        data += delta

        # adjust contrast
        mean = data.mean(axis=2, keepdims=True)
        data = (data - mean) * (np.random.random() + 0.5) + mean
        data = np.minimum(np.maximum(data, 0), 1)
        return data

其中翻转比较反直觉
看这一张图,左边是原图,右手在上面,左手在下面
翻转之后,左手在上面,右手在下面,因此才会有keypoints[:, ds.flipped_parts[‘mpii’]](也就是交换左右关节)
在这里插入图片描述

img_aug, keypoint_aug = aug(image=img, keypoints=kps_array)
keypoint_aug = iaa.Resize({"width": 64, "height": 64})(keypoint_aug)

裁剪和数据增强(除去翻转和调整对比度)可以用下面这个代替

import imgaug as ia
import imgaug.augmenters as iaa
from imgaug.augmentables import KeypointsOnImage, Keypoint
aug = iaa.Sequential([
    iaa.CropAndPad(px=(-left_y, -(w - right_x), -(h - right_y), -left_x), keep_size=False),
    iaa.Resize({"width": 256, "height": 256}),
    iaa.Affine(scale=(0.75, 1.25),
               rotate=(-30, 30)),
    iaa.WithColorspace(
        to_colorspace="HSV",
        from_colorspace="BGR",
        children=[
            iaa.WithChannels(0, iaa.Add(value=(-0.2 * 360, 0.2 * 360))),
            iaa.WithChannels(1, iaa.Multiply(mul=(0.5, 1.5)))
        ]
    ),
    iaa.AddToBrightness((-0.3, 0.3)),
    # iaa.LinearContrast,
    # iaa.Fliplr(),
])
img_aug, keypoint_aug = aug(image=img, keypoints=kps_array)
keypoint_aug = iaa.Resize({"width": 64, "height": 64})(keypoint_aug)

https://towardsdatascience.com/using-hourglass-networks-to-understand-human-poses-1e40e349fa15#:~:text=Hourglass%20networks%20are%20a%20type,image%20into%20a%20feature%20matrix.
https://medium.com/@monadsblog/stacked-hourglass-networks-14bee8c35678