1 /* Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
2 * Use of this file is governed by the BSD 3-clause license that
3 * can be found in the LICENSE.txt file in the project root.
6 #include "tree/pattern/ParseTreePattern.h"
7 #include "tree/pattern/ParseTreeMatch.h"
8 #include "tree/TerminalNode.h"
9 #include "CommonTokenStream.h"
10 #include "ParserInterpreter.h"
11 #include "tree/pattern/TokenTagToken.h"
12 #include "ParserRuleContext.h"
13 #include "tree/pattern/RuleTagToken.h"
14 #include "tree/pattern/TagChunk.h"
17 #include "BailErrorStrategy.h"
19 #include "ListTokenSource.h"
20 #include "tree/pattern/TextChunk.h"
21 #include "ANTLRInputStream.h"
22 #include "support/Arrays.h"
23 #include "Exceptions.h"
24 #include "support/StringUtils.h"
25 #include "support/CPPUtils.h"
27 #include "tree/pattern/ParseTreePatternMatcher.h"
29 using namespace antlr4;
30 using namespace antlr4::tree;
31 using namespace antlr4::tree::pattern;
32 using namespace antlrcpp;
34 ParseTreePatternMatcher::CannotInvokeStartRule::CannotInvokeStartRule(const RuntimeException &e) : RuntimeException(e.what()) {
37 ParseTreePatternMatcher::CannotInvokeStartRule::~CannotInvokeStartRule() {
40 ParseTreePatternMatcher::StartRuleDoesNotConsumeFullPattern::~StartRuleDoesNotConsumeFullPattern() {
43 ParseTreePatternMatcher::ParseTreePatternMatcher(Lexer *lexer, Parser *parser) : _lexer(lexer), _parser(parser) {
44 InitializeInstanceFields();
47 ParseTreePatternMatcher::~ParseTreePatternMatcher() {
50 void ParseTreePatternMatcher::setDelimiters(const std::string &start, const std::string &stop, const std::string &escapeLeft) {
52 throw IllegalArgumentException("start cannot be null or empty");
56 throw IllegalArgumentException("stop cannot be null or empty");
64 bool ParseTreePatternMatcher::matches(ParseTree *tree, const std::string &pattern, int patternRuleIndex) {
65 ParseTreePattern p = compile(pattern, patternRuleIndex);
66 return matches(tree, p);
69 bool ParseTreePatternMatcher::matches(ParseTree *tree, const ParseTreePattern &pattern) {
70 std::map<std::string, std::vector<ParseTree *>> labels;
71 ParseTree *mismatchedNode = matchImpl(tree, pattern.getPatternTree(), labels);
72 return mismatchedNode == nullptr;
75 ParseTreeMatch ParseTreePatternMatcher::match(ParseTree *tree, const std::string &pattern, int patternRuleIndex) {
76 ParseTreePattern p = compile(pattern, patternRuleIndex);
77 return match(tree, p);
80 ParseTreeMatch ParseTreePatternMatcher::match(ParseTree *tree, const ParseTreePattern &pattern) {
81 std::map<std::string, std::vector<ParseTree *>> labels;
82 tree::ParseTree *mismatchedNode = matchImpl(tree, pattern.getPatternTree(), labels);
83 return ParseTreeMatch(tree, pattern, labels, mismatchedNode);
86 ParseTreePattern ParseTreePatternMatcher::compile(const std::string &pattern, int patternRuleIndex) {
87 ListTokenSource tokenSrc(tokenize(pattern));
88 CommonTokenStream tokens(&tokenSrc);
90 ParserInterpreter parserInterp(_parser->getGrammarFileName(), _parser->getVocabulary(),
91 _parser->getRuleNames(), _parser->getATNWithBypassAlts(), &tokens);
93 ParserRuleContext *tree = nullptr;
95 parserInterp.setErrorHandler(std::make_shared<BailErrorStrategy>());
96 tree = parserInterp.parse(patternRuleIndex);
97 } catch (ParseCancellationException &e) {
98 #if defined(_MSC_FULL_VER) && _MSC_FULL_VER < 190023026
99 // rethrow_if_nested is not available before VS 2015.
102 std::rethrow_if_nested(e); // Unwrap the nested exception.
104 } catch (RecognitionException &re) {
106 #if defined(_MSC_FULL_VER) && _MSC_FULL_VER < 190023026
107 } catch (std::exception &e) {
108 // throw_with_nested is not available before VS 2015.
111 } catch (std::exception & /*e*/) {
112 std::throw_with_nested((const char*)"Cannot invoke start rule"); // Wrap any other exception. We should however probably use one of the ANTLR exceptions here.
116 // Make sure tree pattern compilation checks for a complete parse
117 if (tokens.LA(1) != Token::EOF) {
118 throw StartRuleDoesNotConsumeFullPattern();
121 return ParseTreePattern(this, pattern, patternRuleIndex, tree);
124 Lexer* ParseTreePatternMatcher::getLexer() {
128 Parser* ParseTreePatternMatcher::getParser() {
132 ParseTree* ParseTreePatternMatcher::matchImpl(ParseTree *tree, ParseTree *patternTree,
133 std::map<std::string, std::vector<ParseTree *>> &labels) {
134 if (tree == nullptr) {
135 throw IllegalArgumentException("tree cannot be nul");
138 if (patternTree == nullptr) {
139 throw IllegalArgumentException("patternTree cannot be nul");
142 // x and <ID>, x and y, or x and x; or could be mismatched types
143 if (is<TerminalNode *>(tree) && is<TerminalNode *>(patternTree)) {
144 TerminalNode *t1 = dynamic_cast<TerminalNode *>(tree);
145 TerminalNode *t2 = dynamic_cast<TerminalNode *>(patternTree);
147 ParseTree *mismatchedNode = nullptr;
148 // both are tokens and they have same type
149 if (t1->getSymbol()->getType() == t2->getSymbol()->getType()) {
150 if (is<TokenTagToken *>(t2->getSymbol())) { // x and <ID>
151 TokenTagToken *tokenTagToken = dynamic_cast<TokenTagToken *>(t2->getSymbol());
153 // track label->list-of-nodes for both token name and label (if any)
154 labels[tokenTagToken->getTokenName()].push_back(tree);
155 if (tokenTagToken->getLabel() != "") {
156 labels[tokenTagToken->getLabel()].push_back(tree);
158 } else if (t1->getText() == t2->getText()) {
162 if (mismatchedNode == nullptr) {
167 if (mismatchedNode == nullptr) {
172 return mismatchedNode;
175 if (is<ParserRuleContext *>(tree) && is<ParserRuleContext *>(patternTree)) {
176 ParserRuleContext *r1 = dynamic_cast<ParserRuleContext *>(tree);
177 ParserRuleContext *r2 = dynamic_cast<ParserRuleContext *>(patternTree);
178 ParseTree *mismatchedNode = nullptr;
180 // (expr ...) and <expr>
181 RuleTagToken *ruleTagToken = getRuleTagToken(r2);
182 if (ruleTagToken != nullptr) {
183 //ParseTreeMatch *m = nullptr; // unused?
184 if (r1->getRuleIndex() == r2->getRuleIndex()) {
185 // track label->list-of-nodes for both rule name and label (if any)
186 labels[ruleTagToken->getRuleName()].push_back(tree);
187 if (ruleTagToken->getLabel() != "") {
188 labels[ruleTagToken->getLabel()].push_back(tree);
191 if (!mismatchedNode) {
196 return mismatchedNode;
199 // (expr ...) and (expr ...)
200 if (r1->children.size() != r2->children.size()) {
201 if (mismatchedNode == nullptr) {
205 return mismatchedNode;
208 std::size_t n = r1->children.size();
209 for (size_t i = 0; i < n; i++) {
210 ParseTree *childMatch = matchImpl(r1->children[i], patternTree->children[i], labels);
216 return mismatchedNode;
219 // if nodes aren't both tokens or both rule nodes, can't match
223 RuleTagToken* ParseTreePatternMatcher::getRuleTagToken(ParseTree *t) {
224 if (t->children.size() == 1 && is<TerminalNode *>(t->children[0])) {
225 TerminalNode *c = dynamic_cast<TerminalNode *>(t->children[0]);
226 if (is<RuleTagToken *>(c->getSymbol())) {
227 return dynamic_cast<RuleTagToken *>(c->getSymbol());
233 std::vector<std::unique_ptr<Token>> ParseTreePatternMatcher::tokenize(const std::string &pattern) {
234 // split pattern into chunks: sea (raw input) and islands (<ID>, <expr>)
235 std::vector<Chunk> chunks = split(pattern);
237 // create token stream from text and tags
238 std::vector<std::unique_ptr<Token>> tokens;
239 for (auto chunk : chunks) {
240 if (is<TagChunk *>(&chunk)) {
241 TagChunk &tagChunk = (TagChunk&)chunk;
242 // add special rule token or conjure up new token from name
243 if (isupper(tagChunk.getTag()[0])) {
244 size_t ttype = _parser->getTokenType(tagChunk.getTag());
245 if (ttype == Token::INVALID_TYPE) {
246 throw IllegalArgumentException("Unknown token " + tagChunk.getTag() + " in pattern: " + pattern);
248 tokens.emplace_back(new TokenTagToken(tagChunk.getTag(), (int)ttype, tagChunk.getLabel()));
249 } else if (islower(tagChunk.getTag()[0])) {
250 size_t ruleIndex = _parser->getRuleIndex(tagChunk.getTag());
251 if (ruleIndex == INVALID_INDEX) {
252 throw IllegalArgumentException("Unknown rule " + tagChunk.getTag() + " in pattern: " + pattern);
254 size_t ruleImaginaryTokenType = _parser->getATNWithBypassAlts().ruleToTokenType[ruleIndex];
255 tokens.emplace_back(new RuleTagToken(tagChunk.getTag(), ruleImaginaryTokenType, tagChunk.getLabel()));
257 throw IllegalArgumentException("invalid tag: " + tagChunk.getTag() + " in pattern: " + pattern);
260 TextChunk &textChunk = (TextChunk&)chunk;
261 ANTLRInputStream input(textChunk.getText());
262 _lexer->setInputStream(&input);
263 std::unique_ptr<Token> t(_lexer->nextToken());
264 while (t->getType() != Token::EOF) {
265 tokens.push_back(std::move(t));
266 t = _lexer->nextToken();
268 _lexer->setInputStream(nullptr);
275 std::vector<Chunk> ParseTreePatternMatcher::split(const std::string &pattern) {
277 size_t n = pattern.length();
278 std::vector<Chunk> chunks;
280 // find all start and stop indexes first, then collect
281 std::vector<size_t> starts;
282 std::vector<size_t> stops;
284 if (p == pattern.find(_escape + _start,p)) {
285 p += _escape.length() + _start.length();
286 } else if (p == pattern.find(_escape + _stop,p)) {
287 p += _escape.length() + _stop.length();
288 } else if (p == pattern.find(_start,p)) {
290 p += _start.length();
291 } else if (p == pattern.find(_stop,p)) {
299 if (starts.size() > stops.size()) {
300 throw IllegalArgumentException("unterminated tag in pattern: " + pattern);
303 if (starts.size() < stops.size()) {
304 throw IllegalArgumentException("missing start tag in pattern: " + pattern);
307 size_t ntags = starts.size();
308 for (size_t i = 0; i < ntags; i++) {
309 if (starts[i] >= stops[i]) {
310 throw IllegalArgumentException("tag delimiters out of order in pattern: " + pattern);
314 // collect into chunks now
316 std::string text = pattern.substr(0, n);
317 chunks.push_back(TextChunk(text));
320 if (ntags > 0 && starts[0] > 0) { // copy text up to first tag into chunks
321 std::string text = pattern.substr(0, starts[0]);
322 chunks.push_back(TextChunk(text));
325 for (size_t i = 0; i < ntags; i++) {
326 // copy inside of <tag>
327 std::string tag = pattern.substr(starts[i] + _start.length(), stops[i] - (starts[i] + _start.length()));
328 std::string ruleOrToken = tag;
329 std::string label = "";
330 size_t colon = tag.find(':');
331 if (colon != std::string::npos) {
332 label = tag.substr(0,colon);
333 ruleOrToken = tag.substr(colon + 1, tag.length() - (colon + 1));
335 chunks.push_back(TagChunk(label, ruleOrToken));
337 // copy from end of <tag> to start of next
338 std::string text = pattern.substr(stops[i] + _stop.length(), starts[i + 1] - (stops[i] + _stop.length()));
339 chunks.push_back(TextChunk(text));
344 size_t afterLastTag = stops[ntags - 1] + _stop.length();
345 if (afterLastTag < n) { // copy text from end of last tag to end
346 std::string text = pattern.substr(afterLastTag, n - afterLastTag);
347 chunks.push_back(TextChunk(text));
351 // strip out all backslashes from text chunks but not tags
352 for (size_t i = 0; i < chunks.size(); i++) {
353 Chunk &c = chunks[i];
354 if (is<TextChunk *>(&c)) {
355 TextChunk &tc = (TextChunk&)c;
356 std::string unescaped = tc.getText();
357 unescaped.erase(std::remove(unescaped.begin(), unescaped.end(), '\\'), unescaped.end());
358 if (unescaped.length() < tc.getText().length()) {
359 chunks[i] = TextChunk(unescaped);
367 void ParseTreePatternMatcher::InitializeInstanceFields() {