]> gitweb.ps.run Git - toc/blobdiff - antlr4-cpp-runtime-4.9.2-source/runtime/src/atn/ATNSerializer.cpp
add antlr source code and ReadMe
[toc] / antlr4-cpp-runtime-4.9.2-source / runtime / src / atn / ATNSerializer.cpp
diff --git a/antlr4-cpp-runtime-4.9.2-source/runtime/src/atn/ATNSerializer.cpp b/antlr4-cpp-runtime-4.9.2-source/runtime/src/atn/ATNSerializer.cpp
new file mode 100644 (file)
index 0000000..293bee5
--- /dev/null
@@ -0,0 +1,621 @@
+/* Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
+ * Use of this file is governed by the BSD 3-clause license that
+ * can be found in the LICENSE.txt file in the project root.
+ */
+
+#include "misc/IntervalSet.h"
+#include "atn/ATNType.h"
+#include "atn/ATNState.h"
+#include "atn/BlockEndState.h"
+
+#include "atn/DecisionState.h"
+#include "atn/RuleStartState.h"
+#include "atn/LoopEndState.h"
+#include "atn/BlockStartState.h"
+#include "atn/Transition.h"
+#include "atn/SetTransition.h"
+#include "Token.h"
+#include "misc/Interval.h"
+#include "atn/ATN.h"
+
+#include "atn/RuleTransition.h"
+#include "atn/PrecedencePredicateTransition.h"
+#include "atn/PredicateTransition.h"
+#include "atn/RangeTransition.h"
+#include "atn/AtomTransition.h"
+#include "atn/ActionTransition.h"
+#include "atn/ATNDeserializer.h"
+
+#include "atn/TokensStartState.h"
+#include "Exceptions.h"
+#include "support/CPPUtils.h"
+
+#include "atn/LexerChannelAction.h"
+#include "atn/LexerCustomAction.h"
+#include "atn/LexerModeAction.h"
+#include "atn/LexerPushModeAction.h"
+#include "atn/LexerTypeAction.h"
+
+#include "Exceptions.h"
+
+#include "atn/ATNSerializer.h"
+
+using namespace antlrcpp;
+using namespace antlr4::atn;
+
+ATNSerializer::ATNSerializer(ATN *atn) { this->atn = atn; }
+
+ATNSerializer::ATNSerializer(ATN *atn, const std::vector<std::string> &tokenNames) {
+  this->atn = atn;
+  _tokenNames = tokenNames;
+}
+
+ATNSerializer::~ATNSerializer() { }
+
+std::vector<size_t> ATNSerializer::serialize() {
+  std::vector<size_t> data;
+  data.push_back(ATNDeserializer::SERIALIZED_VERSION);
+  serializeUUID(data, ATNDeserializer::SERIALIZED_UUID());
+
+  // convert grammar type to ATN const to avoid dependence on ANTLRParser
+  data.push_back(static_cast<size_t>(atn->grammarType));
+  data.push_back(atn->maxTokenType);
+  size_t nedges = 0;
+
+  std::unordered_map<misc::IntervalSet, int> setIndices;
+  std::vector<misc::IntervalSet> sets;
+
+  // dump states, count edges and collect sets while doing so
+  std::vector<size_t> nonGreedyStates;
+  std::vector<size_t> precedenceStates;
+  data.push_back(atn->states.size());
+  for (ATNState *s : atn->states) {
+    if (s == nullptr) {  // might be optimized away
+      data.push_back(ATNState::ATN_INVALID_TYPE);
+      continue;
+    }
+
+    size_t stateType = s->getStateType();
+    if (is<DecisionState *>(s) && (static_cast<DecisionState *>(s))->nonGreedy) {
+      nonGreedyStates.push_back(s->stateNumber);
+    }
+
+    if (is<RuleStartState *>(s) && (static_cast<RuleStartState *>(s))->isLeftRecursiveRule) {
+      precedenceStates.push_back(s->stateNumber);
+    }
+
+    data.push_back(stateType);
+
+    if (s->ruleIndex == INVALID_INDEX) {
+      data.push_back(0xFFFF);
+    }
+    else {
+      data.push_back(s->ruleIndex);
+    }
+
+    if (s->getStateType() == ATNState::LOOP_END) {
+      data.push_back((static_cast<LoopEndState *>(s))->loopBackState->stateNumber);
+    }
+    else if (is<BlockStartState *>(s)) {
+      data.push_back((static_cast<BlockStartState *>(s))->endState->stateNumber);
+    }
+
+    if (s->getStateType() != ATNState::RULE_STOP) {
+      // the deserializer can trivially derive these edges, so there's no need
+      // to serialize them
+      nedges += s->transitions.size();
+    }
+
+    for (size_t i = 0; i < s->transitions.size(); i++) {
+      Transition *t = s->transitions[i];
+      Transition::SerializationType edgeType = t->getSerializationType();
+      if (edgeType == Transition::SET || edgeType == Transition::NOT_SET) {
+        SetTransition *st = static_cast<SetTransition *>(t);
+        if (setIndices.find(st->set) == setIndices.end()) {
+          sets.push_back(st->set);
+          setIndices.insert({ st->set, (int)sets.size() - 1 });
+        }
+      }
+    }
+  }
+
+  // non-greedy states
+  data.push_back(nonGreedyStates.size());
+  for (size_t i = 0; i < nonGreedyStates.size(); i++) {
+    data.push_back(nonGreedyStates.at(i));
+  }
+
+  // precedence states
+  data.push_back(precedenceStates.size());
+  for (size_t i = 0; i < precedenceStates.size(); i++) {
+    data.push_back(precedenceStates.at(i));
+  }
+
+  size_t nrules = atn->ruleToStartState.size();
+  data.push_back(nrules);
+  for (size_t r = 0; r < nrules; r++) {
+    ATNState *ruleStartState = atn->ruleToStartState[r];
+    data.push_back(ruleStartState->stateNumber);
+    if (atn->grammarType == ATNType::LEXER) {
+      if (atn->ruleToTokenType[r] == Token::EOF) {
+        data.push_back(0xFFFF);
+      }
+      else {
+        data.push_back(atn->ruleToTokenType[r]);
+      }
+    }
+  }
+
+  size_t nmodes = atn->modeToStartState.size();
+  data.push_back(nmodes);
+  if (nmodes > 0) {
+    for (const auto &modeStartState : atn->modeToStartState) {
+      data.push_back(modeStartState->stateNumber);
+    }
+  }
+
+  size_t nsets = sets.size();
+  data.push_back(nsets);
+  for (auto set : sets) {
+    bool containsEof = set.contains(Token::EOF);
+    if (containsEof && set.getIntervals().at(0).b == -1) {
+      data.push_back(set.getIntervals().size() - 1);
+    }
+    else {
+      data.push_back(set.getIntervals().size());
+    }
+
+    data.push_back(containsEof ? 1 : 0);
+    for (const auto &interval : set.getIntervals()) {
+      if (interval.a == -1) {
+        if (interval.b == -1) {
+          continue;
+        } else {
+          data.push_back(0);
+        }
+      }
+      else {
+        data.push_back(interval.a);
+      }
+
+      data.push_back(interval.b);
+    }
+  }
+
+  data.push_back(nedges);
+  for (ATNState *s : atn->states) {
+    if (s == nullptr) {
+      // might be optimized away
+      continue;
+    }
+
+    if (s->getStateType() == ATNState::RULE_STOP) {
+      continue;
+    }
+
+    for (size_t i = 0; i < s->transitions.size(); i++) {
+      Transition *t = s->transitions[i];
+
+      if (atn->states[t->target->stateNumber] == nullptr) {
+        throw IllegalStateException("Cannot serialize a transition to a removed state.");
+      }
+
+      size_t src = s->stateNumber;
+      size_t trg = t->target->stateNumber;
+      Transition::SerializationType edgeType = t->getSerializationType();
+      size_t arg1 = 0;
+      size_t arg2 = 0;
+      size_t arg3 = 0;
+      switch (edgeType) {
+        case Transition::RULE:
+          trg = (static_cast<RuleTransition *>(t))->followState->stateNumber;
+          arg1 = (static_cast<RuleTransition *>(t))->target->stateNumber;
+          arg2 = (static_cast<RuleTransition *>(t))->ruleIndex;
+          arg3 = (static_cast<RuleTransition *>(t))->precedence;
+          break;
+        case Transition::PRECEDENCE:
+        {
+          PrecedencePredicateTransition *ppt =
+          static_cast<PrecedencePredicateTransition *>(t);
+          arg1 = ppt->precedence;
+        }
+          break;
+        case Transition::PREDICATE:
+        {
+          PredicateTransition *pt = static_cast<PredicateTransition *>(t);
+          arg1 = pt->ruleIndex;
+          arg2 = pt->predIndex;
+          arg3 = pt->isCtxDependent ? 1 : 0;
+        }
+          break;
+        case Transition::RANGE:
+          arg1 = (static_cast<RangeTransition *>(t))->from;
+          arg2 = (static_cast<RangeTransition *>(t))->to;
+          if (arg1 == Token::EOF) {
+            arg1 = 0;
+            arg3 = 1;
+          }
+
+          break;
+        case Transition::ATOM:
+          arg1 = (static_cast<AtomTransition *>(t))->_label;
+          if (arg1 == Token::EOF) {
+            arg1 = 0;
+            arg3 = 1;
+          }
+
+          break;
+        case Transition::ACTION:
+        {
+          ActionTransition *at = static_cast<ActionTransition *>(t);
+          arg1 = at->ruleIndex;
+          arg2 = at->actionIndex;
+          if (arg2 == INVALID_INDEX) {
+            arg2 = 0xFFFF;
+          }
+
+          arg3 = at->isCtxDependent ? 1 : 0;
+        }
+          break;
+        case Transition::SET:
+          arg1 = setIndices[(static_cast<SetTransition *>(t))->set];
+          break;
+
+        case Transition::NOT_SET:
+          arg1 = setIndices[(static_cast<SetTransition *>(t))->set];
+          break;
+
+        default:
+          break;
+      }
+
+      data.push_back(src);
+      data.push_back(trg);
+      data.push_back(edgeType);
+      data.push_back(arg1);
+      data.push_back(arg2);
+      data.push_back(arg3);
+    }
+  }
+
+  size_t ndecisions = atn->decisionToState.size();
+  data.push_back(ndecisions);
+  for (DecisionState *decStartState : atn->decisionToState) {
+    data.push_back(decStartState->stateNumber);
+  }
+
+  // LEXER ACTIONS
+  if (atn->grammarType == ATNType::LEXER) {
+    data.push_back(atn->lexerActions.size());
+    for (Ref<LexerAction> &action : atn->lexerActions) {
+      data.push_back(static_cast<size_t>(action->getActionType()));
+      switch (action->getActionType()) {
+        case LexerActionType::CHANNEL:
+        {
+          int channel = std::dynamic_pointer_cast<LexerChannelAction>(action)->getChannel();
+          data.push_back(channel != -1 ? channel : 0xFFFF);
+          data.push_back(0);
+          break;
+        }
+
+        case LexerActionType::CUSTOM:
+        {
+          size_t ruleIndex = std::dynamic_pointer_cast<LexerCustomAction>(action)->getRuleIndex();
+          size_t actionIndex = std::dynamic_pointer_cast<LexerCustomAction>(action)->getActionIndex();
+          data.push_back(ruleIndex != INVALID_INDEX ? ruleIndex : 0xFFFF);
+          data.push_back(actionIndex != INVALID_INDEX ? actionIndex : 0xFFFF);
+          break;
+        }
+
+        case LexerActionType::MODE:
+        {
+          int mode = std::dynamic_pointer_cast<LexerModeAction>(action)->getMode();
+          data.push_back(mode != -1 ? mode : 0xFFFF);
+          data.push_back(0);
+          break;
+        }
+
+        case LexerActionType::MORE:
+          data.push_back(0);
+          data.push_back(0);
+          break;
+
+        case LexerActionType::POP_MODE:
+          data.push_back(0);
+          data.push_back(0);
+          break;
+
+        case LexerActionType::PUSH_MODE:
+        {
+          int mode = std::dynamic_pointer_cast<LexerPushModeAction>(action)->getMode();
+          data.push_back(mode != -1 ? mode : 0xFFFF);
+          data.push_back(0);
+          break;
+        }
+
+        case LexerActionType::SKIP:
+          data.push_back(0);
+          data.push_back(0);
+          break;
+
+        case LexerActionType::TYPE:
+        {
+          int type = std::dynamic_pointer_cast<LexerTypeAction>(action)->getType();
+          data.push_back(type != -1 ? type : 0xFFFF);
+          data.push_back(0);
+          break;
+        }
+
+        default:
+          throw IllegalArgumentException("The specified lexer action type " +
+                                         std::to_string(static_cast<size_t>(action->getActionType())) +
+                                         " is not valid.");
+      }
+    }
+  }
+
+  // don't adjust the first value since that's the version number
+  for (size_t i = 1; i < data.size(); i++) {
+    if (data.at(i) > 0xFFFF) {
+      throw UnsupportedOperationException("Serialized ATN data element out of range.");
+    }
+
+    size_t value = (data.at(i) + 2) & 0xFFFF;
+    data.at(i) = value;
+  }
+
+  return data;
+}
+
+//------------------------------------------------------------------------------------------------------------
+
+std::string ATNSerializer::decode(const std::wstring &inpdata) {
+  if (inpdata.size() < 10)
+    throw IllegalArgumentException("Not enough data to decode");
+
+  std::vector<uint16_t> data(inpdata.size());
+  data[0] = (uint16_t)inpdata[0];
+
+  // Don't adjust the first value since that's the version number.
+  for (size_t i = 1; i < inpdata.size(); ++i) {
+    data[i] = (uint16_t)inpdata[i] - 2;
+  }
+
+  std::string buf;
+  size_t p = 0;
+  size_t version = data[p++];
+  if (version != ATNDeserializer::SERIALIZED_VERSION) {
+    std::string reason = "Could not deserialize ATN with version " + std::to_string(version) + "(expected " +
+    std::to_string(ATNDeserializer::SERIALIZED_VERSION) + ").";
+    throw UnsupportedOperationException("ATN Serializer" + reason);
+  }
+
+  Guid uuid = ATNDeserializer::toUUID(data.data(), p);
+  p += 8;
+  if (uuid != ATNDeserializer::SERIALIZED_UUID()) {
+    std::string reason = "Could not deserialize ATN with UUID " + uuid.toString() + " (expected " +
+    ATNDeserializer::SERIALIZED_UUID().toString() + ").";
+    throw UnsupportedOperationException("ATN Serializer" + reason);
+  }
+
+  p++;  // skip grammarType
+  size_t maxType = data[p++];
+  buf.append("max type ").append(std::to_string(maxType)).append("\n");
+  size_t nstates = data[p++];
+  for (size_t i = 0; i < nstates; i++) {
+    size_t stype = data[p++];
+    if (stype == ATNState::ATN_INVALID_TYPE) {  // ignore bad type of states
+      continue;
+    }
+    size_t ruleIndex = data[p++];
+    if (ruleIndex == 0xFFFF) {
+      ruleIndex = INVALID_INDEX;
+    }
+
+    std::string arg = "";
+    if (stype == ATNState::LOOP_END) {
+      int loopBackStateNumber = data[p++];
+      arg = std::string(" ") + std::to_string(loopBackStateNumber);
+    }
+    else if (stype == ATNState::PLUS_BLOCK_START ||
+             stype == ATNState::STAR_BLOCK_START ||
+             stype == ATNState::BLOCK_START) {
+      int endStateNumber = data[p++];
+      arg = std::string(" ") + std::to_string(endStateNumber);
+    }
+    buf.append(std::to_string(i))
+    .append(":")
+    .append(ATNState::serializationNames[stype])
+    .append(" ")
+    .append(std::to_string(ruleIndex))
+    .append(arg)
+    .append("\n");
+  }
+  size_t numNonGreedyStates = data[p++];
+  p += numNonGreedyStates; // Instead of that useless loop below.
+  /*
+   for (int i = 0; i < numNonGreedyStates; i++) {
+   int stateNumber = data[p++];
+   }
+   */
+
+  size_t numPrecedenceStates = data[p++];
+  p += numPrecedenceStates;
+  /*
+   for (int i = 0; i < numPrecedenceStates; i++) {
+   int stateNumber = data[p++];
+   }
+   */
+
+  size_t nrules = data[p++];
+  for (size_t i = 0; i < nrules; i++) {
+    size_t s = data[p++];
+    if (atn->grammarType == ATNType::LEXER) {
+      size_t arg1 = data[p++];
+      buf.append("rule ")
+      .append(std::to_string(i))
+      .append(":")
+      .append(std::to_string(s))
+      .append(" ")
+      .append(std::to_string(arg1))
+      .append("\n");
+    }
+    else {
+      buf.append("rule ")
+      .append(std::to_string(i))
+      .append(":")
+      .append(std::to_string(s))
+      .append("\n");
+    }
+  }
+  size_t nmodes = data[p++];
+  for (size_t i = 0; i < nmodes; i++) {
+    size_t s = data[p++];
+    buf.append("mode ")
+    .append(std::to_string(i))
+    .append(":")
+    .append(std::to_string(s))
+    .append("\n");
+  }
+  size_t nsets = data[p++];
+  for (size_t i = 0; i < nsets; i++) {
+    size_t nintervals = data[p++];
+    buf.append(std::to_string(i)).append(":");
+    bool containsEof = data[p++] != 0;
+    if (containsEof) {
+      buf.append(getTokenName(Token::EOF));
+    }
+
+    for (size_t j = 0; j < nintervals; j++) {
+      if (containsEof || j > 0) {
+        buf.append(", ");
+      }
+
+      buf.append(getTokenName(data[p]))
+      .append("..")
+      .append(getTokenName(data[p + 1]));
+      p += 2;
+    }
+    buf.append("\n");
+  }
+  size_t nedges = data[p++];
+  for (size_t i = 0; i < nedges; i++) {
+    size_t src = data[p];
+    size_t trg = data[p + 1];
+    size_t ttype = data[p + 2];
+    size_t arg1 = data[p + 3];
+    size_t arg2 = data[p + 4];
+    size_t arg3 = data[p + 5];
+    buf.append(std::to_string(src))
+    .append("->")
+    .append(std::to_string(trg))
+    .append(" ")
+    .append(Transition::serializationNames[ttype])
+    .append(" ")
+    .append(std::to_string(arg1))
+    .append(",")
+    .append(std::to_string(arg2))
+    .append(",")
+    .append(std::to_string(arg3))
+    .append("\n");
+    p += 6;
+  }
+  size_t ndecisions = data[p++];
+  for (size_t i = 0; i < ndecisions; i++) {
+    size_t s = data[p++];
+    buf += std::to_string(i) + ":" + std::to_string(s) + "\n";
+  }
+
+  if (atn->grammarType == ATNType::LEXER) {
+    //int lexerActionCount = data[p++];
+
+    //p += lexerActionCount * 3; // Instead of useless loop below.
+    /*
+    for (int i = 0; i < lexerActionCount; i++) {
+      LexerActionType actionType = (LexerActionType)data[p++];
+      int data1 = data[p++];
+      int data2 = data[p++];
+    }
+     */
+  }
+
+  return buf;
+}
+
+std::string ATNSerializer::getTokenName(size_t t) {
+  if (t == Token::EOF) {
+    return "EOF";
+  }
+
+  if (atn->grammarType == ATNType::LEXER && t <= 0x10FFFF) {
+    switch (t) {
+      case '\n':
+        return "'\\n'";
+      case '\r':
+        return "'\\r'";
+      case '\t':
+        return "'\\t'";
+      case '\b':
+        return "'\\b'";
+      case '\f':
+        return "'\\f'";
+      case '\\':
+        return "'\\\\'";
+      case '\'':
+        return "'\\''";
+      default:
+        std::string s_hex = antlrcpp::toHexString((int)t);
+        if (s_hex >= "0" && s_hex <= "7F" && !iscntrl((int)t)) {
+          return "'" + std::to_string(t) + "'";
+        }
+
+        // turn on the bit above max "\u10FFFF" value so that we pad with zeros
+        // then only take last 6 digits
+        std::string hex = antlrcpp::toHexString((int)t | 0x1000000).substr(1, 6);
+        std::string unicodeStr = std::string("'\\u") + hex + std::string("'");
+        return unicodeStr;
+    }
+  }
+
+  if (_tokenNames.size() > 0 && t < _tokenNames.size()) {
+    return _tokenNames[t];
+  }
+
+  return std::to_string(t);
+}
+
+std::wstring ATNSerializer::getSerializedAsString(ATN *atn) {
+  std::vector<size_t> data = getSerialized(atn);
+  std::wstring result;
+  for (size_t entry : data)
+    result.push_back((wchar_t)entry);
+
+  return result;
+}
+
+std::vector<size_t> ATNSerializer::getSerialized(ATN *atn) {
+  return ATNSerializer(atn).serialize();
+}
+
+std::string ATNSerializer::getDecoded(ATN *atn, std::vector<std::string> &tokenNames) {
+  std::wstring serialized = getSerializedAsString(atn);
+  return ATNSerializer(atn, tokenNames).decode(serialized);
+}
+
+void ATNSerializer::serializeUUID(std::vector<size_t> &data, Guid uuid) {
+  unsigned int twoBytes = 0;
+  bool firstByte = true;
+  for( std::vector<unsigned char>::const_reverse_iterator rit = uuid.rbegin(); rit != uuid.rend(); ++rit )
+  {
+     if (firstByte) {
+       twoBytes = *rit;
+       firstByte = false;
+     } else {
+       twoBytes |= (*rit << 8);
+       data.push_back(twoBytes);
+       firstByte = true;
+     }
+  }
+  if (!firstByte)
+     throw IllegalArgumentException( "The UUID provided is not valid (odd number of bytes)." );
+}