/*
  In case you're wondering, dgemm stands for Double-precision, GEneral
  Matrix-Matrix multiplication.
*/

const char* dgemm_desc = "Simple blocked dgemm.";

/* You'll definitely change this... */
#if !defined(BLOCK_SIZE)
#define BLOCK_SIZE ((int) 16)
#endif

/*
  A is M-by-K
  B is K-by-N
  C is M-by-N

  lda is the leading dimension of the matrix (the M of square_dgemm).
*/

void
basic_dgemm (const int lda,
             const int M, const int N, const int K,
             const double *A, const double *B, double *C)
{
  int i, j, k;

  /*
    To optimize this, think about loop unrolling and software
    pipelining.  Hint:  For the majority of the matmuls, you
    know exactly how many iterations there are (the block size)...
  */

  for (i = 0; i < M; ++i) {
       const double *Ai_ = A + i;
       for (j = 0; j < N; ++j) {
            const double *B_j = B + j*lda;

            double cij = *(C + j*lda + i);

            for (k = 0; k < K; ++k) {
                 cij += *(Ai_ + k*lda) * *(B_j + k);
            }

            *(C + j*lda + i) = cij;
       }
  }
}

void
do_block (const int lda,
          const double *A, const double *B, double *C,
          const int i, const int j, const int k)
{
     /*
       Remember that you need to deal with the fringes in each
       dimension.

       If the matrix is 7x7 and the blocks are 3x3, you'll have 1x3,
       3x1, and 1x1 fringe blocks.

             xxxoooX
             xxxoooX
             xxxoooX
             oooxxxO
             oooxxxO
             oooxxxO
             XXXOOOX

       You won't get this to go fast until you figure out a `better'
       way to handle the fringe blocks.  The better way will be more
       machine-efficient, but very programmer-inefficient.
     */
     const int M = (i+BLOCK_SIZE > lda? lda-i : BLOCK_SIZE);
     const int N = (j+BLOCK_SIZE > lda? lda-j : BLOCK_SIZE);
     const int K = (k+BLOCK_SIZE > lda? lda-k : BLOCK_SIZE);

     basic_dgemm (lda, M, N, K,
                  A + i + k*lda, B + k + j*lda, C + i + j*lda);
}

void
square_dgemm (const int M, 
              const double *A, const double *B, double *C)
{
     const int n_blocks = M / BLOCK_SIZE + (M%BLOCK_SIZE? 1 : 0);
     int bi, bj, bk;

     for (bi = 0; bi < n_blocks; ++bi) {
          const int i = bi * BLOCK_SIZE;
          
          for (bj = 0; bj < n_blocks; ++bj) {
               const int j = bj * BLOCK_SIZE;

               for (bk = 0; bk < n_blocks; ++bk) {
                    const int k = bk * BLOCK_SIZE;
                    
                    do_block (M, A, B, C, i, j, k);
               }
          }
     }
}

