Skip to content

Commit

Permalink
Use the "shldq" decoder for the specialized 64-bit Varint parsers, ra…
Browse files Browse the repository at this point in the history
…ther than

using the "RotRight7" decoder.  The "shldq" technique is much faster on recent
Intel and AMD CPUs, when processing larger integers, especially on Zen.

PiperOrigin-RevId: 498078103
  • Loading branch information
protobuf-github-bot authored and copybara-github committed Dec 28, 2022
1 parent b3ec9ec commit 0ca97a1
Show file tree
Hide file tree
Showing 3 changed files with 427 additions and 170 deletions.
257 changes: 206 additions & 51 deletions src/google/protobuf/generated_message_tctable_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,16 +259,12 @@ enum FieldType : uint16_t {
} // namespace field_layout

#ifndef NDEBUG
template <size_t align>
void AlignFail(uintptr_t address) {
GOOGLE_ABSL_LOG(FATAL) << "Unaligned (" << align << ") access at " << address;

// Explicit abort to let compilers know this function does not return
abort();
}

extern template void AlignFail<4>(uintptr_t);
extern template void AlignFail<8>(uintptr_t);
PROTOBUF_EXPORT void AlignFail(std::integral_constant<size_t, 4>,
std::uintptr_t address);
PROTOBUF_EXPORT void AlignFail(std::integral_constant<size_t, 8>,
std::uintptr_t address);
inline void AlignFail(std::integral_constant<size_t, 1>,
std::uintptr_t address) {}
#endif

// TcParser implements most of the parsing logic for tailcall tables.
Expand Down Expand Up @@ -365,29 +361,39 @@ class PROTOBUF_EXPORT TcParser final {

// Manually unrolled and specialized Varint parsing.
template <typename FieldType, int data_offset, int hasbit_idx>
static const char* SpecializedUnrolledVImpl1(PROTOBUF_TC_PARAM_DECL);
static const char* FastTV32S1(PROTOBUF_TC_PARAM_DECL);
template <typename FieldType, int data_offset, int hasbit_idx>
static const char* FastTV64S1(PROTOBUF_TC_PARAM_DECL);
template <int data_offset, int hasbit_idx>
static const char* SpecializedFastV8S1(PROTOBUF_TC_PARAM_DECL);
static const char* FastTV8S1(PROTOBUF_TC_PARAM_DECL);

template <typename FieldType, int data_offset, int hasbit_idx>
static constexpr TailCallParseFunc SingularVarintNoZag1() {
if (data_offset < 100) {
if (sizeof(FieldType) == 1) {
return &SpecializedFastV8S1<data_offset, hasbit_idx>;
if (sizeof(FieldType) == 1) {
if (data_offset < 100) {
return &FastTV8S1<data_offset, hasbit_idx>;
} else {
return &FastV8S1;
}
}
if (sizeof(FieldType) == 4) {
if (data_offset < 100) {
return &FastTV32S1<FieldType, data_offset, hasbit_idx>;
} else { //
return &FastV32S1;
}
}
if (sizeof(FieldType) == 8) {
if (data_offset < 128) {
return &FastTV64S1<FieldType, data_offset, hasbit_idx>;
} else {
return &FastV64S1;
}
return &SpecializedUnrolledVImpl1<FieldType, data_offset, hasbit_idx>;
} else if (sizeof(FieldType) == 1) {
return &FastV8S1;
} else if (sizeof(FieldType) == 4) {
return &FastV32S1;
} else if (sizeof(FieldType) == 8) {
return &FastV64S1;
} else {
static_assert(sizeof(FieldType) == 1 || sizeof(FieldType) == 4 ||
sizeof(FieldType) == 8,
"");
return nullptr;
}
static_assert(sizeof(FieldType) == 1 || sizeof(FieldType) == 4 ||
sizeof(FieldType) == 8,
"");
std::abort(); // unreachable
}

// Functions referenced by generated fast tables (closed enum):
Expand Down Expand Up @@ -482,7 +488,10 @@ class PROTOBUF_EXPORT TcParser final {
#ifndef NDEBUG
if (PROTOBUF_PREDICT_FALSE(
reinterpret_cast<uintptr_t>(target) % alignof(T) != 0)) {
AlignFail<alignof(T)>(reinterpret_cast<uintptr_t>(target));
AlignFail(std::integral_constant<size_t, alignof(T)>(),
reinterpret_cast<uintptr_t>(target));
// Explicit abort to let compilers know this code-path does not return
abort();
}
#endif
return *target;
Expand All @@ -495,7 +504,10 @@ class PROTOBUF_EXPORT TcParser final {
#ifndef NDEBUG
if (PROTOBUF_PREDICT_FALSE(
reinterpret_cast<uintptr_t>(target) % alignof(T) != 0)) {
AlignFail<alignof(T)>(reinterpret_cast<uintptr_t>(target));
AlignFail(std::integral_constant<size_t, alignof(T)>(),
reinterpret_cast<uintptr_t>(target));
// Explicit abort to let compilers know this code-path does not return
abort();
}
#endif
return *target;
Expand Down Expand Up @@ -537,7 +549,7 @@ class PROTOBUF_EXPORT TcParser final {
};
static TestMiniParseResult TestMiniParse(PROTOBUF_TC_PARAM_DECL);
template <bool export_called_function>
static const char* MiniParseImpl(PROTOBUF_TC_PARAM_DECL);
static const char* MiniParse(PROTOBUF_TC_PARAM_DECL);

template <typename TagType, bool group_coding, bool aux_is_table>
static inline const char* SingularParseMessageAuxImpl(PROTOBUF_TC_PARAM_DECL);
Expand Down Expand Up @@ -714,12 +726,127 @@ class PROTOBUF_EXPORT TcParser final {
static const char* MpFallback(PROTOBUF_TC_PARAM_DECL);
};

// Shift "byte" left by n * 7 bits, filling vacated bits with ones.
template <int n>
inline PROTOBUF_ALWAYS_INLINE uint64_t
shift_left_fill_with_ones(uint64_t byte, uint64_t ones) {
return (byte << (n * 7)) | (ones >> (64 - (n * 7)));
}

// Shift "byte" left by n * 7 bits, filling vacated bits with ones, and
// put the new value in res. Return whether the result was negative.
template <int n>
inline PROTOBUF_ALWAYS_INLINE bool shift_left_fill_with_ones_was_negative(
uint64_t byte, uint64_t ones, int64_t& res) {
#if defined(__GCC_ASM_FLAG_OUTPUTS__) && defined(__x86_64__)
// For the first two rounds (ptr[1] and ptr[2]), micro benchmarks show a
// substantial improvement from capturing the sign from the condition code
// register on x86-64.
bool sign_bit;
asm("shldq %3, %2, %1"
: "=@ccs"(sign_bit), "+r"(byte)
: "r"(ones), "i"(n * 7));
res = byte;
return sign_bit;
#else
// Generic fallback:
res = shift_left_fill_with_ones<n>(byte, ones);
return static_cast<int64_t>(res) < 0;
#endif
}

inline PROTOBUF_ALWAYS_INLINE std::pair<const char*, uint64_t>
Parse64FallbackPair(const char* p, int64_t res1) {
auto ptr = reinterpret_cast<const int8_t*>(p);

// The algorithm relies on sign extension for each byte to set all high bits
// when the varint continues. It also relies on asserting all of the lower
// bits for each successive byte read. This allows the result to be aggregated
// using a bitwise AND. For example:
//
// 8 1 64 57 ... 24 17 16 9 8 1
// ptr[0] = 1aaa aaaa ; res1 = 1111 1111 ... 1111 1111 1111 1111 1aaa aaaa
// ptr[1] = 1bbb bbbb ; res2 = 1111 1111 ... 1111 1111 11bb bbbb b111 1111
// ptr[2] = 1ccc cccc ; res3 = 0000 0000 ... 000c cccc cc11 1111 1111 1111
// ---------------------------------------------
// res1 & res2 & res3 = 0000 0000 ... 000c cccc ccbb bbbb baaa aaaa
//
// On x86-64, a shld from a single register filled with enough 1s in the high
// bits can accomplish all this in one instruction. It so happens that res1
// has 57 high bits of ones, which is enough for the largest shift done.
//
// Just as importantly, by keeping results in res1, res2, and res3, we take
// advantage of the superscalar abilities of the CPU.
GOOGLE_ABSL_DCHECK_EQ(res1 >> 7, -1);
uint64_t ones = res1; // save the high 1 bits from res1 (input to SHLD)
int64_t res2, res3; // accumulated result chunks

if (!shift_left_fill_with_ones_was_negative<1>(ptr[1], ones, res2))
goto done2;
if (!shift_left_fill_with_ones_was_negative<2>(ptr[2], ones, res3))
goto done3;

// For the remainder of the chunks, check the sign of the AND result.
res1 &= shift_left_fill_with_ones<3>(ptr[3], ones);
if (res1 >= 0) goto done4;
res2 &= shift_left_fill_with_ones<4>(ptr[4], ones);
if (res2 >= 0) goto done5;
res3 &= shift_left_fill_with_ones<5>(ptr[5], ones);
if (res3 >= 0) goto done6;
res1 &= shift_left_fill_with_ones<6>(ptr[6], ones);
if (res1 >= 0) goto done7;
res2 &= shift_left_fill_with_ones<7>(ptr[7], ones);
if (res2 >= 0) goto done8;
res3 &= shift_left_fill_with_ones<8>(ptr[8], ones);
if (res3 >= 0) goto done9;

// For valid 64bit varints, the 10th byte/ptr[9] should be exactly 1. In this
// case, the continuation bit of ptr[8] already set the top bit of res3
// correctly, so all we have to do is check that the expected case is true.
if (PROTOBUF_PREDICT_TRUE(ptr[9] == 1)) goto done10;

// A value of 0, however, represents an over-serialized varint. This case
// should not happen, but if does (say, due to a nonconforming serializer),
// deassert the continuation bit that came from ptr[8].
if (ptr[9] == 0) {
#if defined(__GCC_ASM_FLAG_OUTPUTS__) && defined(__x86_64__)
// Use a small instruction since this is an uncommon code path.
asm("btcq $63,%0" : "+r"(res3));
#else
res3 ^= static_cast<uint64_t>(1) << 63;
#endif
goto done10;
}

// If the 10th byte/ptr[9] itself has any other value, then it is too big to
// fit in 64 bits. If the continue bit is set, it is an unterminated varint.
return {nullptr, 0};

done2:
return {p + 2, res1 & res2};
done3:
return {p + 3, res1 & res2 & res3};
done4:
return {p + 4, res1 & res2 & res3};
done5:
return {p + 5, res1 & res2 & res3};
done6:
return {p + 6, res1 & res2 & res3};
done7:
return {p + 7, res1 & res2 & res3};
done8:
return {p + 8, res1 & res2 & res3};
done9:
return {p + 9, res1 & res2 & res3};
done10:
return {p + 10, res1 & res2 & res3};
}

// Notes:
// 1) if data_offset is negative, it's read from data.offset()
// 2) if hasbit_idx is negative, it's read from data.hasbit_idx()
template <int data_offset, int hasbit_idx>
PROTOBUF_NOINLINE const char* TcParser::SpecializedFastV8S1(
PROTOBUF_TC_PARAM_DECL) {
PROTOBUF_NOINLINE const char* TcParser::FastTV8S1(PROTOBUF_TC_PARAM_DECL) {
using TagType = uint8_t;

// Special case for a varint bool field with a tag of 1 byte:
Expand Down Expand Up @@ -766,8 +893,40 @@ PROTOBUF_NOINLINE const char* TcParser::SpecializedFastV8S1(
}

template <typename FieldType, int data_offset, int hasbit_idx>
PROTOBUF_NOINLINE const char* TcParser::SpecializedUnrolledVImpl1(
PROTOBUF_TC_PARAM_DECL) {
PROTOBUF_NOINLINE const char* TcParser::FastTV64S1(PROTOBUF_TC_PARAM_DECL) {
using TagType = uint8_t;
// super-early success test...
if (PROTOBUF_PREDICT_TRUE(((data.data) & 0x80FF) == 0)) {
ptr += sizeof(TagType); // Consume tag
if (hasbit_idx < 32) {
hasbits |= (uint64_t{1} << hasbit_idx);
}
uint8_t value = data.data >> 8;
RefAt<FieldType>(msg, data_offset) = value;
ptr += 1;
PROTOBUF_MUSTTAIL return ToTagDispatch(PROTOBUF_TC_PARAM_PASS);
}
if (PROTOBUF_PREDICT_FALSE(data.coded_tag<TagType>() != 0)) {
PROTOBUF_MUSTTAIL return MiniParse(PROTOBUF_TC_PARAM_PASS);
}
ptr += sizeof(TagType); // Consume tag
if (hasbit_idx < 32) {
hasbits |= (uint64_t{1} << hasbit_idx);
}

auto tmp = Parse64FallbackPair(ptr, static_cast<int8_t>(data.data >> 8));
data.data = 0; // Indicate to the compiler that we don't need this anymore.
ptr = tmp.first;
if (PROTOBUF_PREDICT_FALSE(ptr == nullptr)) {
return Error(PROTOBUF_TC_PARAM_PASS);
}

RefAt<FieldType>(msg, data_offset) = static_cast<FieldType>(tmp.second);
PROTOBUF_MUSTTAIL return ToTagDispatch(PROTOBUF_TC_PARAM_PASS);
}

template <typename FieldType, int data_offset, int hasbit_idx>
PROTOBUF_NOINLINE const char* TcParser::FastTV32S1(PROTOBUF_TC_PARAM_DECL) {
using TagType = uint8_t;
// super-early success test...
if (PROTOBUF_PREDICT_TRUE(((data.data) & 0x80FF) == 0)) {
Expand Down Expand Up @@ -800,34 +959,30 @@ PROTOBUF_NOINLINE const char* TcParser::SpecializedUnrolledVImpl1(
if (PROTOBUF_PREDICT_FALSE(res & 0x80)) {
res = RotRight7AndReplaceLowByte(res, ptr[4]);
if (PROTOBUF_PREDICT_FALSE(res & 0x80)) {
res = RotRight7AndReplaceLowByte(res, ptr[5]);
if (PROTOBUF_PREDICT_FALSE(res & 0x80)) {
res = RotRight7AndReplaceLowByte(res, ptr[6]);
if (PROTOBUF_PREDICT_FALSE(res & 0x80)) {
res = RotRight7AndReplaceLowByte(res, ptr[7]);
if (PROTOBUF_PREDICT_FALSE(res & 0x80)) {
res = RotRight7AndReplaceLowByte(res, ptr[8]);
if (PROTOBUF_PREDICT_FALSE(res & 0x80)) {
if (PROTOBUF_PREDICT_FALSE(ptr[5] & 0x80)) {
if (PROTOBUF_PREDICT_FALSE(ptr[6] & 0x80)) {
if (PROTOBUF_PREDICT_FALSE(ptr[7] & 0x80)) {
if (PROTOBUF_PREDICT_FALSE(ptr[8] & 0x80)) {
if (ptr[9] & 0xFE) return Error(PROTOBUF_TC_PARAM_PASS);
res = RotateLeft(res, -7) & ~1;
res += ptr[9] & 1;
*out = RotateLeft(res, 63);
*out = RotateLeft(res, 28);
ptr += 10;
PROTOBUF_MUSTTAIL return ToTagDispatch(PROTOBUF_TC_PARAM_PASS);
PROTOBUF_MUSTTAIL return ToTagDispatch(
PROTOBUF_TC_PARAM_PASS);
}
*out = RotateLeft(res, 56);
*out = RotateLeft(res, 28);
ptr += 9;
PROTOBUF_MUSTTAIL return ToTagDispatch(PROTOBUF_TC_PARAM_PASS);
PROTOBUF_MUSTTAIL return ToTagDispatch(
PROTOBUF_TC_PARAM_PASS);
}
*out = RotateLeft(res, 49);
*out = RotateLeft(res, 28);
ptr += 8;
PROTOBUF_MUSTTAIL return ToTagDispatch(PROTOBUF_TC_PARAM_PASS);
}
*out = RotateLeft(res, 42);
*out = RotateLeft(res, 28);
ptr += 7;
PROTOBUF_MUSTTAIL return ToTagDispatch(PROTOBUF_TC_PARAM_PASS);
}
*out = RotateLeft(res, 35);
*out = RotateLeft(res, 28);
ptr += 6;
PROTOBUF_MUSTTAIL return ToTagDispatch(PROTOBUF_TC_PARAM_PASS);
}
Expand Down
Loading

0 comments on commit 0ca97a1

Please sign in to comment.