Skip to content

Commit

Permalink
Revert some changes after pybind#165 was merged.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ralf W. Grosse-Kunstleve committed Jun 21, 2024
1 parent d7d37cd commit 2f79409
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 269 deletions.
4 changes: 0 additions & 4 deletions pybind11_protobuf/check_unknown_fields.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ std::string MakeAllowListKey(
return absl::StrCat(top_message_descriptor_full_name, ":",
unknown_field_parent_message_fqn);
}
#if defined(PYBIND11_PROTOBUF_ENABLE_PYPROTO_API)

/// Recurses through the message Descriptor class looking for valid extensions.
/// Stores the result to `memoized`.
Expand Down Expand Up @@ -174,7 +173,6 @@ std::string HasUnknownFields::BuildErrorMessage() const {
return emsg;
}

#endif
} // namespace

void AllowUnknownFieldsFor(absl::string_view top_message_descriptor_full_name,
Expand All @@ -183,7 +181,6 @@ void AllowUnknownFieldsFor(absl::string_view top_message_descriptor_full_name,
unknown_field_parent_message_fqn));
}

#if defined(PYBIND11_PROTOBUF_ENABLE_PYPROTO_API)
std::optional<std::string> CheckRecursively(
const ::google::protobuf::python::PyProto_API* py_proto_api,
const ::google::protobuf::Message* message) {
Expand All @@ -198,6 +195,5 @@ std::optional<std::string> CheckRecursively(
}
return search.BuildErrorMessage();
}
#endif

} // namespace pybind11_protobuf::check_unknown_fields
7 changes: 1 addition & 6 deletions pybind11_protobuf/check_unknown_fields.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,9 @@

#include <optional>

#include "absl/strings/string_view.h"
#include "google/protobuf/message.h"

#if defined(PYBIND11_PROTOBUF_ENABLE_PYPROTO_API)
#include "python/google/protobuf/proto_api.h"
#endif // PYBIND11_PROTOBUF_ENABLE_PYPROTO_API
#include "absl/strings/string_view.h"

namespace pybind11_protobuf::check_unknown_fields {

Expand Down Expand Up @@ -48,11 +45,9 @@ class ExtensionsWithUnknownFieldsPolicy {
void AllowUnknownFieldsFor(absl::string_view top_message_descriptor_full_name,
absl::string_view unknown_field_parent_message_fqn);

#if defined(PYBIND11_PROTOBUF_ENABLE_PYPROTO_API)
std::optional<std::string> CheckRecursively(
const ::google::protobuf::python::PyProto_API* py_proto_api,
const ::google::protobuf::Message* top_message);
#endif // PYBIND11_PROTOBUF_ENABLE_PYPROTO_API

} // namespace pybind11_protobuf::check_unknown_fields

Expand Down
259 changes: 0 additions & 259 deletions pybind11_protobuf/proto_cast_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,6 @@
#include "google/protobuf/descriptor.h"
#include "google/protobuf/descriptor_database.h"
#include "google/protobuf/dynamic_message.h"
#if defined(PYBIND11_PROTOBUF_ENABLE_PYPROTO_API)
#include "python/google/protobuf/proto_api.h"
#else
namespace google::protobuf::python {
struct PyProto_API;
}
#endif
#include "pybind11_protobuf/check_unknown_fields.h"

#if defined(GOOGLE_PROTOBUF_VERSION)
#include "absl/strings/numbers.h"
#include "absl/strings/str_split.h"
#endif

namespace py = pybind11;

Expand All @@ -51,7 +38,6 @@ using ::google::protobuf::FileDescriptor;
using ::google::protobuf::FileDescriptorProto;
using ::google::protobuf::Message;
using ::google::protobuf::MessageFactory;
using ::google::protobuf::python::PyProto_API;

namespace pybind11_protobuf {

Expand Down Expand Up @@ -154,22 +140,6 @@ absl::optional<std::string> CastToOptionalString(py::handle src) {
return absl::nullopt;
}

#if defined(GOOGLE_PROTOBUF_VERSION)
// The current version, represented as a single integer to make comparison
// easier: major * 10^6 + minor * 10^3 + micro
uint64_t VersionStringToNumericVersion(absl::string_view version_str) {
std::vector<absl::string_view> split = absl::StrSplit(version_str, '.');
uint64_t major = 0, minor = 0, micro = 0;
if (split.size() == 3 && //
absl::SimpleAtoi(split[0], &major) &&
absl::SimpleAtoi(split[1], &minor) &&
absl::SimpleAtoi(split[2], &micro)) {
return major * 1000000 + minor * 1000 + micro;
}
return 0;
}
#endif

class GlobalState {
public:
// Global state singleton intentionally leaks at program termination.
Expand All @@ -182,8 +152,6 @@ class GlobalState {
}

py::handle global_pool() { return global_pool_; }
const PyProto_API* py_proto_api() { return py_proto_api_; }
bool using_fast_cpp() const { return using_fast_cpp_; }

// Allocate a python proto message instance using the native python
// allocations.
Expand All @@ -201,8 +169,6 @@ class GlobalState {
private:
GlobalState();

const PyProto_API* py_proto_api_ = nullptr;
bool using_fast_cpp_ = false;
py::object global_pool_;
py::object factory_;
py::object find_message_type_by_name_;
Expand Down Expand Up @@ -248,61 +214,6 @@ GlobalState::GlobalState() {
get_prototype_ = {};
get_message_class_ = {};
}

// determine the proto implementation.
auto type =
ImportCached("google.protobuf.internal.api_implementation")
.attr("Type")();
using_fast_cpp_ = (CastToOptionalString(type).value_or("") == "cpp");

#if defined(PYBIND11_PROTOBUF_ENABLE_PYPROTO_API)
// DANGER: The only way to guarantee that the PyProto_API doesn't have
// incompatible ABI changes is to ensure that the python protobuf .so
// and all other extension .so files are built with the exact same
// environment, including compiler, flags, etc. It's also expected
// that the global_pool() objects are the same. And there's no way for
// bazel to do that right now.
//
// Otherwise, we're left with (1), the PyProto_API module reaching into the
// internals of a potentially incompatible Descriptor type from this CU, (2)
// this CU reaching into the potentially incompatible internals of PyProto_API
// implementation, or (3) disabling access to PyProto_API unless compile
// options suggest otherwise.
//
// By default (3) is used, however if the define is set *and* the version
// matches, then pybind11_protobuf will assume that this will work.
using ::google::protobuf::python::PyProtoAPICapsuleName;
py_proto_api_ =
static_cast<PyProto_API*>(PyCapsule_Import(PyProtoAPICapsuleName(), 0));
if (py_proto_api_ == nullptr) {
// The module implementing fast cpp protos is not loaded, clear the error.
assert(!using_fast_cpp_);
PyErr_Clear();
}
#else
py_proto_api_ = nullptr;
using_fast_cpp_ = false;
#endif

#if defined(GOOGLE_PROTOBUF_VERSION)
/// The C++ version of PyProto_API must match that loaded by python,
/// otherwise the details of the underlying implementation may cause
/// crashes. This limits the ability to pass some protos from C++ to
/// python.
if (py_proto_api_) {
auto version =
ResolveAttrs(ImportCached("google.protobuf"), {"__version__"});
std::string version_str =
version ? CastToOptionalString(*version).value_or("") : "";
if (GOOGLE_PROTOBUF_VERSION != VersionStringToNumericVersion(version_str)) {
std::cerr << "Python version " << version_str
<< " does not match C++ version " << GOOGLE_PROTOBUF_VERSION
<< std::endl;
using_fast_cpp_ = false;
py_proto_api_ = nullptr;
}
}
#endif
}

py::module_ GlobalState::ImportCached(const std::string& module_name) {
Expand Down Expand Up @@ -361,49 +272,6 @@ py::object GlobalState::PyMessageInstance(const Descriptor* descriptor) {
module_name + "?");
}

#if defined(PYBIND11_PROTOBUF_ENABLE_PYPROTO_API)
std::pair<py::object, Message*> GlobalState::PyFastCppProtoMessageInstance(
const Descriptor* descriptor) {
assert(descriptor != nullptr);
assert(py_proto_api_ != nullptr);

// Create a PyDescriptorPool, temporarily, it will be used by the NewMessage
// API call which will store it in the classes it creates.
//
// Note: Creating Python classes is a bit expensive, it might be a good idea
// for client code to create the pool once, and store it somewhere along with
// the C++ pool; then Python pools and classes are cached and reused.
// Otherwise, consecutives calls to this function may or may not reuse
// previous classes, depending on whether the returned instance has been
// kept alive.
//
// IMPORTANT CAVEAT: The C++ DescriptorPool must not be deallocated while
// there are any messages using it.
// Furthermore, since the cache uses the DescriptorPool address, allocating
// a new DescriptorPool with the same address is likely to use dangling
// pointers.
// It is probably better for client code to keep the C++ DescriptorPool alive
// until the end of the process.
// TODO(amauryfa): Add weakref or on-deletion callbacks to C++ DescriptorPool.
py::object descriptor_pool = py::reinterpret_steal<py::object>(
py_proto_api_->DescriptorPool_FromPool(descriptor->file()->pool()));
if (descriptor_pool.ptr() == nullptr) {
throw py::error_already_set();
}

py::object result = py::reinterpret_steal<py::object>(
py_proto_api_->NewMessage(descriptor, nullptr));
if (result.ptr() == nullptr) {
throw py::error_already_set();
}
Message* message = py_proto_api_->GetMutableMessagePointer(result.ptr());
if (message == nullptr) {
throw py::error_already_set();
}
return {std::move(result), message};
}
#endif

// Create C++ DescriptorPools based on Python DescriptorPools.
// The Python pool will provide message definitions when they are needed.
// This gives an efficient way to create C++ Messages from Python definitions.
Expand Down Expand Up @@ -542,26 +410,6 @@ class PythonDescriptorPoolWrapper {
private:
bool CopyToFileDescriptorProto(py::handle py_file_descriptor,
FileDescriptorProto* output) {
#if defined(PYBIND11_PROTOBUF_ENABLE_PYPROTO_API)
if (GlobalState::instance()->py_proto_api()) {
try {
py::object c_proto = py::reinterpret_steal<py::object>(
GlobalState::instance()
->py_proto_api()
->NewMessageOwnedExternally(output, nullptr));
if (c_proto) {
py_file_descriptor.attr("CopyToProto")(c_proto);
return true;
}
} catch (py::error_already_set& e) {
std::cerr << "CopyToFileDescriptorProto raised an error";

// This prints and clears the error.
e.restore();
PyErr_Print();
}
}
#endif

return output->ParsePartialFromString(
PyBytesAsStringView(py_file_descriptor.attr("serialized_pb")));
Expand Down Expand Up @@ -760,82 +608,6 @@ py::handle GenericPyProtoCast(Message* src, py::return_value_policy policy,
return py_proto.release();
}

#if defined(PYBIND11_PROTOBUF_ENABLE_PYPROTO_API)
py::handle GenericFastCppProtoCast(Message* src, py::return_value_policy policy,
py::handle parent, bool is_const) {
assert(policy != pybind11::return_value_policy::automatic);
assert(policy != pybind11::return_value_policy::automatic_reference);
assert(src != nullptr);
assert(PyGILState_Check());
assert(GlobalState::instance()->py_proto_api() != nullptr);

switch (policy) {
case py::return_value_policy::move:
case py::return_value_policy::take_ownership: {
std::pair<py::object, Message*> descriptor_pair =
GlobalState::instance()->PyFastCppProtoMessageInstance(
src->GetDescriptor());
py::object& result = descriptor_pair.first;
Message* result_message = descriptor_pair.second;

if (result_message->GetReflection() == src->GetReflection()) {
// The internals may be Swapped iff the protos use the same Reflection
// instance.
result_message->GetReflection()->Swap(src, result_message);
} else {
auto serialized = src->SerializePartialAsString();
if (!result_message->ParseFromString(serialized)) {
throw py::type_error(
"Failed to copy protocol buffer with mismatched descriptor");
}
}
return result.release();
} break;

#if defined(PYBIND11_HAS_RETURN_VALUE_POLICY_CLIF_AUTOMATIC)
// TODO(clif-team): Review this choice for `_clif_automatic`.
case py::return_value_policy::_clif_automatic:
#endif
case py::return_value_policy::copy: {
std::pair<py::object, Message*> descriptor_pair =
GlobalState::instance()->PyFastCppProtoMessageInstance(
src->GetDescriptor());
py::object& result = descriptor_pair.first;
Message* result_message = descriptor_pair.second;

if (result_message->GetReflection() == src->GetReflection()) {
// The internals may be copied iff the protos use the same Reflection
// instance.
result_message->CopyFrom(*src);
} else {
auto serialized = src->SerializePartialAsString();
if (!result_message->ParseFromString(serialized)) {
throw py::type_error(
"Failed to copy protocol buffer with mismatched descriptor");
}
}
return result.release();
} break;

case py::return_value_policy::reference:
case py::return_value_policy::reference_internal: {
// NOTE: Reference to const are currently unsafe to return.
py::object result = py::reinterpret_steal<py::object>(
GlobalState::instance()->py_proto_api()->NewMessageOwnedExternally(
src, nullptr));
if (policy == py::return_value_policy::reference_internal) {
py::detail::keep_alive_impl(result, parent);
}
return result.release();
} break;

default:
std::string message("pybind11_protobuf unhandled return_value_policy::");
throw py::cast_error(message + ReturnValuePolicyName(policy));
}
}
#endif

py::handle GenericProtoCast(Message* src, py::return_value_policy policy,
py::handle parent, bool is_const) {
assert(src != nullptr);
Expand All @@ -845,38 +617,7 @@ py::handle GenericProtoCast(Message* src, py::return_value_policy policy,
// 1. The binary does not have a py_proto_api instance, or
// 2. a) the proto is from the default pool and
// b) the binary is not using fast_cpp_protos.
#if ! defined(PYBIND11_PROTOBUF_ENABLE_PYPROTO_API)
return GenericPyProtoCast(src, policy, parent, is_const);
#else
if (GlobalState::instance()->py_proto_api() == nullptr ||
(src->GetDescriptor()->file()->pool() ==
DescriptorPool::generated_pool() &&
!GlobalState::instance()->using_fast_cpp())) {
return GenericPyProtoCast(src, policy, parent, is_const);
}

std::optional<std::string> unknown_field_message =
check_unknown_fields::CheckRecursively(
GlobalState::instance()->py_proto_api(), src);
if (unknown_field_message) {
if (check_unknown_fields::ExtensionsWithUnknownFieldsPolicy::
UnknownFieldsAreDisallowed()) {
throw py::value_error(*unknown_field_message);
}
// Emit one LOG(WARNING) per unique unknown_field_message:
static auto fall_back_log_shown = new std::unordered_set<std::string>();
if (fall_back_log_shown->insert(*unknown_field_message).second) {
LOG(WARNING) << "FALL BACK TO PROTOBUF SERIALIZE/PARSE: "
<< *unknown_field_message;
}
return GenericPyProtoCast(src, policy, parent, is_const);
}

// If this is a dynamically generated proto, then we're going to need to
// construct a mapping between C++ pool() and python pool(), and then
// use the PyProto_API to make it work.
return GenericFastCppProtoCast(src, policy, parent, is_const);
#endif
}

} // namespace pybind11_protobuf

0 comments on commit 2f79409

Please sign in to comment.