Skip to content

Commit

Permalink
Merge pull request OpenMathLib#4875 from ChipKerchner/addGEMVtoBF16Test
Browse files Browse the repository at this point in the history
Add GEMV to SBGEMx vs SGEMx testing
  • Loading branch information
martin-frbg committed Aug 15, 2024
2 parents a388c4b + c23897f commit 4944148
Showing 1 changed file with 75 additions and 1 deletion.
76 changes: 75 additions & 1 deletion test/compare_sgemm_sbgemm.c
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "../common.h"
#define SGEMM BLASFUNC(sgemm)
#define SBGEMM BLASFUNC(sbgemm)
#define SGEMV BLASFUNC(sgemv)
#define SBGEMV BLASFUNC(sbgemv)
typedef union
{
unsigned short v;
Expand Down Expand Up @@ -187,7 +189,79 @@ main (int argc, char *argv[])
free(CC);
}

if (ret != 0)
if (ret != 0) {
fprintf (stderr, "FATAL ERROR SBGEMM - Return code: %d\n", ret);
return ret;
}

k = 1;
for (x = 1; x <= loop; x++)
{
float *A = (float *)malloc(x * x * sizeof(FLOAT));
float *B = (float *)malloc(x * sizeof(FLOAT));
float *C = (float *)malloc(x * sizeof(FLOAT));
bfloat16_bits *AA = (bfloat16_bits *)malloc(x * x * sizeof(bfloat16_bits));
bfloat16_bits *BB = (bfloat16_bits *)malloc(x * sizeof(bfloat16_bits));
float *DD = (float *)malloc(x * sizeof(FLOAT));
float *CC = (float *)malloc(x * sizeof(FLOAT));
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) ||
(DD == NULL) || (CC == NULL))
return 1;
bfloat16 atmp, btmp;
blasint one = 1;

for (j = 0; j < x; j++)
{
for (i = 0; i < x; i++)
{
A[j * x + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
sbstobf16_(&one, &A[j*x+i], &one, &atmp, &one);
AA[j * x + i].v = atmp;
}
B[j] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
sbstobf16_(&one, &B[j], &one, &btmp, &one);
BB[j].v = btmp;
}
for (y = 0; y < 2; y++)
{
if (y == 0) {
transA = 'N';
} else {
transA = 'T';
}

memset(CC, 0, x * sizeof(FLOAT));
memset(DD, 0, x * sizeof(FLOAT));
memset(C, 0, x * sizeof(FLOAT));

SGEMV (&transA, &x, &x, &alpha, A, &x, B, &k, &beta, C, &k);
SBGEMV (&transA, &x, &x, &alpha, (bfloat16*) AA, &x, (bfloat16*) BB, &k, &beta, CC, &k);

for (j = 0; j < x; j++)
for (i = 0; i < x; i++)
if (transA == 'N') {
DD[i] += float16to32 (AA[j * x + i]) * float16to32 (BB[j]);
} else if (transA == 'T') {
DD[j] += float16to32 (AA[j * x + i]) * float16to32 (BB[i]);
}

for (j = 0; j < x; j++) {
if (fabs (CC[j] - C[j]) > 1.0)
ret++;
if (fabs (CC[j] - DD[j]) > 1.0)
ret++;
}
}
free(A);
free(B);
free(C);
free(AA);
free(BB);
free(DD);
free(CC);
}

if (ret != 0)
fprintf (stderr, "FATAL ERROR SBGEMV - Return code: %d\n", ret);
return ret;
}

0 comments on commit 4944148

Please sign in to comment.