Skip to content

Commit

Permalink
MAINT: Rework to HIGHS_INT64 work with pybind11
Browse files Browse the repository at this point in the history
  • Loading branch information
HaoZeke committed Oct 15, 2023
1 parent 94f53c5 commit 39a892e
Showing 1 changed file with 31 additions and 23 deletions.
54 changes: 31 additions & 23 deletions highspy/highs_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,22 @@ HighsStatus highs_passModelPointers(
const double* col_upper_ptr = static_cast<double*>(col_upper_info.ptr);
const double* row_lower_ptr = static_cast<double*>(row_lower_info.ptr);
const double* row_upper_ptr = static_cast<double*>(row_upper_info.ptr);
const double* a_value_ptr = static_cast<double*>(a_value_info.ptr);
const double* q_value_ptr = static_cast<double*>(q_value_info.ptr);
const HighsInt* a_start_ptr = static_cast<HighsInt*>(a_start_info.ptr);
const HighsInt* a_index_ptr = static_cast<HighsInt*>(a_index_info.ptr);
const double* a_value_ptr = static_cast<double*>(a_value_info.ptr);
const HighsInt* q_start_ptr = static_cast<HighsInt*>(q_start_info.ptr);
const HighsInt* q_index_ptr = static_cast<HighsInt*>(q_index_info.ptr);
const double* q_value_ptr = static_cast<double*>(q_value_info.ptr);
const HighsInt* integrality_ptr =
static_cast<HighsInt*>(integrality_info.ptr);

return h->passModel(num_col, num_row, num_nz, q_num_nz, a_format, q_format,
sense, offset, col_cost_ptr, col_lower_ptr, col_upper_ptr,
row_lower_ptr, row_upper_ptr, a_start_ptr, a_index_ptr,
a_value_ptr, q_start_ptr, q_index_ptr, q_value_ptr,
integrality_ptr);
return h->passModel(
static_cast<HighsInt>(num_col), static_cast<HighsInt>(num_row),
static_cast<HighsInt>(num_nz), static_cast<HighsInt>(q_num_nz),
static_cast<HighsInt>(a_format), static_cast<HighsInt>(q_format),
static_cast<HighsInt>(sense), offset, col_cost_ptr, col_lower_ptr,
col_upper_ptr, row_lower_ptr, row_upper_ptr, a_start_ptr, a_index_ptr,
a_value_ptr, q_start_ptr, q_index_ptr, q_value_ptr, integrality_ptr);
}

HighsStatus highs_passLp(Highs* h, HighsLp& lp) { return h->passModel(lp); }
Expand Down Expand Up @@ -90,10 +92,12 @@ HighsStatus highs_passLpPointers(
const HighsInt* integrality_ptr =
static_cast<HighsInt*>(integrality_info.ptr);

return h->passModel(num_col, num_row, num_nz, a_format, sense, offset,
col_cost_ptr, col_lower_ptr, col_upper_ptr, row_lower_ptr,
row_upper_ptr, a_start_ptr, a_index_ptr, a_value_ptr,
integrality_ptr);
return h->passModel(
static_cast<HighsInt>(num_col), static_cast<HighsInt>(num_row),
static_cast<HighsInt>(num_nz), static_cast<HighsInt>(a_format),
static_cast<HighsInt>(sense), offset, col_cost_ptr, col_lower_ptr,
col_upper_ptr, row_lower_ptr, row_upper_ptr, a_start_ptr, a_index_ptr,
a_value_ptr, integrality_ptr);
}

HighsStatus highs_passHessian(Highs* h, HighsHessian& hessian) {
Expand Down Expand Up @@ -148,7 +152,7 @@ HighsStatus highs_addRow(Highs* h, double lower, double upper,
py::buffer_info indices_info = indices.request();
py::buffer_info values_info = values.request();

HighsInt* indices_ptr = static_cast<HighsInt*>(indices_info.ptr);
HighsInt* indices_ptr = reinterpret_cast<HighsInt*>(indices_info.ptr);
double* values_ptr = static_cast<double*>(values_info.ptr);

return h->addRow(lower, upper, num_new_nz, indices_ptr, values_ptr);
Expand All @@ -167,8 +171,8 @@ HighsStatus highs_addRows(Highs* h, HighsInt num_row, py::array_t<double> lower,

double* lower_ptr = static_cast<double*>(lower_info.ptr);
double* upper_ptr = static_cast<double*>(upper_info.ptr);
HighsInt* starts_ptr = static_cast<HighsInt*>(starts_info.ptr);
HighsInt* indices_ptr = static_cast<HighsInt*>(indices_info.ptr);
HighsInt* starts_ptr = reinterpret_cast<HighsInt*>(starts_info.ptr);
HighsInt* indices_ptr = reinterpret_cast<HighsInt*>(indices_info.ptr);
double* values_ptr = static_cast<double*>(values_info.ptr);

return h->addRows(num_row, lower_ptr, upper_ptr, num_new_nz, starts_ptr,
Expand All @@ -181,7 +185,7 @@ HighsStatus highs_addCol(Highs* h, double cost, double lower, double upper,
py::buffer_info indices_info = indices.request();
py::buffer_info values_info = values.request();

HighsInt* indices_ptr = static_cast<HighsInt*>(indices_info.ptr);
HighsInt* indices_ptr = reinterpret_cast<HighsInt*>(indices_info.ptr);
double* values_ptr = static_cast<double*>(values_info.ptr);

return h->addCol(cost, lower, upper, num_new_nz, indices_ptr, values_ptr);
Expand All @@ -202,8 +206,8 @@ HighsStatus highs_addCols(Highs* h, HighsInt num_col, py::array_t<double> cost,
double* cost_ptr = static_cast<double*>(cost_info.ptr);
double* lower_ptr = static_cast<double*>(lower_info.ptr);
double* upper_ptr = static_cast<double*>(upper_info.ptr);
HighsInt* starts_ptr = static_cast<HighsInt*>(starts_info.ptr);
HighsInt* indices_ptr = static_cast<HighsInt*>(indices_info.ptr);
HighsInt* starts_ptr = reinterpret_cast<HighsInt*>(starts_info.ptr);
const HighsInt* indices_ptr = reinterpret_cast<HighsInt*>(indices_info.ptr);
double* values_ptr = static_cast<double*>(values_info.ptr);

return h->addCols(num_col, cost_ptr, lower_ptr, upper_ptr, num_new_nz,
Expand Down Expand Up @@ -363,7 +367,8 @@ std::tuple<HighsStatus, double, double, double, HighsInt> highs_getCol(
double cost, lower, upper;
HighsInt get_num_col;
HighsInt get_num_nz;
HighsStatus status = h->getCols(1, &col, get_num_col, &cost, &lower, &upper,
HighsInt col_ = static_cast<HighsInt>(col);
HighsStatus status = h->getCols(1, &col_, get_num_col, &cost, &lower, &upper,
get_num_nz, nullptr, nullptr, nullptr);
return std::make_tuple(status, cost, lower, upper, get_num_nz);
}
Expand All @@ -373,7 +378,8 @@ highs_getColEntries(Highs* h, HighsInt col) {
double cost, lower, upper;
HighsInt get_num_col;
HighsInt get_num_nz;
h->getCols(1, &col, get_num_col, nullptr, nullptr, nullptr, get_num_nz,
HighsInt col_ = static_cast<HighsInt>(col);
h->getCols(1, &col_, get_num_col, nullptr, nullptr, nullptr, get_num_nz,
nullptr, nullptr, nullptr);
get_num_nz = get_num_nz > 0 ? get_num_nz : 1;
HighsInt start;
Expand All @@ -382,7 +388,7 @@ highs_getColEntries(Highs* h, HighsInt col) {
HighsInt* index_ptr = static_cast<HighsInt*>(index.data());
double* value_ptr = static_cast<double*>(value.data());
HighsStatus status =
h->getCols(1, &col, get_num_col, nullptr, nullptr, nullptr, get_num_nz,
h->getCols(1, &col_, get_num_col, nullptr, nullptr, nullptr, get_num_nz,
&start, index_ptr, value_ptr);
return std::make_tuple(status, py::cast(index), py::cast(value));
}
Expand All @@ -392,7 +398,8 @@ std::tuple<HighsStatus, double, double, HighsInt> highs_getRow(Highs* h,
double cost, lower, upper;
HighsInt get_num_row;
HighsInt get_num_nz;
HighsStatus status = h->getRows(1, &row, get_num_row, &lower, &upper,
HighsInt row_ = static_cast<HighsInt>(row);
HighsStatus status = h->getRows(1, &row_, get_num_row, &lower, &upper,
get_num_nz, nullptr, nullptr, nullptr);
return std::make_tuple(status, lower, upper, get_num_nz);
}
Expand All @@ -402,15 +409,16 @@ highs_getRowEntries(Highs* h, HighsInt row) {
double cost, lower, upper;
HighsInt get_num_row;
HighsInt get_num_nz;
h->getRows(1, &row, get_num_row, nullptr, nullptr, get_num_nz, nullptr,
HighsInt row_ = static_cast<HighsInt>(row);
h->getRows(1, &row_, get_num_row, nullptr, nullptr, get_num_nz, nullptr,
nullptr, nullptr);
get_num_nz = get_num_nz > 0 ? get_num_nz : 1;
HighsInt start;
std::vector<HighsInt> index(get_num_nz);
std::vector<double> value(get_num_nz);
HighsInt* index_ptr = static_cast<HighsInt*>(index.data());
double* value_ptr = static_cast<double*>(value.data());
HighsStatus status = h->getRows(1, &row, get_num_row, nullptr, nullptr,
HighsStatus status = h->getRows(1, &row_, get_num_row, nullptr, nullptr,
get_num_nz, &start, index_ptr, value_ptr);
return std::make_tuple(status, py::cast(index), py::cast(value));
}
Expand Down

0 comments on commit 39a892e

Please sign in to comment.