Skip to content

Commit

Permalink
forward to GEMV when one argument is actually a vector
Browse files Browse the repository at this point in the history
  • Loading branch information
martin-frbg authored May 20, 2024
1 parent 700ea74 commit c2a9b19
Showing 1 changed file with 45 additions and 4 deletions.
49 changes: 45 additions & 4 deletions interface/gemm.c
Original file line number Diff line number Diff line change
Expand Up @@ -47,22 +47,29 @@
#define SMP_THRESHOLD_MIN 65536.0
#ifdef XDOUBLE
#define ERROR_NAME "QGEMM "
#define GEMV BLASFUNC(qgemv)
#elif defined(DOUBLE)
#define ERROR_NAME "DGEMM "
#define GEMV BLASFUNC(dgemv)
#elif defined(BFLOAT16)
#define ERROR_NAME "SBGEMM "
#define GEMV BLASFUNC(sbgemv)
#else
#define ERROR_NAME "SGEMM "
#define GEMV BLASFUNC(sgemv)
#endif
#else
#define SMP_THRESHOLD_MIN 8192.0
#ifndef GEMM3M
#ifdef XDOUBLE
#define ERROR_NAME "XGEMM "
#define GEMV BLASFUNC(xgemv)
#elif defined(DOUBLE)
#define ERROR_NAME "ZGEMM "
#define GEMV BLASFUNC(zgemv)
#else
#define ERROR_NAME "CGEMM "
#define GEMV BLASFUNC(cgemv)
#endif
#else
#ifdef XDOUBLE
Expand Down Expand Up @@ -485,9 +492,38 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
}
#endif
#endif // defined(__linux__) && defined(__x86_64__) && defined(BFLOAT16)
// fprintf(stderr,"G E M M interface m n k %d %d %d\n",args.m,args.n,args.k);

if ((args.m == 0) || (args.n == 0)) return;

#if 1
#ifndef GEMM3M
if (args.m == 1) {
char *NT=(char*)malloc(2*sizeof(char));
if (transb&1)strcpy(NT,"T");
else NT="N";
// fprintf(stderr,"G E M V\n");
GEMV(NT, &args.n ,&args.k, args.alpha, args.b, &args.ldb, args.a, &args.m, args.beta, args.c, &args.m);
//SUBROUTINE SGEMV(TRANS,M,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
return;
} else {
if (args.n == 1) {
#ifndef CBLAS
char *NT=(char*)malloc(2*sizeof(char));
strcpy(NT,"N");
#else
char *NT=(char*)malloc(2*sizeof(char));
if (transb&1)strcpy(NT,"T");
else strcpy(NT,"N");
#endif
// fprintf(stderr,"G E M V ! ! ! lda=%d ldb=%d ldc=%d\n",args.lda,args.ldb,args.ldc);
GEMV(NT, &args.m ,&args.k, args.alpha, args.a, &args.lda, args.b, &args.n, args.beta, args.c, &args.n);
//SUBROUTINE SGEMV(TRANS,M,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
return;
}
}
#endif
#endif
#if 0
fprintf(stderr, "m = %4d n = %d k = %d lda = %4d ldb = %4d ldc = %4d\n",
args.m, args.n, args.k, args.lda, args.ldb, args.ldc);
Expand Down Expand Up @@ -521,10 +557,15 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS

buffer = (XFLOAT *)blas_memory_alloc(0);

//For target LOONGSON3R5, applying an offset to the buffer is essential
//for minimizing cache conflicts and optimizing performance.
#if defined(ARCH_LOONGARCH64) && !defined(NO_AFFINITY)
sa = (XFLOAT *)((BLASLONG)buffer + (WhereAmI() & 0xf) * GEMM_OFFSET_A);
//For Loongson servers, like the 3C5000 (featuring 16 cores), applying an
//offset to the buffer is essential for minimizing cache conflicts and optimizing performance.
#if defined(LOONGSON3R5) && !defined(NO_AFFINITY)
char model_name[128];
get_cpu_model(model_name);
if ((strstr(model_name, "3C5000") != NULL) || (strstr(model_name, "3D5000") != NULL))
sa = (XFLOAT *)((BLASLONG)buffer + (WhereAmI() & 0xf) * GEMM_OFFSET_A);
else
sa = (XFLOAT *)((BLASLONG)buffer + GEMM_OFFSET_A);
#else
sa = (XFLOAT *)((BLASLONG)buffer +GEMM_OFFSET_A);
#endif
Expand Down

0 comments on commit c2a9b19

Please sign in to comment.