Skip to content

Commit

Permalink
ENH: Ditch the global variables for callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
HaoZeke committed Oct 8, 2023
1 parent 8b96274 commit a9c53e1
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 40 deletions.
14 changes: 8 additions & 6 deletions src/Highs.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ HighsInt highsVersionPatch();
const char* highsGithash();
const char* highsCompilationDate();

using HighsCallbackFunctionType =
std::function<void(int, const std::string&, const HighsCallbackDataOut*,
HighsCallbackDataIn*, void*)>;

/**
* @brief Class to set parameters and run HiGHS
*/
Expand Down Expand Up @@ -1019,11 +1023,10 @@ class Highs {
/**
* @brief Set the callback method to use for HiGHS
*/
HighsStatus setCallback(
std::function<void(int, const std::string&, const HighsCallbackDataOut*,
HighsCallbackDataIn*, void*)>
user_callback,
void* user_callback_data = nullptr);
HighsStatus setCallback(HighsCallbackFunctionType user_callback,
void* user_callback_data = nullptr);
HighsStatus setCallback(HighsCCallbackType c_callback,
void* user_callback_data = nullptr);

/**
* @brief Start callback of given type
Expand Down Expand Up @@ -1461,5 +1464,4 @@ class Highs {
void cpp_callback_forwarder(int code, const std::string& msg,
const HighsCallbackDataOut* out_data,
HighsCallbackDataIn* in_data, void* user_data);

#endif
17 changes: 5 additions & 12 deletions src/interfaces/highs_c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,6 @@

#include "Highs.h"

// This global variable can hold the C function pointer between the calls Needed
// since std::function can take any callable, not just function pointers Can be
// dropped if users *promise* to never pass anything but a function pointer
CCallbackType g_user_callback = nullptr;

HighsInt Highs_lpCall(const HighsInt num_col, const HighsInt num_row,
const HighsInt num_nz, const HighsInt a_format,
const HighsInt sense, const double offset,
Expand Down Expand Up @@ -630,13 +625,11 @@ HighsInt Highs_setSolution(void* highs, const double* col_value,
return (HighsInt)((Highs*)highs)->setSolution(solution);
}

HighsInt Highs_setCallback(void* highs, CCallbackType user_callback, void* user_callback_data) {
// Store the C function pointer globally
g_user_callback = user_callback;

// Use the forwarder as the C++ callback
auto status = static_cast<Highs*>(highs)->setCallback(cpp_callback_forwarder, user_callback_data);
return static_cast<int>(status);
HighsInt Highs_setCallback(void* highs, HighsCCallbackType user_callback,
void* user_callback_data) {
auto status = static_cast<Highs*>(highs)->setCallback(user_callback,
user_callback_data);
return static_cast<int>(status);
}

HighsInt Highs_startCallback(void* highs, const int callback_type) {
Expand Down
8 changes: 3 additions & 5 deletions src/interfaces/highs_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#ifndef HIGHS_C_API
#define HIGHS_C_API

//#include "util/HighsInt.h"
// #include "util/HighsInt.h"
#include "lp_data/HighsCallbackStruct.h"

const HighsInt kHighsMaximumStringLength = 512;
Expand Down Expand Up @@ -1078,10 +1078,8 @@ HighsInt Highs_setSolution(void* highs, const double* col_value,
*
* @returns A `kHighsStatus` constant indicating whether the call succeeded.
*/
HighsInt Highs_setCallback(
void* highs,
CCallbackType user_callback,
void* user_callback_data);
HighsInt Highs_setCallback(void* highs, HighsCCallbackType user_callback,
void* user_callback_data);

/**
* Start callback of given type
Expand Down
33 changes: 19 additions & 14 deletions src/lp_data/Highs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
#include <algorithm>
#include <cassert>
#include <csignal>
#include <functional>
#include <iostream>
#include <memory>
#include <sstream>

#include "io/Filereader.h"
#include "io/LoadOptions.h"
#include "lp_data/HighsCallbackStruct.h"
#include "lp_data/HighsInfoDebug.h"
#include "lp_data/HighsLpSolverObject.h"
#include "lp_data/HighsSolve.h"
Expand Down Expand Up @@ -53,15 +55,6 @@ void highsSignalHandler(int signum) {
exit(signum);
}

// C++ callback function that forwards to the C function pointer
void cpp_callback_forwarder(int code, const std::string& msg,
const HighsCallbackDataOut* out_data,
HighsCallbackDataIn* in_data, void* user_data) {
if (g_user_callback) {
g_user_callback(code, msg.c_str(), out_data, in_data, user_data);
}
}

Highs::Highs() { signal(SIGINT, highsSignalHandler); }

HighsStatus Highs::clear() {
Expand Down Expand Up @@ -1888,11 +1881,8 @@ HighsStatus Highs::setSolution(const HighsSolution& solution) {
return returnFromHighs(return_status);
}

HighsStatus Highs::setCallback(
std::function<void(int, const std::string&, const HighsCallbackDataOut*,
HighsCallbackDataIn*, void*)>
user_callback,
void* user_callback_data) {
HighsStatus Highs::setCallback(HighsCallbackFunctionType user_callback,
void* user_callback_data) {
this->callback_.clear();
this->callback_.user_callback = user_callback;
this->callback_.user_callback_data = user_callback_data;
Expand All @@ -1903,6 +1893,21 @@ HighsStatus Highs::setCallback(
return HighsStatus::kOk;
}

HighsStatus Highs::setCallback(HighsCCallbackType c_callback,
void* user_callback_data) {
this->callback_.clear();
this->callback_.user_callback =
[c_callback](int a, const std::string& b, const HighsCallbackDataOut* c,
HighsCallbackDataIn* d,
void* e) { c_callback(a, b.c_str(), c, d, e); };
this->callback_.user_callback_data = user_callback_data;

options_.log_options.user_callback = this->callback_.user_callback;
options_.log_options.user_callback_data = this->callback_.user_callback_data;
options_.log_options.user_callback_active = false;
return HighsStatus::kOk;
}

HighsStatus Highs::startCallback(const int callback_type) {
const bool callback_type_ok =
callback_type >= kCallbackMin && callback_type <= kCallbackMax;
Expand Down
1 change: 1 addition & 0 deletions src/lp_data/HighsCallback.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ struct HighsCallback {
std::function<void(int, const std::string&, const HighsCallbackDataOut*,
HighsCallbackDataIn*, void*)>
user_callback = nullptr;
HighsCCallbackType c_callback = nullptr;
void* user_callback_data = nullptr;
std::vector<bool> active;
HighsCallbackDataOut data_out;
Expand Down
6 changes: 3 additions & 3 deletions src/lp_data/HighsCallbackStruct.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ typedef struct {
} HighsCallbackDataIn;

// Additional callback handling
typedef void (*CCallbackType)(int, const char*, const HighsCallbackDataOut*,
HighsCallbackDataIn*, void*);
extern CCallbackType g_user_callback;
typedef void (*HighsCCallbackType)(int, const char*,
const HighsCallbackDataOut*,
HighsCallbackDataIn*, void*);

#ifdef __cplusplus
}
Expand Down

0 comments on commit a9c53e1

Please sign in to comment.