zl程序教程

您现在的位置是:首页 >  后端

当前栏目

C#,数值计算,矩阵相乘的斯特拉森(Strassen’s Matrix Multiplication)分治算法与源代码

c#算法计算 矩阵 源代码 数值 matrix 分治
2023-09-11 14:15:48 时间

Volker Strassen

矩阵乘法机器学习中最基本的运算之一,对其进行优化是多种优化的关键。通常,将两个大小为N X N的矩阵相乘需要N^3次运算。从那以后,我们在更好、更聪明的矩阵乘法算法方面取得了长足的进步。沃尔克·斯特拉森于1969年首次发表了他的算法。这是第一个证明基本O(n^3)运行时不是optiomal的算法。

Strassen算法的基本思想是将A和B分为8个子矩阵,然后递归计算C的子矩阵。这种策略称为分而治之。

伪代码

  1. 如上图所示,将矩阵A和B划分为大小为N/2 x N/2的4个子矩阵。
  2. 递归计算7个矩阵乘法。
  3. 计算C的子矩阵。
  4. 将这些子矩阵组合到我们的新矩阵C中

复杂性

  1. 最坏情况时间复杂度:Θ(n^2.8074)
  2. 最佳情况时间复杂度:Θ(1)
  3. 空间复杂度:Θ(logn)

年青时正在发愁的  Volker Strassen

算法的详细解释:

矩阵相乘在进行3D变换的时候是经常用到的。在应用中常用矩阵相乘的定义算法对其进行计算。这个算法用到了大量的循环和相乘运算,这使得算法效率不高。而矩阵相乘的计算效率很大程度上的影响了整个程序的运行速度,所以对矩阵相乘算法进行一些改进是必要的。

        我们先讨论二阶矩阵的计算方法。

        对于二阶矩阵

        a11    a12                    b11    b12    
        A =    a21    a22    B =    b21    b22
        先计算下面7个量(1)

        x1 = (a11 + a22) * (b11 + b22);
        x2 = (a21 + a22) * b11;
        x3 = a11 * (b12 - b22);
        x4 = a22 * (b21 - b11);
        x5 = (a11 + a12) * b22;
        x6 = (a21 - a11) * (b11 + b12);
        x7 = (a12 - a22) * (b21 + b22);
        再设C = AB。根据矩阵相乘的规则,C的各元素为(2)

        c11 = a11 * b11 + a12 * b21
        c12 = a11 * b12 + a12 * b22
        c21 = a21 * b11 + a22 * b21
        c22 = a21 * b12 + a22 * b22
        比较(1)(2),C的各元素可以表示为(3)

        c11 = x1 + x4 - x5 + x7
        c12 = x3 + x5
        c21 = x2 + x4
        c22 = x1 + x3 - x2 + x6


        根据以上的方法,我们就可以计算4阶矩阵了,先将4阶矩阵A和B划分成四块2阶矩阵,分别利用公式计算它们的乘积,再使用(1)(3)来计算出最后结果。

本文给出了多种算法,大家自己选择吧。

using System;
using System.Text;

namespace Legal.Truffer.Algorithm
{
	/// <summary>
	/// 矩阵相乘的斯特拉森(V. Strassen)方法
	/// </summary>
	public static class Matrix_Calculator
	{
		#region [4x4]x[4x4]矩阵相乘的斯特拉森(V. Strassen)方法(快速算法)

		// 计算2X2矩阵
		private static void Multiply2X2(out double fOut_11, out double fOut_12, out double fOut_21, out double fOut_22,
							double f1_11, double f1_12, double f1_21, double f1_22,
							double f2_11, double f2_12, double f2_21, double f2_22)
		{
			double x1 = ((f1_11 + f1_22) * (f2_11 + f2_22));
			double x2 = ((f1_21 + f1_22) * f2_11);
			double x3 = (f1_11 * (f2_12 - f2_22));
			double x4 = (f1_22 * (f2_21 - f2_11));
			double x5 = ((f1_11 + f1_12) * f2_22);
			double x6 = ((f1_21 - f1_11) * (f2_11 + f2_12));
			double x7 = ((f1_12 - f1_22) * (f2_21 + f2_22));

			fOut_11 = x1 + x4 - x5 + x7;
			fOut_12 = x3 + x5;
			fOut_21 = x2 + x4;
			fOut_22 = x1 - x2 + x3 + x6;
		}

		// 计算4X4矩阵
		public static Matrix Multiply4x4(Matrix a, Matrix b)
		{
			//double c[7,4] = new double[7,4];
			double c_0_0, c_0_1, c_0_2, c_0_3;
			double c_1_0, c_1_1, c_1_2, c_1_3;
			double c_2_0, c_2_1, c_2_2, c_2_3;
			double c_3_0, c_3_1, c_3_2, c_3_3;
			double c_4_0, c_4_1, c_4_2, c_4_3;
			double c_5_0, c_5_1, c_5_2, c_5_3;
			double c_6_0, c_6_1, c_6_2, c_6_3;

			// (ma11 + ma22) * (mb11 + mb22)
			Multiply2X2(out c_0_0, out c_0_1, out c_0_2, out c_0_3,
							a[0] + a[10], a[1] + a[11], a[4] + a[14], a[5] + a[15],
							b[0] + b[10], b[1] + b[11], b[4] + b[14], b[5] + b[15]);

			// (ma21 + ma22) * mb11
			Multiply2X2(out c_1_0, out c_1_1, out c_1_2, out c_1_3,
							a[8] + a[10], a[9] + a[11], a[12] + a[14], a[13] + a[15],
							b[0], b[1], b[4], b[5]);

			// ma11 * (mb12 - mb22)
			Multiply2X2(out c_2_0, out c_2_1, out c_2_2, out c_2_3,
							a[0], a[1], a[4], a[5],
							b[2] - b[10], b[3] - b[11], b[6] - b[14], b[7] - b[15]);

			// ma22 * (mb21 - mb11)
			Multiply2X2(out c_3_0, out c_3_1, out c_3_2, out c_3_3,
							a[10], a[11], a[14], a[15],
							b[8] - b[0], b[9] - b[1], b[12] - b[4], b[13] - b[5]);

			// (ma11 + ma12) * mb22
			Multiply2X2(out c_4_0, out c_4_1, out c_4_2, out c_4_3,
							a[0] + a[2], a[1] + a[3], a[4] + a[6], a[5] + a[7],
							b[10], b[11], b[14], b[15]);

			// (ma21 - ma11) * (mb11 + mb12)
			Multiply2X2(out c_5_0, out c_5_1, out c_5_2, out c_5_3,
							a[8] - a[0], a[9] - a[1], a[12] - a[4], a[13] - a[5],
							b[0] + b[2], b[1] + b[3], b[4] + b[6], b[5] + b[7]);

			// (ma12 - ma22) * (mb21 + mb22)
			Multiply2X2(out c_6_0, out c_6_1, out c_6_2, out c_6_3,
							a[2] - a[10], a[3] - a[11], a[6] - a[14], a[7] - a[15],
							b[8] + b[10], b[9] + b[11], b[12] + b[14], b[13] + b[15]);

			return new Matrix(4, 4, new double[4 * 4] {
				c_0_0 + c_3_0 - c_4_0 + c_6_0,
				c_0_1 + c_3_1 - c_4_1 + c_6_1,
				c_2_0 + c_4_0,
				c_2_1 + c_4_1,

				c_0_2 + c_3_2 - c_4_2 + c_6_2,
				c_0_3 + c_3_3 - c_4_3 + c_6_3,
				c_2_2 + c_4_2,
				c_2_3 + c_4_3,

				c_1_0 + c_3_0,
				c_1_1 + c_3_1,
				c_0_0 - c_1_0 + c_2_0 + c_5_0,
				c_0_1 - c_1_1 + c_2_1 + c_5_1,

				c_1_2 + c_3_2,
				c_1_3 + c_3_3,
				c_0_2 - c_1_2 + c_2_2 + c_5_2,
				c_0_3 - c_1_3 + c_2_3 + c_5_3
			});
		}
		#endregion

		#region 基于Strassen算法的矩阵“分治”乘法(只支持维度为2的幂次的方阵相乘。)

		private static Matrix create(Matrix input, int r1, int r2, int c1, int c2)
		{
			Matrix res = new Matrix(r2 - r1, c2 - c1);
			for (int i = r1, ii = 0; i <= r2 && ii < r2 - r1; i++, ii++)
			{
				for (int j = c1, jj = 0; j < c2 && jj < c2 - c1; j++, jj++)
				{
					res[ii, jj] = input[i, j];
				}
			}
			return res;
		}

		public static Matrix Multipy(Matrix A, Matrix B, int len, int r1 = 0, int c1 = 0)
		{
			if (len == 1)
			{
				return new Matrix(1, 1,
					new double[1] { A[0] * B[0] }
				);
			}

			int lend2 = len / 2;
			Matrix a = create(A, r1, r1 + lend2, c1, c1 + lend2);
			Matrix e = create(B, r1, r1 + lend2, c1, c1 + lend2);
			Matrix b = create(A, r1, r1 + lend2, c1 + lend2, len);
			Matrix f = create(B, r1, r1 + lend2, c1 + lend2, len);
			Matrix c = create(A, r1 + lend2, len, c1, c1 + lend2);
			Matrix g = create(B, r1 + lend2, len, c1, c1 + lend2);
			Matrix d = create(A, r1 + lend2, len, c1 + lend2, len);
			Matrix h = create(B, r1 + lend2, len, c1 + lend2, len);

			Matrix p1 = a * (f - h); // multi(a, sub(f, h, lend2), 0, 0, lend2); 
			Matrix p2 = (a + b) * h; // multi(add(a, b, lend2), h, 0, 0, lend2);
			Matrix p3 = (c + d) * e; // multi(add(c, d, lend2), e, 0, 0, lend2);
			Matrix p4 = d * (g - e); // multi(d, sub(g, e, lend2), 0, 0, lend2);
			Matrix p5 = (a + d) * (e + h); // multi(add(a, d, lend2), add(e, h, lend2), 0, 0, lend2);
			Matrix p6 = (b - d) * (g + h); // multi(sub(b, d, lend2), add(g, h, lend2), 0, 0, lend2);
			Matrix p7 = (a - c) * (e + f); // multi(sub(a, c, lend2), add(e, f, lend2), 0, 0, lend2);

			Matrix r = (((p5 + p4) + p6) - p2); // sub(add(add(p5, p4, lend2), p6, lend2), p2, lend2);
			Matrix s = p1 + p2; // add(p1, p2, lend2);
			Matrix t = p3 + p4; // add(p3, p4, lend2);
			Matrix u = (p5 + p1) - (p3 + p7);// sub(add(p5, p1, lend2), add(p3, p7, lend2), lend2);

			Matrix rr = new Matrix(len, len);
			for (int j = 0; j < lend2; j++)
			{
				for (int jj = 0; jj < lend2; jj++)
				{
					rr[j, jj] = r[j, jj];
				}
			}
			for (int j = 0; j < lend2; j++)
			{
				for (int jj = 0; jj < lend2; jj++)
				{
					rr[j, jj + lend2] = s[j, jj];
				}
			}
			for (int j = 0; j < lend2; j++)
			{
				for (int jj = 0; jj < lend2; jj++)
				{
					rr[j + lend2, jj] = t[j, jj];
				}
			}
			for (int j = 0; j < lend2; j++)
			{
				for (int jj = 0; jj < lend2; jj++)
				{
					rr[j + lend2, jj + lend2] = u[j, jj];
				}
			}
			return rr;
		}
		#endregion

		#region 基于Strassen矩阵乘法的递归分治算法

		/// <summary>
		/// 基于Strassen矩阵乘法的递归分治算法
		/// </summary>
		/// <param name="n"></param>
		/// <param name="A"></param>
		/// <param name="B"></param>
		/// <returns></returns>
		public static Matrix Strassen(int n, Matrix A, Matrix B)
		{
			//2-order 
			if (n == 2)
			{
				return A * B;
			}

			int N = n / 2;
			Matrix A11 = new Matrix(N, N);
			Matrix A12 = new Matrix(N, N);
			Matrix A21 = new Matrix(N, N);
			Matrix A22 = new Matrix(N, N);
			Matrix B11 = new Matrix(N, N);
			Matrix B12 = new Matrix(N, N);
			Matrix B21 = new Matrix(N, N);
			Matrix B22 = new Matrix(N, N);

			//将矩阵A和B分成阶数相同的四个子矩阵,即分治思想。  
			for (int i = 0; i < n / 2; i++)
			{
				for (int j = 0; j < n / 2; j++)
				{
					A11[i, j] = A[i, j];
					A12[i, j] = A[i, j + n / 2];
					A21[i, j] = A[i + n / 2, j];
					A22[i, j] = A[i + n / 2, j + n / 2];

					B11[i, j] = B[i, j];
					B12[i, j] = B[i, j + n / 2];
					B21[i, j] = B[i + n / 2, j];
					B22[i, j] = B[i + n / 2, j + n / 2];
				}
			}

			//Calculate M1 = (A0 + A3) × (B0 + B3)  
			Matrix M1 = Strassen(N, A11 + A22, B11 + B22);

			//Calculate M2 = (A2 + A3) × B0  
			Matrix M2 = Strassen(N, A21 + A22, B11);

			//Calculate M3 = A0 × (B1 - B3)  
			Matrix M3 = Strassen(N, A11, B12 - B22);

			//Calculate M4 = A3 × (B2 - B0)  
			Matrix M4 = Strassen(N, A22, B21 - B11);

			//Calculate M5 = (A0 + A1) × B3  
			Matrix M5 = Strassen(N, A11 + A12, B22);

			//Calculate M6 = (A2 - A0) × (B0 + B1)  
			Matrix M6 = Strassen(N, A21 - A11, B11 + B12);

			//Calculate M7 = (A1 - A3) × (B2 + B3)  
			Matrix M7 = Strassen(N, A12 - A22, B21 + B22);

			//Calculate C0 = M1 + M4 - M5 + M7  
			Matrix C11 = (M1 + M4) + (M7 - M5);

			//Calculate C1 = M3 + M5  
			Matrix C12 = M3 + M5;

			//Calculate C2 = M2 + M4  
			Matrix C21 = M2 + M4;

			//Calculate C3 = M1 - M2 + M3 + M6  
			Matrix C22 = (M1 - M2) + (M3 + M6);

			Matrix C = new Matrix(n, n);
			for (int i = 0; i < N; i++)
			{
				for (int j = 0; j < N; j++)
				{
					C[i, j] = C11[i, j];
					C[i, j + N] = C12[i, j];
					C[i + N, j] = C21[i, j];
					C[i + N, j + N] = C22[i, j];
				}
			}
			return C;
		}
        #endregion
    }
}