zl程序教程

您现在的位置是:首页 >  后端

当前栏目

1.2 矩阵乘法Strassen算法

算法 矩阵 1.2 乘法
2023-09-14 09:06:54 时间

传统算法

  Strassen算法将2X2矩阵的乘法次数从8次减少到了7次。在介绍strassen算法之前,先用传统的算法计算一下2*2的矩阵乘法。
A = [ 1 2 3 4 ] B = [ 5 6 7 8 ] A × B = [ 1 × 5 + 2 × 7 1 × 6 + 2 × 8 3 × 5 + 4 × 7 3 × 6 + 4 × 8 ] = [ 19 22 43 50 ] A= \left[ \begin{matrix} 1 & 2 \\ 3 & 4 \end{matrix}\right]\\ B= \left[ \begin{matrix} 5 & 6 \\ 7 & 8 \end{matrix}\right]\\ A\times B=\left[ \begin{matrix} 1\times5+2\times7 & 1\times6+2\times8 \\ 3\times5+4\times7 & 3\times6+4\times8 \end{matrix}\right]=\left[ \begin{matrix} 19 & 22 \\ 43 & 50 \end{matrix}\right]\\ A=[1324]B=[5768]A×B=[1×5+2×73×5+4×71×6+2×83×6+4×8]=[19432250]
  总共使用了8次乘法和4次加法。

Strassen算法

  Strassen算法使用了7个中间变量,巧妙地用7次乘法合18次加法,减少了1次乘法操作,提高了算法的性能。其算法如下:
  设矩阵A、B为:
A = [ A 11 A 12 A 21 A 22 ] B = [ B 11 B 12 B 21 B 22 ] A= \left[ \begin{matrix} A_{11} & A_{12} \\ A_{21} & A_{22} \end{matrix}\right]\\ B= \left[ \begin{matrix} B_{11} & B_{12} \\ B_{21} & B_{22} \end{matrix}\right]\\ A=[A11A21A12A22]B=[B11B21B12B22]
  建立7个临时变量 P 1 P_1 P1 P 7 P_7 P7,每个变量使用一次乘法运算。
P 1 = ( A 11 + A 22 ) ( B 11 + B 22 ) P 2 = ( A 21 + A 22 ) B 11 P 3 = A 11 ( B 12 − B 22 ) P 4 = A 22 ( B 21 − B 11 ) P 5 = ( A 11 + A 12 ) B 22 P 6 = ( A 21 − A 11 ) ( B 11 + B 12 ) P 7 = ( A 12 − A 22 ) ( B 21 + B 22 ) C 11 = P 1 + P 4 − P 5 + P 7 C 12 = P 3 + P 5 C 21 = P 2 + P 4 C 22 = P 1 − P 2 + P 3 + P 6 A × B = [ C 11 C 12 C 21 C 22 ] P_1 = (A_{11}+A_{22})(B_{11}+B_{22})\\ P_2 = (A_{21}+A_{22})B_{11}\\ P_3 = A_{11}(B_{12} − B_{22})\\ P_4 = A_{22}(B_{21} − B_{11})\\ P_5 = (A_{11} + A_{12})B_{22}\\ P_6 = (A_{21} − A_{11})(B_{11} + B_{12})\\ P_7 = (A_{12} − A_{22})(B_{21 }+ B_{22})\\ C_{11} = P_1 + P_4 − P_5 + P_7\\ C_{12} = P_3 + P_5\\ C_{21} = P_2 + P_4\\ C_{22} = P_1 − P_2 + P_3 + P_6\\ A\times B=\left[ \begin{matrix} C_{11} & C_{12} \\ C_{21} & C_{22} \end{matrix}\right]\\ P1=(A11+A22)(B11+B22)P2=(A21+A22)B11P3=A11(B12B22)P4=A22(B21B11)P5=(A11+A12)B22P6=(A21A11)(B11+B12)P7=(A12A22)(B21+B22)C11=P1+P4P5+P7C12=P3+P5C21=P2+P4C22=P1P2+P3+P6A×B=[C11C21C12C22]
  公式比较复杂,总共11个公式呢,根本记不住,所以我建议,收藏我的博文,不要去记忆,当然也可以顺便关注我一波。
  需要注意的是上面11个公式中,乘法的左右顺序特别重要,因为这个公式可以适用于任何代数环。代数环就是乘法不需要符合交换律的集合、加法与乘法运算符。这意味着什么,这意味着2X2矩阵中的元素不仅可以是数字,还可以是矩阵。也就是说可以利用分块矩阵的方法,将大矩阵拆分为2X2的矩阵再使用Strassen算法。
  不过需要注意的是因为存在 A 11 + A 22 A_{11}+A_{22} A11+A22这样的骚操作,所以进行矩阵分块时,行数或者列数不能是奇数,所以在为奇数的时候还是要用传统的方法啊。

python实现

  跟我以往的文章不同,这次我没有把本文的算法代码和其他博文的代码混在一起。我新写了一个python文件,只做Strassen算法,而且使用了分治以处理大矩阵,代码如下:

class Matrix:
    # 矩阵
    @staticmethod
    def create_by_lines(lines):
        # 为了支持分块,设置四个属性
        return Matrix(lines, 0, len(lines), 0, len(lines[0]))

    def __init__(self, lines, row_start, row_end, column_start, column_end):
        self.__lines = lines
        # 为了支持分块,设置四个属性
        self.__column_start = column_start
        self.__column_end = column_end
        self.__row_start = row_start
        self.__row_end = row_end

    def __mul__(self, other):
        # 首先判断能不能相乘
        if self.column_len() != other.row_len():
            raise Exception("矩阵A列数%d != 矩阵B的行数%d" % (len(self.__lines[0]), len(other.__lines)))
        # 然后判断是不是2X2矩阵
        # 这里场景比较多:
        # 1 1 x n n x 1
        # 2 n x 1 1 x n
        # 3 2 x 2 2 x 2 strassen 数值运算
        # 4 其他,进行分块 strassen 矩阵运算
        if self.row_len() == 1 or self.column_len() == 1:
            return self.plain_mul(other)

        # 奇数不能分块
        if self.row_len() & 1 == 1 or self.column_len() & 1 == 1 or other.row_len() & 1 == 1:
            return self.plain_mul(other)

        # 这个时候就可以使用strassen算法了

        a11, a12, a21, a22 = self.sub()
        b11, b12, b21, b22 = other.sub()

        p1 = (a11 + a22) * (b11 + b22)
        p2 = (a21 + a22) * b11
        p3 = a11 * (b12 - b22)
        p4 = a22 * (b21 - b11)
        p5 = (a11 + a12) * b22
        p6 = (a21 - a11) * (b11 + b12)
        p7 = (a12 - a22) * (b21 + b22)

        return Matrix.create(p1 + p4 - p5 + p7, p3 + p5, p2 + p4, p1 - p2 + p3 + p6)

    def __add__(self, other):
        arr = [[0] * self.column_len() for _ in range(0, self.row_len())]
        # 里面不能是同一个数组
        for i in range(0, self.row_len()):
            self_row = self.__lines[self.__row_start + i]
            other_row = other.__lines[other.__row_start + i]
            for j in range(0, self.column_len()):
                arr[i][j] = self_row[self.__column_start + j] + other_row[other.__column_start + j]
        return Matrix.create_by_lines(arr)

    def __sub__(self, other):
        arr = [[0] * self.column_len() for _ in range(0, self.row_len())]
        # 里面不能是同一个数组
        for i in range(0, self.row_len()):
            self_row = self.__lines[self.__row_start + i]
            other_row = other.__lines[other.__row_start + i]
            for j in range(0, self.column_len()):
                arr[i][j] = self_row[self.__column_start + j] - other_row[other.__column_start + j]
        return Matrix.create_by_lines(arr)

    def plain_mul(self, other):
        # 弄一个m行n列的新矩阵
        m = self.row_len()
        n = other.column_len()
        p = other.row_len()

        result = [[0] * n for _ in range(0, m)]
        # i 代表 A矩阵的行
        for i in range(self.__row_start, self.__row_end):
            # j 代表 B 矩阵的列
            for j in range(other.__column_start, other.__column_end):
                # 第一个矩阵的行 与第二个矩阵列的乘积和
                # k 代表 A矩阵的列和B矩阵的行
                for k in range(0, p):
                    self_line = self.__lines[i]
                    other_line = other.__lines[other.__row_start + k]
                    a = self_line[self.__column_start + k]
                    b = other_line[j]
                    mul = a * b
                    result[i - self.__row_start][j - other.__column_start] += mul
        return Matrix.create_by_lines(result)

    def row_len(self):
        return self.__row_end - self.__row_start

    def column_len(self):
        return self.__column_end - self.__column_start

    def sub(self):
        a_middle_row = (self.__row_end + self.__row_start) // 2
        a_middle_column = (self.__column_end + self.__column_start) // 2
        a11 = Matrix(self.__lines, self.__row_start, a_middle_row, self.__column_start, a_middle_column)
        a12 = Matrix(self.__lines, self.__row_start, a_middle_row, a_middle_column, self.__column_end)
        a21 = Matrix(self.__lines, a_middle_row, self.__row_end, self.__column_start, a_middle_column)
        a22 = Matrix(self.__lines, a_middle_row, self.__row_end, a_middle_column, self.__column_end)
        return a11, a12, a21, a22

    @staticmethod
    def create(a11, a12, a21, a22):
        len_rows = a11.row_len() + a21.row_len()
        len_columns = a11.column_len() + a12.column_len()
        lines = [[0] * len_columns for _ in range(0, len_rows)]
        # 拷贝进去
        a11.copy_to(lines, 0, 0)
        a12.copy_to(lines, 0, a11.column_len())
        a21.copy_to(lines, a11.row_len(), 0)
        a22.copy_to(lines, a12.row_len(), a21.column_len())
        return Matrix.create_by_lines(lines)

    def copy_to(self, lines, row_start, column_start):
        for i in range(0, self.row_len()):
            self_row = self.__lines[self.__row_start + i]
            other_row = lines[row_start + i]
            for j in range(0, self.column_len()):
                other_row[column_start + j] = self_row[self.__column_start + j]

    @property
    def lines(self):
        return self.__lines