zl程序教程

您现在的位置是:首页 >  后端

当前栏目

使用PyTorch进行语义分割「建议收藏」

PyTorch 使用 建议 收藏 进行 分割 语义
2023-06-13 09:11:15 时间

大家好,又见面了,我是你们的朋友全栈君。

本篇文章使用进行pytorch进行语义分割的实验。

1.什么是语义分割?

语义分割是一项图像分析任务,我们将图像中的每个像素分类为对应的类。 这类似于我们人类在默认情况下一直在做的事情。每当我们看到某些画面时,我们都会尝试“分割”图像的哪一部分属于哪个类/标签/类别。 从本质上讲,语义分割是我们可以在计算机中实现这一点的技术。 您可以在我们关于图像分割的帖子中阅读更多关于分割的内容。 这篇文章的重点是语义分割 ,所以,假设我们有下面的图像。

经过语义分割,会得到如下输出:

如您所见,图像中的每个像素都被分类为各自的类。例如,人是一个类,自行车是另一个类,第三个是背景。 简单说,这就是语义分割。

2.语义分割的应用

语义分割最常见的用例是:

2.1 自动驾驶

在自主驾驶中,计算机驾驶汽车需要对前面的道路场景有很好的理解。分割出汽车、行人、车道和交通标志等物体是很重要的。

2.2 人脸分割

面部分割用于将面部的每个部分分割成语义相似的区域-嘴唇、眼睛等。这在许多现实世界的应用程序中都是有用的。一个非常有趣的应用程序可以是虚拟美颜。

2.3 室内物体分割

你能猜出这个在哪里使用吗?在AR(增强现实)和VR(虚拟现实)中。应用程序可以分割整个室内区域,以了解椅子、桌子、人、墙和其他类似物体的位置,从而可以高效地放置和操作虚拟物体。

2.4 大地遥感

地球遥感是一种将卫星图像中的每个像素分类为一个类别的方法,以便我们可以跟踪每个区域的土地覆盖情况。因此,如果在某些地区发生了严重的森林砍伐,那么就可以采取适当的措施。在卫星图像上使用语义分割可以有更多的应用。

让我们看看如何使用PyTorch和Torchvision进行语义分割。

3 torchvision的语义分割

我们将研究两个基于深度学习的语义分割模型。全卷积网络(FCN)和DeepLabv3。这些模型已经在COCO Train 2017数据集的子集上进行了训练,该子集对应于PASCALVOC数据集。模型共支持20个类别。

3.1.输入和输出

在我们开始之前,让我们了解模型的输入和输出。 这些模型期望输入一个3通道图像(RGB),它使用Imagenet的均值和标准差归一化,即, 平均值=[0.485,0.456,0.406],标准差=[0.229,0.224,0.225] 。

输入维数为[Ni x Ci x Hi x Wi] ,其中,

Ni ->批次大小

Ci ->通道数(即3)

Hi ->图片的高度

Wi ->图像的宽度

而模型的输出维数为[No x Co x Ho x Wo] ,其中

No ->批次大小(与Ni相同)

Co->是数据集的类数!

Ho ->图像的高度(几乎在所有情况下都与Hi相同)

Wo ->图像的宽度(几乎在所有情况下都与Wi相同)

注:torchvision模型的输出是一个有序的字典,而不是一个torch.Tensor(张量)。 在推理(.val()模式)过程中,输出是一个有序的Dict,且只有一个键-值对,键为out,它的相应值具有[No x Co x Ho x Wo]的形状。

3.2.具有Resnet-101骨干的FCN 全卷积网络

FCN是第一次成功的使用神经网络用于语义分割工作。让我们看看如何在Torchvision中使用该模型。

3.2.1 加载模型

from torchvision import models
fcn = models.segmentation.fcn_resnet101(pretrained=True).eval()

很简单!我们有一个基于Resnet101的预先训练的FCN模型。如果模型尚未存在于缓存中,则pretrained=True标志将下载该模型。该.val方法将以推理模式加载它。

3.2.2.加载图像

接下来,让我们加载一个图像!我们直接从URL下载一个鸟的图像并保存它。我们使用PIL加载图像。

在当前目录中,下载一张图片。

wget -nv https://static.independent.co.uk/s3fs-public/thumbnails/image/2018/04/10/19/pinyon-jay-bird.jpg -O bird.png

加载显示图像

from PIL import Image
import matplotlib.pyplot as plt
import torch
 
img = Image.open('./bird.png')
plt.imshow(img); plt.show()

3.2.3.对图像进行预处理

为了使图像达到输入格式要求,以便使用模型进行推理,我们需要对其进行预处理并对其进行正则化! 因此,对于预处理步骤,我们进行以下操作。

将图像大小调整为(256×256)

将其转换为(224×224)

将其转换为张量-图像中的所有元素值都将被缩放,以便在[0,1]之间而不是原来的[0,255]范围内。

将其正则化,使用Imagenet数据 的均值=[0.485,0.456,0.406],标准差=[0.229,0.224,0.225]

最后,我们对图像进行增加维度,使它从[C x H x W]变成[1x C x H x W]。这是必需的,因为模型需要按批处理图像。

# Apply the transformations needed
import torchvision.transforms as T
trf = T.Compose([T.Resize(256),
                 T.CenterCrop(224),
                 T.ToTensor(), 
                 T.Normalize(mean = [0.485, 0.456, 0.406], 
                             std = [0.229, 0.224, 0.225])])
inp = trf(img).unsqueeze(0)

让我们看看上面的代码单元是做什么的。 torchvision有许多有用的函数。其中之一是用于预处理图像的Transforms。T.Compose是一个函数,它接受一个列表,其中每个元素都是transforms 类型,它返回一个对象,我们可以通过这个对象传递一批图像,所有所需的转换都将应用于图像。

让我们来看看应用于图像上的转换:

T.Resize(256):将图像尺寸调整为256×256

T.CenterCrop(224):从图像的中心抠图,大小为224×224

T.ToTensor():将图像转换为张量,并将值缩放到[0,1]范围

T.Normalize(mean, std):用给定的均值和标准差对图像进行正则化。

3.2.4.正向传递通过网络

现在我们已经对图像进行了所有的预处理,让我们通过模型并得到OUT键。 正如我们前面提到的,模型的输出是一个有序的Dict,所以我们需要从其中取出out键的value来获得模型的输出。

# Pass the input through the net
out = fcn(inp)['out']
print (out.shape)

输出为:

torch.Size([1, 21, 224, 224])

所以,out是模型的最终输出。正如我们所看到的,它的形状是[1 x 21 x H x W],正如前面所讨论的。因为,模型是在21个类上训练的,输出有21个通道!(包括背景类) 现在我们需要做的是,使这21个通道输出到一个2D图像或一个1通道图像,其中该图像的每个像素对应于一个类! 因此,2D图像(形状[HxW])的每个像素将与相应的类标签对应,对于该2D图像中的每个(x,y)像素将对应于表示类的0-20之间的数字。 我们如何从这个[1 x 21 x H x W]的列表到达那里?我们为每个像素位置取一个最大索引,该索引表示类的下标,看到这里是否似曾相识,对了,之前的文章讲到,多分类的输出是一个列表,存有每个类的置信度,这里每个像素点的21个通道对应着每个类的置信度。

import numpy as np
om = torch.argmax(out.squeeze(), dim=0).detach().cpu().numpy()
print (om.shape)
(224, 224)

out.squeeze()对out进行降纬,变为[21 x H x W], dim=0说明对第一纬(21)进行取极大值操作,故结果为[H x W]。

print (np.unique(om))
[0 3]

可见,处理完的列表中共有2种元素,0(背景),3(鸟)。正如我们所看到的,现在我们有了一个2D图像,其中每个像素属于一个类。最后一件事是把这个2D图像转换成一个分割图像,每个类标签对应于一个RGB颜色,从而使图像易于观看。

3.2.5.解码输出

我们将使用以下函数将此2D图像转换为RGB图像,其中每个(元素)标签映射到相应的颜色。

# Define the helper function
def decode_segmap(image, nc=21):
   
  label_colors = np.array([(0, 0, 0),  # 0=background
               # 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle
               (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128),
               # 6=bus, 7=car, 8=cat, 9=chair, 10=cow
               (0, 128, 128), (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0),
               # 11=dining table, 12=dog, 13=horse, 14=motorbike, 15=person
               (192, 128, 0), (64, 0, 128), (192, 0, 128), (64, 128, 128), (192, 128, 128),
               # 16=potted plant, 17=sheep, 18=sofa, 19=train, 20=tv/monitor
               (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128)])
 
  r = np.zeros_like(image).astype(np.uint8)
  g = np.zeros_like(image).astype(np.uint8)
  b = np.zeros_like(image).astype(np.uint8)
   
  for l in range(0, nc):
    idx = image == l
    r[idx] = label_colors[l, 0]
    g[idx] = label_colors[l, 1]
    b[idx] = label_colors[l, 2]
     
  rgb = np.stack([r, g, b], axis=2)
  return rgb

idx = image == l r[idx] = label_colors[l, 0] 使用了numpy列表的高级索引功能,即使用布尔列表进行索引,在这里就是每个元素赋值成对应标签的颜色。

首先,列表label_colors根据索引存储每个类的颜色。因此,第一类的颜色是背景,存储在label_colors列表的第0个索引处。第二类,即飞机,存储在索引1中,以此类推。 现在,我们必须从我们拥有的2D图像中创建一个RGB图像。因此,我们所做的是为所有3个通道创建空的2D矩阵。 因此,r、g和b是构成最终图像的RGB通道的列表,这些列表中的每一个的形状都是[HxW](这与2D图像的形状相同)。 现在,我们循环存储在label_colors中的每个颜色,并在存在特定类标签的2D图像中获取索引。然后,对于每个通道,我们将其相应的颜色放置到存在该类标签的像素上。 最后,我们将3个独立的通道叠加起来,形成RGB图像。 好吧!现在,让我们使用这个函数来查看最终的输出!

rgb = decode_segmap(om)
plt.imshow(rgb); plt.show()

3.2.6.最终结果

接下来,让我们把所有操作放入一个函数下,并测试更多的图像!

def segment(net, path):
  img = Image.open(path)
  plt.imshow(img); plt.axis('off'); plt.show()
  # Comment the Resize and CenterCrop for better inference results
  trf = T.Compose([T.Resize(256), 
                   T.CenterCrop(224), 
                   T.ToTensor(), 
                   T.Normalize(mean = [0.485, 0.456, 0.406], 
                               std = [0.229, 0.224, 0.225])])
  inp = trf(img).unsqueeze(0)
  out = net(inp)['out']
  om = torch.argmax(out.squeeze(), dim=0).detach().cpu().numpy()
  rgb = decode_segmap(om)
  plt.imshow(rgb); plt.axis('off'); plt.show()

3.4.多个对象

如果我们测试一个更复杂的图像,那么我们可以开始看到一些不同的结果。

wget -nv "https://images.pexels.com/photos/2385051/pexels-photo-2385051.jpeg" -O dog-park.png
img = Image.open('./dog-park.png')
plt.imshow(img); plt.show()
 
print ('Segmenatation Image on FCN')
segment(fcn, path='./dog-park.png', show_orig=False)
 
print ('Segmenatation Image on DeepLabv3')
segment(dlab, path='./dog-park.png', show_orig=False)

FCN结果

DeepLab结果

两者有细微差别

参考连接:https://www.learnopencv.com/pytorch-for-beginners-semantic-segmentation-using-torchvision/