Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[C++] Avoid copying statically generated serialized ATNs #3613

Merged
merged 1 commit into from
Mar 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions runtime/Cpp/runtime/src/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ struct BypassAltsAtnCache final {
/// bypass alternatives.
///
/// <seealso cref= ATNDeserializationOptions#isGenerateRuleBypassTransitions() </seealso>
std::map<std::vector<int32_t>, std::unique_ptr<const atn::ATN>> map;
std::map<std::vector<int32_t>, std::unique_ptr<const atn::ATN>, std::less<>> map;
};

BypassAltsAtnCache* getBypassAltsAtnCache() {
Expand Down Expand Up @@ -227,9 +227,8 @@ TokenFactory<CommonToken>* Parser::getTokenFactory() {
return _input->getTokenSource()->getTokenFactory();
}


const atn::ATN& Parser::getATNWithBypassAlts() {
const std::vector<int32_t> &serializedAtn = getSerializedATN();
auto serializedAtn = getSerializedATN();
if (serializedAtn.empty()) {
throw UnsupportedOperationException("The current parser does not support an ATN with bypass alternatives.");
}
Expand All @@ -244,15 +243,16 @@ const atn::ATN& Parser::getATNWithBypassAlts() {
}
}

std::unique_lock<std::shared_mutex> lock(cache->mutex);
auto existing = cache->map.find(serializedAtn);
if (existing != cache->map.end()) {
return *existing->second;
}
atn::ATNDeserializationOptions deserializationOptions;
deserializationOptions.setGenerateRuleBypassTransitions(true);
atn::ATNDeserializer deserializer(deserializationOptions);
auto atn = deserializer.deserialize(serializedAtn);

{
std::unique_lock<std::shared_mutex> lock(cache->mutex);
return *cache->map.insert(std::make_pair(serializedAtn, std::move(atn))).first->second;
}
return *cache->map.insert(std::make_pair(std::vector<int32_t>(serializedAtn.begin(), serializedAtn.end()), std::move(atn))).first->second;
}

tree::pattern::ParseTreePattern Parser::compileParseTreePattern(const std::string &pattern, int patternRuleIndex) {
Expand Down
3 changes: 2 additions & 1 deletion runtime/Cpp/runtime/src/Recognizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include "ProxyErrorListener.h"
#include "support/Casts.h"
#include "atn/SerializedATNView.h"

namespace antlr4 {

Expand Down Expand Up @@ -53,7 +54,7 @@ namespace antlr4 {
/// For interpreters, we don't know their serialized ATN despite having
/// created the interpreter from it.
/// </summary>
virtual const std::vector<int32_t>& getSerializedATN() const {
virtual atn::SerializedATNView getSerializedATN() const {
throw "there is no serialized ATN";
}

Expand Down
1 change: 1 addition & 0 deletions runtime/Cpp/runtime/src/antlr4-runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
#include "atn/RuleStopState.h"
#include "atn/RuleTransition.h"
#include "atn/SemanticContext.h"
#include "atn/SerializedATNView.h"
#include "atn/SetTransition.h"
#include "atn/SingletonPredictionContext.h"
#include "atn/StarBlockStartState.h"
Expand Down
6 changes: 3 additions & 3 deletions runtime/Cpp/runtime/src/atn/ATNDeserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,12 @@ namespace {
return s;
}

ssize_t readUnicodeInt32(const std::vector<int32_t>& data, int& p) {
ssize_t readUnicodeInt32(SerializedATNView data, int& p) {
return static_cast<ssize_t>(data[p++]);
}

void deserializeSets(
const std::vector<int32_t>& data,
SerializedATNView data,
int& p,
std::vector<misc::IntervalSet>& sets) {
size_t nsets = data[p++];
Expand Down Expand Up @@ -255,7 +255,7 @@ ATNDeserializer::ATNDeserializer() : ATNDeserializer(ATNDeserializationOptions::

ATNDeserializer::ATNDeserializer(ATNDeserializationOptions deserializationOptions) : _deserializationOptions(std::move(deserializationOptions)) {}

std::unique_ptr<ATN> ATNDeserializer::deserialize(const std::vector<int32_t>& data) const {
std::unique_ptr<ATN> ATNDeserializer::deserialize(SerializedATNView data) const {
int p = 0;
int version = data[p++];
if (version != SERIALIZED_VERSION) {
Expand Down
3 changes: 2 additions & 1 deletion runtime/Cpp/runtime/src/atn/ATNDeserializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#pragma once

#include "atn/ATNDeserializationOptions.h"
#include "atn/SerializedATNView.h"
#include "atn/LexerAction.h"
#include "atn/Transition.h"

Expand All @@ -20,7 +21,7 @@ namespace atn {

explicit ATNDeserializer(ATNDeserializationOptions deserializationOptions);

std::unique_ptr<ATN> deserialize(const std::vector<int32_t> &input) const;
std::unique_ptr<ATN> deserialize(SerializedATNView input) const;
void verifyATN(const ATN &atn) const;

private:
Expand Down
101 changes: 101 additions & 0 deletions runtime/Cpp/runtime/src/atn/SerializedATNView.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/* Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
* Use of this file is governed by the BSD 3-clause license that
* can be found in the LICENSE.txt file in the project root.
*/

#pragma once

#include <cstddef>
#include <cstdint>
#include <cstring>
#include <iterator>
#include <vector>

#include "antlr4-common.h"
#include "misc/MurmurHash.h"

namespace antlr4 {
namespace atn {

class ANTLR4CPP_PUBLIC SerializedATNView final {
public:
using value_type = int32_t;
using size_type = size_t;
using difference_type = ptrdiff_t;
using reference = int32_t&;
using const_reference = const int32_t&;
using pointer = int32_t*;
using const_pointer = const int32_t*;
using iterator = const_pointer;
using const_iterator = const_pointer;
using reverse_iterator = std::reverse_iterator<iterator>;
using const_reverse_iterator = std::reverse_iterator<const_iterator>;

SerializedATNView() = default;

SerializedATNView(const_pointer data, size_type size) : _data(data), _size(size) {}

SerializedATNView(const std::vector<int32_t> &serializedATN) : _data(serializedATN.data()), _size(serializedATN.size()) {}

SerializedATNView(const SerializedATNView&) = default;

SerializedATNView& operator=(const SerializedATNView&) = default;

const_iterator begin() const { return data(); }

const_iterator cbegin() const { return data(); }

const_iterator end() const { return data() + size(); }

const_iterator cend() const { return data() + size(); }

const_reverse_iterator rbegin() const { return const_reverse_iterator(end()); }

const_reverse_iterator crbegin() const { return const_reverse_iterator(cend()); }

const_reverse_iterator rend() const { return const_reverse_iterator(begin()); }

const_reverse_iterator crend() const { return const_reverse_iterator(cbegin()); }

bool empty() const { return size() == 0; }

const_pointer data() const { return _data; }

size_type size() const { return _size; }

size_type size_bytes() const { return size() * sizeof(value_type); }

const_reference operator[](size_type index) const { return _data[index]; }

private:
const_pointer _data = nullptr;
size_type _size = 0;
};

inline bool operator==(const SerializedATNView &lhs, const SerializedATNView &rhs) {
return (lhs.data() == rhs.data() && lhs.size() == rhs.size()) ||
(lhs.size() == rhs.size() && std::memcmp(lhs.data(), rhs.data(), lhs.size_bytes()) == 0);
}

inline bool operator!=(const SerializedATNView &lhs, const SerializedATNView &rhs) {
return !operator==(lhs, rhs);
}

inline bool operator<(const SerializedATNView &lhs, const SerializedATNView &rhs) {
int diff = std::memcmp(lhs.data(), rhs.data(), std::min(lhs.size_bytes(), rhs.size_bytes()));
return diff < 0 || (diff == 0 && lhs.size() < rhs.size());
}

} // namespace atn
} // namespace antlr4

namespace std {

template <>
struct hash<::antlr4::atn::SerializedATNView> {
size_t operator()(const ::antlr4::atn::SerializedATNView &serializedATNView) const {
return ::antlr4::misc::MurmurHash::hashCode(serializedATNView.data(), serializedATNView.size());
}
};

} // namespace std
18 changes: 18 additions & 0 deletions runtime/Cpp/runtime/src/misc/MurmurHash.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include <cstddef>
#include <cstdint>
#include <cstring>

#include "misc/MurmurHash.h"

Expand Down Expand Up @@ -62,6 +63,23 @@ size_t MurmurHash::update(size_t hash, size_t value) {
return hash;
}

size_t MurmurHash::update(size_t hash, const void *data, size_t size) {
size_t value;
const uint8_t *bytes = static_cast<const uint8_t*>(data);
while (size >= sizeof(size_t)) {
std::memcpy(&value, bytes, sizeof(size_t));
hash = update(hash, value);
bytes += sizeof(size_t);
size -= sizeof(size_t);
}
if (size != 0) {
value = 0;
std::memcpy(&value, bytes, size);
hash = update(hash, value);
}
return hash;
}

size_t MurmurHash::finish(size_t hash, size_t entryCount) {
hash ^= entryCount * 8;
hash ^= hash >> 33;
Expand Down
21 changes: 20 additions & 1 deletion runtime/Cpp/runtime/src/misc/MurmurHash.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#pragma once

#include <cstdint>
#include <type_traits>

#include "antlr4-common.h"

Expand Down Expand Up @@ -47,6 +48,13 @@ namespace misc {
return update(hash, value != nullptr ? value->hashCode() : 0);
}

static size_t update(size_t hash, const void *data, size_t size);

template <typename T>
static size_t update(size_t hash, const T *data, size_t size) {
return update(hash, static_cast<const void*>(data), size * sizeof(std::remove_reference_t<T>));
}

/// <summary>
/// Apply the final computation steps to the intermediate value {@code hash}
/// to form the final result of the MurmurHash 3 hash function.
Expand All @@ -63,14 +71,25 @@ namespace misc {
/// <param name="seed"> the seed for the MurmurHash algorithm </param>
/// <returns> the hash code of the data </returns>
template<typename T> // where T is C array type
static size_t hashCode(const std::vector<Ref<T>> &data, size_t seed) {
static size_t hashCode(const std::vector<Ref<T>> &data, size_t seed = DEFAULT_SEED) {
size_t hash = initialize(seed);
for (auto &entry : data) {
hash = update(hash, entry);
}
return finish(hash, data.size());
}

static size_t hashCode(const void *data, size_t size, size_t seed = DEFAULT_SEED) {
size_t hash = initialize(seed);
hash = update(hash, data, size);
return finish(hash, size);
}

template <typename T>
static size_t hashCode(const T *data, size_t size, size_t seed = DEFAULT_SEED) {
return hashCode(static_cast<const void*>(data), size * sizeof(std::remove_reference_t<T>), seed);
}

private:
MurmurHash() = delete;

Expand Down
13 changes: 4 additions & 9 deletions runtime/Cpp/runtime/src/tree/xpath/XPathLexer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ struct XPathLexerStaticData final {
const std::vector<std::string> literalNames;
const std::vector<std::string> symbolicNames;
const antlr4::dfa::Vocabulary vocabulary;
std::vector<int32_t> serializedATN;
antlr4::atn::SerializedATNView serializedATN;
std::unique_ptr<antlr4::atn::ATN> atn;
};

Expand Down Expand Up @@ -61,7 +61,7 @@ void xpathLexerInitialize() {
"STRING"
}
);
static const int32_t serializedATNSegment0[] = {
static const int32_t serializedATNSegment[] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These could even be static constexpr. Probably sizeof() below will be smart enough to already know this at compile-time even with const, but with constexpr it doesn't have an excuse not to know.

0x4, 0x0, 0x8, 0x32, 0x6, -1, 0x2, 0x0, 0x7, 0x0, 0x2, 0x1, 0x7,
0x1, 0x2, 0x2, 0x7, 0x2, 0x2, 0x3, 0x7, 0x3, 0x2, 0x4, 0x7, 0x4,
0x2, 0x5, 0x7, 0x5, 0x2, 0x6, 0x7, 0x6, 0x2, 0x7, 0x7, 0x7, 0x1,
Expand Down Expand Up @@ -102,12 +102,7 @@ void xpathLexerInitialize() {
0x1, 0x0, 0x0, 0x0, 0x4, 0x0, 0x1e, 0x25, 0x2d, 0x1, 0x1, 0x4, 0x0,
};

size_t serializedATNSize = 0;
serializedATNSize += sizeof(serializedATNSegment0) / sizeof(serializedATNSegment0[0]);
staticData->serializedATN.reserve(serializedATNSize);

staticData->serializedATN.insert(staticData->serializedATN.end(), serializedATNSegment0,
serializedATNSegment0 + sizeof(serializedATNSegment0) / sizeof(serializedATNSegment0[0]));
staticData->serializedATN = antlr4::atn::SerializedATNView(serializedATNSegment, sizeof(serializedATNSegment) / sizeof(serializedATNSegment[0]));

atn::ATNDeserializer deserializer;
staticData->atn = deserializer.deserialize(staticData->serializedATN);
Expand Down Expand Up @@ -151,7 +146,7 @@ const dfa::Vocabulary& XPathLexer::getVocabulary() const {
return xpathLexerStaticData->vocabulary;
}

const std::vector<int32_t>& XPathLexer::getSerializedATN() const {
antlr4::atn::SerializedATNView XPathLexer::getSerializedATN() const {
return xpathLexerStaticData->serializedATN;
}

Expand Down
2 changes: 1 addition & 1 deletion runtime/Cpp/runtime/src/tree/xpath/XPathLexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class XPathLexer : public antlr4::Lexer {

virtual const antlr4::dfa::Vocabulary& getVocabulary() const override;

virtual const std::vector<int32_t>& getSerializedATN() const override;
virtual antlr4::atn::SerializedATNView getSerializedATN() const override;

virtual const antlr4::atn::ATN& getATN() const override;

Expand Down
Loading