zl程序教程

您现在的位置是:首页 >  后端

当前栏目

矩阵乘法脉动阵列的C++模拟

C++模拟 矩阵 乘法
2023-09-14 09:16:18 时间

自从谷歌的TPU问世以后,被人们遗忘很久的脉动阵列又再次火了一把。矩阵乘法就可以用脉动阵列进行计算,而脉动阵列这种数据流又特别适合用硬件进行实现。下面是用脉动阵列进行矩阵乘法的示意图。可以看到,A的每一行不是同时进入脉动阵列的,而B的每一列也不是同时进入脉动阵列的,相邻行或列进入脉动阵列的时间恰好相差一个时钟周期。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
以下是用c++模拟脉动阵列的代码:

#include <iostream>
#include <cstdlib>
#include <ctime>
#include <iomanip>
#define N 256
using namespace std;

typedef struct PE{
    int weight;
    int neuron;
    int psum;
};

class Systolic{
public:
    PE S[N][N];
public:
    void Init(){
        for(int i=0;i<N;i++)
            for(int j=0;j<N;j++){
                S[i][j].psum=0;
                S[i][j].weight=0;
                S[i][j].neuron=0;
            }
    }
    void calc(){
        for(int i=0;i<N;i++)
            for(int j=0;j<N;j++)
                S[i][j].psum+=S[i][j].weight*S[i][j].neuron;
    }
    void shift(int a[N],int b[N]){
        //水平方向传播矩阵A,a[N]是本次要被读入的列(left->right)
        for(int i=0;i<N;i++)
            for(int j=N-1;j>0;j--){
                S[i][j].neuron=S[i][j-1].neuron;
        }
        for(int i=0;i<N;i++)
            S[i][0].neuron=a[i];
        //竖直方向上传播矩阵B,b[N]是本次要被读入的行(up->bottom)
        for(int j=0;j<N;j++)
            for(int i=N-1;i>0;i--){
                S[i][j].weight=S[i-1][j].weight;
        }
        for(int j=0;j<N;j++)
            S[0][j].weight=b[j];
    }
    void Display(){
        cout<<"weight:"<<endl;
        for(int i=0;i<N;i++){
            for(int j=0;j<N;j++)
                cout<<fixed<<setw(4)<<S[i][j].weight<<",";
            cout<<endl;
        }
        cout<<"neuron:"<<endl;
        for(int i=0;i<N;i++){
            for(int j=0;j<N;j++)
                cout<<fixed<<setw(4)<<S[i][j].neuron<<",";
            cout<<endl;
        }
    }
};

void Print(int A[N][N]){
   for(int i=0;i<N;i++){
       for(int j=0;j<N;j++)
           cout<<fixed<<setw(4)<<A[i][j]<<",";
       cout<<endl;
   }
}


void systolic_mm(int A[N][N],int B[N][N],int C[N][N]){
   Systolic S;
   S.Init();
   int a[N];
   int b[N];
   int clock=0;
   while(clock<=3*N-3){
        //产生a[N]
        for(int i=0;i<N;i++)
            a[i]=(clock>=i&&clock<N+i)?A[i][clock-i]:0;
        //产生b[N]
        for(int j=0;j<N;j++)
            b[j]=(clock>=j&&clock<N+j)?B[clock-j][j]:0;
        S.shift(a,b);
        S.calc();
        //cout<<"clock="<<clock<<endl;
        //S.Display();
        clock++;
   }
   for(int i=0;i<N;i++)
       for(int j=0;j<N;j++)
           C[i][j]=S.S[i][j].psum;
   return;
}
void Matrix_Mult(int A[N][N],int B[N][N],int C[N][N]){
   for(int i=0;i<N;i++)
       for(int j=0;j<N;j++){
            C[i][j]=0;
            for(int k=0;k<N;k++)
                C[i][j]+=A[i][k]*B[k][j];
   }
   return;
}

bool Compare(int O1[N][N],int O2[N][N]){
    for(int i=0;i<N;i++)
        for(int j=0;j<N;j++)
            if(O1[i][j]!=O2[i][j])
                return false;
    return true;
}
int main()
{
    srand(time(0));
    int C1[N][N];
    int C2[N][N];
    int A[N][N];
    int B[N][N];
    int n;
    cout<<"Input check num n:";
    cin>>n;
    cout<<"N="<<N<<endl;
    int i=0;
    while(i++<n){
    for(int i=0;i<N;i++)
        for(int j=0;j<N;j++){
            A[i][j]=rand()%50-25;
            B[i][j]=rand()%20-10;
    }
    //cout<<"Matrix A:"<<endl;
    //Print(A);
    //cout<<"Matrix B:"<<endl;
    //Print(B);
    Matrix_Mult(A,B,C1);
    //cout<<"Matrix C1:"<<endl;
    //Print(C1);
    systolic_mm(A,B,C2);
    //cout<<"Matrix C2:"<<endl;
    //Print(C2);
    bool is_right=Compare(C1,C2);
    if(!is_right){
        cout<<"error"<<endl;
        break;}
    cout<<"Compare C1 and C2,and the result is "<<boolalpha<<is_right<<endl;
    }
    return 0;
}

N=256时测试100次,发现改代码不存在错误:
在这里插入图片描述