Skip to content

Commit

Permalink
Add COMPLEX_RETSTYLE_FNDA for Windows x64
Browse files Browse the repository at this point in the history
Windows x64 automatically forces return values onto the stack if they
are larger than 64 bits wide [0].  This causes return values from e.g.
`zdotc` to be pushed onto a secret first argument, but not the return
values from e.g. `cdotc`.

To address this, we add a new complex return style, "Float Normal,
Double Argument", to specify that `complex float`-returning functions
use the normal return style, whereas `complex double`-returning
functions use the argument return style.

This should fix JuliaLinearAlgebra/BLISBLAS.jl#15

[0] https://learn.microsoft.com/en-us/cpp/build/x64-calling-convention?view=msvc-170
  • Loading branch information
staticfloat committed Jun 13, 2024
1 parent 07c3509 commit dfd4f5f
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 51 deletions.
78 changes: 64 additions & 14 deletions src/autodetection.c
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,9 @@ int32_t autodetect_complex_return_style(void * handle, const char * suffix) {
if (env_lowercase_match("LBT_FORCE_RETSTYLE", "argument")) {
return LBT_COMPLEX_RETSTYLE_ARGUMENT;
}
if (env_lowercase_match("LBT_FORCE_RETSTYLE", "fnda")) {
return LBT_COMPLEX_RETSTYLE_FNDA;
}
char symbol_name[MAX_SYMBOL_LEN];

build_symbol_name(symbol_name, "zdotc_", suffix);
Expand All @@ -222,37 +225,84 @@ int32_t autodetect_complex_return_style(void * handle, const char * suffix) {
return LBT_COMPLEX_RETSTYLE_UNKNOWN;
}

build_symbol_name(symbol_name, "cdotc_", suffix);
void * cdotc_addr = lookup_symbol(handle, symbol_name);
if (cdotc_addr == NULL) {
return LBT_COMPLEX_RETSTYLE_UNKNOWN;
}

// Typecast to function pointer for easier usage below
double complex (*zdotc_normal)( int64_t *, double complex *, int64_t *, double complex *, int64_t *) = zdotc_addr;
void (*zdotc_retarg)(double complex *, int64_t *, double complex *, int64_t *, double complex *, int64_t *) = zdotc_addr;

// Typecast to function pointer for easier usage below
float complex (*cdotc_normal)( int64_t *, float complex *, int64_t *, float complex *, int64_t *) = cdotc_addr;
void (*cdotc_retarg)(float complex *, int64_t *, float complex *, int64_t *, float complex *, int64_t *) = cdotc_addr;

/*
* First, check to see if `zdotc` zeros out the first argument if all arguments are zero.
* Supposedly, most well-behaved implementations will return `0 + 0*I` if the length of
* the inputs is zero; so if it is using a "return argument", that's a good way to find out.
*
* We detect this by setting `retval` to an initial value of `0.0 + 1.0*I`. This has the
* added benefit of being interpretable as `0` if looked at as an `int{32,64}_t *`, which
* makes this invocation safe across the full normal-return/argument-return vs. lp64/ilp64
* compatibility square.
* We detect this by setting `retval` to an initial value of `-1` typecast to a complex
* value. The floating-point values are unimportant as they will be written to, but if
* it is interpreted as an `int{32,64}_t`, it will be a negative value (which is not
* allowed and should end the routine immediately). This makes this invocation safe
* across the full normal/argument, lp64/ilp64, cdotc/zdotc compatibility cube.
*/
double complex retval = 0.0 + 1.0*I;
double complex retval_double = 0.0 + 1.0*I;
int64_t zero = 0;
double complex zeroc = 0.0 + 0.0*I;
zdotc_retarg(&retval, &zero, &zeroc, &zero, &zeroc, &zero);
double complex zeroc_double = 0.0 + 0.0*I;
zdotc_retarg(&retval_double, &zero, &zeroc_double, &zero, &zeroc_double, &zero);

if (creal(retval) == 0.0 && cimag(retval) == 0.0) {
return LBT_COMPLEX_RETSTYLE_ARGUMENT;
/*
* Next, do the same with `cdotc`, in order to detect situations where the ABI is
* automatically inserting an extra argument to return 128-bit-wide values.
* We call this `FNDA` for "Float Normal, Double Argument" style.
*/
int64_t neg1 = -1;
float complex retval_float = *(complex float *)(&neg1);
float complex zeroc_float = 0.0f + 0.0f*I;
cdotc_retarg(&retval_float, &zero, &zeroc_float, &zero, &zeroc_float, &zero);

if (creal(retval_double) == 0.0 && cimag(retval_double) == 0.0) {
// If the double values were reset, and the float values were also,
// this is easy, we're just always argument-style:
if (creal(retval_float) == 0.0f && cimag(retval_float) == 0.0f) {
return LBT_COMPLEX_RETSTYLE_ARGUMENT;
}

// If the float values were not, let's try the normal return style:
retval_float = 0.0f + 1.0f*I;
retval_float = cdotc_normal(&zero, &zeroc_float, &zero, &zeroc_float, &zero);


// If this works, we are in FNDA style (currently only observed on Windows x64)
if (creal(retval_float) == 0.0f && cimag(retval_float) == 0.0f) {
return LBT_COMPLEX_RETSTYLE_FNDA;
}

// Otherwise, cdotc is throwing a fit and we don't know what's up.
return LBT_COMPLEX_RETSTYLE_UNKNOWN;
}

// If it was _not_ reset, let's hazard a guess that we're dealing with a normal return style:
retval = 0.0 + 1.0*I;
retval = zdotc_normal(&zero, &zeroc, &zero, &zeroc, &zero);
if (creal(retval) == 0.0 && cimag(retval) == 0.0) {
// If our double values were _not_ reset, let's hazard a guess that
// we're dealing with a normal return style and test both types again:
retval_double = 0.0 + 1.0*I;
retval_double = zdotc_normal(&zero, &zeroc_double, &zero, &zeroc_double, &zero);
retval_float = 0.0f + 1.0f*I;
retval_float = cdotc_normal(&zero, &zeroc_float, &zero, &zeroc_float, &zero);


// We only test for both working; we don't have a retstyle for float
// being argument style and double being normal style.
if ((creal(retval_double) == 0.0 && cimag(retval_double) == 0.0) &&
(creal(retval_float) == 0.0f && cimag(retval_float) == 0.0f)) {
return LBT_COMPLEX_RETSTYLE_NORMAL;
}

// If that was not reset either, we have no idea what's going on.
// If we get here, zdotc and cdotc are being uncooperative and we
// do not appreciate it at all, not we don't my precious.
return LBT_COMPLEX_RETSTYLE_UNKNOWN;
}
#endif // COMPLEX_RETSTYLE_AUTODETECTION
Expand Down
4 changes: 2 additions & 2 deletions src/cblas_adapters.c
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ void lbt_cblas_cdotc_sub(const int32_t N,
}

extern float complex cdotc_64_(const int64_t *,
const float complex *, const int64_t *,
const float complex *, const int64_t *);
const float complex *, const int64_t *,
const float complex *, const int64_t *);
void lbt_cblas_cdotc_sub64_(const int64_t N,
const float complex *X, const int64_t incX,
const float complex *Y, const int64_t incY,
Expand Down
16 changes: 8 additions & 8 deletions src/complex_return_style_adapters.c
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ extern void (*cmplxret_cdotc__addr)(float complex * z,
const float complex *, const int32_t *,
const float complex *, const int32_t *);
float complex cmplxret_cdotc_(const int32_t * N,
const float complex *X, const int32_t * incX,
const float complex *Y, const int32_t * incY)
const float complex *X, const int32_t * incX,
const float complex *Y, const int32_t * incY)
{
float complex c;
cmplxret_cdotc__addr(&c, N, X, incX, Y, incY);
Expand All @@ -85,8 +85,8 @@ extern void (*cmplxret_cdotc_64__addr)(float complex * z,
const float complex *, const int64_t *,
const float complex *, const int64_t *);
float complex cmplxret_cdotc_64_(const int64_t * N,
const float complex *X, const int64_t * incX,
const float complex *Y, const int64_t * incY)
const float complex *X, const int64_t * incX,
const float complex *Y, const int64_t * incY)
{
float complex c;
cmplxret_cdotc_64__addr(&c, N, X, incX, Y, incY);
Expand All @@ -100,8 +100,8 @@ extern void (*cmplxret_cdotu__addr)(float complex * z,
const float complex *, const int32_t *,
const float complex *, const int32_t *);
float complex cmplxret_cdotu_(const int32_t * N,
const float complex *X, const int32_t * incX,
const float complex *Y, const int32_t * incY)
const float complex *X, const int32_t * incX,
const float complex *Y, const int32_t * incY)
{
float complex c;
cmplxret_cdotu__addr(&c, N, X, incX, Y, incY);
Expand All @@ -113,8 +113,8 @@ extern void (*cmplxret_cdotu_64__addr)(float complex * z,
const float complex *, const int64_t *,
const float complex *, const int64_t *);
float complex cmplxret_cdotu_64_(const int64_t * N,
const float complex *X, const int64_t * incX,
const float complex *Y, const int64_t * incY)
const float complex *X, const int64_t * incX,
const float complex *Y, const int64_t * incY)
{
float complex c;
cmplxret_cdotu_64__addr(&c, N, X, incX, Y, incY);
Expand Down
41 changes: 22 additions & 19 deletions src/libblastrampoline.c
Original file line number Diff line number Diff line change
Expand Up @@ -70,26 +70,29 @@ int32_t set_forward_by_index(int32_t symbol_idx, const void * addr, int32_t inte
}

#ifdef COMPLEX_RETSTYLE_AUTODETECTION
if (complex_retstyle == LBT_COMPLEX_RETSTYLE_ARGUMENT) {
// Check to see if this symbol is one of the complex-returning functions
for (int complex_symbol_idx=0; cmplxret_func_idxs[complex_symbol_idx] != -1; ++complex_symbol_idx) {
// Skip any symbols that aren't ours
if (cmplxret_func_idxs[complex_symbol_idx] != symbol_idx)
continue;

// Report to the user that we're cblas-wrapping this one
if (verbose) {
char exported_name[MAX_SYMBOL_LEN];
build_symbol_name(exported_name, exported_func_names[symbol_idx], interface == LBT_INTERFACE_ILP64 ? "64_" : "");
printf(" - [%04d] complex(%s)\n", symbol_idx, exported_name);
}
for (int array_idx=0; array_idx < sizeof(cmplxret_func_idxs)/sizeof(int *); ++array_idx) {
if ((complex_retstyle == LBT_COMPLEX_RETSTYLE_ARGUMENT) ||
((complex_retstyle == LBT_COMPLEX_RETSTYLE_FNDA) && array_idx == 1)) {
// Check to see if this symbol is one of the complex-returning functions
for (int complex_symbol_idx=0; cmplxret_func_idxs[array_idx][complex_symbol_idx] != -1; ++complex_symbol_idx) {
// Skip any symbols that aren't ours
if (cmplxret_func_idxs[array_idx][complex_symbol_idx] != symbol_idx)
continue;

// Report to the user that we're cmplxret-wrapping this one
if (verbose) {
char exported_name[MAX_SYMBOL_LEN];
build_symbol_name(exported_name, exported_func_names[symbol_idx], interface == LBT_INTERFACE_ILP64 ? "64_" : "");
printf(" - [%04d] complex(%s)\n", symbol_idx, exported_name);
}

if (interface == LBT_INTERFACE_LP64) {
(*cmplxret_func32_addrs[complex_symbol_idx]) = (*exported_func32_addrs[symbol_idx]);
(*exported_func32_addrs[symbol_idx]) = cmplxret32_func_wrappers[complex_symbol_idx];
} else {
(*cmplxret_func64_addrs[complex_symbol_idx]) = (*exported_func64_addrs[symbol_idx]);
(*exported_func64_addrs[symbol_idx]) = cmplxret64_func_wrappers[complex_symbol_idx];
if (interface == LBT_INTERFACE_LP64) {
(*cmplxret_func32_addrs[array_idx][complex_symbol_idx]) = (*exported_func32_addrs[symbol_idx]);
(*exported_func32_addrs[symbol_idx]) = cmplxret_func32_wrappers[array_idx][complex_symbol_idx];
} else {
(*cmplxret_func64_addrs[array_idx][complex_symbol_idx]) = (*exported_func64_addrs[symbol_idx]);
(*exported_func64_addrs[symbol_idx]) = cmplxret_func64_wrappers[array_idx][complex_symbol_idx];
}
}
}
}
Expand Down
5 changes: 5 additions & 0 deletions src/libblastrampoline.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,13 @@ typedef struct {
// Possible values for `retstyle` in `lbt_library_info_t`
// These describe whether a library is using "normal" return value passing (e.g. through
// the `XMM{0,1}` registers on x86_64, or the `ST{0,1}` floating-point registers on i686)
// This is further complicated by the fact that on certain platforms (such as Windows x64
// this is dependent on the size of the value being returned, e.g. a complex64 value will
// be returned through registers, but a complex128 value will not. We therefore have a
// special value that denotes this situation)
#define LBT_COMPLEX_RETSTYLE_NORMAL 0
#define LBT_COMPLEX_RETSTYLE_ARGUMENT 1
#define LBT_COMPLEX_RETSTYLE_FNDA 2 // "Float Normal, Double Argument"
#define LBT_COMPLEX_RETSTYLE_UNKNOWN -1

// Possible values for `cblas` in `lbt_library_info_t`
Expand Down
55 changes: 48 additions & 7 deletions src/libblastrampoline_complex_retdata.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,35 @@ COMPLEX128_FUNCS(XX_64)
// Build mapping from cmplxret-index to `_addr` instance
#define XX(name, index) &cmplxret_##name##_addr,
#define XX_64(name, index) &cmplxret_##name##64__addr,
const void ** cmplxret_func32_addrs[] = {
const void ** cmplx64ret_func32_addrs[] = {
COMPLEX64_FUNCS(XX)
NULL
};
const void ** cmplx128ret_func32_addrs[] = {
COMPLEX128_FUNCS(XX)
NULL
};
const void ** cmplxret_func64_addrs[] = {
const void ** cmplx64ret_func64_addrs[] = {
COMPLEX64_FUNCS(XX_64)
NULL
};
const void ** cmplx128ret_func64_addrs[] = {
COMPLEX128_FUNCS(XX_64)
NULL
};
#undef XX
#undef XX_64

const void *** cmplxret_func32_addrs[] = {
cmplx64ret_func32_addrs,
cmplx128ret_func32_addrs
};
const void *** cmplxret_func64_addrs[] = {
cmplx64ret_func64_addrs,
cmplx128ret_func64_addrs
};



// Forward-declare some functions
#define XX(name, index) extern const void * cmplxret_##name ;
Expand All @@ -40,24 +56,49 @@ COMPLEX128_FUNCS(XX_64)
// locations, allowing a cblas index -> function lookup
#define XX(name, index) &cmplxret_##name,
#define XX_64(name, index) &cmplxret_##name##64_,
const void ** cmplxret32_func_wrappers[] = {
const void ** cmplx64ret_func32_wrappers[] = {
COMPLEX64_FUNCS(XX)
NULL
};
const void ** cmplx128ret_func32_wrappers[] = {
COMPLEX128_FUNCS(XX)
NULL
};
const void ** cmplxret64_func_wrappers[] = {
const void ** cmplx64ret_func64_wrappers[] = {
COMPLEX64_FUNCS(XX_64)
NULL
};
const void ** cmplx128ret_func64_wrappers[] = {
COMPLEX128_FUNCS(XX_64)
NULL
};
#undef XX
#undef XX_64

// Finally, an array that maps cblas index -> exported symbol index
const void *** cmplxret_func32_wrappers[] = {
cmplx64ret_func32_wrappers,
cmplx128ret_func32_wrappers
};
const void *** cmplxret_func64_wrappers[] = {
cmplx64ret_func64_wrappers,
cmplx128ret_func64_wrappers
};



// Finally, an array that maps cmplxret index -> exported symbol index
#define XX(name, index) index,
const int cmplxret_func_idxs[] = {
const int cmplx64ret_func_idxs[] = {
COMPLEX64_FUNCS(XX)
-1
};
const int cmplx128ret_func_idxs[] = {
COMPLEX128_FUNCS(XX)
-1
};
#undef XX
#undef XX

const int * cmplxret_func_idxs[] = {
cmplx64ret_func_idxs,
cmplx128ret_func_idxs
};
2 changes: 1 addition & 1 deletion test/direct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ lbt_handle = dlopen("$(lbt_prefix)/$(binlib)/lib$(lbt_link_name).$(shlib_ext)",
@test libs[1].f2c == LBT_F2C_PLAIN
if Sys.ARCH (:x86_64, :aarch64)
if Sys.iswindows()
@test libs[1].complex_retstyle == LBT_COMPLEX_RETSTYLE_ARGUMENT
@test libs[1].complex_retstyle == LBT_COMPLEX_RETSTYLE_FNDA
else
@test libs[1].complex_retstyle == LBT_COMPLEX_RETSTYLE_NORMAL
end
Expand Down
1 change: 1 addition & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ const LBT_INTERFACE_ILP64 = 64
const LBT_F2C_PLAIN = 0
const LBT_COMPLEX_RETSTYLE_NORMAL = 0
const LBT_COMPLEX_RETSTYLE_ARGUMENT = 1
const LBT_COMPLEX_RETSTYLE_FNDA = 2
const LBT_COMPLEX_RETSTYLE_UNKNOWN = -1
const LBT_CBLAS_CONFORMANT = 0
const LBT_CBLAS_DIVERGENT = 1
Expand Down

0 comments on commit dfd4f5f

Please sign in to comment.