zl程序教程

您现在的位置是:首页 >  其它

当前栏目

FFT实现卷积运算

实现 运算 卷积 FFT
2023-09-14 09:16:18 时间

2D卷积的实现:

import numpy as np
def fft_conv(x,weight,padding,stride):
    x_padded=np.pad(x,((padding,padding),(padding,padding)))
    h_pad,w_pad=x_padded.shape
    kx,ky=weight.shape
    weight_padded=np.pad(weight,((0,h_pad-kx),(0,w_pad-ky)))
    x_ffted=np.fft.fft2(x_padded)
    weight_ffted=np.fft.fft2(weight_padded)
    weight_ffted.imag*=-1
    output=np.fft.ifft2(x_ffted*weight_ffted)
    #crop_slices = [slice(0, (h_pad - kx + 1), stride),slice(0,(w_pad-ky+1),stride)]
    output = output[0:h_pad-kx+1:stride,0:w_pad-ky+1:stride]
    return output

Nin=27
K=3
S=2
P=1
Nout=(Nin-K+2*P)//S+1
print(Nout)

x=torch.randn((1,1,Nin,Nin))
w=torch.randn((1,1,K,K))
y1=F.conv2d(x,weight=w,bias=None,stride=S,padding=P)
y1=y1.view(Nout,Nout).numpy()
x=x.numpy().reshape(Nin,Nin)
w=w.numpy().reshape(K,K)
y2=fft_conv(x=x,weight=w,padding=P,stride=S)

print(np.max(np.abs(y1-y2)))

包含输入输出通道时卷积的实现

def fft_conv2d(input,weight,stride,padding):
    N,H,W=input.shape
    M,N,K,K=weight.shape
    x=np.pad(input,((0,0),(padding,padding),(padding,padding)))    #padding
    w=np.pad(weight,((0,0),(0,0),(0,H+2*P-K),(0,W+2*P-K)))         #权重padding到和x一样的大小,并且只在下方和右边padding
    x_ffted=np.fft.fft2(x,axes=(-2,-1))
    w_ffted=np.fft.fft2(w,axes=(-2,-1))
    w_ffted.imag*=-1                                              #import!
    out=np.sum(x_ffted*w_ffted,axis=1)                            #输入通道维度求和
    assert out.shape==(M,H+2*P,W+2*P)
    out=np.fft.ifft2(out,axes=(-2,-1))
    return out[:,0:H+2*P-K+1:stride,0:W+2*P-K+1:stride]           #选取有效的计算结果

Nin=27
C=3
M=4
K=3
S=2
P=1
Nout=(Nin-K+2*P)//S+1
print(Nout)

x=torch.randn((1,C,Nin,Nin))
w=torch.randn((M,C,K,K))
y1=F.conv2d(x,weight=w,bias=None,stride=S,padding=P)
y1=y1.view(M,Nout,Nout).numpy()
x=x.numpy().reshape(C,Nin,Nin)
w=w.numpy().reshape(M,C,K,K)
# y2=fft_conv(x=x,weight=w,padding=P,stride=S)

x=x.reshape(C,Nin,Nin)
w=w.reshape(M,C,K,K)
y2=fft_conv2d(x,weight=w,stride=S,padding=P)
print(np.max(np.abs(y1-y2)))

运行结果
在这里插入图片描述
使用Pytorch将基于2D FFT的卷积运算写成一个自定义层,即插即用:

import torch
import numpy as np
import torch.fft
import torch.nn as nn
import torch.nn.functional as F

class FFTConv2d(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size,stride,padding,bias):
        super().__init__()
        self.has_bias=bias
        self.in_channels=in_channels
        self.out_channels=out_channels
        self.kernel_size=kernel_size
        self.stride=stride
        self.padding=padding
        w=torch.empty((out_channels,in_channels,kernel_size,kernel_size))
        self.weight=nn.Parameter(w)
        nn.init.xavier_uniform_(self.weight)
        if bias:
            b=torch.empty((out_channels,))
            self.bias=nn.Parameter(b)
            nn.init.constant(self.bias, 0)

    def forward(self,x):
        batch,in_planes,height,width=x.size()
        x_padded=F.pad(input=x,pad=[self.padding,self.padding,self.padding,self.padding],mode="constant",value=0)
        w_padded=F.pad(input=self.weight,pad=[0,width+2*self.padding-self.kernel_size,0,height+2*self.padding-self.kernel_size]
                       ,mode="constant",value=0)
        x_ffted=torch.fft.fftn(input=x_padded,dim=(-2,-1))\
            .view(batch,1,self.in_channels,height+2*self.padding,width+2*self.padding)
        w_ffted=torch.fft.fftn(input=w_padded,dim=(-2,-1))\
            .view(1,self.out_channels,self.in_channels,height+2*self.padding,width+2*self.padding)
        w_ffted.imag*=-1
        output=torch.fft.ifftn(torch.sum(w_ffted*x_ffted,dim=2),dim=(-2,-1)).real
        if self.has_bias:
            return output[:,:,0:height+2*self.padding-self.kernel_size+1:self.stride,0:width+2*self.padding-self.kernel_size+1:self.stride]\
                   +self.bias.view(1,-1,1,1)
        else:
            return output[:,:,0:height+2*self.padding-self.kernel_size+1:self.stride,0:width+2*self.padding-self.kernel_size+1:self.stride]





if __name__=='__main__':
    H=32
    W=28
    C=24
    M=32
    K=3
    S=1
    P=1
    B=8
    #
    f=FFTConv2d(in_channels=C,out_channels=M,kernel_size=K,stride=S,padding=P,bias=True)
    x=torch.randn((B,C,H,W))
    w=torch.randn((M,C,K,K))
    b=torch.randn((M,))
    #
    y1=F.conv2d(x,weight=w,bias=b,stride=S,padding=P)
    #
    f.weight.data=w
    f.bias.data=b
    y3=f.forward(x)
    #
    print(torch.max(torch.abs(y3-y1)))