矩阵乘法脉动阵列的C++模拟
2023-09-14 09:16:18 时间
自从谷歌的TPU问世以后,被人们遗忘很久的脉动阵列又再次火了一把。矩阵乘法就可以用脉动阵列进行计算,而脉动阵列这种数据流又特别适合用硬件进行实现。下面是用脉动阵列进行矩阵乘法的示意图。可以看到,A的每一行不是同时进入脉动阵列的,而B的每一列也不是同时进入脉动阵列的,相邻行或列进入脉动阵列的时间恰好相差一个时钟周期。
以下是用c++模拟脉动阵列的代码:
#include <iostream>
#include <cstdlib>
#include <ctime>
#include <iomanip>
#define N 256
using namespace std;
typedef struct PE{
int weight;
int neuron;
int psum;
};
class Systolic{
public:
PE S[N][N];
public:
void Init(){
for(int i=0;i<N;i++)
for(int j=0;j<N;j++){
S[i][j].psum=0;
S[i][j].weight=0;
S[i][j].neuron=0;
}
}
void calc(){
for(int i=0;i<N;i++)
for(int j=0;j<N;j++)
S[i][j].psum+=S[i][j].weight*S[i][j].neuron;
}
void shift(int a[N],int b[N]){
//水平方向传播矩阵A,a[N]是本次要被读入的列(left->right)
for(int i=0;i<N;i++)
for(int j=N-1;j>0;j--){
S[i][j].neuron=S[i][j-1].neuron;
}
for(int i=0;i<N;i++)
S[i][0].neuron=a[i];
//竖直方向上传播矩阵B,b[N]是本次要被读入的行(up->bottom)
for(int j=0;j<N;j++)
for(int i=N-1;i>0;i--){
S[i][j].weight=S[i-1][j].weight;
}
for(int j=0;j<N;j++)
S[0][j].weight=b[j];
}
void Display(){
cout<<"weight:"<<endl;
for(int i=0;i<N;i++){
for(int j=0;j<N;j++)
cout<<fixed<<setw(4)<<S[i][j].weight<<",";
cout<<endl;
}
cout<<"neuron:"<<endl;
for(int i=0;i<N;i++){
for(int j=0;j<N;j++)
cout<<fixed<<setw(4)<<S[i][j].neuron<<",";
cout<<endl;
}
}
};
void Print(int A[N][N]){
for(int i=0;i<N;i++){
for(int j=0;j<N;j++)
cout<<fixed<<setw(4)<<A[i][j]<<",";
cout<<endl;
}
}
void systolic_mm(int A[N][N],int B[N][N],int C[N][N]){
Systolic S;
S.Init();
int a[N];
int b[N];
int clock=0;
while(clock<=3*N-3){
//产生a[N]
for(int i=0;i<N;i++)
a[i]=(clock>=i&&clock<N+i)?A[i][clock-i]:0;
//产生b[N]
for(int j=0;j<N;j++)
b[j]=(clock>=j&&clock<N+j)?B[clock-j][j]:0;
S.shift(a,b);
S.calc();
//cout<<"clock="<<clock<<endl;
//S.Display();
clock++;
}
for(int i=0;i<N;i++)
for(int j=0;j<N;j++)
C[i][j]=S.S[i][j].psum;
return;
}
void Matrix_Mult(int A[N][N],int B[N][N],int C[N][N]){
for(int i=0;i<N;i++)
for(int j=0;j<N;j++){
C[i][j]=0;
for(int k=0;k<N;k++)
C[i][j]+=A[i][k]*B[k][j];
}
return;
}
bool Compare(int O1[N][N],int O2[N][N]){
for(int i=0;i<N;i++)
for(int j=0;j<N;j++)
if(O1[i][j]!=O2[i][j])
return false;
return true;
}
int main()
{
srand(time(0));
int C1[N][N];
int C2[N][N];
int A[N][N];
int B[N][N];
int n;
cout<<"Input check num n:";
cin>>n;
cout<<"N="<<N<<endl;
int i=0;
while(i++<n){
for(int i=0;i<N;i++)
for(int j=0;j<N;j++){
A[i][j]=rand()%50-25;
B[i][j]=rand()%20-10;
}
//cout<<"Matrix A:"<<endl;
//Print(A);
//cout<<"Matrix B:"<<endl;
//Print(B);
Matrix_Mult(A,B,C1);
//cout<<"Matrix C1:"<<endl;
//Print(C1);
systolic_mm(A,B,C2);
//cout<<"Matrix C2:"<<endl;
//Print(C2);
bool is_right=Compare(C1,C2);
if(!is_right){
cout<<"error"<<endl;
break;}
cout<<"Compare C1 and C2,and the result is "<<boolalpha<<is_right<<endl;
}
return 0;
}
N=256时测试100次,发现改代码不存在错误:
相关文章
- C++ 标准库类型-String,Vector and Bitset
- <现代C++实战30讲>笔记 01 | 堆、栈、RAII:C++里该如何管理资源?
- C++中的指针与引用、如何参数传递
- java实现第二届蓝桥杯地铁换乘(C++)
- C语言/C++常见习题问答集锦(九)
- paip.提升用户体验---c++ qt 悬浮窗实现
- 【华为OD机试 2023】 网上商城优惠活动 / 模拟商场优惠打折II(C++ Java Javascript Python)
- 解答私信@被c++折磨头秃的花季美少女 //C++ 利用指针数组输入10个单词,编写函数对10个单词进行排序并输出,要求判断是否有相同的单词,如果有相同的单词在输出时该单词只输出一次。
- 解答私信@被c++折磨头秃的花季美少女 //C++ 编写一个进阶版的进制转换程序,运行功能如下:请选择要输入的数字的进制(2、8、10、16):请输入该数字:请选择要转换成的进制(2、8。。。
- 解答私信@被c++折磨头秃的花季美少女 //C++ 写一个带命令行参数的程序,可以实现将参数求和、求平均值以及排序之后输出(参数的数量不确定)。
- 通过c++11的std::bind及std::function实现类方法回调,模拟Qt实现信号槽
- C++ string顺序查找和逆序查找
- Ubuntu20.04下,qt交叉编译报错::15: warning: identifier ‘nullptr‘ is a keyword in C++11 [-Wc++0x-compat]
- C++查看变量类型
- C# 与C/C++相互调用
- C++ 文件和流
- 中介模式C++实现
- 【opencv-c++】 关于opencv.hpp头文件
- PAT 1049 C++版
- C++ 特性之多态
- C++基础知识要点--字符串、向量和数组 (Primer C++ 第五版 · 阅读笔记)
- 【C++要笑着学】深浅拷贝 | string 模拟实现 | 传统写法与现代写法