Skip to content

Commit

Permalink
MAINT: Use HighsCallbackFunctionType more
Browse files Browse the repository at this point in the history
Co-authored-by: jajhall <jajhall@users.noreply.github.com>
  • Loading branch information
HaoZeke and jajhall committed Oct 21, 2023
1 parent 910af61 commit 58cc6c2
Showing 1 changed file with 92 additions and 83 deletions.
175 changes: 92 additions & 83 deletions check/TestCallbacks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "HCheckConfig.h"
#include "Highs.h"
#include "catch.hpp"
#include "lp_data/HighsCallback.h"

const bool dev_run = false;

Expand All @@ -26,92 +27,100 @@ using std::strncmp;
using std::strstr;

// Callback that saves message for comparison
std::function<void(int, const std::string&, const HighsCallbackDataOut*, HighsCallbackDataIn*, void*)>
myLogCallback = [](int callback_type, const std::string& message, const HighsCallbackDataOut* data_out, HighsCallbackDataIn* data_in, void* user_callback_data) {
strcpy(printed_log, message.c_str());
};

std::function<void(int, const std::string&, const HighsCallbackDataOut*, HighsCallbackDataIn*, void*)>
userInterruptCallback = [](int callback_type, const std::string& message, const HighsCallbackDataOut* data_out, HighsCallbackDataIn* data_in, void* user_callback_data) {
// Extract local_callback_data from user_callback_data unless it
// is nullptr
if (callback_type == kCallbackMipImprovingSolution) {
// Use local_callback_data to maintain the objective value from
// the previous callback
assert(user_callback_data);
// Extract the double value pointed to from void* user_callback_data
const double local_callback_data = *(double*)user_callback_data;
if (dev_run)
printf(
"userCallback(type %2d; data %11.4g): %s with objective %g and "
"solution[0] = %g\n",
callback_type, local_callback_data, message.c_str(),
data_out->objective_function_value, data_out->mip_solution[0]);
REQUIRE(local_callback_data >= data_out->objective_function_value);
// Update the double value pointed to from void* user_callback_data
*(double*)user_callback_data = data_out->objective_function_value;
} else {
const int local_callback_data =
user_callback_data
? static_cast<int>(reinterpret_cast<intptr_t>(user_callback_data))
: kUserCallbackNoData;
if (user_callback_data) {
REQUIRE(local_callback_data == kUserCallbackData);
} else {
REQUIRE(local_callback_data == kUserCallbackNoData);
}
if (callback_type == kCallbackLogging) {
if (dev_run)
printf("userInterruptCallback(type %2d; data %2d): %s", callback_type,
local_callback_data, message.c_str());
} else if (callback_type == kCallbackSimplexInterrupt) {
if (dev_run)
printf(
"userInterruptCallback(type %2d; data %2d): %s with iteration "
"count = "
"%d\n",
callback_type, local_callback_data, message.c_str(),
int(data_out->simplex_iteration_count));
data_in->user_interrupt =
data_out->simplex_iteration_count > adlittle_simplex_iteration_limit;
} else if (callback_type == kCallbackIpmInterrupt) {
if (dev_run)
printf(
"userInterruptCallback(type %2d; data %2d): %s with iteration "
"count = "
"%d\n",
callback_type, local_callback_data, message.c_str(),
int(data_out->ipm_iteration_count));
data_in->user_interrupt =
data_out->ipm_iteration_count > adlittle_ipm_iteration_limit;
} else if (callback_type == kCallbackMipInterrupt) {
HighsCallbackFunctionType myLogCallback =
[](int callback_type, const std::string& message,
const HighsCallbackDataOut* data_out, HighsCallbackDataIn* data_in,
void* user_callback_data) { strcpy(printed_log, message.c_str()); };

HighsCallbackFunctionType userInterruptCallback =
[](int callback_type, const std::string& message,
const HighsCallbackDataOut* data_out, HighsCallbackDataIn* data_in,
void* user_callback_data) {
// Extract local_callback_data from user_callback_data unless it
// is nullptr
if (callback_type == kCallbackMipImprovingSolution) {
// Use local_callback_data to maintain the objective value from
// the previous callback
assert(user_callback_data);
// Extract the double value pointed to from void* user_callback_data
const double local_callback_data = *(double*)user_callback_data;
if (dev_run)
printf(
"userCallback(type %2d; data %11.4g): %s with objective %g and "
"solution[0] = %g\n",
callback_type, local_callback_data, message.c_str(),
data_out->objective_function_value, data_out->mip_solution[0]);
REQUIRE(local_callback_data >= data_out->objective_function_value);
// Update the double value pointed to from void* user_callback_data
*(double*)user_callback_data = data_out->objective_function_value;
} else {
const int local_callback_data =
user_callback_data ? static_cast<int>(reinterpret_cast<intptr_t>(
user_callback_data))
: kUserCallbackNoData;
if (user_callback_data) {
REQUIRE(local_callback_data == kUserCallbackData);
} else {
REQUIRE(local_callback_data == kUserCallbackNoData);
}
if (callback_type == kCallbackLogging) {
if (dev_run)
printf("userInterruptCallback(type %2d; data %2d): %s",
callback_type, local_callback_data, message.c_str());
} else if (callback_type == kCallbackSimplexInterrupt) {
if (dev_run)
printf(
"userInterruptCallback(type %2d; data %2d): %s with iteration "
"count = "
"%d\n",
callback_type, local_callback_data, message.c_str(),
int(data_out->simplex_iteration_count));
data_in->user_interrupt = data_out->simplex_iteration_count >
adlittle_simplex_iteration_limit;
} else if (callback_type == kCallbackIpmInterrupt) {
if (dev_run)
printf(
"userInterruptCallback(type %2d; data %2d): %s with iteration "
"count = "
"%d\n",
callback_type, local_callback_data, message.c_str(),
int(data_out->ipm_iteration_count));
data_in->user_interrupt =
data_out->ipm_iteration_count > adlittle_ipm_iteration_limit;
} else if (callback_type == kCallbackMipInterrupt) {
if (dev_run)
printf(
"userInterruptCallback(type %2d; data %2d): %s with Bounds "
"(%11.4g, %11.4g); Gap = %11.4g; Objective = "
"%g\n",
callback_type, local_callback_data, message.c_str(),
data_out->mip_dual_bound, data_out->mip_primal_bound,
data_out->mip_gap, data_out->objective_function_value);
data_in->user_interrupt =
data_out->objective_function_value < egout_objective_target;
}
}
};

std::function<void(int, const std::string&, const HighsCallbackDataOut*,
HighsCallbackDataIn*, void*)>
userDataCallback = [](int callback_type, const std::string& message,
const HighsCallbackDataOut* data_out,
HighsCallbackDataIn* data_in,
void* user_callback_data) {
assert(callback_type == kCallbackMipInterrupt ||
callback_type == kCallbackMipLogging ||
callback_type == kCallbackMipImprovingSolution);
if (dev_run)
printf(
"userInterruptCallback(type %2d; data %2d): %s with Bounds "
"(%11.4g, %11.4g); Gap = %11.4g; Objective = "
"%g\n",
callback_type, local_callback_data, message.c_str(),
"userDataCallback: Node count = %" PRId64
"; Time = %6.2f; "
"Bounds (%11.4g, %11.4g); Gap = %11.4g; Objective = %11.4g: %s\n",
data_out->mip_node_count, data_out->running_time,
data_out->mip_dual_bound, data_out->mip_primal_bound,
data_out->mip_gap, data_out->objective_function_value);
data_in->user_interrupt =
data_out->objective_function_value < egout_objective_target;
}
}
};

std::function<void(int, const std::string&, const HighsCallbackDataOut*, HighsCallbackDataIn*, void*)>
userDataCallback = [](int callback_type, const std::string& message, const HighsCallbackDataOut* data_out, HighsCallbackDataIn* data_in, void* user_callback_data) {
assert(callback_type == kCallbackMipInterrupt ||
callback_type == kCallbackMipLogging ||
callback_type == kCallbackMipImprovingSolution);
if (dev_run)
printf("userDataCallback: Node count = %" PRId64
"; Time = %6.2f; "
"Bounds (%11.4g, %11.4g); Gap = %11.4g; Objective = %11.4g: %s\n",
data_out->mip_node_count, data_out->running_time,
data_out->mip_dual_bound, data_out->mip_primal_bound,
data_out->mip_gap, data_out->objective_function_value, message.c_str());
};
data_out->mip_gap, data_out->objective_function_value,
message.c_str());
};

TEST_CASE("my-callback-logging", "[highs-callback]") {
bool output_flag = true; // Still runs quietly
Expand Down

0 comments on commit 58cc6c2

Please sign in to comment.