__global__ voidsgemm_naive( int M, int N, int K, constfloat alpha, constfloat *A, constfloat *B, constfloat beta, float *C) { constint i = blockIdx.x * blockDim.x + threadIdx.x; constint j = blockIdx.y * blockDim.y + threadIdx.y;
// avoids memory access error if threads are more than elements if (i < M && j < N) { float sum = 0.0f; for (int k = 0; k < K; ++k) { sum += A[i * K + k] * B[k * N + j]; } C[i * N + j] = alpha * sum + beta * C[i * N + j]; } }
__global__ voidsgemm_coalesce( int M, int N, int K, constfloat alpha, constfloat *A, constfloat *B, constfloat beta, float *C) { constint j = blockIdx.x * blockDim.x + threadIdx.x; constint i = blockIdx.y * blockDim.y + threadIdx.y;
// avoids memory access error if threads are more than elements if (i < M && j < N) { float sum = 0.0f; for (int k = 0; k < K; ++k) { sum += A[i * K + k] * B[k * N + j]; } C[i * N + j] = alpha * sum + beta * C[i * N + j]; } }
sgemm_coalesce只需要392ms。
Tiled Matrix Multiply with Shared Memory
sgemm_coalesce的主要开销还是由从全局内存读取数据的次数决定。一种改进思路是用shared memory实现tiled matrix multiply。tiled matrix multiply通过把数据分块,一次性读取分块数据到更快的内存上(认为更快的内存访问时间可以忽略不计),从而减小整体全局内存访问的开销。假设缓存的大小是B x B,则总显存访问次数为MN(2K/B+1),因此B越大,访问全局内存的时间开销就越少。
#define BLOCKSIZE 32
__global__ voidsgemm_shared( int M, int N, int K, constfloat alpha, constfloat *A, constfloat *B, constfloat beta, float *C) { __shared__ float As[BLOCKSIZE * BLOCKSIZE]; __shared__ float Bs[BLOCKSIZE * BLOCKSIZE];
A += blockIdx.y * BLOCKSIZE * K; B += blockIdx.x * BLOCKSIZE; C += blockIdx.y * BLOCKSIZE * N + blockIdx.x * BLOCKSIZE;
// avoids memory access error if threads are more than elements if (i < M && j < N) { float fSum = 0.0f; // stores result of (threadIdx.y, threadIdx.x) on each block for (int iBlkIdx = 0; iBlkIdx < K; iBlkIdx += BLOCKSIZE) { if (iBlkIdx + threadIdx.x < K) { As[threadIdx.y * BLOCKSIZE + threadIdx.x] = A[threadIdx.y * K + threadIdx.x]; } if (iBlkIdx + threadIdx.y < K) { Bs[threadIdx.y * BLOCKSIZE + threadIdx.x] = B[threadIdx.y * N + threadIdx.x]; } __syncthreads(); // syncronize until all caches are fulfilled
// updates to the next chunk A += BLOCKSIZE; B += BLOCKSIZE * N;