1.2 矩阵乘法Strassen算法
传统算法
Strassen算法将2X2矩阵的乘法次数从8次减少到了7次。在介绍strassen算法之前,先用传统的算法计算一下2*2的矩阵乘法。
A
=
[
1
2
3
4
]
B
=
[
5
6
7
8
]
A
×
B
=
[
1
×
5
+
2
×
7
1
×
6
+
2
×
8
3
×
5
+
4
×
7
3
×
6
+
4
×
8
]
=
[
19
22
43
50
]
A= \left[ \begin{matrix} 1 & 2 \\ 3 & 4 \end{matrix}\right]\\ B= \left[ \begin{matrix} 5 & 6 \\ 7 & 8 \end{matrix}\right]\\ A\times B=\left[ \begin{matrix} 1\times5+2\times7 & 1\times6+2\times8 \\ 3\times5+4\times7 & 3\times6+4\times8 \end{matrix}\right]=\left[ \begin{matrix} 19 & 22 \\ 43 & 50 \end{matrix}\right]\\
A=[1324]B=[5768]A×B=[1×5+2×73×5+4×71×6+2×83×6+4×8]=[19432250]
总共使用了8次乘法和4次加法。
Strassen算法
Strassen算法使用了7个中间变量,巧妙地用7次乘法合18次加法,减少了1次乘法操作,提高了算法的性能。其算法如下:
设矩阵A、B为:
A
=
[
A
11
A
12
A
21
A
22
]
B
=
[
B
11
B
12
B
21
B
22
]
A= \left[ \begin{matrix} A_{11} & A_{12} \\ A_{21} & A_{22} \end{matrix}\right]\\ B= \left[ \begin{matrix} B_{11} & B_{12} \\ B_{21} & B_{22} \end{matrix}\right]\\
A=[A11A21A12A22]B=[B11B21B12B22]
建立7个临时变量
P
1
P_1
P1到
P
7
P_7
P7,每个变量使用一次乘法运算。
P
1
=
(
A
11
+
A
22
)
(
B
11
+
B
22
)
P
2
=
(
A
21
+
A
22
)
B
11
P
3
=
A
11
(
B
12
−
B
22
)
P
4
=
A
22
(
B
21
−
B
11
)
P
5
=
(
A
11
+
A
12
)
B
22
P
6
=
(
A
21
−
A
11
)
(
B
11
+
B
12
)
P
7
=
(
A
12
−
A
22
)
(
B
21
+
B
22
)
C
11
=
P
1
+
P
4
−
P
5
+
P
7
C
12
=
P
3
+
P
5
C
21
=
P
2
+
P
4
C
22
=
P
1
−
P
2
+
P
3
+
P
6
A
×
B
=
[
C
11
C
12
C
21
C
22
]
P_1 = (A_{11}+A_{22})(B_{11}+B_{22})\\ P_2 = (A_{21}+A_{22})B_{11}\\ P_3 = A_{11}(B_{12} − B_{22})\\ P_4 = A_{22}(B_{21} − B_{11})\\ P_5 = (A_{11} + A_{12})B_{22}\\ P_6 = (A_{21} − A_{11})(B_{11} + B_{12})\\ P_7 = (A_{12} − A_{22})(B_{21 }+ B_{22})\\ C_{11} = P_1 + P_4 − P_5 + P_7\\ C_{12} = P_3 + P_5\\ C_{21} = P_2 + P_4\\ C_{22} = P_1 − P_2 + P_3 + P_6\\ A\times B=\left[ \begin{matrix} C_{11} & C_{12} \\ C_{21} & C_{22} \end{matrix}\right]\\
P1=(A11+A22)(B11+B22)P2=(A21+A22)B11P3=A11(B12−B22)P4=A22(B21−B11)P5=(A11+A12)B22P6=(A21−A11)(B11+B12)P7=(A12−A22)(B21+B22)C11=P1+P4−P5+P7C12=P3+P5C21=P2+P4C22=P1−P2+P3+P6A×B=[C11C21C12C22]
公式比较复杂,总共11个公式呢,根本记不住,所以我建议,收藏我的博文,不要去记忆,当然也可以顺便关注我一波。
需要注意的是上面11个公式中,乘法的左右顺序特别重要,因为这个公式可以适用于任何代数环。代数环就是乘法不需要符合交换律的集合、加法与乘法运算符。这意味着什么,这意味着2X2矩阵中的元素不仅可以是数字,还可以是矩阵。也就是说可以利用分块矩阵的方法,将大矩阵拆分为2X2的矩阵再使用Strassen算法。
不过需要注意的是因为存在
A
11
+
A
22
A_{11}+A_{22}
A11+A22这样的骚操作,所以进行矩阵分块时,行数或者列数不能是奇数,所以在为奇数的时候还是要用传统的方法啊。
python实现
跟我以往的文章不同,这次我没有把本文的算法代码和其他博文的代码混在一起。我新写了一个python文件,只做Strassen算法,而且使用了分治以处理大矩阵,代码如下:
class Matrix:
# 矩阵
@staticmethod
def create_by_lines(lines):
# 为了支持分块,设置四个属性
return Matrix(lines, 0, len(lines), 0, len(lines[0]))
def __init__(self, lines, row_start, row_end, column_start, column_end):
self.__lines = lines
# 为了支持分块,设置四个属性
self.__column_start = column_start
self.__column_end = column_end
self.__row_start = row_start
self.__row_end = row_end
def __mul__(self, other):
# 首先判断能不能相乘
if self.column_len() != other.row_len():
raise Exception("矩阵A列数%d != 矩阵B的行数%d" % (len(self.__lines[0]), len(other.__lines)))
# 然后判断是不是2X2矩阵
# 这里场景比较多:
# 1 1 x n n x 1
# 2 n x 1 1 x n
# 3 2 x 2 2 x 2 strassen 数值运算
# 4 其他,进行分块 strassen 矩阵运算
if self.row_len() == 1 or self.column_len() == 1:
return self.plain_mul(other)
# 奇数不能分块
if self.row_len() & 1 == 1 or self.column_len() & 1 == 1 or other.row_len() & 1 == 1:
return self.plain_mul(other)
# 这个时候就可以使用strassen算法了
a11, a12, a21, a22 = self.sub()
b11, b12, b21, b22 = other.sub()
p1 = (a11 + a22) * (b11 + b22)
p2 = (a21 + a22) * b11
p3 = a11 * (b12 - b22)
p4 = a22 * (b21 - b11)
p5 = (a11 + a12) * b22
p6 = (a21 - a11) * (b11 + b12)
p7 = (a12 - a22) * (b21 + b22)
return Matrix.create(p1 + p4 - p5 + p7, p3 + p5, p2 + p4, p1 - p2 + p3 + p6)
def __add__(self, other):
arr = [[0] * self.column_len() for _ in range(0, self.row_len())]
# 里面不能是同一个数组
for i in range(0, self.row_len()):
self_row = self.__lines[self.__row_start + i]
other_row = other.__lines[other.__row_start + i]
for j in range(0, self.column_len()):
arr[i][j] = self_row[self.__column_start + j] + other_row[other.__column_start + j]
return Matrix.create_by_lines(arr)
def __sub__(self, other):
arr = [[0] * self.column_len() for _ in range(0, self.row_len())]
# 里面不能是同一个数组
for i in range(0, self.row_len()):
self_row = self.__lines[self.__row_start + i]
other_row = other.__lines[other.__row_start + i]
for j in range(0, self.column_len()):
arr[i][j] = self_row[self.__column_start + j] - other_row[other.__column_start + j]
return Matrix.create_by_lines(arr)
def plain_mul(self, other):
# 弄一个m行n列的新矩阵
m = self.row_len()
n = other.column_len()
p = other.row_len()
result = [[0] * n for _ in range(0, m)]
# i 代表 A矩阵的行
for i in range(self.__row_start, self.__row_end):
# j 代表 B 矩阵的列
for j in range(other.__column_start, other.__column_end):
# 第一个矩阵的行 与第二个矩阵列的乘积和
# k 代表 A矩阵的列和B矩阵的行
for k in range(0, p):
self_line = self.__lines[i]
other_line = other.__lines[other.__row_start + k]
a = self_line[self.__column_start + k]
b = other_line[j]
mul = a * b
result[i - self.__row_start][j - other.__column_start] += mul
return Matrix.create_by_lines(result)
def row_len(self):
return self.__row_end - self.__row_start
def column_len(self):
return self.__column_end - self.__column_start
def sub(self):
a_middle_row = (self.__row_end + self.__row_start) // 2
a_middle_column = (self.__column_end + self.__column_start) // 2
a11 = Matrix(self.__lines, self.__row_start, a_middle_row, self.__column_start, a_middle_column)
a12 = Matrix(self.__lines, self.__row_start, a_middle_row, a_middle_column, self.__column_end)
a21 = Matrix(self.__lines, a_middle_row, self.__row_end, self.__column_start, a_middle_column)
a22 = Matrix(self.__lines, a_middle_row, self.__row_end, a_middle_column, self.__column_end)
return a11, a12, a21, a22
@staticmethod
def create(a11, a12, a21, a22):
len_rows = a11.row_len() + a21.row_len()
len_columns = a11.column_len() + a12.column_len()
lines = [[0] * len_columns for _ in range(0, len_rows)]
# 拷贝进去
a11.copy_to(lines, 0, 0)
a12.copy_to(lines, 0, a11.column_len())
a21.copy_to(lines, a11.row_len(), 0)
a22.copy_to(lines, a12.row_len(), a21.column_len())
return Matrix.create_by_lines(lines)
def copy_to(self, lines, row_start, column_start):
for i in range(0, self.row_len()):
self_row = self.__lines[self.__row_start + i]
other_row = lines[row_start + i]
for j in range(0, self.column_len()):
other_row[column_start + j] = self_row[self.__column_start + j]
@property
def lines(self):
return self.__lines
相关文章
- 编写js程序实现n的阶乘_javascript矩阵算法
- 智能优化算法简介
- 量子算法征服了一种新的问题
- 经典排序算法(1)——冒泡排序算法详解
- java 随机数算法_Java随机数算法原理与实现方法实例详解
- 从矩阵链式求导的角度来深入理解BP算法(原理+代码)
- 3D重建算法综述
- 算法面试题
- 全网最详细!油管1小时视频详解AlphaTensor矩阵乘法算法
- 带你详细了解AES算法《附带java、vue实现》
- 【安全算法之SHA384】SHA384摘要运算的C语言源码实现
- 【数据挖掘】聚类 Cluster 简介 ( 概念 | 应用场景 | 质量 | 相似度 | 算法要求 | 数据矩阵 | 相似度矩阵 | 二模矩阵 | 单模矩阵 )
- 【计算理论】计算复杂性 ( 时间复杂度时间单位 : 步数 | 算法分析 | 算法复杂性分析 )
- nginx负载均衡算法
- 算法练习题(六)——Z字型打印矩阵
- 复杂度估算和一些简单排序算法
- 一致性Hash算法的Java实现详解编程语言
- MySQL 中 SEPOR 算法的作用和原理简介(mysql中sepor)
- Redis实现浏览量排序算法(redis设置浏览量排序)
- JavaScript实现的图像模糊算法代码分享
- C++实现矩阵原地转置算法