Skip to content

Commit

Permalink
Implement enforcement of target constraints
Browse files Browse the repository at this point in the history
This is a new feature allowing fields to be annotated with a `targets` option
specifying what kinds of entities that field may be applied to when used in an
option.

PiperOrigin-RevId: 527990260
  • Loading branch information
acozzette authored and copybara-github committed Apr 28, 2023
1 parent 37dfe80 commit e3848c1
Show file tree
Hide file tree
Showing 4 changed files with 453 additions and 36 deletions.
12 changes: 7 additions & 5 deletions src/google/protobuf/compiler/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@ load(
)
load("@rules_proto//proto:defs.bzl", "proto_library")
load("//build_defs:arch_tests.bzl", "aarch64_test", "x86_64_test")
load("//build_defs:cpp_opts.bzl", "COPTS", "LINK_OPTS")
load("//build_defs:cpp_opts.bzl", "COPTS")

proto_library(
name = "plugin_proto",
srcs = ["plugin.proto"],
strip_import_prefix = "/src",
visibility = [
"//:__pkg__",
"//pkg:__pkg__",
],
strip_import_prefix = "/src",
deps = ["//:descriptor_proto"],
)

Expand Down Expand Up @@ -93,11 +93,14 @@ cc_library(
"//src/google/protobuf:descriptor_legacy",
"//src/google/protobuf:protobuf_nowkt",
"//src/google/protobuf/compiler/allowlists",
"@com_google_absl//absl/algorithm",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
],
)

Expand Down Expand Up @@ -309,22 +312,21 @@ cc_library(
visibility = ["//src/google/protobuf:__subpackages__"],
deps = [
"//src/google/protobuf:protobuf_nowkt",
"@com_google_absl//absl/types:span",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/types:span",
],
)


cc_test(
name = "retention_unittest",
srcs = ["retention_unittest.cc"],
deps = [
":importer",
":retention",
"//src/google/protobuf/io",
"@com_google_absl//absl/log:die_if_null",
"@com_google_googletest//:gtest",
"@com_google_googletest//:gtest_main",
"@com_google_absl//absl/log:die_if_null",
],
)

Expand Down
257 changes: 226 additions & 31 deletions src/google/protobuf/compiler/command_line_interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@

#include "google/protobuf/compiler/command_line_interface.h"

#include "absl/algorithm/container.h"
#include "absl/container/btree_set.h"
#include "absl/container/flat_hash_map.h"
#include "absl/types/span.h"
#include "google/protobuf/compiler/allowlists/allowlists.h"
#include "google/protobuf/descriptor_legacy.h"

Expand Down Expand Up @@ -1028,7 +1030,20 @@ struct VisitImpl {
Visitor visitor;
void Visit(const FieldDescriptor* descriptor) { visitor(descriptor); }

void Visit(const EnumDescriptor* descriptor) { visitor(descriptor); }
void Visit(const EnumValueDescriptor* descriptor) { visitor(descriptor); }

void Visit(const EnumDescriptor* descriptor) {
visitor(descriptor);
for (int i = 0; i < descriptor->value_count(); i++) {
Visit(descriptor->value(i));
}
}

void Visit(const Descriptor::ExtensionRange* descriptor) {
visitor(descriptor);
}

void Visit(const OneofDescriptor* descriptor) { visitor(descriptor); }

void Visit(const Descriptor* descriptor) {
visitor(descriptor);
Expand All @@ -1048,10 +1063,27 @@ struct VisitImpl {
for (int i = 0; i < descriptor->extension_count(); i++) {
Visit(descriptor->extension(i));
}

for (int i = 0; i < descriptor->extension_range_count(); i++) {
Visit(descriptor->extension_range(i));
}

for (int i = 0; i < descriptor->oneof_decl_count(); i++) {
Visit(descriptor->oneof_decl(i));
}
}

void Visit(const std::vector<const FileDescriptor*>& descriptors) {
for (auto* descriptor : descriptors) {
void Visit(const MethodDescriptor* method) { visitor(method); }

void Visit(const ServiceDescriptor* descriptor) {
visitor(descriptor);
for (int i = 0; i < descriptor->method_count(); i++) {
Visit(descriptor->method(i));
}
}

void Visit(absl::Span<const FileDescriptor*> descriptors) {
for (const FileDescriptor* descriptor : descriptors) {
visitor(descriptor);
for (int i = 0; i < descriptor->message_type_count(); i++) {
Visit(descriptor->message_type(i));
Expand All @@ -1062,17 +1094,18 @@ struct VisitImpl {
for (int i = 0; i < descriptor->extension_count(); i++) {
Visit(descriptor->extension(i));
}
for (int i = 0; i < descriptor->service_count(); i++) {
Visit(descriptor->service(i));
}
}
}
};

// Visit every node in the descriptors calling `visitor(node)`.
// The visitor does not need to handle all possible node types. Types that are
// not visitable via `visitor` will be ignored.
// Disclaimer: this is not fully implemented yet to visit _every_ node.
// VisitImpl might need to be updated where needs arise.
template <typename Visitor>
void VisitDescriptors(const std::vector<const FileDescriptor*>& descriptors,
void VisitDescriptors(absl::Span<const FileDescriptor*> descriptors,
Visitor visitor) {
// Provide a fallback to ignore all the nodes that are not interesting to the
// input visitor.
Expand All @@ -1099,8 +1132,151 @@ bool HasReservedFieldNumber(const FieldDescriptor* field) {
namespace {
std::unique_ptr<SimpleDescriptorDatabase>
PopulateSingleSimpleDescriptorDatabase(const std::string& descriptor_set_name);

// Indicates whether the field is compatible with the given target type.
bool IsFieldCompatible(const FieldDescriptor& field,
FieldOptions::OptionTargetType target_type) {
const RepeatedField<int>& allowed_targets = field.options().targets();
return allowed_targets.empty() ||
absl::c_linear_search(allowed_targets, target_type);
}

// Converts the OptionTargetType enum to a string suitable for use in error
// messages.
absl::string_view TargetTypeString(FieldOptions::OptionTargetType target_type) {
switch (target_type) {
case FieldOptions::TARGET_TYPE_FILE:
return "file";
case FieldOptions::TARGET_TYPE_EXTENSION_RANGE:
return "extension range";
case FieldOptions::TARGET_TYPE_MESSAGE:
return "message";
case FieldOptions::TARGET_TYPE_FIELD:
return "field";
case FieldOptions::TARGET_TYPE_ONEOF:
return "oneof";
case FieldOptions::TARGET_TYPE_ENUM:
return "enum";
case FieldOptions::TARGET_TYPE_ENUM_ENTRY:
return "enum entry";
case FieldOptions::TARGET_TYPE_SERVICE:
return "service";
case FieldOptions::TARGET_TYPE_METHOD:
return "method";
default:
return "unknown";
}
}

// Recursively validates that the options message (or subpiece of an options
// message) is compatible with the given target type.
bool ValidateTargetConstraintsRecursive(
const Message& m, DescriptorPool::ErrorCollector& error_collector,
absl::string_view file_name, FieldOptions::OptionTargetType target_type) {
std::vector<const FieldDescriptor*> fields;
const Reflection* reflection = m.GetReflection();
reflection->ListFields(m, &fields);
bool success = true;
for (const auto* field : fields) {
if (!IsFieldCompatible(*field, target_type)) {
success = false;
error_collector.RecordError(
file_name, "", nullptr, DescriptorPool::ErrorCollector::OPTION_NAME,
absl::StrCat("Option ", field->full_name(),
" cannot be set on an entity of type `",
TargetTypeString(target_type), "`."));
}
if (field->type() == FieldDescriptor::TYPE_MESSAGE) {
if (field->is_repeated()) {
int field_size = reflection->FieldSize(m, field);
for (int i = 0; i < field_size; ++i) {
if (!ValidateTargetConstraintsRecursive(
reflection->GetRepeatedMessage(m, field, i), error_collector,
file_name, target_type)) {
success = false;
}
}
} else if (!ValidateTargetConstraintsRecursive(
reflection->GetMessage(m, field), error_collector,
file_name, target_type)) {
success = false;
}
}
}
return success;
}

// Validates that the options message is correct with respect to target
// constraints, returning true if successful. This function converts the
// options message to a DynamicMessage so that we have visibility into custom
// options. We take the element name as a FunctionRef so that we do not have to
// pay the cost of constructing it unless there is an error.
bool ValidateTargetConstraints(const Message& options,
const DescriptorPool& pool,
DescriptorPool::ErrorCollector& error_collector,
absl::string_view file_name,
FieldOptions::OptionTargetType target_type) {
const Descriptor* descriptor =
pool.FindMessageTypeByName(options.GetTypeName());
if (descriptor == nullptr) {
// We were unable to find the options message in the descriptor pool. This
// implies that the proto files we are working with do not depend on
// descriptor.proto, in which case there are no custom options to worry
// about. We can therefore skip the use of DynamicMessage.
return ValidateTargetConstraintsRecursive(options, error_collector,
file_name, target_type);
} else {
DynamicMessageFactory factory;
std::unique_ptr<Message> dynamic_message(
factory.GetPrototype(descriptor)->New());
std::string serialized;
ABSL_CHECK(options.SerializeToString(&serialized));
ABSL_CHECK(dynamic_message->ParseFromString(serialized));
return ValidateTargetConstraintsRecursive(*dynamic_message, error_collector,
file_name, target_type);
}
}

// The overloaded GetTargetType() functions below allow us to map from a
// descriptor type to the associated OptionTargetType enum.
FieldOptions::OptionTargetType GetTargetType(const FileDescriptor*) {
return FieldOptions::TARGET_TYPE_FILE;
}

FieldOptions::OptionTargetType GetTargetType(
const Descriptor::ExtensionRange*) {
return FieldOptions::TARGET_TYPE_EXTENSION_RANGE;
}

FieldOptions::OptionTargetType GetTargetType(const Descriptor*) {
return FieldOptions::TARGET_TYPE_MESSAGE;
}

FieldOptions::OptionTargetType GetTargetType(const FieldDescriptor*) {
return FieldOptions::TARGET_TYPE_FIELD;
}

FieldOptions::OptionTargetType GetTargetType(const OneofDescriptor*) {
return FieldOptions::TARGET_TYPE_ONEOF;
}

FieldOptions::OptionTargetType GetTargetType(const EnumDescriptor*) {
return FieldOptions::TARGET_TYPE_ENUM;
}

FieldOptions::OptionTargetType GetTargetType(const EnumValueDescriptor*) {
return FieldOptions::TARGET_TYPE_ENUM_ENTRY;
}

FieldOptions::OptionTargetType GetTargetType(const ServiceDescriptor*) {
return FieldOptions::TARGET_TYPE_SERVICE;
}

FieldOptions::OptionTargetType GetTargetType(const MethodDescriptor*) {
return FieldOptions::TARGET_TYPE_METHOD;
}
} // namespace

int CommandLineInterface::Run(int argc, const char* const argv[]) {
Clear();

Expand Down Expand Up @@ -1189,31 +1365,50 @@ int CommandLineInterface::Run(int argc, const char* const argv[]) {

bool validation_error = false; // Defer exiting so we log more warnings.

VisitDescriptors(parsed_files, [&](const FieldDescriptor* field) {
if (HasReservedFieldNumber(field)) {
const char* error_link = nullptr;
validation_error = true;
std::string error;
if (field->number() >= FieldDescriptor::kFirstReservedNumber &&
field->number() <= FieldDescriptor::kLastReservedNumber) {
error = absl::Substitute(
"Field numbers $0 through $1 are reserved "
"for the protocol buffer library implementation.",
FieldDescriptor::kFirstReservedNumber,
FieldDescriptor::kLastReservedNumber);
} else {
error = absl::Substitute(
"Field number $0 is reserved for specific purposes.",
field->number());
}
if (error_link) {
absl::StrAppend(&error, "(See ", error_link, ")");
}
static_cast<DescriptorPool::ErrorCollector*>(error_collector.get())
->RecordError(field->file()->name(), field->full_name(), nullptr,
DescriptorPool::ErrorCollector::NUMBER, error);
}
});
VisitDescriptors(
absl::Span<const FileDescriptor*>(parsed_files.data(),
parsed_files.size()),
[&](const FieldDescriptor* field) {
if (HasReservedFieldNumber(field)) {
const char* error_link = nullptr;
validation_error = true;
std::string error;
if (field->number() >= FieldDescriptor::kFirstReservedNumber &&
field->number() <= FieldDescriptor::kLastReservedNumber) {
error = absl::Substitute(
"Field numbers $0 through $1 are reserved "
"for the protocol buffer library implementation.",
FieldDescriptor::kFirstReservedNumber,
FieldDescriptor::kLastReservedNumber);
} else {
error = absl::Substitute(
"Field number $0 is reserved for specific purposes.",
field->number());
}
if (error_link) {
absl::StrAppend(&error, "(See ", error_link, ")");
}
static_cast<DescriptorPool::ErrorCollector*>(error_collector.get())
->RecordError(field->file()->name(), field->full_name(), nullptr,
DescriptorPool::ErrorCollector::NUMBER, error);
}
});

// We visit one file at a time because we need to provide the file name for
// error messages. Usually we can get the file name from any descriptor with
// something like descriptor->file()->name(), but ExtensionRange does not
// support this.
for (const google::protobuf::FileDescriptor* file : parsed_files) {
VisitDescriptors(
absl::Span<const FileDescriptor*>(&file, 1),
[&](const auto* descriptor) {
if (!ValidateTargetConstraints(
descriptor->options(), *descriptor_pool, *error_collector,
file->name(), GetTargetType(descriptor))) {
validation_error = true;
}
});
}


if (validation_error) {
Expand Down
Loading

0 comments on commit e3848c1

Please sign in to comment.