1 /* Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
2 * Use of this file is governed by the BSD 3-clause license that
3 * can be found in the LICENSE.txt file in the project root.
6 #include "atn/EmptyPredictionContext.h"
7 #include "misc/MurmurHash.h"
8 #include "atn/ArrayPredictionContext.h"
9 #include "RuleContext.h"
10 #include "ParserRuleContext.h"
11 #include "atn/RuleTransition.h"
12 #include "support/Arrays.h"
13 #include "support/CPPUtils.h"
15 #include "atn/PredictionContext.h"
17 using namespace antlr4;
18 using namespace antlr4::misc;
19 using namespace antlr4::atn;
21 using namespace antlrcpp;
23 size_t PredictionContext::globalNodeCount = 0;
24 const Ref<PredictionContext> PredictionContext::EMPTY = std::make_shared<EmptyPredictionContext>();
26 //----------------- PredictionContext ----------------------------------------------------------------------------------
28 PredictionContext::PredictionContext(size_t cachedHashCode) : id(globalNodeCount++), cachedHashCode(cachedHashCode) {
31 PredictionContext::~PredictionContext() {
34 Ref<PredictionContext> PredictionContext::fromRuleContext(const ATN &atn, RuleContext *outerContext) {
35 if (outerContext == nullptr) {
36 return PredictionContext::EMPTY;
39 // if we are in RuleContext of start rule, s, then PredictionContext
40 // is EMPTY. Nobody called us. (if we are empty, return empty)
41 if (outerContext->parent == nullptr || outerContext == &ParserRuleContext::EMPTY) {
42 return PredictionContext::EMPTY;
45 // If we have a parent, convert it to a PredictionContext graph
46 Ref<PredictionContext> parent = PredictionContext::fromRuleContext(atn, dynamic_cast<RuleContext *>(outerContext->parent));
48 ATNState *state = atn.states.at(outerContext->invokingState);
49 RuleTransition *transition = (RuleTransition *)state->transitions[0];
50 return SingletonPredictionContext::create(parent, transition->followState->stateNumber);
53 bool PredictionContext::isEmpty() const {
54 return this == EMPTY.get();
57 bool PredictionContext::hasEmptyPath() const {
58 // since EMPTY_RETURN_STATE can only appear in the last position, we check last one
59 return getReturnState(size() - 1) == EMPTY_RETURN_STATE;
62 size_t PredictionContext::hashCode() const {
63 return cachedHashCode;
66 size_t PredictionContext::calculateEmptyHashCode() {
67 size_t hash = MurmurHash::initialize(INITIAL_HASH);
68 hash = MurmurHash::finish(hash, 0);
72 size_t PredictionContext::calculateHashCode(Ref<PredictionContext> parent, size_t returnState) {
73 size_t hash = MurmurHash::initialize(INITIAL_HASH);
74 hash = MurmurHash::update(hash, parent);
75 hash = MurmurHash::update(hash, returnState);
76 hash = MurmurHash::finish(hash, 2);
80 size_t PredictionContext::calculateHashCode(const std::vector<Ref<PredictionContext>> &parents,
81 const std::vector<size_t> &returnStates) {
82 size_t hash = MurmurHash::initialize(INITIAL_HASH);
84 for (auto parent : parents) {
85 hash = MurmurHash::update(hash, parent);
88 for (auto returnState : returnStates) {
89 hash = MurmurHash::update(hash, returnState);
92 return MurmurHash::finish(hash, parents.size() + returnStates.size());
95 Ref<PredictionContext> PredictionContext::merge(const Ref<PredictionContext> &a,
96 const Ref<PredictionContext> &b, bool rootIsWildcard, PredictionContextMergeCache *mergeCache) {
99 // share same graph if both same
100 if (a == b || *a == *b) {
104 if (is<SingletonPredictionContext>(a) && is<SingletonPredictionContext>(b)) {
105 return mergeSingletons(std::dynamic_pointer_cast<SingletonPredictionContext>(a),
106 std::dynamic_pointer_cast<SingletonPredictionContext>(b), rootIsWildcard, mergeCache);
109 // At least one of a or b is array.
110 // If one is $ and rootIsWildcard, return $ as * wildcard.
111 if (rootIsWildcard) {
112 if (is<EmptyPredictionContext>(a)) {
115 if (is<EmptyPredictionContext>(b)) {
120 // convert singleton so both are arrays to normalize
121 Ref<ArrayPredictionContext> left;
122 if (is<SingletonPredictionContext>(a)) {
123 left = std::make_shared<ArrayPredictionContext>(std::dynamic_pointer_cast<SingletonPredictionContext>(a));
125 left = std::dynamic_pointer_cast<ArrayPredictionContext>(a);
127 Ref<ArrayPredictionContext> right;
128 if (is<SingletonPredictionContext>(b)) {
129 right = std::make_shared<ArrayPredictionContext>(std::dynamic_pointer_cast<SingletonPredictionContext>(b));
131 right = std::dynamic_pointer_cast<ArrayPredictionContext>(b);
133 return mergeArrays(left, right, rootIsWildcard, mergeCache);
136 Ref<PredictionContext> PredictionContext::mergeSingletons(const Ref<SingletonPredictionContext> &a,
137 const Ref<SingletonPredictionContext> &b, bool rootIsWildcard, PredictionContextMergeCache *mergeCache) {
139 if (mergeCache != nullptr) { // Can be null if not given to the ATNState from which this call originates.
140 auto existing = mergeCache->get(a, b);
144 existing = mergeCache->get(b, a);
150 Ref<PredictionContext> rootMerge = mergeRoot(a, b, rootIsWildcard);
152 if (mergeCache != nullptr) {
153 mergeCache->put(a, b, rootMerge);
158 Ref<PredictionContext> parentA = a->parent;
159 Ref<PredictionContext> parentB = b->parent;
160 if (a->returnState == b->returnState) { // a == b
161 Ref<PredictionContext> parent = merge(parentA, parentB, rootIsWildcard, mergeCache);
163 // If parent is same as existing a or b parent or reduced to a parent, return it.
164 if (parent == parentA) { // ax + bx = ax, if a=b
167 if (parent == parentB) { // ax + bx = bx, if a=b
171 // else: ax + ay = a'[x,y]
172 // merge parents x and y, giving array node with x,y then remainders
173 // of those graphs. dup a, a' points at merged array
174 // new joined parent so create new singleton pointing to it, a'
175 Ref<PredictionContext> a_ = SingletonPredictionContext::create(parent, a->returnState);
176 if (mergeCache != nullptr) {
177 mergeCache->put(a, b, a_);
181 // a != b payloads differ
182 // see if we can collapse parents due to $+x parents if local ctx
183 Ref<PredictionContext> singleParent;
184 if (a == b || (*parentA == *parentB)) { // ax + bx = [a,b]x
185 singleParent = parentA;
187 if (singleParent) { // parents are same, sort payloads and use same parent
188 std::vector<size_t> payloads = { a->returnState, b->returnState };
189 if (a->returnState > b->returnState) {
190 payloads[0] = b->returnState;
191 payloads[1] = a->returnState;
193 std::vector<Ref<PredictionContext>> parents = { singleParent, singleParent };
194 Ref<PredictionContext> a_ = std::make_shared<ArrayPredictionContext>(parents, payloads);
195 if (mergeCache != nullptr) {
196 mergeCache->put(a, b, a_);
201 // parents differ and can't merge them. Just pack together
202 // into array; can't merge.
204 Ref<PredictionContext> a_;
205 if (a->returnState > b->returnState) { // sort by payload
206 std::vector<size_t> payloads = { b->returnState, a->returnState };
207 std::vector<Ref<PredictionContext>> parents = { b->parent, a->parent };
208 a_ = std::make_shared<ArrayPredictionContext>(parents, payloads);
210 std::vector<size_t> payloads = {a->returnState, b->returnState};
211 std::vector<Ref<PredictionContext>> parents = { a->parent, b->parent };
212 a_ = std::make_shared<ArrayPredictionContext>(parents, payloads);
215 if (mergeCache != nullptr) {
216 mergeCache->put(a, b, a_);
222 Ref<PredictionContext> PredictionContext::mergeRoot(const Ref<SingletonPredictionContext> &a,
223 const Ref<SingletonPredictionContext> &b, bool rootIsWildcard) {
224 if (rootIsWildcard) {
225 if (a == EMPTY) { // * + b = *
228 if (b == EMPTY) { // a + * = *
232 if (a == EMPTY && b == EMPTY) { // $ + $ = $
235 if (a == EMPTY) { // $ + x = [$,x]
236 std::vector<size_t> payloads = { b->returnState, EMPTY_RETURN_STATE };
237 std::vector<Ref<PredictionContext>> parents = { b->parent, nullptr };
238 Ref<PredictionContext> joined = std::make_shared<ArrayPredictionContext>(parents, payloads);
241 if (b == EMPTY) { // x + $ = [$,x] ($ is always first if present)
242 std::vector<size_t> payloads = { a->returnState, EMPTY_RETURN_STATE };
243 std::vector<Ref<PredictionContext>> parents = { a->parent, nullptr };
244 Ref<PredictionContext> joined = std::make_shared<ArrayPredictionContext>(parents, payloads);
251 Ref<PredictionContext> PredictionContext::mergeArrays(const Ref<ArrayPredictionContext> &a,
252 const Ref<ArrayPredictionContext> &b, bool rootIsWildcard, PredictionContextMergeCache *mergeCache) {
254 if (mergeCache != nullptr) {
255 auto existing = mergeCache->get(a, b);
259 existing = mergeCache->get(b, a);
265 // merge sorted payloads a + b => M
266 size_t i = 0; // walks a
267 size_t j = 0; // walks b
268 size_t k = 0; // walks target M array
270 std::vector<size_t> mergedReturnStates(a->returnStates.size() + b->returnStates.size());
271 std::vector<Ref<PredictionContext>> mergedParents(a->returnStates.size() + b->returnStates.size());
273 // walk and merge to yield mergedParents, mergedReturnStates
274 while (i < a->returnStates.size() && j < b->returnStates.size()) {
275 Ref<PredictionContext> a_parent = a->parents[i];
276 Ref<PredictionContext> b_parent = b->parents[j];
277 if (a->returnStates[i] == b->returnStates[j]) {
278 // same payload (stack tops are equal), must yield merged singleton
279 size_t payload = a->returnStates[i];
281 bool both$ = payload == EMPTY_RETURN_STATE && !a_parent && !b_parent;
282 bool ax_ax = (a_parent && b_parent) && *a_parent == *b_parent; // ax+ax -> ax
283 if (both$ || ax_ax) {
284 mergedParents[k] = a_parent; // choose left
285 mergedReturnStates[k] = payload;
287 else { // ax+ay -> a'[x,y]
288 Ref<PredictionContext> mergedParent = merge(a_parent, b_parent, rootIsWildcard, mergeCache);
289 mergedParents[k] = mergedParent;
290 mergedReturnStates[k] = payload;
292 i++; // hop over left one as usual
293 j++; // but also skip one in right side since we merge
294 } else if (a->returnStates[i] < b->returnStates[j]) { // copy a[i] to M
295 mergedParents[k] = a_parent;
296 mergedReturnStates[k] = a->returnStates[i];
299 else { // b > a, copy b[j] to M
300 mergedParents[k] = b_parent;
301 mergedReturnStates[k] = b->returnStates[j];
307 // copy over any payloads remaining in either array
308 if (i < a->returnStates.size()) {
309 for (std::vector<int>::size_type p = i; p < a->returnStates.size(); p++) {
310 mergedParents[k] = a->parents[p];
311 mergedReturnStates[k] = a->returnStates[p];
315 for (std::vector<int>::size_type p = j; p < b->returnStates.size(); p++) {
316 mergedParents[k] = b->parents[p];
317 mergedReturnStates[k] = b->returnStates[p];
322 // trim merged if we combined a few that had same stack tops
323 if (k < mergedParents.size()) { // write index < last position; trim
324 if (k == 1) { // for just one merged element, return singleton top
325 Ref<PredictionContext> a_ = SingletonPredictionContext::create(mergedParents[0], mergedReturnStates[0]);
326 if (mergeCache != nullptr) {
327 mergeCache->put(a, b, a_);
331 mergedParents.resize(k);
332 mergedReturnStates.resize(k);
335 Ref<ArrayPredictionContext> M = std::make_shared<ArrayPredictionContext>(mergedParents, mergedReturnStates);
337 // if we created same array as a or b, return that instead
338 // TODO: track whether this is possible above during merge sort for speed
340 if (mergeCache != nullptr) {
341 mergeCache->put(a, b, a);
346 if (mergeCache != nullptr) {
347 mergeCache->put(a, b, b);
352 // ml: this part differs from Java code. We have to recreate the context as the parents array is copied on creation.
353 if (combineCommonParents(mergedParents)) {
354 mergedReturnStates.resize(mergedParents.size());
355 M = std::make_shared<ArrayPredictionContext>(mergedParents, mergedReturnStates);
358 if (mergeCache != nullptr) {
359 mergeCache->put(a, b, M);
364 bool PredictionContext::combineCommonParents(std::vector<Ref<PredictionContext>> &parents) {
366 std::set<Ref<PredictionContext>> uniqueParents;
367 for (size_t p = 0; p < parents.size(); ++p) {
368 Ref<PredictionContext> parent = parents[p];
369 if (uniqueParents.find(parent) == uniqueParents.end()) { // don't replace
370 uniqueParents.insert(parent);
374 for (size_t p = 0; p < parents.size(); ++p) {
375 parents[p] = *uniqueParents.find(parents[p]);
381 std::string PredictionContext::toDOTString(const Ref<PredictionContext> &context) {
382 if (context == nullptr) {
386 std::stringstream ss;
387 ss << "digraph G {\n" << "rankdir=LR;\n";
389 std::vector<Ref<PredictionContext>> nodes = getAllContextNodes(context);
390 std::sort(nodes.begin(), nodes.end(), [](const Ref<PredictionContext> &o1, const Ref<PredictionContext> &o2) {
391 return o1->id - o2->id;
394 for (auto current : nodes) {
395 if (is<SingletonPredictionContext>(current)) {
396 std::string s = std::to_string(current->id);
398 std::string returnState = std::to_string(current->getReturnState(0));
399 if (is<EmptyPredictionContext>(current)) {
402 ss << " [label=\"" << returnState << "\"];\n";
405 Ref<ArrayPredictionContext> arr = std::static_pointer_cast<ArrayPredictionContext>(current);
406 ss << " s" << arr->id << " [shape=box, label=\"" << "[";
408 for (auto inv : arr->returnStates) {
412 if (inv == EMPTY_RETURN_STATE) {
423 for (auto current : nodes) {
424 if (current == EMPTY) {
427 for (size_t i = 0; i < current->size(); i++) {
428 if (!current->getParent(i)) {
431 ss << " s" << current->id << "->" << "s" << current->getParent(i)->id;
432 if (current->size() > 1) {
433 ss << " [label=\"parent[" << i << "]\"];\n";
444 // The "visited" map is just a temporary structure to control the retrieval process (which is recursive).
445 Ref<PredictionContext> PredictionContext::getCachedContext(const Ref<PredictionContext> &context,
446 PredictionContextCache &contextCache, std::map<Ref<PredictionContext>, Ref<PredictionContext>> &visited) {
447 if (context->isEmpty()) {
452 auto iterator = visited.find(context);
453 if (iterator != visited.end())
454 return iterator->second; // Not necessarly the same as context.
457 auto iterator = contextCache.find(context);
458 if (iterator != contextCache.end()) {
459 visited[context] = *iterator;
464 bool changed = false;
466 std::vector<Ref<PredictionContext>> parents(context->size());
467 for (size_t i = 0; i < parents.size(); i++) {
468 Ref<PredictionContext> parent = getCachedContext(context->getParent(i), contextCache, visited);
469 if (changed || parent != context->getParent(i)) {
472 for (size_t j = 0; j < context->size(); j++) {
473 parents.push_back(context->getParent(j));
484 contextCache.insert(context);
485 visited[context] = context;
490 Ref<PredictionContext> updated;
491 if (parents.empty()) {
493 } else if (parents.size() == 1) {
494 updated = SingletonPredictionContext::create(parents[0], context->getReturnState(0));
495 contextCache.insert(updated);
497 updated = std::make_shared<ArrayPredictionContext>(parents, std::dynamic_pointer_cast<ArrayPredictionContext>(context)->returnStates);
498 contextCache.insert(updated);
501 visited[updated] = updated;
502 visited[context] = updated;
507 std::vector<Ref<PredictionContext>> PredictionContext::getAllContextNodes(const Ref<PredictionContext> &context) {
508 std::vector<Ref<PredictionContext>> nodes;
509 std::set<PredictionContext *> visited;
510 getAllContextNodes_(context, nodes, visited);
515 void PredictionContext::getAllContextNodes_(const Ref<PredictionContext> &context, std::vector<Ref<PredictionContext>> &nodes,
516 std::set<PredictionContext *> &visited) {
518 if (visited.find(context.get()) != visited.end()) {
519 return; // Already done.
522 visited.insert(context.get());
523 nodes.push_back(context);
525 for (size_t i = 0; i < context->size(); i++) {
526 getAllContextNodes_(context->getParent(i), nodes, visited);
530 std::string PredictionContext::toString() const {
532 return antlrcpp::toString(this);
535 std::string PredictionContext::toString(Recognizer * /*recog*/) const {
539 std::vector<std::string> PredictionContext::toStrings(Recognizer *recognizer, int currentState) {
540 return toStrings(recognizer, EMPTY, currentState);
543 std::vector<std::string> PredictionContext::toStrings(Recognizer *recognizer, const Ref<PredictionContext> &stop, int currentState) {
545 std::vector<std::string> result;
547 for (size_t perm = 0; ; perm++) {
550 PredictionContext *p = this;
551 size_t stateNumber = currentState;
553 std::stringstream ss;
555 bool outerContinue = false;
556 while (!p->isEmpty() && p != stop.get()) {
560 while ((1ULL << bits) < p->size()) {
564 size_t mask = (1 << bits) - 1;
565 index = (perm >> offset) & mask;
566 last &= index >= p->size() - 1;
567 if (index >= p->size()) {
568 outerContinue = true;
574 if (recognizer != nullptr) {
575 if (ss.tellp() > 1) {
576 // first char is '[', if more than that this isn't the first rule
580 const ATN &atn = recognizer->getATN();
581 ATNState *s = atn.states[stateNumber];
582 std::string ruleName = recognizer->getRuleNames()[s->ruleIndex];
584 } else if (p->getReturnState(index) != EMPTY_RETURN_STATE) {
586 if (ss.tellp() > 1) {
587 // first char is '[', if more than that this isn't the first rule
591 ss << p->getReturnState(index);
594 stateNumber = p->getReturnState(index);
595 p = p->getParent(index).get();
602 result.push_back(ss.str());
612 //----------------- PredictionContextMergeCache ------------------------------------------------------------------------
614 Ref<PredictionContext> PredictionContextMergeCache::put(Ref<PredictionContext> const& key1, Ref<PredictionContext> const& key2,
615 Ref<PredictionContext> const& value) {
616 Ref<PredictionContext> previous;
618 auto iterator = _data.find(key1);
619 if (iterator == _data.end())
620 _data[key1][key2] = value;
622 auto iterator2 = iterator->second.find(key2);
623 if (iterator2 != iterator->second.end())
624 previous = iterator2->second;
625 iterator->second[key2] = value;
631 Ref<PredictionContext> PredictionContextMergeCache::get(Ref<PredictionContext> const& key1, Ref<PredictionContext> const& key2) {
632 auto iterator = _data.find(key1);
633 if (iterator == _data.end())
636 auto iterator2 = iterator->second.find(key2);
637 if (iterator2 == iterator->second.end())
640 return iterator2->second;
643 void PredictionContextMergeCache::clear() {
647 std::string PredictionContextMergeCache::toString() const {
649 for (auto pair : _data)
650 for (auto pair2 : pair.second)
651 result += pair2.second->toString() + "\n";
656 size_t PredictionContextMergeCache::count() const {
658 for (auto entry : _data)
659 result += entry.second.size();