zl程序教程

您现在的位置是:首页 >  其它

当前栏目

hdu4965 巧用矩阵乘法结合律

矩阵 乘法 巧用
2023-09-11 14:14:00 时间
题意:
     给两个矩阵,n*m的矩阵A,和m*n的矩阵B,
求(A*B)^(n*n)其中 m<=6,n<=1000。
思路:
      一开始直接模拟,写了个矩阵快速幂,超时了,因为A*B后得到的是1000*1000的矩阵,做乘法直接超时了,后来写了个这样的
    (A*B)^(n*n) = (A*B)*(A*B)*(A*B)...
                       = A * (B*A)*(B*A)*(B*A)...*B

矩阵虽然没有交换律但是有结合律,我们直接先B*A(得到的是一个最大6*6的矩阵)然后快速幂,然后再A * BA^(n*n-1) * B这样就行了,然后又超时了,算了很多次,感觉不可能超时,但还是超时了,原因就是我所有的矩阵用的都是mat[1002][1002]为了方便我都开结构体了,结果各种超时,最后没办法了,全都开数组,然后去模拟,A[1002][8],B[8][1002],BA[8][8]...,这样就AC了,难道开大的数组也会浪费很多时间?(这个地方头一次碰到)。


#include<stdio.h>
#include<string.h>

typedef struct
{
   int mat[8][8];
}AA;

int A[1002][8] ,B[8][1002] ,C[1002][1002];
int nmm[1002][8];

AA mat_matba(int n ,int m)
{
    AA c;
    memset(c.mat ,0 ,sizeof(c.mat)); 
    for(int k = 1 ;k <= n ;k ++)
    for(int i = 1 ;i <= m ;i ++)
    if(B[i][k])
    for(int j = 1 ;j <= m ;j ++)
    c.mat[i][j] = (c.mat[i][j] + B[i][k] * A[k][j])%6 ;
    return c;
}

AA mat_mat(AA a ,AA b ,int n)
{
   AA c;
   memset(c.mat ,0 ,sizeof(c.mat));
   for(int k = 1 ;k <= n ;k ++)
   for(int i = 1 ;i <= n ;i ++)
   if(a.mat[i][k])
   for(int j = 1 ;j <= n ;j ++)
   c.mat[i][j] = (c.mat[i][j] + a.mat[i][k] * b.mat[k][j]) % 6;
   return c;
}


AA quick_mat(AA a ,int b ,int n)
{
   AA c;
   memset(c.mat ,0 ,sizeof(c.mat));
   for(int i = 1 ;i <= n ;i ++)
   c.mat[i][i] = 1;
   while(b)
   {
      if(b&1) c = mat_mat(c ,a ,n);
      a = mat_mat(a ,a ,n);
      b >>= 1;
   }
   return c;
}

void mat_matnmm(AA mm ,int n ,int m)
{
   memset(nmm ,0 ,sizeof(nmm));
   for(int k = 1 ;k <= m ;k ++)
   for(int i = 1 ;i <= n ;i ++)
   if(A[i][k])  
   for(int j = 1 ;j <= m ;j ++)
   nmm[i][j] = (nmm[i][j] + A[i][k] * mm.mat[k][j]) % 6;
}

void mat_matnmmn(int n ,int m)
{
   memset(C ,0 ,sizeof(C));
   for(int k = 1 ;k <= m ;k ++)
   for(int i = 1 ;i <= n ;i ++)
   for(int j = 1 ;j <= n ;j ++)
   C[i][j] = (C[i][j] + nmm[i][k] * B[k][j]) % 6;
} 



int main ()
{
    int n ,m ,i ,j;
    while(~scanf("%d %d" ,&n ,&m) && n + m)
    {
       for(i = 1 ;i <= n ;i ++)
       for(j = 1 ;j <= m ;j ++)
       scanf("%d" ,&A[i][j]);
       for(i = 1 ;i <= m ;i ++)
       for(j = 1 ;j <= n ;j ++)
       scanf("%d" ,&B[i][j]);
       AA c = mat_matba(n ,m);
       AA ban = quick_mat(c ,n*n-1 ,m);
       mat_matnmm(ban ,n ,m);
       mat_matnmmn(n ,m);
     
       
       int sum = 0;
       for(i = 1 ;i <= n ;i ++)
       for(j = 1 ;j <= n ;j ++)
       sum += C[i][j];
       printf("%d\n" ,sum);
     }
     return 0;
}