diff --git a/src/google/protobuf/generated_message_tctable_impl.h b/src/google/protobuf/generated_message_tctable_impl.h index 42cb26ff3d51..0bd0d1e1e673 100644 --- a/src/google/protobuf/generated_message_tctable_impl.h +++ b/src/google/protobuf/generated_message_tctable_impl.h @@ -805,22 +805,24 @@ Parse64FallbackPair(const char* p, int64_t res1) { // correctly, so all we have to do is check that the expected case is true. if (PROTOBUF_PREDICT_TRUE(ptr[9] == 1)) goto done10; - // A value of 0, however, represents an over-serialized varint. This case - // should not happen, but if does (say, due to a nonconforming serializer), - // deassert the continuation bit that came from ptr[8]. - if (ptr[9] == 0) { + if (PROTOBUF_PREDICT_FALSE(ptr[9] & 0x80)) { + // If the continue bit is set, it is an unterminated varint. + return {nullptr, 0}; + } + + // A zero value of the first bit of the 10th byte represents an + // over-serialized varint. This case should not happen, but if does (say, due + // to a nonconforming serializer), deassert the continuation bit that came + // from ptr[8]. + if ((ptr[9] & 1) == 0) { #if defined(__GCC_ASM_FLAG_OUTPUTS__) && defined(__x86_64__) // Use a small instruction since this is an uncommon code path. asm("btcq $63,%0" : "+r"(res3)); #else res3 ^= static_cast(1) << 63; #endif - goto done10; } - - // If the 10th byte/ptr[9] itself has any other value, then it is too big to - // fit in 64 bits. If the continue bit is set, it is an unterminated varint. - return {nullptr, 0}; + goto done10; done2: return {p + 2, res1 & res2}; @@ -963,7 +965,7 @@ PROTOBUF_NOINLINE const char* TcParser::FastTV32S1(PROTOBUF_TC_PARAM_DECL) { if (PROTOBUF_PREDICT_FALSE(ptr[6] & 0x80)) { if (PROTOBUF_PREDICT_FALSE(ptr[7] & 0x80)) { if (PROTOBUF_PREDICT_FALSE(ptr[8] & 0x80)) { - if (ptr[9] & 0xFE) return Error(PROTOBUF_TC_PARAM_PASS); + if (ptr[9] & 0x80) return Error(PROTOBUF_TC_PARAM_PASS); *out = RotateLeft(res, 28); ptr += 10; PROTOBUF_MUSTTAIL return ToTagDispatch( diff --git a/src/google/protobuf/generated_message_tctable_lite.cc b/src/google/protobuf/generated_message_tctable_lite.cc index 622702658815..6520894a2be7 100644 --- a/src/google/protobuf/generated_message_tctable_lite.cc +++ b/src/google/protobuf/generated_message_tctable_lite.cc @@ -751,7 +751,9 @@ inline PROTOBUF_ALWAYS_INLINE const char* ParseVarint(const char* p, if (PROTOBUF_PREDICT_FALSE(byte & 0x80)) { byte = (byte - 0x80) | *p++; if (PROTOBUF_PREDICT_FALSE(byte & 0x80)) { - byte = (byte - 0x80) | *p++; + // We only care about the continuation bit and the first bit + // of the 10th byte. + byte = (byte - 0x80) | (*p++ & 0x81); if (PROTOBUF_PREDICT_FALSE(byte & 0x80)) { return nullptr; } diff --git a/src/google/protobuf/generated_message_tctable_lite_test.cc b/src/google/protobuf/generated_message_tctable_lite_test.cc index 85ffc3ff7ba7..b1c960daf8b4 100644 --- a/src/google/protobuf/generated_message_tctable_lite_test.cc +++ b/src/google/protobuf/generated_message_tctable_lite_test.cc @@ -126,6 +126,7 @@ TEST(FastVarints, NameHere) { uint8_t serialize_buffer[64]; for (int size : {8, 32, 64, -8, -32, -64}) { + SCOPED_TRACE(size); auto next_i = [](uint64_t i) { // if i + 1 is a power of two, return that. // (This will also match when i == -1, but for this loop we know that will @@ -136,28 +137,48 @@ TEST(FastVarints, NameHere) { return i + (i - 1); }; for (uint64_t i = 0; i + 1 != 0; i = next_i(i)) { - char fake_msg[64] = { - kDND, kDND, kDND, kDND, kDND, kDND, kDND, kDND, // - kDND, kDND, kDND, kDND, kDND, kDND, kDND, kDND, // - kDND, kDND, kDND, kDND, kDND, kDND, kDND, kDND, // - kDND, kDND, kDND, kDND, kDND, kDND, kDND, kDND, // - kDND, kDND, kDND, kDND, kDND, kDND, kDND, kDND, // - kDND, kDND, kDND, kDND, kDND, kDND, kDND, kDND, // - kDND, kDND, kDND, kDND, kDND, kDND, kDND, kDND, // - kDND, kDND, kDND, kDND, kDND, kDND, kDND, kDND, // - }; - memset(&fake_msg[kHasBitsOffset], 0, sizeof(uint32_t)); - - auto serialize_ptr = WireFormatLite::WriteUInt64ToArray( - /* field_number= */ 1, i, serialize_buffer); - absl::string_view serialized{ - reinterpret_cast(&serialize_buffer[0]), - static_cast(serialize_ptr - serialize_buffer)}; - - const char* ptr = nullptr; - const char* end_ptr = nullptr; - ParseContext ctx(io::CodedInputStream::GetDefaultRecursionLimit(), - /* aliasing= */ false, &ptr, serialized); + SCOPED_TRACE(i); + enum OverlongKind { kNotOverlong, kOverlong, kOverlongWithDroppedBits }; + for (OverlongKind overlong : + {kNotOverlong, kOverlong, kOverlongWithDroppedBits}) { + SCOPED_TRACE(overlong); + alignas(16) char fake_msg[64] = { + kDND, kDND, kDND, kDND, kDND, kDND, kDND, kDND, // + kDND, kDND, kDND, kDND, kDND, kDND, kDND, kDND, // + kDND, kDND, kDND, kDND, kDND, kDND, kDND, kDND, // + kDND, kDND, kDND, kDND, kDND, kDND, kDND, kDND, // + kDND, kDND, kDND, kDND, kDND, kDND, kDND, kDND, // + kDND, kDND, kDND, kDND, kDND, kDND, kDND, kDND, // + kDND, kDND, kDND, kDND, kDND, kDND, kDND, kDND, // + kDND, kDND, kDND, kDND, kDND, kDND, kDND, kDND, // + }; + memset(&fake_msg[kHasBitsOffset], 0, sizeof(uint32_t)); + + auto serialize_ptr = WireFormatLite::WriteUInt64ToArray( + /* field_number= */ 1, i, serialize_buffer); + + if (overlong == kOverlong || overlong == kOverlongWithDroppedBits) { + // 1 for the tag plus 10 for the value + while (serialize_ptr - serialize_buffer < 11) { + serialize_ptr[-1] |= 0x80; + *serialize_ptr++ = 0; + } + if (overlong == kOverlongWithDroppedBits) { + // For this one we add some unused bits to the last byte. + // They should be dropped. Bits 1-6 are dropped. Bit 0 is used and + // bit 7 is checked for continuation. + serialize_ptr[-1] |= 0b0111'1110; + } + } + + absl::string_view serialized{ + reinterpret_cast(&serialize_buffer[0]), + static_cast(serialize_ptr - serialize_buffer)}; + + const char* ptr = nullptr; + const char* end_ptr = nullptr; + ParseContext ctx(io::CodedInputStream::GetDefaultRecursionLimit(), + /* aliasing= */ false, &ptr, serialized); #if 0 // FOR_DEBUGGING GOOGLE_ABSL_LOG(ERROR) << "size=" << size << " i=" << i << " ptr points to " // << +ptr[0] << "," << +ptr[1] << "," // @@ -166,84 +187,88 @@ TEST(FastVarints, NameHere) { << +ptr[6] << "," << +ptr[7] << "," // << +ptr[8] << "," << +ptr[9] << "," << +ptr[10] << "\n"; #endif - TailCallParseFunc fn = nullptr; - switch (size) { - case 8: - fn = &TcParser::FastV8S1; - break; - case -8: - fn = &TcParser::FastTV8S1; - break; - case 32: - fn = &TcParser::FastV32S1; - break; - case -32: - fn = &TcParser::FastTV32S1; - break; - case 64: - fn = &TcParser::FastV64S1; - break; - case -64: - fn = &TcParser::FastTV64S1; - break; - } - fallback_ptr_received = absl::nullopt; - fallback_hasbits_received = absl::nullopt; - fallback_tag_received = absl::nullopt; - end_ptr = fn(reinterpret_cast(fake_msg), ptr, &ctx, - Xor2SerializedBytes(parse_table.fast_entries[0].bits, ptr), - &parse_table.header, /*hasbits=*/0); - switch (size) { - case -8: - case 8: { - if (end_ptr == nullptr) { - // If end_ptr is nullptr, that means the FastParser gave up and - // tried to pass control to MiniParse.... which is expected anytime - // we encounter something other than 0 or 1 encodings. (Since - // FastV8S1 is only used for `bool` fields.) - EXPECT_NE(i, true); - EXPECT_NE(i, false); - EXPECT_THAT(fallback_hasbits_received, Optional(0)); - // Like the mini-parser functions, and unlike the fast-parser - // functions, the fallback receives a ptr already incremented past - // the tag, and receives the actual tag in the `data` parameter. - EXPECT_THAT(fallback_ptr_received, Optional(ptr + 1)); - EXPECT_THAT(fallback_tag_received, Optional(0x7F & *ptr)); - continue; - } - ASSERT_EQ(end_ptr - ptr, serialized.size()); - - auto actual_field = ReadAndReset(&fake_msg[kFieldOffset]); - EXPECT_EQ(actual_field, static_cast(i)) // - << " hex: " << absl::StrCat(absl::Hex(actual_field)); - }; break; - case -32: - case 32: { - ASSERT_EQ(end_ptr - ptr, serialized.size()); - - auto actual_field = ReadAndReset(&fake_msg[kFieldOffset]); - EXPECT_EQ(actual_field, static_cast(i)) // - << " hex: " << absl::StrCat(absl::Hex(actual_field)); - }; break; - case -64: - case 64: { - ASSERT_EQ(end_ptr - ptr, serialized.size()); - - auto actual_field = ReadAndReset(&fake_msg[kFieldOffset]); - EXPECT_EQ(actual_field, static_cast(i)) // - << " hex: " << absl::StrCat(absl::Hex(actual_field)); - }; break; - } - EXPECT_TRUE(!fallback_ptr_received); - EXPECT_TRUE(!fallback_hasbits_received); - EXPECT_TRUE(!fallback_tag_received); - auto hasbits = ReadAndReset(&fake_msg[kHasBitsOffset]); - EXPECT_EQ(hasbits, 1 << kHasBitIndex); - - int offset = 0; - for (char ch : fake_msg) { - EXPECT_EQ(ch, kDND) << " corruption of message at offset " << offset; - ++offset; + TailCallParseFunc fn = nullptr; + switch (size) { + case 8: + fn = &TcParser::FastV8S1; + break; + case -8: + fn = &TcParser::FastTV8S1; + break; + case 32: + fn = &TcParser::FastV32S1; + break; + case -32: + fn = &TcParser::FastTV32S1; + break; + case 64: + fn = &TcParser::FastV64S1; + break; + case -64: + fn = &TcParser::FastTV64S1; + break; + } + fallback_ptr_received = absl::nullopt; + fallback_hasbits_received = absl::nullopt; + fallback_tag_received = absl::nullopt; + end_ptr = fn(reinterpret_cast(fake_msg), ptr, &ctx, + Xor2SerializedBytes(parse_table.fast_entries[0].bits, ptr), + &parse_table.header, /*hasbits=*/0); + switch (size) { + case -8: + case 8: { + if (end_ptr == nullptr) { + // If end_ptr is nullptr, that means the FastParser gave up and + // tried to pass control to MiniParse.... which is expected + // anytime we encounter something other than 0 or 1 encodings. + // (Since FastV8S1 is only used for `bool` fields.) + if (overlong == kNotOverlong) { + EXPECT_NE(i, true); + EXPECT_NE(i, false); + } + EXPECT_THAT(fallback_hasbits_received, Optional(0)); + // Like the mini-parser functions, and unlike the fast-parser + // functions, the fallback receives a ptr already incremented past + // the tag, and receives the actual tag in the `data` parameter. + EXPECT_THAT(fallback_ptr_received, Optional(ptr + 1)); + EXPECT_THAT(fallback_tag_received, Optional(0x7F & *ptr)); + continue; + } + ASSERT_EQ(end_ptr - ptr, serialized.size()); + + auto actual_field = ReadAndReset(&fake_msg[kFieldOffset]); + EXPECT_EQ(actual_field, static_cast(i)) // + << " hex: " << absl::StrCat(absl::Hex(actual_field)); + }; break; + case -32: + case 32: { + ASSERT_TRUE(end_ptr); + ASSERT_EQ(end_ptr - ptr, serialized.size()); + + auto actual_field = ReadAndReset(&fake_msg[kFieldOffset]); + EXPECT_EQ(actual_field, static_cast(i)) // + << " hex: " << absl::StrCat(absl::Hex(actual_field)); + }; break; + case -64: + case 64: { + ASSERT_EQ(end_ptr - ptr, serialized.size()); + + auto actual_field = ReadAndReset(&fake_msg[kFieldOffset]); + EXPECT_EQ(actual_field, static_cast(i)) // + << " hex: " << absl::StrCat(absl::Hex(actual_field)); + }; break; + } + EXPECT_TRUE(!fallback_ptr_received); + EXPECT_TRUE(!fallback_hasbits_received); + EXPECT_TRUE(!fallback_tag_received); + auto hasbits = ReadAndReset(&fake_msg[kHasBitsOffset]); + EXPECT_EQ(hasbits, 1 << kHasBitIndex); + + int offset = 0; + for (char ch : fake_msg) { + EXPECT_EQ(ch, kDND) << " corruption of message at offset " << offset; + ++offset; + } } } } diff --git a/src/google/protobuf/message_unittest.inc b/src/google/protobuf/message_unittest.inc index df9a9572ec30..b69a50b940cf 100644 --- a/src/google/protobuf/message_unittest.inc +++ b/src/google/protobuf/message_unittest.inc @@ -1176,6 +1176,12 @@ TEST(MESSAGE_TEST_NAME, PreservesFloatingPointNegative0) { std::signbit(out_message.optional_double())); } +const uint8_t* SkipTag(const uint8_t* buf) { + while (*buf & 0x80) ++buf; + ++buf; + return buf; +} + // Adds `non_canonical_bytes` bytes to the varint representation at the tail of // the buffer. // `buf` points to the start of the buffer, `p` points to one-past-the-end. @@ -1208,7 +1214,7 @@ std::string EncodeEnumValue(int number, int value, int non_canonical_bytes, } else { p = internal::WireFormatLite::WriteEnumToArray(number, value, p); - p = AddNonCanonicalBytes(buf, p, non_canonical_bytes); + p = AddNonCanonicalBytes(SkipTag(buf), p, non_canonical_bytes); return std::string(buf, p); } } @@ -1257,27 +1263,136 @@ TEST(MESSAGE_TEST_NAME, TestEnumParsers) { SCOPED_TRACE(use_packed); for (bool use_tail_field : {false, true}) { SCOPED_TRACE(use_tail_field); - for (int non_canonical_bytes = 0; non_canonical_bytes < 5; + for (int non_canonical_bytes = 0; non_canonical_bytes < 9; ++non_canonical_bytes) { SCOPED_TRACE(non_canonical_bytes); + for (bool add_garbage_bits : {false, true}) { + if (add_garbage_bits && non_canonical_bytes != 9) { + // We only add garbage on the 10th byte. + continue; + } + SCOPED_TRACE(add_garbage_bits); + for (int i = 0; i < descriptor->field_count(); ++i) { + const auto* field = descriptor->field(i); + if (field->name() == "other_field") continue; + if (!field->is_repeated() && use_packed) continue; + SCOPED_TRACE(field->full_name()); + const auto* enum_desc = field->enum_type(); + for (int e = 0; e < enum_desc->value_count(); ++e) { + const auto* value_desc = enum_desc->value(e); + if (value_desc->number() < 0 && non_canonical_bytes > 0) { + // Negative numbers only have a canonical representation. + continue; + } + SCOPED_TRACE(value_desc->number()); + GOOGLE_ABSL_CHECK_NE(value_desc->number(), kInvalidValue) + << "Invalid value is a real label."; + auto encoded = + EncodeEnumValue(field->number(), value_desc->number(), + non_canonical_bytes, use_packed); + if (add_garbage_bits) { + // These bits should be discarded even in the `false` case. + encoded.back() |= 0b0111'1110; + } + if (use_tail_field) { + // Make sure that fields after this one can be parsed too. ie + // test that the "next" jump is correct too. + encoded += other_field; + } + + EXPECT_TRUE(obj.ParseFromString(encoded)); + if (field->is_repeated()) { + ASSERT_EQ(ref->FieldSize(obj, field), 1); + EXPECT_EQ(ref->GetRepeatedEnumValue(obj, field, 0), + value_desc->number()); + } else { + EXPECT_TRUE(ref->HasField(obj, field)); + EXPECT_EQ(ref->GetEnumValue(obj, field), value_desc->number()); + } + auto& unknown = ref->GetUnknownFields(obj); + ASSERT_EQ(unknown.field_count(), 0); + } + + { + SCOPED_TRACE("Invalid value"); + // Try an invalid value, which should go to the unknown fields. + EXPECT_TRUE(obj.ParseFromString( + EncodeEnumValue(field->number(), kInvalidValue, + non_canonical_bytes, use_packed))); + if (field->is_repeated()) { + ASSERT_EQ(ref->FieldSize(obj, field), 0); + } else { + EXPECT_FALSE(ref->HasField(obj, field)); + EXPECT_EQ(ref->GetEnumValue(obj, field), + enum_desc->value(0)->number()); + } + auto& unknown = ref->GetUnknownFields(obj); + ASSERT_EQ(unknown.field_count(), 1); + EXPECT_EQ(unknown.field(0).number(), field->number()); + EXPECT_EQ(unknown.field(0).type(), unknown.field(0).TYPE_VARINT); + EXPECT_EQ(unknown.field(0).varint(), kInvalidValue); + } + { + SCOPED_TRACE("Overlong varint"); + // Try an overlong varint. It should fail parsing, but not trigger + // any sanitizer warning. + EXPECT_FALSE(obj.ParseFromString( + EncodeOverlongEnum(field->number(), use_packed))); + } + } + } + } + } + } +} + +std::string EncodeBoolValue(int number, bool value, int non_canonical_bytes) { + uint8_t buf[100]; + uint8_t* p = buf; + + p = internal::WireFormatLite::WriteBoolToArray(number, value, p); + p = AddNonCanonicalBytes(SkipTag(buf), p, non_canonical_bytes); + return std::string(buf, p); +} + +TEST(MESSAGE_TEST_NAME, TestBoolParsers) { + UNITTEST::BoolParseTester obj; + + const auto other_field = EncodeOtherField(); + + // Encode a boolean field for many different cases and verify that it can be + // parsed as expected. + // There are: + // - optional/repeated/packed fields + // - field tags that encode in 1/2/3 bytes + // - canonical and non-canonical encodings of the varint + // - last vs not last field + + auto* ref = obj.GetReflection(); + auto* descriptor = obj.descriptor(); + for (bool use_tail_field : {false, true}) { + SCOPED_TRACE(use_tail_field); + for (int non_canonical_bytes = 0; non_canonical_bytes < 10; + ++non_canonical_bytes) { + SCOPED_TRACE(non_canonical_bytes); + for (bool add_garbage_bits : {false, true}) { + if (add_garbage_bits && non_canonical_bytes != 9) { + // We only add garbage on the 10th byte. + continue; + } + SCOPED_TRACE(add_garbage_bits); for (int i = 0; i < descriptor->field_count(); ++i) { const auto* field = descriptor->field(i); if (field->name() == "other_field") continue; - if (!field->is_repeated() && use_packed) continue; SCOPED_TRACE(field->full_name()); - const auto* enum_desc = field->enum_type(); - for (int e = 0; e < enum_desc->value_count(); ++e) { - const auto* value_desc = enum_desc->value(e); - if (value_desc->number() < 0 && non_canonical_bytes > 0) { - // Negative numbers only have a canonical representation. - continue; - } - SCOPED_TRACE(value_desc->number()); - GOOGLE_ABSL_CHECK_NE(value_desc->number(), kInvalidValue) - << "Invalid value is a real label."; + for (bool value : {false, true}) { + SCOPED_TRACE(value); auto encoded = - EncodeEnumValue(field->number(), value_desc->number(), - non_canonical_bytes, use_packed); + EncodeBoolValue(field->number(), value, non_canonical_bytes); + if (add_garbage_bits) { + // These bits should be discarded even in the `false` case. + encoded.back() |= 0b0111'1110; + } if (use_tail_field) { // Make sure that fields after this one can be parsed too. ie test // that the "next" jump is correct too. @@ -1287,41 +1402,87 @@ TEST(MESSAGE_TEST_NAME, TestEnumParsers) { EXPECT_TRUE(obj.ParseFromString(encoded)); if (field->is_repeated()) { ASSERT_EQ(ref->FieldSize(obj, field), 1); - EXPECT_EQ(ref->GetRepeatedEnumValue(obj, field, 0), - value_desc->number()); + EXPECT_EQ(ref->GetRepeatedBool(obj, field, 0), value); } else { EXPECT_TRUE(ref->HasField(obj, field)); - EXPECT_EQ(ref->GetEnumValue(obj, field), value_desc->number()); + EXPECT_EQ(ref->GetBool(obj, field), value) + << testing::PrintToString(encoded); } auto& unknown = ref->GetUnknownFields(obj); ASSERT_EQ(unknown.field_count(), 0); } + } + } + } + } +} + +std::string EncodeInt32Value(int number, int32_t value, + int non_canonical_bytes) { + uint8_t buf[100]; + uint8_t* p = buf; + + p = internal::WireFormatLite::WriteInt32ToArray(number, value, p); + p = AddNonCanonicalBytes(SkipTag(buf), p, non_canonical_bytes); + return std::string(buf, p); +} + +TEST(MESSAGE_TEST_NAME, TestInt32Parsers) { + UNITTEST::Int32ParseTester obj; + + const auto other_field = EncodeOtherField(); + + // Encode an int32 field for many different cases and verify that it can be + // parsed as expected. + // There are: + // - optional/repeated/packed fields + // - field tags that encode in 1/2/3 bytes + // - canonical and non-canonical encodings of the varint + // - last vs not last field + + auto* ref = obj.GetReflection(); + auto* descriptor = obj.descriptor(); + for (bool use_tail_field : {false, true}) { + SCOPED_TRACE(use_tail_field); + for (int non_canonical_bytes = 0; non_canonical_bytes < 10; + ++non_canonical_bytes) { + SCOPED_TRACE(non_canonical_bytes); + for (bool add_garbage_bits : {false, true}) { + if (add_garbage_bits && non_canonical_bytes != 9) { + // We only add garbage on the 10th byte. + continue; + } + SCOPED_TRACE(add_garbage_bits); + for (int i = 0; i < descriptor->field_count(); ++i) { + const auto* field = descriptor->field(i); + if (field->name() == "other_field") continue; + SCOPED_TRACE(field->full_name()); + for (int32_t value : {1, 0, -1, (std::numeric_limits::min)(), + (std::numeric_limits::max)()}) { + SCOPED_TRACE(value); + auto encoded = + EncodeInt32Value(field->number(), value, non_canonical_bytes); + if (add_garbage_bits) { + // These bits should be discarded even in the `false` case. + encoded.back() |= 0b0111'1110; + } + if (use_tail_field) { + // Make sure that fields after this one can be parsed too. ie test + // that the "next" jump is correct too. + encoded += other_field; + } - { - SCOPED_TRACE("Invalid value"); - // Try an invalid value, which should go to the unknown fields. - EXPECT_TRUE(obj.ParseFromString( - EncodeEnumValue(field->number(), kInvalidValue, - non_canonical_bytes, use_packed))); + EXPECT_TRUE(obj.ParseFromString(encoded)); if (field->is_repeated()) { - ASSERT_EQ(ref->FieldSize(obj, field), 0); + ASSERT_EQ(ref->FieldSize(obj, field), 1); + EXPECT_EQ(ref->GetRepeatedInt32(obj, field, 0), value); } else { - EXPECT_FALSE(ref->HasField(obj, field)); - EXPECT_EQ(ref->GetEnumValue(obj, field), - enum_desc->value(0)->number()); + EXPECT_TRUE(ref->HasField(obj, field)); + EXPECT_EQ(ref->GetInt32(obj, field), value) + << testing::PrintToString(encoded); } auto& unknown = ref->GetUnknownFields(obj); - ASSERT_EQ(unknown.field_count(), 1); - EXPECT_EQ(unknown.field(0).number(), field->number()); - EXPECT_EQ(unknown.field(0).type(), unknown.field(0).TYPE_VARINT); - EXPECT_EQ(unknown.field(0).varint(), kInvalidValue); - } - { - SCOPED_TRACE("Overlong varint"); - // Try an overlong varint. It should fail parsing, but not trigger - // any sanitizer warning. - EXPECT_FALSE(obj.ParseFromString( - EncodeOverlongEnum(field->number(), use_packed))); + ASSERT_EQ(unknown.field_count(), 0); } } } @@ -1329,21 +1490,22 @@ TEST(MESSAGE_TEST_NAME, TestEnumParsers) { } } -std::string EncodeBoolValue(int number, bool value, int non_canonical_bytes) { +std::string EncodeInt64Value(int number, int64_t value, + int non_canonical_bytes) { uint8_t buf[100]; uint8_t* p = buf; - p = internal::WireFormatLite::WriteBoolToArray(number, value, p); - p = AddNonCanonicalBytes(buf, p, non_canonical_bytes); + p = internal::WireFormatLite::WriteInt64ToArray(number, value, p); + p = AddNonCanonicalBytes(SkipTag(buf), p, non_canonical_bytes); return std::string(buf, p); } -TEST(MESSAGE_TEST_NAME, TestBoolParsers) { - UNITTEST::BoolParseTester obj; +TEST(MESSAGE_TEST_NAME, TestInt64Parsers) { + UNITTEST::Int64ParseTester obj; const auto other_field = EncodeOtherField(); - // Encode a boolean field for many different cases and verify that it can be + // Encode an int64 field for many different cases and verify that it can be // parsed as expected. // There are: // - optional/repeated/packed fields @@ -1358,30 +1520,44 @@ TEST(MESSAGE_TEST_NAME, TestBoolParsers) { for (int non_canonical_bytes = 0; non_canonical_bytes < 10; ++non_canonical_bytes) { SCOPED_TRACE(non_canonical_bytes); - for (int i = 0; i < descriptor->field_count(); ++i) { - const auto* field = descriptor->field(i); - if (field->name() == "other_field") continue; - SCOPED_TRACE(field->full_name()); - for (bool value : {false, true}) { - SCOPED_TRACE(value); - auto encoded = - EncodeBoolValue(field->number(), value, non_canonical_bytes); - if (use_tail_field) { - // Make sure that fields after this one can be parsed too. ie test - // that the "next" jump is correct too. - encoded += other_field; - } + for (bool add_garbage_bits : {false, true}) { + if (add_garbage_bits && non_canonical_bytes != 9) { + // We only add garbage on the 10th byte. + continue; + } + SCOPED_TRACE(add_garbage_bits); + for (int i = 0; i < descriptor->field_count(); ++i) { + const auto* field = descriptor->field(i); + if (field->name() == "other_field") continue; + SCOPED_TRACE(field->full_name()); + for (int64_t value : {int64_t{1}, int64_t{0}, int64_t{-1}, + (std::numeric_limits::min)(), + (std::numeric_limits::max)()}) { + SCOPED_TRACE(value); + auto encoded = + EncodeInt64Value(field->number(), value, non_canonical_bytes); + if (add_garbage_bits) { + // These bits should be discarded even in the `false` case. + encoded.back() |= 0b0111'1110; + } + if (use_tail_field) { + // Make sure that fields after this one can be parsed too. ie test + // that the "next" jump is correct too. + encoded += other_field; + } - EXPECT_TRUE(obj.ParseFromString(encoded)); - if (field->is_repeated()) { - ASSERT_EQ(ref->FieldSize(obj, field), 1); - EXPECT_EQ(ref->GetRepeatedBool(obj, field, 0), value); - } else { - EXPECT_TRUE(ref->HasField(obj, field)); - EXPECT_EQ(ref->GetBool(obj, field), value); + EXPECT_TRUE(obj.ParseFromString(encoded)); + if (field->is_repeated()) { + ASSERT_EQ(ref->FieldSize(obj, field), 1); + EXPECT_EQ(ref->GetRepeatedInt64(obj, field, 0), value); + } else { + EXPECT_TRUE(ref->HasField(obj, field)); + EXPECT_EQ(ref->GetInt64(obj, field), value) + << testing::PrintToString(encoded); + } + auto& unknown = ref->GetUnknownFields(obj); + ASSERT_EQ(unknown.field_count(), 0); } - auto& unknown = ref->GetUnknownFields(obj); - ASSERT_EQ(unknown.field_count(), 0); } } } diff --git a/src/google/protobuf/unittest.proto b/src/google/protobuf/unittest.proto index 0cd30adb9fd5..479262f61272 100644 --- a/src/google/protobuf/unittest.proto +++ b/src/google/protobuf/unittest.proto @@ -1565,6 +1565,36 @@ message BoolParseTester { optional int32 other_field = 99; }; +message Int32ParseTester { + optional int32 optional_int32_lowfield = 1; + optional int32 optional_int32_midfield = 1001; + optional int32 optional_int32_hifield = 1000001; + repeated int32 repeated_int32_lowfield = 2; + repeated int32 repeated_int32_midfield = 1002; + repeated int32 repeated_int32_hifield = 1000002; + repeated int32 packed_int32_lowfield = 3 [packed = true]; + repeated int32 packed_int32_midfield = 1003 [packed = true]; + repeated int32 packed_int32_hifield = 1000003 [packed = true]; + + // An arbitrary field we can append to to break the runs of repeated fields. + optional int32 other_field = 99; +}; + +message Int64ParseTester { + optional int64 optional_int64_lowfield = 1; + optional int64 optional_int64_midfield = 1001; + optional int64 optional_int64_hifield = 1000001; + repeated int64 repeated_int64_lowfield = 2; + repeated int64 repeated_int64_midfield = 1002; + repeated int64 repeated_int64_hifield = 1000002; + repeated int64 packed_int64_lowfield = 3 [packed = true]; + repeated int64 packed_int64_midfield = 1003 [packed = true]; + repeated int64 packed_int64_hifield = 1000003 [packed = true]; + + // An arbitrary field we can append to to break the runs of repeated fields. + optional int32 other_field = 99; +}; + message StringParseTester { optional string optional_string_lowfield = 1; optional string optional_string_midfield = 1001;