@@ -442,142 +442,194 @@ bool IsPureMath(Expr expr) {
442
442
return complex_nodes.empty ();
443
443
}
444
444
445
- Tokenizer::Tokenizer (const std::string &in) : input(in), pos(0 ) {}
446
-
447
- IndexToken Tokenizer::NextToken () {
448
- // skip whitespace
449
- while (pos < input.size () && std::isspace (input[pos])) {
450
- pos++;
451
- }
452
- // check if we reached the end of the input
453
- if (pos >= input.size ()) {
454
- return IndexToken (IndexToken::TokenType::kEnd );
455
- }
456
-
457
- char c = input[pos++];
445
+ /* !
446
+ * \brief Index Token in Tokenizer and Parser
447
+ */
448
+ struct IndexToken {
449
+ enum class TokenType {
450
+ kNumber ,
451
+ kVar ,
452
+ kPlus ,
453
+ kMinus ,
454
+ kMultiply ,
455
+ kDivide ,
456
+ kModulo ,
457
+ kLeftParen ,
458
+ kRightParen ,
459
+ kEnd
460
+ };
461
+
462
+ TokenType type;
463
+ std::string value;
464
+
465
+ explicit IndexToken (TokenType t, const std::string &v = " " )
466
+ : type(t), value(v) {}
467
+ };
468
+
469
+ /* !
470
+ * \brief Tokenizer for IndexExpr, split the input string into IndexToken.
471
+ */
472
+ class Tokenizer {
473
+ public:
474
+ explicit Tokenizer (const std::string &in) : input(in), pos(0 ) {}
475
+ // generate IndexToken for the next `pos`. it supports the following:
476
+ // 1. Number: 123, 1234...
477
+ // 2. Variable: a, b, a_1, aa, f1...
478
+ // 3. Operator: +, -, *, /, %, (, )
479
+ // 4. Whitespace
480
+ IndexToken NextToken () {
481
+ // skip whitespace
482
+ while (pos < input.size () && std::isspace (input[pos])) {
483
+ pos++;
484
+ }
485
+ // check if we reached the end of the input
486
+ if (pos >= input.size ()) {
487
+ return IndexToken (IndexToken::TokenType::kEnd );
488
+ }
489
+
490
+ char c = input[pos++];
491
+
492
+ // deal with number (0, 1, 11, 123...) not support float.
493
+ if (std::isdigit (c)) {
494
+ std::string num;
495
+ num += c;
496
+ while (pos < input.size () && std::isdigit (input[pos])) {
497
+ num += input[pos++];
498
+ }
499
+ return IndexToken (IndexToken::TokenType::kNumber , num);
500
+ }
458
501
459
- // deal with number (0, 1, 11, 123...) not support float.
460
- if (std::isdigit (c)) {
461
- std::string num;
462
- num += c;
463
- while (pos < input.size () && std::isdigit (input[pos])) {
464
- num += input[pos++];
502
+ // deal with variable name (a, b, a1, a123, a_1...).
503
+ if (std::isalpha (c) || input[pos] == ' _' ) {
504
+ std::string var;
505
+ var += c;
506
+ while (pos < input.size () &&
507
+ (std::isalnum (input[pos]) || input[pos] == ' _' )) {
508
+ var += input[pos++];
509
+ }
510
+ return IndexToken (IndexToken::TokenType::kVar , var);
511
+ }
512
+
513
+ // deal with operator {+, -, *, /, %, '(', ')'}.
514
+ switch (c) {
515
+ case ' +' :
516
+ return IndexToken (IndexToken::TokenType::kPlus );
517
+ case ' -' :
518
+ return IndexToken (IndexToken::TokenType::kMinus );
519
+ case ' *' :
520
+ return IndexToken (IndexToken::TokenType::kMultiply );
521
+ case ' /' :
522
+ return IndexToken (IndexToken::TokenType::kDivide );
523
+ case ' %' :
524
+ return IndexToken (IndexToken::TokenType::kModulo );
525
+ case ' (' :
526
+ return IndexToken (IndexToken::TokenType::kLeftParen );
527
+ case ' )' :
528
+ return IndexToken (IndexToken::TokenType::kRightParen );
529
+ default :
530
+ PADDLE_THROW (::common::errors::InvalidArgument (
531
+ " Tokenizer Unexpected character: %s" , c));
465
532
}
466
- return IndexToken (IndexToken::TokenType::kNumber , num);
467
533
}
468
534
469
- // deal with variable name (a, b, a1, a123, a_1...).
470
- if (std::isalpha (c) || input[pos] == ' _' ) {
471
- std::string var;
472
- var += c;
473
- while (pos < input.size () &&
474
- (std::isalnum (input[pos]) || input[pos] == ' _' )) {
475
- var += input[pos++];
535
+ private:
536
+ const std::string &input;
537
+ size_t pos;
538
+ };
539
+
540
+ /* !
541
+ * \brief Parser for IndexExpr, parse the input string into ir::Expr.
542
+ */
543
+ class Parser {
544
+ public:
545
+ explicit Parser (const std::string &input)
546
+ : tokenizer(input), currentToken(tokenizer.NextToken()) {}
547
+ ir::Expr Parse () { return ParseExpression (); }
548
+
549
+ private:
550
+ void Advance () { currentToken = tokenizer.NextToken (); }
551
+
552
+ // Processing addition and subtraction expressions, with the lowest priority.
553
+ ir::Expr ParseExpression () {
554
+ auto left = ParseTerm ();
555
+
556
+ while (currentToken.type == IndexToken::TokenType::kPlus ||
557
+ currentToken.type == IndexToken::TokenType::kMinus ) {
558
+ auto op = currentToken.type ;
559
+ Advance ();
560
+ auto right = ParseTerm ();
561
+
562
+ if (op == IndexToken::TokenType::kPlus ) {
563
+ left = ir::Add::Make (left, right);
564
+ } else {
565
+ left = ir::Sub::Make (left, right);
566
+ }
476
567
}
477
- return IndexToken (IndexToken::TokenType::kVar , var);
478
- }
479
568
480
- // deal with operator {+, -, *, /, %, '(', ')'}.
481
- switch (c) {
482
- case ' +' :
483
- return IndexToken (IndexToken::TokenType::kPlus );
484
- case ' -' :
485
- return IndexToken (IndexToken::TokenType::kMinus );
486
- case ' *' :
487
- return IndexToken (IndexToken::TokenType::kMultiply );
488
- case ' /' :
489
- return IndexToken (IndexToken::TokenType::kDivide );
490
- case ' %' :
491
- return IndexToken (IndexToken::TokenType::kModulo );
492
- case ' (' :
493
- return IndexToken (IndexToken::TokenType::kLeftParen );
494
- case ' )' :
495
- return IndexToken (IndexToken::TokenType::kRightParen );
496
- default :
497
- PADDLE_THROW (::common::errors::InvalidArgument (
498
- " Tokenizer Unexpected character: %s" , c));
569
+ return left;
499
570
}
500
- }
501
-
502
- Parser::Parser (const std::string &input)
503
- : tokenizer(input), currentToken(tokenizer.NextToken()) {}
504
-
505
- ir::Expr Parser::Parse () { return ParseExpression (); }
506
-
507
- void Parser::Advance () { currentToken = tokenizer.NextToken (); }
508
-
509
- ir::Expr Parser::ParseExpression () {
510
- auto left = ParseTerm ();
511
-
512
- while (currentToken.type == IndexToken::TokenType::kPlus ||
513
- currentToken.type == IndexToken::TokenType::kMinus ) {
514
- auto op = currentToken.type ;
515
- Advance ();
516
- auto right = ParseTerm ();
517
571
518
- if (op == IndexToken::TokenType::kPlus ) {
519
- left = ir::Add::Make (left, right);
520
- } else {
521
- left = ir::Sub::Make (left, right);
572
+ // Process multiplication, division and modulo expressions, with higher
573
+ // priority than addition and subtraction, and the parsing result appears as
574
+ // one Term. e.g. a * b + c, a * b is a Term.
575
+ ir::Expr ParseTerm () {
576
+ auto left = ParseFactor ();
577
+ while (currentToken.type == IndexToken::TokenType::kMultiply ||
578
+ currentToken.type == IndexToken::TokenType::kDivide ||
579
+ currentToken.type == IndexToken::TokenType::kModulo ) {
580
+ auto op = currentToken.type ;
581
+ Advance ();
582
+ auto right = ParseFactor ();
583
+
584
+ if (op == IndexToken::TokenType::kMultiply ) {
585
+ left = ir::Mul::Make (left, right);
586
+ } else if (op == IndexToken::TokenType::kDivide ) {
587
+ left = ir::Div::Make (left, right);
588
+ } else {
589
+ left = ir::Mod::Make (left, right);
590
+ }
522
591
}
523
- }
524
592
525
- return left;
526
- }
527
-
528
- ir::Expr Parser::ParseTerm () {
529
- auto left = ParseFactor ();
593
+ return left;
594
+ }
530
595
531
- while (currentToken.type == IndexToken::TokenType::kMultiply ||
532
- currentToken.type == IndexToken::TokenType::kDivide ||
533
- currentToken.type == IndexToken::TokenType::kModulo ) {
534
- auto op = currentToken.type ;
535
- Advance ();
536
- auto right = ParseFactor ();
596
+ // Process numeric, variables and brackets, with the highest priority, as
597
+ // parameters for each item.
598
+ ir::Expr ParseFactor () {
599
+ if (currentToken.type == IndexToken::TokenType::kNumber ) {
600
+ int value = std::stoi (currentToken.value );
601
+ Advance ();
602
+ return ir::Expr (value);
603
+ } else if (currentToken.type == IndexToken::TokenType::kVar ) {
604
+ auto var_name = currentToken.value ;
605
+ Advance ();
606
+ return GetOrCreateVar (var_name);
607
+ } else if (currentToken.type == IndexToken::TokenType::kLeftParen ) {
608
+ Advance ();
609
+ auto expr = ParseExpression ();
610
+
611
+ if (currentToken.type != IndexToken::TokenType::kRightParen ) {
612
+ PADDLE_THROW (::common::errors::InvalidArgument (
613
+ " Parser Expected ')', because of '(' in before." ));
614
+ }
537
615
538
- if (op == IndexToken::TokenType::kMultiply ) {
539
- left = ir::Mul::Make (left, right);
540
- } else if (op == IndexToken::TokenType::kDivide ) {
541
- left = ir::Div::Make (left, right);
616
+ Advance ();
617
+ return expr;
542
618
} else {
543
- left = ir::Mod::Make (left, right);
619
+ PADDLE_THROW (
620
+ ::common::errors::InvalidArgument (" Parser Unexpected IndexToken" ));
544
621
}
545
622
}
546
-
547
- return left;
548
- }
549
-
550
- ir::Expr Parser::ParseFactor () {
551
- if (currentToken.type == IndexToken::TokenType::kNumber ) {
552
- int value = std::stoi (currentToken.value );
553
- Advance ();
554
- return ir::Expr (value);
555
- } else if (currentToken.type == IndexToken::TokenType::kVar ) {
556
- auto var_name = currentToken.value ;
557
- Advance ();
558
- return GetOrCreateVar (var_name);
559
- } else if (currentToken.type == IndexToken::TokenType::kLeftParen ) {
560
- Advance ();
561
- auto expr = ParseExpression ();
562
-
563
- if (currentToken.type != IndexToken::TokenType::kRightParen ) {
564
- PADDLE_THROW (::common::errors::InvalidArgument (" Parser Expected ')'" ));
565
- }
566
-
567
- Advance ();
568
- return expr;
569
- } else {
570
- PADDLE_THROW (
571
- ::common::errors::InvalidArgument (" Parser Unexpected IndexToken" ));
572
- }
573
- }
574
-
575
- ir::Expr Parser::GetOrCreateVar (const std::string &var_name) {
576
- if (vars.find (var_name) == vars.end ()) {
577
- vars[var_name] = ir::Var (var_name);
623
+ ir::Expr GetOrCreateVar (const std::string &var_name) {
624
+ if (vars.find (var_name) == vars.end ()) {
625
+ vars[var_name] = ir::Var (var_name);
626
+ }
627
+ return vars[var_name];
578
628
}
579
- return vars[var_name];
580
- }
629
+ Tokenizer tokenizer;
630
+ IndexToken currentToken;
631
+ std::unordered_map<std::string, ir::Var> vars;
632
+ };
581
633
582
634
ir::Expr ParseExpressionFromString (const std::string &expr_str) {
583
635
Parser parser (expr_str);
0 commit comments