]> gitweb.ps.run Git - toc/blob - antlr4-cpp-runtime-4.9.2-source/runtime/src/atn/ATNSerializer.cpp
add antlr source code and ReadMe
[toc] / antlr4-cpp-runtime-4.9.2-source / runtime / src / atn / ATNSerializer.cpp
1 /* Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
2  * Use of this file is governed by the BSD 3-clause license that
3  * can be found in the LICENSE.txt file in the project root.
4  */
5
6 #include "misc/IntervalSet.h"
7 #include "atn/ATNType.h"
8 #include "atn/ATNState.h"
9 #include "atn/BlockEndState.h"
10
11 #include "atn/DecisionState.h"
12 #include "atn/RuleStartState.h"
13 #include "atn/LoopEndState.h"
14 #include "atn/BlockStartState.h"
15 #include "atn/Transition.h"
16 #include "atn/SetTransition.h"
17 #include "Token.h"
18 #include "misc/Interval.h"
19 #include "atn/ATN.h"
20
21 #include "atn/RuleTransition.h"
22 #include "atn/PrecedencePredicateTransition.h"
23 #include "atn/PredicateTransition.h"
24 #include "atn/RangeTransition.h"
25 #include "atn/AtomTransition.h"
26 #include "atn/ActionTransition.h"
27 #include "atn/ATNDeserializer.h"
28
29 #include "atn/TokensStartState.h"
30 #include "Exceptions.h"
31 #include "support/CPPUtils.h"
32
33 #include "atn/LexerChannelAction.h"
34 #include "atn/LexerCustomAction.h"
35 #include "atn/LexerModeAction.h"
36 #include "atn/LexerPushModeAction.h"
37 #include "atn/LexerTypeAction.h"
38
39 #include "Exceptions.h"
40
41 #include "atn/ATNSerializer.h"
42
43 using namespace antlrcpp;
44 using namespace antlr4::atn;
45
46 ATNSerializer::ATNSerializer(ATN *atn) { this->atn = atn; }
47
48 ATNSerializer::ATNSerializer(ATN *atn, const std::vector<std::string> &tokenNames) {
49   this->atn = atn;
50   _tokenNames = tokenNames;
51 }
52
53 ATNSerializer::~ATNSerializer() { }
54
55 std::vector<size_t> ATNSerializer::serialize() {
56   std::vector<size_t> data;
57   data.push_back(ATNDeserializer::SERIALIZED_VERSION);
58   serializeUUID(data, ATNDeserializer::SERIALIZED_UUID());
59
60   // convert grammar type to ATN const to avoid dependence on ANTLRParser
61   data.push_back(static_cast<size_t>(atn->grammarType));
62   data.push_back(atn->maxTokenType);
63   size_t nedges = 0;
64
65   std::unordered_map<misc::IntervalSet, int> setIndices;
66   std::vector<misc::IntervalSet> sets;
67
68   // dump states, count edges and collect sets while doing so
69   std::vector<size_t> nonGreedyStates;
70   std::vector<size_t> precedenceStates;
71   data.push_back(atn->states.size());
72   for (ATNState *s : atn->states) {
73     if (s == nullptr) {  // might be optimized away
74       data.push_back(ATNState::ATN_INVALID_TYPE);
75       continue;
76     }
77
78     size_t stateType = s->getStateType();
79     if (is<DecisionState *>(s) && (static_cast<DecisionState *>(s))->nonGreedy) {
80       nonGreedyStates.push_back(s->stateNumber);
81     }
82
83     if (is<RuleStartState *>(s) && (static_cast<RuleStartState *>(s))->isLeftRecursiveRule) {
84       precedenceStates.push_back(s->stateNumber);
85     }
86
87     data.push_back(stateType);
88
89     if (s->ruleIndex == INVALID_INDEX) {
90       data.push_back(0xFFFF);
91     }
92     else {
93       data.push_back(s->ruleIndex);
94     }
95
96     if (s->getStateType() == ATNState::LOOP_END) {
97       data.push_back((static_cast<LoopEndState *>(s))->loopBackState->stateNumber);
98     }
99     else if (is<BlockStartState *>(s)) {
100       data.push_back((static_cast<BlockStartState *>(s))->endState->stateNumber);
101     }
102
103     if (s->getStateType() != ATNState::RULE_STOP) {
104       // the deserializer can trivially derive these edges, so there's no need
105       // to serialize them
106       nedges += s->transitions.size();
107     }
108
109     for (size_t i = 0; i < s->transitions.size(); i++) {
110       Transition *t = s->transitions[i];
111       Transition::SerializationType edgeType = t->getSerializationType();
112       if (edgeType == Transition::SET || edgeType == Transition::NOT_SET) {
113         SetTransition *st = static_cast<SetTransition *>(t);
114         if (setIndices.find(st->set) == setIndices.end()) {
115           sets.push_back(st->set);
116           setIndices.insert({ st->set, (int)sets.size() - 1 });
117         }
118       }
119     }
120   }
121
122   // non-greedy states
123   data.push_back(nonGreedyStates.size());
124   for (size_t i = 0; i < nonGreedyStates.size(); i++) {
125     data.push_back(nonGreedyStates.at(i));
126   }
127
128   // precedence states
129   data.push_back(precedenceStates.size());
130   for (size_t i = 0; i < precedenceStates.size(); i++) {
131     data.push_back(precedenceStates.at(i));
132   }
133
134   size_t nrules = atn->ruleToStartState.size();
135   data.push_back(nrules);
136   for (size_t r = 0; r < nrules; r++) {
137     ATNState *ruleStartState = atn->ruleToStartState[r];
138     data.push_back(ruleStartState->stateNumber);
139     if (atn->grammarType == ATNType::LEXER) {
140       if (atn->ruleToTokenType[r] == Token::EOF) {
141         data.push_back(0xFFFF);
142       }
143       else {
144         data.push_back(atn->ruleToTokenType[r]);
145       }
146     }
147   }
148
149   size_t nmodes = atn->modeToStartState.size();
150   data.push_back(nmodes);
151   if (nmodes > 0) {
152     for (const auto &modeStartState : atn->modeToStartState) {
153       data.push_back(modeStartState->stateNumber);
154     }
155   }
156
157   size_t nsets = sets.size();
158   data.push_back(nsets);
159   for (auto set : sets) {
160     bool containsEof = set.contains(Token::EOF);
161     if (containsEof && set.getIntervals().at(0).b == -1) {
162       data.push_back(set.getIntervals().size() - 1);
163     }
164     else {
165       data.push_back(set.getIntervals().size());
166     }
167
168     data.push_back(containsEof ? 1 : 0);
169     for (const auto &interval : set.getIntervals()) {
170       if (interval.a == -1) {
171         if (interval.b == -1) {
172           continue;
173         } else {
174           data.push_back(0);
175         }
176       }
177       else {
178         data.push_back(interval.a);
179       }
180
181       data.push_back(interval.b);
182     }
183   }
184
185   data.push_back(nedges);
186   for (ATNState *s : atn->states) {
187     if (s == nullptr) {
188       // might be optimized away
189       continue;
190     }
191
192     if (s->getStateType() == ATNState::RULE_STOP) {
193       continue;
194     }
195
196     for (size_t i = 0; i < s->transitions.size(); i++) {
197       Transition *t = s->transitions[i];
198
199       if (atn->states[t->target->stateNumber] == nullptr) {
200         throw IllegalStateException("Cannot serialize a transition to a removed state.");
201       }
202
203       size_t src = s->stateNumber;
204       size_t trg = t->target->stateNumber;
205       Transition::SerializationType edgeType = t->getSerializationType();
206       size_t arg1 = 0;
207       size_t arg2 = 0;
208       size_t arg3 = 0;
209       switch (edgeType) {
210         case Transition::RULE:
211           trg = (static_cast<RuleTransition *>(t))->followState->stateNumber;
212           arg1 = (static_cast<RuleTransition *>(t))->target->stateNumber;
213           arg2 = (static_cast<RuleTransition *>(t))->ruleIndex;
214           arg3 = (static_cast<RuleTransition *>(t))->precedence;
215           break;
216         case Transition::PRECEDENCE:
217         {
218           PrecedencePredicateTransition *ppt =
219           static_cast<PrecedencePredicateTransition *>(t);
220           arg1 = ppt->precedence;
221         }
222           break;
223         case Transition::PREDICATE:
224         {
225           PredicateTransition *pt = static_cast<PredicateTransition *>(t);
226           arg1 = pt->ruleIndex;
227           arg2 = pt->predIndex;
228           arg3 = pt->isCtxDependent ? 1 : 0;
229         }
230           break;
231         case Transition::RANGE:
232           arg1 = (static_cast<RangeTransition *>(t))->from;
233           arg2 = (static_cast<RangeTransition *>(t))->to;
234           if (arg1 == Token::EOF) {
235             arg1 = 0;
236             arg3 = 1;
237           }
238
239           break;
240         case Transition::ATOM:
241           arg1 = (static_cast<AtomTransition *>(t))->_label;
242           if (arg1 == Token::EOF) {
243             arg1 = 0;
244             arg3 = 1;
245           }
246
247           break;
248         case Transition::ACTION:
249         {
250           ActionTransition *at = static_cast<ActionTransition *>(t);
251           arg1 = at->ruleIndex;
252           arg2 = at->actionIndex;
253           if (arg2 == INVALID_INDEX) {
254             arg2 = 0xFFFF;
255           }
256
257           arg3 = at->isCtxDependent ? 1 : 0;
258         }
259           break;
260         case Transition::SET:
261           arg1 = setIndices[(static_cast<SetTransition *>(t))->set];
262           break;
263
264         case Transition::NOT_SET:
265           arg1 = setIndices[(static_cast<SetTransition *>(t))->set];
266           break;
267
268         default:
269           break;
270       }
271
272       data.push_back(src);
273       data.push_back(trg);
274       data.push_back(edgeType);
275       data.push_back(arg1);
276       data.push_back(arg2);
277       data.push_back(arg3);
278     }
279   }
280
281   size_t ndecisions = atn->decisionToState.size();
282   data.push_back(ndecisions);
283   for (DecisionState *decStartState : atn->decisionToState) {
284     data.push_back(decStartState->stateNumber);
285   }
286
287   // LEXER ACTIONS
288   if (atn->grammarType == ATNType::LEXER) {
289     data.push_back(atn->lexerActions.size());
290     for (Ref<LexerAction> &action : atn->lexerActions) {
291       data.push_back(static_cast<size_t>(action->getActionType()));
292       switch (action->getActionType()) {
293         case LexerActionType::CHANNEL:
294         {
295           int channel = std::dynamic_pointer_cast<LexerChannelAction>(action)->getChannel();
296           data.push_back(channel != -1 ? channel : 0xFFFF);
297           data.push_back(0);
298           break;
299         }
300
301         case LexerActionType::CUSTOM:
302         {
303           size_t ruleIndex = std::dynamic_pointer_cast<LexerCustomAction>(action)->getRuleIndex();
304           size_t actionIndex = std::dynamic_pointer_cast<LexerCustomAction>(action)->getActionIndex();
305           data.push_back(ruleIndex != INVALID_INDEX ? ruleIndex : 0xFFFF);
306           data.push_back(actionIndex != INVALID_INDEX ? actionIndex : 0xFFFF);
307           break;
308         }
309
310         case LexerActionType::MODE:
311         {
312           int mode = std::dynamic_pointer_cast<LexerModeAction>(action)->getMode();
313           data.push_back(mode != -1 ? mode : 0xFFFF);
314           data.push_back(0);
315           break;
316         }
317
318         case LexerActionType::MORE:
319           data.push_back(0);
320           data.push_back(0);
321           break;
322
323         case LexerActionType::POP_MODE:
324           data.push_back(0);
325           data.push_back(0);
326           break;
327
328         case LexerActionType::PUSH_MODE:
329         {
330           int mode = std::dynamic_pointer_cast<LexerPushModeAction>(action)->getMode();
331           data.push_back(mode != -1 ? mode : 0xFFFF);
332           data.push_back(0);
333           break;
334         }
335
336         case LexerActionType::SKIP:
337           data.push_back(0);
338           data.push_back(0);
339           break;
340
341         case LexerActionType::TYPE:
342         {
343           int type = std::dynamic_pointer_cast<LexerTypeAction>(action)->getType();
344           data.push_back(type != -1 ? type : 0xFFFF);
345           data.push_back(0);
346           break;
347         }
348
349         default:
350           throw IllegalArgumentException("The specified lexer action type " +
351                                          std::to_string(static_cast<size_t>(action->getActionType())) +
352                                          " is not valid.");
353       }
354     }
355   }
356
357   // don't adjust the first value since that's the version number
358   for (size_t i = 1; i < data.size(); i++) {
359     if (data.at(i) > 0xFFFF) {
360       throw UnsupportedOperationException("Serialized ATN data element out of range.");
361     }
362
363     size_t value = (data.at(i) + 2) & 0xFFFF;
364     data.at(i) = value;
365   }
366
367   return data;
368 }
369
370 //------------------------------------------------------------------------------------------------------------
371
372 std::string ATNSerializer::decode(const std::wstring &inpdata) {
373   if (inpdata.size() < 10)
374     throw IllegalArgumentException("Not enough data to decode");
375
376   std::vector<uint16_t> data(inpdata.size());
377   data[0] = (uint16_t)inpdata[0];
378
379   // Don't adjust the first value since that's the version number.
380   for (size_t i = 1; i < inpdata.size(); ++i) {
381     data[i] = (uint16_t)inpdata[i] - 2;
382   }
383
384   std::string buf;
385   size_t p = 0;
386   size_t version = data[p++];
387   if (version != ATNDeserializer::SERIALIZED_VERSION) {
388     std::string reason = "Could not deserialize ATN with version " + std::to_string(version) + "(expected " +
389     std::to_string(ATNDeserializer::SERIALIZED_VERSION) + ").";
390     throw UnsupportedOperationException("ATN Serializer" + reason);
391   }
392
393   Guid uuid = ATNDeserializer::toUUID(data.data(), p);
394   p += 8;
395   if (uuid != ATNDeserializer::SERIALIZED_UUID()) {
396     std::string reason = "Could not deserialize ATN with UUID " + uuid.toString() + " (expected " +
397     ATNDeserializer::SERIALIZED_UUID().toString() + ").";
398     throw UnsupportedOperationException("ATN Serializer" + reason);
399   }
400
401   p++;  // skip grammarType
402   size_t maxType = data[p++];
403   buf.append("max type ").append(std::to_string(maxType)).append("\n");
404   size_t nstates = data[p++];
405   for (size_t i = 0; i < nstates; i++) {
406     size_t stype = data[p++];
407     if (stype == ATNState::ATN_INVALID_TYPE) {  // ignore bad type of states
408       continue;
409     }
410     size_t ruleIndex = data[p++];
411     if (ruleIndex == 0xFFFF) {
412       ruleIndex = INVALID_INDEX;
413     }
414
415     std::string arg = "";
416     if (stype == ATNState::LOOP_END) {
417       int loopBackStateNumber = data[p++];
418       arg = std::string(" ") + std::to_string(loopBackStateNumber);
419     }
420     else if (stype == ATNState::PLUS_BLOCK_START ||
421              stype == ATNState::STAR_BLOCK_START ||
422              stype == ATNState::BLOCK_START) {
423       int endStateNumber = data[p++];
424       arg = std::string(" ") + std::to_string(endStateNumber);
425     }
426     buf.append(std::to_string(i))
427     .append(":")
428     .append(ATNState::serializationNames[stype])
429     .append(" ")
430     .append(std::to_string(ruleIndex))
431     .append(arg)
432     .append("\n");
433   }
434   size_t numNonGreedyStates = data[p++];
435   p += numNonGreedyStates; // Instead of that useless loop below.
436   /*
437    for (int i = 0; i < numNonGreedyStates; i++) {
438    int stateNumber = data[p++];
439    }
440    */
441
442   size_t numPrecedenceStates = data[p++];
443   p += numPrecedenceStates;
444   /*
445    for (int i = 0; i < numPrecedenceStates; i++) {
446    int stateNumber = data[p++];
447    }
448    */
449
450   size_t nrules = data[p++];
451   for (size_t i = 0; i < nrules; i++) {
452     size_t s = data[p++];
453     if (atn->grammarType == ATNType::LEXER) {
454       size_t arg1 = data[p++];
455       buf.append("rule ")
456       .append(std::to_string(i))
457       .append(":")
458       .append(std::to_string(s))
459       .append(" ")
460       .append(std::to_string(arg1))
461       .append("\n");
462     }
463     else {
464       buf.append("rule ")
465       .append(std::to_string(i))
466       .append(":")
467       .append(std::to_string(s))
468       .append("\n");
469     }
470   }
471   size_t nmodes = data[p++];
472   for (size_t i = 0; i < nmodes; i++) {
473     size_t s = data[p++];
474     buf.append("mode ")
475     .append(std::to_string(i))
476     .append(":")
477     .append(std::to_string(s))
478     .append("\n");
479   }
480   size_t nsets = data[p++];
481   for (size_t i = 0; i < nsets; i++) {
482     size_t nintervals = data[p++];
483     buf.append(std::to_string(i)).append(":");
484     bool containsEof = data[p++] != 0;
485     if (containsEof) {
486       buf.append(getTokenName(Token::EOF));
487     }
488
489     for (size_t j = 0; j < nintervals; j++) {
490       if (containsEof || j > 0) {
491         buf.append(", ");
492       }
493
494       buf.append(getTokenName(data[p]))
495       .append("..")
496       .append(getTokenName(data[p + 1]));
497       p += 2;
498     }
499     buf.append("\n");
500   }
501   size_t nedges = data[p++];
502   for (size_t i = 0; i < nedges; i++) {
503     size_t src = data[p];
504     size_t trg = data[p + 1];
505     size_t ttype = data[p + 2];
506     size_t arg1 = data[p + 3];
507     size_t arg2 = data[p + 4];
508     size_t arg3 = data[p + 5];
509     buf.append(std::to_string(src))
510     .append("->")
511     .append(std::to_string(trg))
512     .append(" ")
513     .append(Transition::serializationNames[ttype])
514     .append(" ")
515     .append(std::to_string(arg1))
516     .append(",")
517     .append(std::to_string(arg2))
518     .append(",")
519     .append(std::to_string(arg3))
520     .append("\n");
521     p += 6;
522   }
523   size_t ndecisions = data[p++];
524   for (size_t i = 0; i < ndecisions; i++) {
525     size_t s = data[p++];
526     buf += std::to_string(i) + ":" + std::to_string(s) + "\n";
527   }
528
529   if (atn->grammarType == ATNType::LEXER) {
530     //int lexerActionCount = data[p++];
531
532     //p += lexerActionCount * 3; // Instead of useless loop below.
533     /*
534     for (int i = 0; i < lexerActionCount; i++) {
535       LexerActionType actionType = (LexerActionType)data[p++];
536       int data1 = data[p++];
537       int data2 = data[p++];
538     }
539      */
540   }
541
542   return buf;
543 }
544
545 std::string ATNSerializer::getTokenName(size_t t) {
546   if (t == Token::EOF) {
547     return "EOF";
548   }
549
550   if (atn->grammarType == ATNType::LEXER && t <= 0x10FFFF) {
551     switch (t) {
552       case '\n':
553         return "'\\n'";
554       case '\r':
555         return "'\\r'";
556       case '\t':
557         return "'\\t'";
558       case '\b':
559         return "'\\b'";
560       case '\f':
561         return "'\\f'";
562       case '\\':
563         return "'\\\\'";
564       case '\'':
565         return "'\\''";
566       default:
567         std::string s_hex = antlrcpp::toHexString((int)t);
568         if (s_hex >= "0" && s_hex <= "7F" && !iscntrl((int)t)) {
569           return "'" + std::to_string(t) + "'";
570         }
571
572         // turn on the bit above max "\u10FFFF" value so that we pad with zeros
573         // then only take last 6 digits
574         std::string hex = antlrcpp::toHexString((int)t | 0x1000000).substr(1, 6);
575         std::string unicodeStr = std::string("'\\u") + hex + std::string("'");
576         return unicodeStr;
577     }
578   }
579
580   if (_tokenNames.size() > 0 && t < _tokenNames.size()) {
581     return _tokenNames[t];
582   }
583
584   return std::to_string(t);
585 }
586
587 std::wstring ATNSerializer::getSerializedAsString(ATN *atn) {
588   std::vector<size_t> data = getSerialized(atn);
589   std::wstring result;
590   for (size_t entry : data)
591     result.push_back((wchar_t)entry);
592
593   return result;
594 }
595
596 std::vector<size_t> ATNSerializer::getSerialized(ATN *atn) {
597   return ATNSerializer(atn).serialize();
598 }
599
600 std::string ATNSerializer::getDecoded(ATN *atn, std::vector<std::string> &tokenNames) {
601   std::wstring serialized = getSerializedAsString(atn);
602   return ATNSerializer(atn, tokenNames).decode(serialized);
603 }
604
605 void ATNSerializer::serializeUUID(std::vector<size_t> &data, Guid uuid) {
606   unsigned int twoBytes = 0;
607   bool firstByte = true;
608   for( std::vector<unsigned char>::const_reverse_iterator rit = uuid.rbegin(); rit != uuid.rend(); ++rit )
609   {
610      if (firstByte) {
611        twoBytes = *rit;
612        firstByte = false;
613      } else {
614        twoBytes |= (*rit << 8);
615        data.push_back(twoBytes);
616        firstByte = true;
617      }
618   }
619   if (!firstByte)
620      throw IllegalArgumentException( "The UUID provided is not valid (odd number of bytes)." );
621 }