Skip to content

Commit 37ec377

Browse files
committed
fix comment
1 parent 336ba5c commit 37ec377

File tree

2 files changed

+172
-188
lines changed

2 files changed

+172
-188
lines changed

paddle/cinn/optim/simplify_util.cc

Lines changed: 172 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -442,142 +442,194 @@ bool IsPureMath(Expr expr) {
442442
return complex_nodes.empty();
443443
}
444444

445-
Tokenizer::Tokenizer(const std::string &in) : input(in), pos(0) {}
446-
447-
IndexToken Tokenizer::NextToken() {
448-
// skip whitespace
449-
while (pos < input.size() && std::isspace(input[pos])) {
450-
pos++;
451-
}
452-
// check if we reached the end of the input
453-
if (pos >= input.size()) {
454-
return IndexToken(IndexToken::TokenType::kEnd);
455-
}
456-
457-
char c = input[pos++];
445+
/*!
446+
* \brief Index Token in Tokenizer and Parser
447+
*/
448+
struct IndexToken {
449+
enum class TokenType {
450+
kNumber,
451+
kVar,
452+
kPlus,
453+
kMinus,
454+
kMultiply,
455+
kDivide,
456+
kModulo,
457+
kLeftParen,
458+
kRightParen,
459+
kEnd
460+
};
461+
462+
TokenType type;
463+
std::string value;
464+
465+
explicit IndexToken(TokenType t, const std::string &v = "")
466+
: type(t), value(v) {}
467+
};
468+
469+
/*!
470+
* \brief Tokenizer for IndexExpr, split the input string into IndexToken.
471+
*/
472+
class Tokenizer {
473+
public:
474+
explicit Tokenizer(const std::string &in) : input(in), pos(0) {}
475+
// generate IndexToken for the next `pos`. it supports the following:
476+
// 1. Number: 123, 1234...
477+
// 2. Variable: a, b, a_1, aa, f1...
478+
// 3. Operator: +, -, *, /, %, (, )
479+
// 4. Whitespace
480+
IndexToken NextToken() {
481+
// skip whitespace
482+
while (pos < input.size() && std::isspace(input[pos])) {
483+
pos++;
484+
}
485+
// check if we reached the end of the input
486+
if (pos >= input.size()) {
487+
return IndexToken(IndexToken::TokenType::kEnd);
488+
}
489+
490+
char c = input[pos++];
491+
492+
// deal with number (0, 1, 11, 123...) not support float.
493+
if (std::isdigit(c)) {
494+
std::string num;
495+
num += c;
496+
while (pos < input.size() && std::isdigit(input[pos])) {
497+
num += input[pos++];
498+
}
499+
return IndexToken(IndexToken::TokenType::kNumber, num);
500+
}
458501

459-
// deal with number (0, 1, 11, 123...) not support float.
460-
if (std::isdigit(c)) {
461-
std::string num;
462-
num += c;
463-
while (pos < input.size() && std::isdigit(input[pos])) {
464-
num += input[pos++];
502+
// deal with variable name (a, b, a1, a123, a_1...).
503+
if (std::isalpha(c) || input[pos] == '_') {
504+
std::string var;
505+
var += c;
506+
while (pos < input.size() &&
507+
(std::isalnum(input[pos]) || input[pos] == '_')) {
508+
var += input[pos++];
509+
}
510+
return IndexToken(IndexToken::TokenType::kVar, var);
511+
}
512+
513+
// deal with operator {+, -, *, /, %, '(', ')'}.
514+
switch (c) {
515+
case '+':
516+
return IndexToken(IndexToken::TokenType::kPlus);
517+
case '-':
518+
return IndexToken(IndexToken::TokenType::kMinus);
519+
case '*':
520+
return IndexToken(IndexToken::TokenType::kMultiply);
521+
case '/':
522+
return IndexToken(IndexToken::TokenType::kDivide);
523+
case '%':
524+
return IndexToken(IndexToken::TokenType::kModulo);
525+
case '(':
526+
return IndexToken(IndexToken::TokenType::kLeftParen);
527+
case ')':
528+
return IndexToken(IndexToken::TokenType::kRightParen);
529+
default:
530+
PADDLE_THROW(::common::errors::InvalidArgument(
531+
"Tokenizer Unexpected character: %s", c));
465532
}
466-
return IndexToken(IndexToken::TokenType::kNumber, num);
467533
}
468534

469-
// deal with variable name (a, b, a1, a123, a_1...).
470-
if (std::isalpha(c) || input[pos] == '_') {
471-
std::string var;
472-
var += c;
473-
while (pos < input.size() &&
474-
(std::isalnum(input[pos]) || input[pos] == '_')) {
475-
var += input[pos++];
535+
private:
536+
const std::string &input;
537+
size_t pos;
538+
};
539+
540+
/*!
541+
* \brief Parser for IndexExpr, parse the input string into ir::Expr.
542+
*/
543+
class Parser {
544+
public:
545+
explicit Parser(const std::string &input)
546+
: tokenizer(input), currentToken(tokenizer.NextToken()) {}
547+
ir::Expr Parse() { return ParseExpression(); }
548+
549+
private:
550+
void Advance() { currentToken = tokenizer.NextToken(); }
551+
552+
// Processing addition and subtraction expressions, with the lowest priority.
553+
ir::Expr ParseExpression() {
554+
auto left = ParseTerm();
555+
556+
while (currentToken.type == IndexToken::TokenType::kPlus ||
557+
currentToken.type == IndexToken::TokenType::kMinus) {
558+
auto op = currentToken.type;
559+
Advance();
560+
auto right = ParseTerm();
561+
562+
if (op == IndexToken::TokenType::kPlus) {
563+
left = ir::Add::Make(left, right);
564+
} else {
565+
left = ir::Sub::Make(left, right);
566+
}
476567
}
477-
return IndexToken(IndexToken::TokenType::kVar, var);
478-
}
479568

480-
// deal with operator {+, -, *, /, %, '(', ')'}.
481-
switch (c) {
482-
case '+':
483-
return IndexToken(IndexToken::TokenType::kPlus);
484-
case '-':
485-
return IndexToken(IndexToken::TokenType::kMinus);
486-
case '*':
487-
return IndexToken(IndexToken::TokenType::kMultiply);
488-
case '/':
489-
return IndexToken(IndexToken::TokenType::kDivide);
490-
case '%':
491-
return IndexToken(IndexToken::TokenType::kModulo);
492-
case '(':
493-
return IndexToken(IndexToken::TokenType::kLeftParen);
494-
case ')':
495-
return IndexToken(IndexToken::TokenType::kRightParen);
496-
default:
497-
PADDLE_THROW(::common::errors::InvalidArgument(
498-
"Tokenizer Unexpected character: %s", c));
569+
return left;
499570
}
500-
}
501-
502-
Parser::Parser(const std::string &input)
503-
: tokenizer(input), currentToken(tokenizer.NextToken()) {}
504-
505-
ir::Expr Parser::Parse() { return ParseExpression(); }
506-
507-
void Parser::Advance() { currentToken = tokenizer.NextToken(); }
508-
509-
ir::Expr Parser::ParseExpression() {
510-
auto left = ParseTerm();
511-
512-
while (currentToken.type == IndexToken::TokenType::kPlus ||
513-
currentToken.type == IndexToken::TokenType::kMinus) {
514-
auto op = currentToken.type;
515-
Advance();
516-
auto right = ParseTerm();
517571

518-
if (op == IndexToken::TokenType::kPlus) {
519-
left = ir::Add::Make(left, right);
520-
} else {
521-
left = ir::Sub::Make(left, right);
572+
// Process multiplication, division and modulo expressions, with higher
573+
// priority than addition and subtraction, and the parsing result appears as
574+
// one Term. e.g. a * b + c, a * b is a Term.
575+
ir::Expr ParseTerm() {
576+
auto left = ParseFactor();
577+
while (currentToken.type == IndexToken::TokenType::kMultiply ||
578+
currentToken.type == IndexToken::TokenType::kDivide ||
579+
currentToken.type == IndexToken::TokenType::kModulo) {
580+
auto op = currentToken.type;
581+
Advance();
582+
auto right = ParseFactor();
583+
584+
if (op == IndexToken::TokenType::kMultiply) {
585+
left = ir::Mul::Make(left, right);
586+
} else if (op == IndexToken::TokenType::kDivide) {
587+
left = ir::Div::Make(left, right);
588+
} else {
589+
left = ir::Mod::Make(left, right);
590+
}
522591
}
523-
}
524592

525-
return left;
526-
}
527-
528-
ir::Expr Parser::ParseTerm() {
529-
auto left = ParseFactor();
593+
return left;
594+
}
530595

531-
while (currentToken.type == IndexToken::TokenType::kMultiply ||
532-
currentToken.type == IndexToken::TokenType::kDivide ||
533-
currentToken.type == IndexToken::TokenType::kModulo) {
534-
auto op = currentToken.type;
535-
Advance();
536-
auto right = ParseFactor();
596+
// Process numeric, variables and brackets, with the highest priority, as
597+
// parameters for each item.
598+
ir::Expr ParseFactor() {
599+
if (currentToken.type == IndexToken::TokenType::kNumber) {
600+
int value = std::stoi(currentToken.value);
601+
Advance();
602+
return ir::Expr(value);
603+
} else if (currentToken.type == IndexToken::TokenType::kVar) {
604+
auto var_name = currentToken.value;
605+
Advance();
606+
return GetOrCreateVar(var_name);
607+
} else if (currentToken.type == IndexToken::TokenType::kLeftParen) {
608+
Advance();
609+
auto expr = ParseExpression();
610+
611+
if (currentToken.type != IndexToken::TokenType::kRightParen) {
612+
PADDLE_THROW(::common::errors::InvalidArgument(
613+
"Parser Expected ')', because of '(' in before."));
614+
}
537615

538-
if (op == IndexToken::TokenType::kMultiply) {
539-
left = ir::Mul::Make(left, right);
540-
} else if (op == IndexToken::TokenType::kDivide) {
541-
left = ir::Div::Make(left, right);
616+
Advance();
617+
return expr;
542618
} else {
543-
left = ir::Mod::Make(left, right);
619+
PADDLE_THROW(
620+
::common::errors::InvalidArgument("Parser Unexpected IndexToken"));
544621
}
545622
}
546-
547-
return left;
548-
}
549-
550-
ir::Expr Parser::ParseFactor() {
551-
if (currentToken.type == IndexToken::TokenType::kNumber) {
552-
int value = std::stoi(currentToken.value);
553-
Advance();
554-
return ir::Expr(value);
555-
} else if (currentToken.type == IndexToken::TokenType::kVar) {
556-
auto var_name = currentToken.value;
557-
Advance();
558-
return GetOrCreateVar(var_name);
559-
} else if (currentToken.type == IndexToken::TokenType::kLeftParen) {
560-
Advance();
561-
auto expr = ParseExpression();
562-
563-
if (currentToken.type != IndexToken::TokenType::kRightParen) {
564-
PADDLE_THROW(::common::errors::InvalidArgument("Parser Expected ')'"));
565-
}
566-
567-
Advance();
568-
return expr;
569-
} else {
570-
PADDLE_THROW(
571-
::common::errors::InvalidArgument("Parser Unexpected IndexToken"));
572-
}
573-
}
574-
575-
ir::Expr Parser::GetOrCreateVar(const std::string &var_name) {
576-
if (vars.find(var_name) == vars.end()) {
577-
vars[var_name] = ir::Var(var_name);
623+
ir::Expr GetOrCreateVar(const std::string &var_name) {
624+
if (vars.find(var_name) == vars.end()) {
625+
vars[var_name] = ir::Var(var_name);
626+
}
627+
return vars[var_name];
578628
}
579-
return vars[var_name];
580-
}
629+
Tokenizer tokenizer;
630+
IndexToken currentToken;
631+
std::unordered_map<std::string, ir::Var> vars;
632+
};
581633

582634
ir::Expr ParseExpressionFromString(const std::string &expr_str) {
583635
Parser parser(expr_str);

paddle/cinn/optim/simplify_util.h

Lines changed: 0 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -280,74 +280,6 @@ bool CheckPattern(const ir::IndexExpr &expr,
280280
// placement of tool functions that are still in use, remove it in the future.
281281
bool IsPureMath(Expr expr);
282282

283-
/*!
284-
* \brief Index Token in Tokenizer and Parser
285-
*/
286-
struct IndexToken {
287-
enum class TokenType {
288-
kNumber,
289-
kVar,
290-
kPlus,
291-
kMinus,
292-
kMultiply,
293-
kDivide,
294-
kModulo,
295-
kLeftParen,
296-
kRightParen,
297-
kEnd
298-
};
299-
300-
TokenType type;
301-
std::string value;
302-
303-
explicit IndexToken(TokenType t, const std::string &v = "")
304-
: type(t), value(v) {}
305-
};
306-
307-
/*!
308-
* \brief Tokenizer for IndexExpr, split the input string into IndexToken.
309-
*/
310-
class Tokenizer {
311-
public:
312-
explicit Tokenizer(const std::string &in);
313-
// generate IndexToken for the next `pos`. it supports the following:
314-
// 1. Number: 123, 1234...
315-
// 2. Variable: a, b, a_1, aa, f1...
316-
// 3. Operator: +, -, *, /, %, (, )
317-
// 4. Whitespace
318-
IndexToken NextToken();
319-
320-
private:
321-
const std::string &input;
322-
size_t pos;
323-
};
324-
325-
/*!
326-
* \brief Parser for IndexExpr, parse the input string into ir::Expr.
327-
*/
328-
class Parser {
329-
public:
330-
explicit Parser(const std::string &input);
331-
ir::Expr Parse();
332-
333-
private:
334-
Tokenizer tokenizer;
335-
IndexToken currentToken;
336-
std::unordered_map<std::string, ir::Var> vars;
337-
338-
void Advance();
339-
// Processing addition and subtraction expressions, with the lowest priority.
340-
ir::Expr ParseExpression();
341-
// Process multiplication, division and modulo expressions, with higher
342-
// priority than addition and subtraction, and the parsing result appears as
343-
// one Term. e.g. a * b + c, a * b is a Term.
344-
ir::Expr ParseTerm();
345-
// Process numeric, variables and brackets, with the highest priority, as
346-
// parameters for each item.
347-
ir::Expr ParseFactor();
348-
ir::Expr GetOrCreateVar(const std::string &var_name);
349-
};
350-
351283
/*!
352284
* \brief Parse the expression from string to Expr.
353285
* \param expr_str The expression to be checked.

0 commit comments

Comments
 (0)