1 /* Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
2 * Use of this file is governed by the BSD 3-clause license that
3 * can be found in the LICENSE.txt file in the project root.
6 #include "misc/IntervalSet.h"
7 #include "atn/ATNType.h"
8 #include "atn/ATNState.h"
9 #include "atn/BlockEndState.h"
11 #include "atn/DecisionState.h"
12 #include "atn/RuleStartState.h"
13 #include "atn/LoopEndState.h"
14 #include "atn/BlockStartState.h"
15 #include "atn/Transition.h"
16 #include "atn/SetTransition.h"
18 #include "misc/Interval.h"
21 #include "atn/RuleTransition.h"
22 #include "atn/PrecedencePredicateTransition.h"
23 #include "atn/PredicateTransition.h"
24 #include "atn/RangeTransition.h"
25 #include "atn/AtomTransition.h"
26 #include "atn/ActionTransition.h"
27 #include "atn/ATNDeserializer.h"
29 #include "atn/TokensStartState.h"
30 #include "Exceptions.h"
31 #include "support/CPPUtils.h"
33 #include "atn/LexerChannelAction.h"
34 #include "atn/LexerCustomAction.h"
35 #include "atn/LexerModeAction.h"
36 #include "atn/LexerPushModeAction.h"
37 #include "atn/LexerTypeAction.h"
39 #include "Exceptions.h"
41 #include "atn/ATNSerializer.h"
43 using namespace antlrcpp;
44 using namespace antlr4::atn;
46 ATNSerializer::ATNSerializer(ATN *atn) { this->atn = atn; }
48 ATNSerializer::ATNSerializer(ATN *atn, const std::vector<std::string> &tokenNames) {
50 _tokenNames = tokenNames;
53 ATNSerializer::~ATNSerializer() { }
55 std::vector<size_t> ATNSerializer::serialize() {
56 std::vector<size_t> data;
57 data.push_back(ATNDeserializer::SERIALIZED_VERSION);
58 serializeUUID(data, ATNDeserializer::SERIALIZED_UUID());
60 // convert grammar type to ATN const to avoid dependence on ANTLRParser
61 data.push_back(static_cast<size_t>(atn->grammarType));
62 data.push_back(atn->maxTokenType);
65 std::unordered_map<misc::IntervalSet, int> setIndices;
66 std::vector<misc::IntervalSet> sets;
68 // dump states, count edges and collect sets while doing so
69 std::vector<size_t> nonGreedyStates;
70 std::vector<size_t> precedenceStates;
71 data.push_back(atn->states.size());
72 for (ATNState *s : atn->states) {
73 if (s == nullptr) { // might be optimized away
74 data.push_back(ATNState::ATN_INVALID_TYPE);
78 size_t stateType = s->getStateType();
79 if (is<DecisionState *>(s) && (static_cast<DecisionState *>(s))->nonGreedy) {
80 nonGreedyStates.push_back(s->stateNumber);
83 if (is<RuleStartState *>(s) && (static_cast<RuleStartState *>(s))->isLeftRecursiveRule) {
84 precedenceStates.push_back(s->stateNumber);
87 data.push_back(stateType);
89 if (s->ruleIndex == INVALID_INDEX) {
90 data.push_back(0xFFFF);
93 data.push_back(s->ruleIndex);
96 if (s->getStateType() == ATNState::LOOP_END) {
97 data.push_back((static_cast<LoopEndState *>(s))->loopBackState->stateNumber);
99 else if (is<BlockStartState *>(s)) {
100 data.push_back((static_cast<BlockStartState *>(s))->endState->stateNumber);
103 if (s->getStateType() != ATNState::RULE_STOP) {
104 // the deserializer can trivially derive these edges, so there's no need
106 nedges += s->transitions.size();
109 for (size_t i = 0; i < s->transitions.size(); i++) {
110 Transition *t = s->transitions[i];
111 Transition::SerializationType edgeType = t->getSerializationType();
112 if (edgeType == Transition::SET || edgeType == Transition::NOT_SET) {
113 SetTransition *st = static_cast<SetTransition *>(t);
114 if (setIndices.find(st->set) == setIndices.end()) {
115 sets.push_back(st->set);
116 setIndices.insert({ st->set, (int)sets.size() - 1 });
123 data.push_back(nonGreedyStates.size());
124 for (size_t i = 0; i < nonGreedyStates.size(); i++) {
125 data.push_back(nonGreedyStates.at(i));
129 data.push_back(precedenceStates.size());
130 for (size_t i = 0; i < precedenceStates.size(); i++) {
131 data.push_back(precedenceStates.at(i));
134 size_t nrules = atn->ruleToStartState.size();
135 data.push_back(nrules);
136 for (size_t r = 0; r < nrules; r++) {
137 ATNState *ruleStartState = atn->ruleToStartState[r];
138 data.push_back(ruleStartState->stateNumber);
139 if (atn->grammarType == ATNType::LEXER) {
140 if (atn->ruleToTokenType[r] == Token::EOF) {
141 data.push_back(0xFFFF);
144 data.push_back(atn->ruleToTokenType[r]);
149 size_t nmodes = atn->modeToStartState.size();
150 data.push_back(nmodes);
152 for (const auto &modeStartState : atn->modeToStartState) {
153 data.push_back(modeStartState->stateNumber);
157 size_t nsets = sets.size();
158 data.push_back(nsets);
159 for (auto set : sets) {
160 bool containsEof = set.contains(Token::EOF);
161 if (containsEof && set.getIntervals().at(0).b == -1) {
162 data.push_back(set.getIntervals().size() - 1);
165 data.push_back(set.getIntervals().size());
168 data.push_back(containsEof ? 1 : 0);
169 for (const auto &interval : set.getIntervals()) {
170 if (interval.a == -1) {
171 if (interval.b == -1) {
178 data.push_back(interval.a);
181 data.push_back(interval.b);
185 data.push_back(nedges);
186 for (ATNState *s : atn->states) {
188 // might be optimized away
192 if (s->getStateType() == ATNState::RULE_STOP) {
196 for (size_t i = 0; i < s->transitions.size(); i++) {
197 Transition *t = s->transitions[i];
199 if (atn->states[t->target->stateNumber] == nullptr) {
200 throw IllegalStateException("Cannot serialize a transition to a removed state.");
203 size_t src = s->stateNumber;
204 size_t trg = t->target->stateNumber;
205 Transition::SerializationType edgeType = t->getSerializationType();
210 case Transition::RULE:
211 trg = (static_cast<RuleTransition *>(t))->followState->stateNumber;
212 arg1 = (static_cast<RuleTransition *>(t))->target->stateNumber;
213 arg2 = (static_cast<RuleTransition *>(t))->ruleIndex;
214 arg3 = (static_cast<RuleTransition *>(t))->precedence;
216 case Transition::PRECEDENCE:
218 PrecedencePredicateTransition *ppt =
219 static_cast<PrecedencePredicateTransition *>(t);
220 arg1 = ppt->precedence;
223 case Transition::PREDICATE:
225 PredicateTransition *pt = static_cast<PredicateTransition *>(t);
226 arg1 = pt->ruleIndex;
227 arg2 = pt->predIndex;
228 arg3 = pt->isCtxDependent ? 1 : 0;
231 case Transition::RANGE:
232 arg1 = (static_cast<RangeTransition *>(t))->from;
233 arg2 = (static_cast<RangeTransition *>(t))->to;
234 if (arg1 == Token::EOF) {
240 case Transition::ATOM:
241 arg1 = (static_cast<AtomTransition *>(t))->_label;
242 if (arg1 == Token::EOF) {
248 case Transition::ACTION:
250 ActionTransition *at = static_cast<ActionTransition *>(t);
251 arg1 = at->ruleIndex;
252 arg2 = at->actionIndex;
253 if (arg2 == INVALID_INDEX) {
257 arg3 = at->isCtxDependent ? 1 : 0;
260 case Transition::SET:
261 arg1 = setIndices[(static_cast<SetTransition *>(t))->set];
264 case Transition::NOT_SET:
265 arg1 = setIndices[(static_cast<SetTransition *>(t))->set];
274 data.push_back(edgeType);
275 data.push_back(arg1);
276 data.push_back(arg2);
277 data.push_back(arg3);
281 size_t ndecisions = atn->decisionToState.size();
282 data.push_back(ndecisions);
283 for (DecisionState *decStartState : atn->decisionToState) {
284 data.push_back(decStartState->stateNumber);
288 if (atn->grammarType == ATNType::LEXER) {
289 data.push_back(atn->lexerActions.size());
290 for (Ref<LexerAction> &action : atn->lexerActions) {
291 data.push_back(static_cast<size_t>(action->getActionType()));
292 switch (action->getActionType()) {
293 case LexerActionType::CHANNEL:
295 int channel = std::dynamic_pointer_cast<LexerChannelAction>(action)->getChannel();
296 data.push_back(channel != -1 ? channel : 0xFFFF);
301 case LexerActionType::CUSTOM:
303 size_t ruleIndex = std::dynamic_pointer_cast<LexerCustomAction>(action)->getRuleIndex();
304 size_t actionIndex = std::dynamic_pointer_cast<LexerCustomAction>(action)->getActionIndex();
305 data.push_back(ruleIndex != INVALID_INDEX ? ruleIndex : 0xFFFF);
306 data.push_back(actionIndex != INVALID_INDEX ? actionIndex : 0xFFFF);
310 case LexerActionType::MODE:
312 int mode = std::dynamic_pointer_cast<LexerModeAction>(action)->getMode();
313 data.push_back(mode != -1 ? mode : 0xFFFF);
318 case LexerActionType::MORE:
323 case LexerActionType::POP_MODE:
328 case LexerActionType::PUSH_MODE:
330 int mode = std::dynamic_pointer_cast<LexerPushModeAction>(action)->getMode();
331 data.push_back(mode != -1 ? mode : 0xFFFF);
336 case LexerActionType::SKIP:
341 case LexerActionType::TYPE:
343 int type = std::dynamic_pointer_cast<LexerTypeAction>(action)->getType();
344 data.push_back(type != -1 ? type : 0xFFFF);
350 throw IllegalArgumentException("The specified lexer action type " +
351 std::to_string(static_cast<size_t>(action->getActionType())) +
357 // don't adjust the first value since that's the version number
358 for (size_t i = 1; i < data.size(); i++) {
359 if (data.at(i) > 0xFFFF) {
360 throw UnsupportedOperationException("Serialized ATN data element out of range.");
363 size_t value = (data.at(i) + 2) & 0xFFFF;
370 //------------------------------------------------------------------------------------------------------------
372 std::string ATNSerializer::decode(const std::wstring &inpdata) {
373 if (inpdata.size() < 10)
374 throw IllegalArgumentException("Not enough data to decode");
376 std::vector<uint16_t> data(inpdata.size());
377 data[0] = (uint16_t)inpdata[0];
379 // Don't adjust the first value since that's the version number.
380 for (size_t i = 1; i < inpdata.size(); ++i) {
381 data[i] = (uint16_t)inpdata[i] - 2;
386 size_t version = data[p++];
387 if (version != ATNDeserializer::SERIALIZED_VERSION) {
388 std::string reason = "Could not deserialize ATN with version " + std::to_string(version) + "(expected " +
389 std::to_string(ATNDeserializer::SERIALIZED_VERSION) + ").";
390 throw UnsupportedOperationException("ATN Serializer" + reason);
393 Guid uuid = ATNDeserializer::toUUID(data.data(), p);
395 if (uuid != ATNDeserializer::SERIALIZED_UUID()) {
396 std::string reason = "Could not deserialize ATN with UUID " + uuid.toString() + " (expected " +
397 ATNDeserializer::SERIALIZED_UUID().toString() + ").";
398 throw UnsupportedOperationException("ATN Serializer" + reason);
401 p++; // skip grammarType
402 size_t maxType = data[p++];
403 buf.append("max type ").append(std::to_string(maxType)).append("\n");
404 size_t nstates = data[p++];
405 for (size_t i = 0; i < nstates; i++) {
406 size_t stype = data[p++];
407 if (stype == ATNState::ATN_INVALID_TYPE) { // ignore bad type of states
410 size_t ruleIndex = data[p++];
411 if (ruleIndex == 0xFFFF) {
412 ruleIndex = INVALID_INDEX;
415 std::string arg = "";
416 if (stype == ATNState::LOOP_END) {
417 int loopBackStateNumber = data[p++];
418 arg = std::string(" ") + std::to_string(loopBackStateNumber);
420 else if (stype == ATNState::PLUS_BLOCK_START ||
421 stype == ATNState::STAR_BLOCK_START ||
422 stype == ATNState::BLOCK_START) {
423 int endStateNumber = data[p++];
424 arg = std::string(" ") + std::to_string(endStateNumber);
426 buf.append(std::to_string(i))
428 .append(ATNState::serializationNames[stype])
430 .append(std::to_string(ruleIndex))
434 size_t numNonGreedyStates = data[p++];
435 p += numNonGreedyStates; // Instead of that useless loop below.
437 for (int i = 0; i < numNonGreedyStates; i++) {
438 int stateNumber = data[p++];
442 size_t numPrecedenceStates = data[p++];
443 p += numPrecedenceStates;
445 for (int i = 0; i < numPrecedenceStates; i++) {
446 int stateNumber = data[p++];
450 size_t nrules = data[p++];
451 for (size_t i = 0; i < nrules; i++) {
452 size_t s = data[p++];
453 if (atn->grammarType == ATNType::LEXER) {
454 size_t arg1 = data[p++];
456 .append(std::to_string(i))
458 .append(std::to_string(s))
460 .append(std::to_string(arg1))
465 .append(std::to_string(i))
467 .append(std::to_string(s))
471 size_t nmodes = data[p++];
472 for (size_t i = 0; i < nmodes; i++) {
473 size_t s = data[p++];
475 .append(std::to_string(i))
477 .append(std::to_string(s))
480 size_t nsets = data[p++];
481 for (size_t i = 0; i < nsets; i++) {
482 size_t nintervals = data[p++];
483 buf.append(std::to_string(i)).append(":");
484 bool containsEof = data[p++] != 0;
486 buf.append(getTokenName(Token::EOF));
489 for (size_t j = 0; j < nintervals; j++) {
490 if (containsEof || j > 0) {
494 buf.append(getTokenName(data[p]))
496 .append(getTokenName(data[p + 1]));
501 size_t nedges = data[p++];
502 for (size_t i = 0; i < nedges; i++) {
503 size_t src = data[p];
504 size_t trg = data[p + 1];
505 size_t ttype = data[p + 2];
506 size_t arg1 = data[p + 3];
507 size_t arg2 = data[p + 4];
508 size_t arg3 = data[p + 5];
509 buf.append(std::to_string(src))
511 .append(std::to_string(trg))
513 .append(Transition::serializationNames[ttype])
515 .append(std::to_string(arg1))
517 .append(std::to_string(arg2))
519 .append(std::to_string(arg3))
523 size_t ndecisions = data[p++];
524 for (size_t i = 0; i < ndecisions; i++) {
525 size_t s = data[p++];
526 buf += std::to_string(i) + ":" + std::to_string(s) + "\n";
529 if (atn->grammarType == ATNType::LEXER) {
530 //int lexerActionCount = data[p++];
532 //p += lexerActionCount * 3; // Instead of useless loop below.
534 for (int i = 0; i < lexerActionCount; i++) {
535 LexerActionType actionType = (LexerActionType)data[p++];
536 int data1 = data[p++];
537 int data2 = data[p++];
545 std::string ATNSerializer::getTokenName(size_t t) {
546 if (t == Token::EOF) {
550 if (atn->grammarType == ATNType::LEXER && t <= 0x10FFFF) {
567 std::string s_hex = antlrcpp::toHexString((int)t);
568 if (s_hex >= "0" && s_hex <= "7F" && !iscntrl((int)t)) {
569 return "'" + std::to_string(t) + "'";
572 // turn on the bit above max "\u10FFFF" value so that we pad with zeros
573 // then only take last 6 digits
574 std::string hex = antlrcpp::toHexString((int)t | 0x1000000).substr(1, 6);
575 std::string unicodeStr = std::string("'\\u") + hex + std::string("'");
580 if (_tokenNames.size() > 0 && t < _tokenNames.size()) {
581 return _tokenNames[t];
584 return std::to_string(t);
587 std::wstring ATNSerializer::getSerializedAsString(ATN *atn) {
588 std::vector<size_t> data = getSerialized(atn);
590 for (size_t entry : data)
591 result.push_back((wchar_t)entry);
596 std::vector<size_t> ATNSerializer::getSerialized(ATN *atn) {
597 return ATNSerializer(atn).serialize();
600 std::string ATNSerializer::getDecoded(ATN *atn, std::vector<std::string> &tokenNames) {
601 std::wstring serialized = getSerializedAsString(atn);
602 return ATNSerializer(atn, tokenNames).decode(serialized);
605 void ATNSerializer::serializeUUID(std::vector<size_t> &data, Guid uuid) {
606 unsigned int twoBytes = 0;
607 bool firstByte = true;
608 for( std::vector<unsigned char>::const_reverse_iterator rit = uuid.rbegin(); rit != uuid.rend(); ++rit )
614 twoBytes |= (*rit << 8);
615 data.push_back(twoBytes);
620 throw IllegalArgumentException( "The UUID provided is not valid (odd number of bytes)." );