From b5d841f10a3c94a2b505476d2a84af58c817df6c Mon Sep 17 00:00:00 2001 From: wszhdshys <1925792291@qq.com> Date: Mon, 21 Jul 2025 13:30:37 +0800 Subject: [PATCH 1/7] feat: supported related subquery by storing the parent to the subquery. --- src/binder/create_table.rs | 1 + src/binder/expr.rs | 8 ++ src/binder/mod.rs | 4 + src/binder/select.rs | 45 +++++++++-- src/db.rs | 1 + src/optimizer/core/memo.rs | 1 + tests/slt/correlated_subquery.slt | 119 ++++++++++++++++++++++++++++++ 7 files changed, 171 insertions(+), 8 deletions(-) create mode 100644 tests/slt/correlated_subquery.slt diff --git a/src/binder/create_table.rs b/src/binder/create_table.rs index ec0971c0..7b359479 100644 --- a/src/binder/create_table.rs +++ b/src/binder/create_table.rs @@ -189,6 +189,7 @@ mod tests { &scala_functions, &table_functions, Arc::new(AtomicUsize::new(0)), + vec![] ), &[], None, diff --git a/src/binder/expr.rs b/src/binder/expr.rs index 890d6d62..3a45ed54 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -279,6 +279,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T scala_functions, table_functions, temp_table_id, + parent_name, .. } = &self.context; let mut binder = Binder::new( @@ -289,10 +290,17 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T scala_functions, table_functions, temp_table_id.clone(), + parent_name.clone(), ), self.args, Some(self), ); + + self.context.bind_table.iter().find(|((t, _, _), _)| { + binder.context.parent_name.push(t.as_str().to_string()); + true + }); + let mut sub_query = binder.bind_query(subquery)?; let sub_query_schema = sub_query.output_schema(); diff --git a/src/binder/mod.rs b/src/binder/mod.rs index e70c4d2d..09e15407 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -114,6 +114,7 @@ pub struct BinderContext<'a, T: Transaction> { using: HashSet, bind_step: QueryBindStep, + parent_name: Vec, sub_queries: HashMap>, temp_table_id: Arc, @@ -171,6 +172,7 @@ impl<'a, T: Transaction> BinderContext<'a, T> { scala_functions: &'a ScalaFunctions, table_functions: &'a TableFunctions, temp_table_id: Arc, + parent_name: Vec, ) -> Self { BinderContext { scala_functions, @@ -185,6 +187,7 @@ impl<'a, T: Transaction> BinderContext<'a, T> { agg_calls: Default::default(), using: Default::default(), bind_step: QueryBindStep::From, + parent_name, sub_queries: Default::default(), temp_table_id, allow_default: false, @@ -550,6 +553,7 @@ pub mod test { &scala_functions, &table_functions, Arc::new(AtomicUsize::new(0)), + vec![], ), &[], None, diff --git a/src/binder/select.rs b/src/binder/select.rs index 9c437f94..e6a8eb8c 100644 --- a/src/binder/select.rs +++ b/src/binder/select.rs @@ -35,11 +35,7 @@ use crate::types::tuple::{Schema, SchemaRef}; use crate::types::value::Utf8Type; use crate::types::{ColumnId, LogicalType}; use itertools::Itertools; -use sqlparser::ast::{ - CharLengthUnits, Distinct, Expr, Ident, Join, JoinConstraint, JoinOperator, Offset, - OrderByExpr, Query, Select, SelectInto, SelectItem, SetExpr, SetOperator, SetQuantifier, - TableAlias, TableFactor, TableWithJoins, -}; +use sqlparser::ast::{CharLengthUnits, Distinct, Expr, Ident, Join, JoinConstraint, JoinOperator, ObjectName, Offset, OrderByExpr, Query, Select, SelectInto, SelectItem, SetExpr, SetOperator, SetQuantifier, TableAlias, TableFactor, TableWithJoins}; impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, 'b, T, A> { pub(crate) fn bind_query(&mut self, query: &Query) -> Result { @@ -559,6 +555,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' scala_functions, table_functions, temp_table_id, + parent_name, .. } = &self.context; let mut binder = Binder::new( @@ -569,6 +566,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' scala_functions, table_functions, temp_table_id.clone(), + parent_name.clone(), ), self.args, Some(self), @@ -592,8 +590,39 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' predicate: &Expr, ) -> Result { self.context.step(QueryBindStep::Where); + let mut need_parent = true; + let mut predicate_result = ScalarExpression::Empty; - let predicate = self.bind_expr(predicate)?; + while need_parent { + predicate_result = match self.bind_expr(predicate) { + Ok(expr) => { + need_parent = false; + expr + }, + Err(DatabaseError::InvalidTable(table)) => { + if !self.context.parent_name.contains(&table) { + return Err(DatabaseError::InvalidTable(table)) + } + let plan = self.bind_table_ref(&TableWithJoins{ + relation: TableFactor::Table { + name: ObjectName(vec![Ident::new(table)]), + alias: None, + args: None, + with_hints: vec![], + }, + joins: vec![], + })?; + children = LJoinOperator::build( + children, + plan, + JoinCondition::None, + JoinType::Full, + ); + continue; + }, + Err(e) => return Err(e) + }; + } if let Some(sub_queries) = self.context.sub_queries_at_now() { for sub_query in sub_queries { @@ -664,7 +693,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' }; Self::extract_join_keys( - predicate.clone(), + predicate_result.clone(), &mut on_keys, &mut filter, children.output_schema(), @@ -694,7 +723,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' } return Ok(children); } - Ok(FilterOperator::build(predicate, children, false)) + Ok(FilterOperator::build(predicate_result, children, false)) } fn bind_having( diff --git a/src/db.rs b/src/db.rs index ff08a4fa..e104e6a1 100644 --- a/src/db.rs +++ b/src/db.rs @@ -150,6 +150,7 @@ impl State { scala_functions, table_functions, Arc::new(AtomicUsize::new(0)), + vec![], ), ¶ms, None, diff --git a/src/optimizer/core/memo.rs b/src/optimizer/core/memo.rs index 5324b1e5..e4df42f4 100644 --- a/src/optimizer/core/memo.rs +++ b/src/optimizer/core/memo.rs @@ -138,6 +138,7 @@ mod tests { &scala_functions, &table_functions, Arc::new(AtomicUsize::new(0)), + vec![] ), &[], None, diff --git a/tests/slt/correlated_subquery.slt b/tests/slt/correlated_subquery.slt new file mode 100644 index 00000000..90331034 --- /dev/null +++ b/tests/slt/correlated_subquery.slt @@ -0,0 +1,119 @@ +statement ok +CREATE TABLE t1 (id INT PRIMARY KEY, v1 VARCHAR(50), v2 INT); + +statement ok +CREATE TABLE t2 (id INT PRIMARY KEY, v1 VARCHAR(50), v2 INT); + +statement ok +CREATE TABLE t3 (id INT PRIMARY KEY, v1 INT, v2 INT); + +statement ok +insert into t1(id, v1, v2) values (1,'a',9) + +statement ok +insert into t1(id, v1, v2) values (2,'b',6) + +statement ok +insert into t1(id, v1, v2) values (3,'c',11) + +statement ok +insert into t2(id, v1, v2) values (1,'A',10) + +statement ok +insert into t2(id, v1, v2) values (2,'B',11) + +statement ok +insert into t2(id, v1, v2) values (3,'C',9) + +statement ok +insert into t3(id, v1, v2) values (1,6,10) + +statement ok +insert into t3(id, v1, v2) values (2,5,10) + +statement ok +insert into t3(id, v1, v2) values (3,4,10) + +query IT rowsort +SELECT id, v1 FROM t1 WHERE id IN ( SELECT t2.id FROM t2 WHERE t2.v2 < t1.v2 ) +---- +1 a +3 c + +query I rowsort +SELECT v1 FROM t1 WHERE EXISTS ( SELECT 1 FROM t3 WHERE t3.id = t1.id ) +---- +a +b +c + +query TT rowsort +SELECT t1.v1, t2.v1 FROM t1 JOIN t2 ON t1.id = t2.id WHERE t2.v2 > ( SELECT AVG(v2) FROM t1 ) +---- +a A +b B +c C + +query IT rowsort +SELECT id, v1 FROM t1 WHERE NOT EXISTS ( SELECT 1 FROM t2 WHERE t2.id = t1.id AND t2.v2 = t1.v2 ) +---- +1 a +2 b +3 c + +query IT rowsort +SELECT id, v1 FROM t1 WHERE v2 > ( SELECT MIN(v2) FROM t2 WHERE t2.id = t1.id ) +---- +3 c + +query IT rowsort +SELECT id, v1 FROM t1 WHERE EXISTS ( SELECT 1 FROM t2 WHERE t2.id = t1.id AND EXISTS ( SELECT 1 FROM t3 WHERE t3.id = t1.id ) ) +---- +1 a +2 b +3 c + +# query ITI rowsort +# SELECT id, v1, ( SELECT COUNT(*) FROM t2 WHERE t2.v2 >= 10 ) as cnt FROM t1 +# ---- +# 1 a 2 +# 2 b 3 +# 3 c 2 + +query IT rowsort +SELECT id, v1 FROM t1 WHERE v2 - 5 > ( SELECT AVG(v1) FROM t3 WHERE t3.id <= t1.id ) +---- +3 c + +query IT rowsort +SELECT id, v1 FROM t1 WHERE id NOT IN ( SELECT t2.id FROM t2 WHERE t2.v2 > t1.v2 ) +---- + + +query IT rowsort +SELECT id, v1 FROM t1 WHERE EXISTS ( SELECT 1 FROM t2 WHERE t2.id = t1.id AND t2.v2 + t1.v2 > 15 ) +---- +1 a +2 b +3 c + +query IT rowsort +SELECT id, v1 FROM t1 WHERE v2 = ( SELECT MAX(v2) FROM t2 WHERE t2.id <= t1.id ) ORDER BY id +---- +3 c + +query IT rowsort +SELECT id, v1 FROM t1 WHERE ( SELECT COUNT(*) FROM t2 WHERE t2.v2 = t1.v2 ) = 2 +---- +1 a +2 b +3 c + +statement ok +DROP TABLE t1; + +statement ok +DROP TABLE t2; + +statement ok +DROP TABLE t3; \ No newline at end of file From b79fc7fd7dd4e9e0d3239ac5d45bb0b1d275c506 Mon Sep 17 00:00:00 2001 From: wszhdshys <1925792291@qq.com> Date: Mon, 21 Jul 2025 13:32:25 +0800 Subject: [PATCH 2/7] fmt --- src/binder/create_table.rs | 2 +- src/binder/select.rs | 24 ++++++++++++------------ src/db.rs | 2 +- src/optimizer/core/memo.rs | 2 +- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/binder/create_table.rs b/src/binder/create_table.rs index 7b359479..7c6dae15 100644 --- a/src/binder/create_table.rs +++ b/src/binder/create_table.rs @@ -189,7 +189,7 @@ mod tests { &scala_functions, &table_functions, Arc::new(AtomicUsize::new(0)), - vec![] + vec![], ), &[], None, diff --git a/src/binder/select.rs b/src/binder/select.rs index e6a8eb8c..0a355bba 100644 --- a/src/binder/select.rs +++ b/src/binder/select.rs @@ -35,7 +35,11 @@ use crate::types::tuple::{Schema, SchemaRef}; use crate::types::value::Utf8Type; use crate::types::{ColumnId, LogicalType}; use itertools::Itertools; -use sqlparser::ast::{CharLengthUnits, Distinct, Expr, Ident, Join, JoinConstraint, JoinOperator, ObjectName, Offset, OrderByExpr, Query, Select, SelectInto, SelectItem, SetExpr, SetOperator, SetQuantifier, TableAlias, TableFactor, TableWithJoins}; +use sqlparser::ast::{ + CharLengthUnits, Distinct, Expr, Ident, Join, JoinConstraint, JoinOperator, ObjectName, Offset, + OrderByExpr, Query, Select, SelectInto, SelectItem, SetExpr, SetOperator, SetQuantifier, + TableAlias, TableFactor, TableWithJoins, +}; impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, 'b, T, A> { pub(crate) fn bind_query(&mut self, query: &Query) -> Result { @@ -598,12 +602,12 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' Ok(expr) => { need_parent = false; expr - }, + } Err(DatabaseError::InvalidTable(table)) => { if !self.context.parent_name.contains(&table) { - return Err(DatabaseError::InvalidTable(table)) + return Err(DatabaseError::InvalidTable(table)); } - let plan = self.bind_table_ref(&TableWithJoins{ + let plan = self.bind_table_ref(&TableWithJoins { relation: TableFactor::Table { name: ObjectName(vec![Ident::new(table)]), alias: None, @@ -612,15 +616,11 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' }, joins: vec![], })?; - children = LJoinOperator::build( - children, - plan, - JoinCondition::None, - JoinType::Full, - ); + children = + LJoinOperator::build(children, plan, JoinCondition::None, JoinType::Full); continue; - }, - Err(e) => return Err(e) + } + Err(e) => return Err(e), }; } diff --git a/src/db.rs b/src/db.rs index e104e6a1..57d02bc9 100644 --- a/src/db.rs +++ b/src/db.rs @@ -150,7 +150,7 @@ impl State { scala_functions, table_functions, Arc::new(AtomicUsize::new(0)), - vec![], + vec![], ), ¶ms, None, diff --git a/src/optimizer/core/memo.rs b/src/optimizer/core/memo.rs index e4df42f4..4b50f914 100644 --- a/src/optimizer/core/memo.rs +++ b/src/optimizer/core/memo.rs @@ -138,7 +138,7 @@ mod tests { &scala_functions, &table_functions, Arc::new(AtomicUsize::new(0)), - vec![] + vec![], ), &[], None, From 23e79016f64ea3b84a103273e548000e1a750c54 Mon Sep 17 00:00:00 2001 From: wszhdshys <1925792291@qq.com> Date: Tue, 22 Jul 2025 00:46:41 +0800 Subject: [PATCH 3/7] Changed the binding method of table in parent, which should support more cases --- src/binder/create_index.rs | 2 +- src/binder/create_table.rs | 1 - src/binder/create_view.rs | 3 +- src/binder/expr.rs | 50 ++-- src/binder/insert.rs | 2 +- src/binder/mod.rs | 23 +- src/binder/select.rs | 223 ++++++++++++++---- src/binder/update.rs | 2 +- src/db.rs | 1 - src/execution/ddl/create_index.rs | 2 +- src/execution/dql/aggregate/hash_agg.rs | 4 +- src/execution/dql/join/hash_join.rs | 4 +- src/execution/dql/join/nested_loop_join.rs | 18 +- src/expression/evaluator.rs | 2 +- src/expression/mod.rs | 58 +++-- src/expression/range_detacher.rs | 4 +- src/expression/simplify.rs | 10 +- src/expression/visitor.rs | 2 +- src/expression/visitor_mut.rs | 2 +- src/optimizer/core/memo.rs | 1 - .../rule/normalization/simplification.rs | 12 +- src/planner/operator/mod.rs | 6 +- src/types/index.rs | 2 +- 23 files changed, 302 insertions(+), 132 deletions(-) diff --git a/src/binder/create_index.rs b/src/binder/create_index.rs index 4b5332ba..2af5fde2 100644 --- a/src/binder/create_index.rs +++ b/src/binder/create_index.rs @@ -43,7 +43,7 @@ impl> Binder<'_, '_, T, A> for expr in exprs { // TODO: Expression Index match self.bind_expr(&expr.expr)? { - ScalarExpression::ColumnRef(column) => columns.push(column), + ScalarExpression::ColumnRef(column, false) => columns.push(column), expr => { return Err(DatabaseError::UnsupportedStmt(format!( "'CREATE INDEX' by {}", diff --git a/src/binder/create_table.rs b/src/binder/create_table.rs index 7c6dae15..ec0971c0 100644 --- a/src/binder/create_table.rs +++ b/src/binder/create_table.rs @@ -189,7 +189,6 @@ mod tests { &scala_functions, &table_functions, Arc::new(AtomicUsize::new(0)), - vec![], ), &[], None, diff --git a/src/binder/create_view.rs b/src/binder/create_view.rs index 2d99b0bd..aba140c7 100644 --- a/src/binder/create_view.rs +++ b/src/binder/create_view.rs @@ -39,9 +39,10 @@ impl> Binder<'_, '_, T, A> column.set_ref_table(view_name.clone(), Ulid::new(), true); ScalarExpression::Alias { - expr: Box::new(ScalarExpression::ColumnRef(mapping_column.clone())), + expr: Box::new(ScalarExpression::ColumnRef(mapping_column.clone(), false)), alias: AliasType::Expr(Box::new(ScalarExpression::ColumnRef( ColumnRef::from(column), + false, ))), } }) diff --git a/src/binder/expr.rs b/src/binder/expr.rs index 3a45ed54..2ffba817 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -11,7 +11,7 @@ use std::collections::HashMap; use std::slice; use std::sync::Arc; -use super::{lower_ident, Binder, BinderContext, QueryBindStep, SubQueryType}; +use super::{lower_ident, Binder, BinderContext, QueryBindStep, Source, SubQueryType}; use crate::expression::function::scala::{ArcScalarFunctionImpl, ScalarFunction}; use crate::expression::function::table::{ArcTableFunctionImpl, TableFunction}; use crate::expression::function::FunctionSummary; @@ -259,9 +259,10 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T let alias_expr = ScalarExpression::Alias { expr: Box::new(expr), - alias: AliasType::Expr(Box::new(ScalarExpression::ColumnRef(ColumnRef::from( - alias_column, - )))), + alias: AliasType::Expr(Box::new(ScalarExpression::ColumnRef( + ColumnRef::from(alias_column), + false, + ))), }; let alias_plan = self.bind_project(sub_query, vec![alias_expr.clone()])?; Ok((alias_expr, alias_plan)) @@ -279,7 +280,6 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T scala_functions, table_functions, temp_table_id, - parent_name, .. } = &self.context; let mut binder = Binder::new( @@ -290,17 +290,11 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T scala_functions, table_functions, temp_table_id.clone(), - parent_name.clone(), ), self.args, Some(self), ); - self.context.bind_table.iter().find(|((t, _, _), _)| { - binder.context.parent_name.push(t.as_str().to_string()); - true - }); - let mut sub_query = binder.bind_query(subquery)?; let sub_query_schema = sub_query.output_schema(); @@ -319,13 +313,13 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T let columns = sub_query_schema .iter() - .map(|column| ScalarExpression::ColumnRef(column.clone())) + .map(|column| ScalarExpression::ColumnRef(column.clone(), false)) .collect::>(); ScalarExpression::Tuple(columns) } else { fn_check(1)?; - ScalarExpression::ColumnRef(sub_query_schema[0].clone()) + ScalarExpression::ColumnRef(sub_query_schema[0].clone(), false) }; Ok((sub_query, expr)) } @@ -376,13 +370,39 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T try_default!(&full_name.0, full_name.1); } if let Some(table) = full_name.0.or(bind_table_name) { - let source = self.context.bind_source(&table)?; + let mut source: &Source; + let mut from_parent: bool; + let (mut parent, mut parent_context) = if let Some(parent) = self.parent { + (Some(parent), Some(&parent.context)) + } else { + (None, None) + }; + + loop { + (source, from_parent) = match self.context.bind_source(parent_context, &table) { + (Ok(source), from_parent) => (source, from_parent), + (Err(e), _) => { + if let Some(p) = parent { + (parent, parent_context) = match p.parent { + Some(parent) => (Some(parent), Some(&parent.context)), + None => return Err(e), + } + } else { + return Err(e); + } + continue; + } + }; + break; + } + let schema_buf = self.table_schema_buf.entry(Arc::new(table)).or_default(); Ok(ScalarExpression::ColumnRef( source .column(&full_name.1, schema_buf) .ok_or_else(|| DatabaseError::ColumnNotFound(full_name.1.to_string()))?, + from_parent, )) } else { let op = @@ -411,7 +431,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T table_schema_buf.entry(table_name.clone()).or_default(); source.column(&full_name.1, schema_buf) } { - *got_column = Some(ScalarExpression::ColumnRef(column)); + *got_column = Some(ScalarExpression::ColumnRef(column, false)); } } }; diff --git a/src/binder/insert.rs b/src/binder/insert.rs index 4c8e78c1..c8f36843 100644 --- a/src/binder/insert.rs +++ b/src/binder/insert.rs @@ -51,7 +51,7 @@ impl> Binder<'_, '_, T, A> slice::from_ref(ident), Some(table_name.to_string()), )? { - ScalarExpression::ColumnRef(catalog) => columns.push(catalog), + ScalarExpression::ColumnRef(catalog, _) => columns.push(catalog), _ => return Err(DatabaseError::UnsupportedStmt(ident.to_string())), } } diff --git a/src/binder/mod.rs b/src/binder/mod.rs index 09e15407..b85778b2 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -114,7 +114,6 @@ pub struct BinderContext<'a, T: Transaction> { using: HashSet, bind_step: QueryBindStep, - parent_name: Vec, sub_queries: HashMap>, temp_table_id: Arc, @@ -172,7 +171,6 @@ impl<'a, T: Transaction> BinderContext<'a, T> { scala_functions: &'a ScalaFunctions, table_functions: &'a TableFunctions, temp_table_id: Arc, - parent_name: Vec, ) -> Self { BinderContext { scala_functions, @@ -187,7 +185,6 @@ impl<'a, T: Transaction> BinderContext<'a, T> { agg_calls: Default::default(), using: Default::default(), bind_step: QueryBindStep::From, - parent_name, sub_queries: Default::default(), temp_table_id, allow_default: false, @@ -278,14 +275,27 @@ impl<'a, T: Transaction> BinderContext<'a, T> { Ok(source) } - pub fn bind_source<'b: 'a>(&self, table_name: &str) -> Result<&Source, DatabaseError> { + pub fn bind_source<'b: 'a>( + &self, + parent: Option<&'a BinderContext<'_, T>>, + table_name: &str, + ) -> (Result<&'b Source, DatabaseError>, bool) { if let Some(source) = self.bind_table.iter().find(|((t, alias, _), _)| { t.as_str() == table_name || matches!(alias.as_ref().map(|a| a.as_str() == table_name), Some(true)) }) { - Ok(source.1) + (Ok(source.1), false) + } else if let Some(context) = parent { + if let Some(source) = context.bind_table.iter().find(|((t, alias, _), _)| { + t.as_str() == table_name + || matches!(alias.as_ref().map(|a| a.as_str() == table_name), Some(true)) + }) { + (Ok(source.1), true) + } else { + (Err(DatabaseError::InvalidTable(table_name.into())), false) + } } else { - Err(DatabaseError::InvalidTable(table_name.into())) + (Err(DatabaseError::InvalidTable(table_name.into())), false) } } @@ -553,7 +563,6 @@ pub mod test { &scala_functions, &table_functions, Arc::new(AtomicUsize::new(0)), - vec![], ), &[], None, diff --git a/src/binder/select.rs b/src/binder/select.rs index 0a355bba..f9a5bd55 100644 --- a/src/binder/select.rs +++ b/src/binder/select.rs @@ -101,6 +101,10 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' }; let mut select_list = self.normalize_select_item(&select.projection, &plan)?; + for expr in &select_list { + plan = self.bind_parent(plan, expr)?; + } + if let Some(predicate) = &select.selection { plan = self.bind_where(plan, predicate)?; } @@ -216,7 +220,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' let distinct_exprs = left_schema .iter() .cloned() - .map(ScalarExpression::ColumnRef) + .map(|col| ScalarExpression::ColumnRef(col, false)) .collect_vec(); Ok(self.bind_distinct( @@ -358,10 +362,11 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' ); let alias_column_expr = ScalarExpression::Alias { - expr: Box::new(ScalarExpression::ColumnRef(column)), - alias: AliasType::Expr(Box::new(ScalarExpression::ColumnRef(ColumnRef::from( - alias_column, - )))), + expr: Box::new(ScalarExpression::ColumnRef(column, false)), + alias: AliasType::Expr(Box::new(ScalarExpression::ColumnRef( + ColumnRef::from(alias_column), + false, + ))), }; self.context.add_alias( Some(table_alias.to_string()), @@ -488,7 +493,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' && matches!(join_used.map(|used| used.contains(column_name)), Some(true)) }; for (_, alias_expr) in context.expr_aliases.iter().filter(|(_, expr)| { - if let ScalarExpression::ColumnRef(col) = expr.unpack_alias_ref() { + if let ScalarExpression::ColumnRef(col, _) = expr.unpack_alias_ref() { let column_name = col.name(); if Some(&table_name) == col.table_name() @@ -524,7 +529,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' if fn_used(column_name, context, join_used.as_deref()) { continue; } - let expr = ScalarExpression::ColumnRef(column.clone()); + let expr = ScalarExpression::ColumnRef(column.clone(), false); if let Some(used) = join_used.as_mut() { used.insert(column_name.to_string()); @@ -534,6 +539,155 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' Ok(()) } + fn bind_parent( + &mut self, + mut plan: LogicalPlan, + expr: &ScalarExpression, + ) -> Result { + match expr { + ScalarExpression::ColumnRef(columnref, true) => { + let table = columnref.0.table_name().unwrap(); + let parent = self.bind_table_ref(&TableWithJoins { + relation: TableFactor::Table { + name: ObjectName(vec![Ident::new(table.as_str().to_string())]), + alias: None, + args: None, + with_hints: vec![], + }, + joins: vec![], + })?; + Ok(LJoinOperator::build( + plan, + parent, + JoinCondition::None, + JoinType::Full, + )) + } + ScalarExpression::Alias { expr, .. } + | ScalarExpression::TypeCast { expr, .. } + | ScalarExpression::IsNull { expr, .. } + | ScalarExpression::Unary { expr, .. } + | ScalarExpression::Reference { expr, .. } => self.bind_parent(plan, expr), + ScalarExpression::Binary { + left_expr, + right_expr, + .. + } => { + plan = self.bind_parent(plan, left_expr)?; + self.bind_parent(plan, right_expr) + } + ScalarExpression::AggCall { args, .. } => { + for expr in args { + plan = self.bind_parent(plan, expr)?; + } + Ok(plan) + } + ScalarExpression::In { expr, args, .. } => { + for expr in args { + plan = self.bind_parent(plan, expr)?; + } + self.bind_parent(plan, expr) + } + ScalarExpression::Between { + expr, + left_expr, + right_expr, + .. + } => { + plan = self.bind_parent(plan, left_expr)?; + plan = self.bind_parent(plan, right_expr)?; + self.bind_parent(plan, expr) + } + ScalarExpression::SubString { + expr, + from_expr, + for_expr, + } => { + if let Some(from_expr) = from_expr { + plan = self.bind_parent(plan, from_expr)?; + } + if let Some(for_expr) = for_expr { + plan = self.bind_parent(plan, for_expr)?; + } + self.bind_parent(plan, expr) + } + ScalarExpression::Position { expr, in_expr } => { + plan = self.bind_parent(plan, in_expr)?; + self.bind_parent(plan, expr) + } + ScalarExpression::Trim { + expr, + trim_what_expr, + .. + } => { + if let Some(trim_what_expr) = trim_what_expr { + plan = self.bind_parent(plan, trim_what_expr)?; + } + self.bind_parent(plan, expr) + } + ScalarExpression::Tuple(exprs) | ScalarExpression::Coalesce { exprs, .. } => { + for expr in exprs { + plan = self.bind_parent(plan, expr)?; + } + Ok(plan) + } + ScalarExpression::ScalaFunction(exprs) => { + for expr in &exprs.args { + plan = self.bind_parent(plan, expr)?; + } + Ok(plan) + } + ScalarExpression::TableFunction(exprs) => { + for expr in &exprs.args { + plan = self.bind_parent(plan, expr)?; + } + Ok(plan) + } + ScalarExpression::If { + condition, + left_expr, + right_expr, + .. + } => { + plan = self.bind_parent(plan, left_expr)?; + plan = self.bind_parent(plan, right_expr)?; + self.bind_parent(plan, condition) + } + ScalarExpression::IfNull { + left_expr, + right_expr, + .. + } + | ScalarExpression::NullIf { + left_expr, + right_expr, + .. + } => { + plan = self.bind_parent(plan, left_expr)?; + self.bind_parent(plan, right_expr) + } + ScalarExpression::CaseWhen { + operand_expr, + expr_pairs, + else_expr, + .. + } => { + if let Some(operand_expr) = operand_expr { + plan = self.bind_parent(plan, operand_expr)?; + } + for (left_expr, right_expr) in expr_pairs { + plan = self.bind_parent(plan, left_expr)?; + plan = self.bind_parent(plan, right_expr)?; + } + if let Some(else_expr) = else_expr { + plan = self.bind_parent(plan, else_expr)?; + } + Ok(plan) + } + _ => Ok(plan), + } + } + fn bind_join( &mut self, mut left: LogicalPlan, @@ -559,7 +713,6 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' scala_functions, table_functions, temp_table_id, - parent_name, .. } = &self.context; let mut binder = Binder::new( @@ -570,7 +723,6 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' scala_functions, table_functions, temp_table_id.clone(), - parent_name.clone(), ), self.args, Some(self), @@ -594,35 +746,10 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' predicate: &Expr, ) -> Result { self.context.step(QueryBindStep::Where); - let mut need_parent = true; - let mut predicate_result = ScalarExpression::Empty; - while need_parent { - predicate_result = match self.bind_expr(predicate) { - Ok(expr) => { - need_parent = false; - expr - } - Err(DatabaseError::InvalidTable(table)) => { - if !self.context.parent_name.contains(&table) { - return Err(DatabaseError::InvalidTable(table)); - } - let plan = self.bind_table_ref(&TableWithJoins { - relation: TableFactor::Table { - name: ObjectName(vec![Ident::new(table)]), - alias: None, - args: None, - with_hints: vec![], - }, - joins: vec![], - })?; - children = - LJoinOperator::build(children, plan, JoinCondition::None, JoinType::Full); - continue; - } - Err(e) => return Err(e), - }; - } + let predicate = self.bind_expr(predicate)?; + + children = self.bind_parent(children, &predicate)?; if let Some(sub_queries) = self.context.sub_queries_at_now() { for sub_query in sub_queries { @@ -657,6 +784,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' }, left_expr: Box::new(ScalarExpression::ColumnRef( agg.output_schema()[0].clone(), + false, )), right_expr: Box::new(ScalarExpression::Constant(DataValue::Int32( 1, @@ -693,7 +821,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' }; Self::extract_join_keys( - predicate_result.clone(), + predicate.clone(), &mut on_keys, &mut filter, children.output_schema(), @@ -723,7 +851,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' } return Ok(children); } - Ok(FilterOperator::build(predicate_result, children, false)) + Ok(FilterOperator::build(predicate, children, false)) } fn bind_having( @@ -832,7 +960,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' } for column in select_items { - if let ScalarExpression::ColumnRef(col) = column { + if let ScalarExpression::ColumnRef(col, _) = column { let _ = table_force_nullable .iter() .find(|(table_name, source, _)| { @@ -895,7 +1023,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' schema .iter() .find(|column| column.name() == name) - .map(|column| ScalarExpression::ColumnRef(column.clone())) + .map(|column| ScalarExpression::ColumnRef(column.clone(), false)) }; for ident in idents { let name = lower_ident(ident); @@ -928,8 +1056,8 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' right_schema.iter().find(|column| column.name() == *name), ) { on_keys.push(( - ScalarExpression::ColumnRef(left_column.clone()), - ScalarExpression::ColumnRef(right_column.clone()), + ScalarExpression::ColumnRef(left_column.clone(), false), + ScalarExpression::ColumnRef(right_column.clone(), false), )); } } @@ -979,7 +1107,10 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' BinaryOperator::Eq => { match (left_expr.unpack_alias_ref(), right_expr.unpack_alias_ref()) { // example: foo = bar - (ScalarExpression::ColumnRef(l), ScalarExpression::ColumnRef(r)) => { + ( + ScalarExpression::ColumnRef(l, _), + ScalarExpression::ColumnRef(r, _), + ) => { // reorder left and right joins keys to pattern: (left, right) if fn_contains(left_schema, l.summary()) && fn_contains(right_schema, r.summary()) @@ -1001,8 +1132,8 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' }); } } - (ScalarExpression::ColumnRef(column), _) - | (_, ScalarExpression::ColumnRef(column)) => { + (ScalarExpression::ColumnRef(column, _), _) + | (_, ScalarExpression::ColumnRef(column, _)) => { if fn_or_contains(left_schema, right_schema, column.summary()) { accum_filter.push(ScalarExpression::Binary { left_expr, diff --git a/src/binder/update.rs b/src/binder/update.rs index a160a24a..0670d56c 100644 --- a/src/binder/update.rs +++ b/src/binder/update.rs @@ -41,7 +41,7 @@ impl> Binder<'_, '_, T, A> slice::from_ref(ident), Some(table_name.to_string()), )? { - ScalarExpression::ColumnRef(column) => { + ScalarExpression::ColumnRef(column, _) => { let mut expr = if matches!(expression, ScalarExpression::Empty) { let default_value = column .default_value()? diff --git a/src/db.rs b/src/db.rs index 57d02bc9..ff08a4fa 100644 --- a/src/db.rs +++ b/src/db.rs @@ -150,7 +150,6 @@ impl State { scala_functions, table_functions, Arc::new(AtomicUsize::new(0)), - vec![], ), ¶ms, None, diff --git a/src/execution/ddl/create_index.rs b/src/execution/ddl/create_index.rs index 51c0dfa0..f6d1ec62 100644 --- a/src/execution/ddl/create_index.rs +++ b/src/execution/ddl/create_index.rs @@ -48,7 +48,7 @@ impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for CreateIndex { .filter_map(|column| { column .id() - .map(|id| (id, ScalarExpression::ColumnRef(column))) + .map(|id| (id, ScalarExpression::ColumnRef(column, false))) }) .unzip(); let schema = self.input.output_schema().clone(); diff --git a/src/execution/dql/aggregate/hash_agg.rs b/src/execution/dql/aggregate/hash_agg.rs index 5df8bcc8..547bd3d7 100644 --- a/src/execution/dql/aggregate/hash_agg.rs +++ b/src/execution/dql/aggregate/hash_agg.rs @@ -145,11 +145,11 @@ mod test { ]); let operator = AggregateOperator { - groupby_exprs: vec![ScalarExpression::ColumnRef(t1_schema[0].clone())], + groupby_exprs: vec![ScalarExpression::ColumnRef(t1_schema[0].clone(), false)], agg_calls: vec![ScalarExpression::AggCall { distinct: false, kind: AggKind::Sum, - args: vec![ScalarExpression::ColumnRef(t1_schema[1].clone())], + args: vec![ScalarExpression::ColumnRef(t1_schema[1].clone(), false)], ty: LogicalType::Integer, }], is_distinct: false, diff --git a/src/execution/dql/join/hash_join.rs b/src/execution/dql/join/hash_join.rs index f7cdce8e..bfd023dc 100644 --- a/src/execution/dql/join/hash_join.rs +++ b/src/execution/dql/join/hash_join.rs @@ -342,8 +342,8 @@ mod test { ]; let on_keys = vec![( - ScalarExpression::ColumnRef(t1_columns[0].clone()), - ScalarExpression::ColumnRef(t2_columns[0].clone()), + ScalarExpression::ColumnRef(t1_columns[0].clone(), false), + ScalarExpression::ColumnRef(t2_columns[0].clone(), false), )]; let values_t1 = LogicalPlan { diff --git a/src/execution/dql/join/nested_loop_join.rs b/src/execution/dql/join/nested_loop_join.rs index e79dca4a..8c2d4355 100644 --- a/src/execution/dql/join/nested_loop_join.rs +++ b/src/execution/dql/join/nested_loop_join.rs @@ -434,8 +434,8 @@ mod test { let on_keys = if eq { vec![( - ScalarExpression::ColumnRef(t1_columns[1].clone()), - ScalarExpression::ColumnRef(t2_columns[1].clone()), + ScalarExpression::ColumnRef(t1_columns[1].clone(), false), + ScalarExpression::ColumnRef(t2_columns[1].clone(), false), )] } else { vec![] @@ -505,12 +505,14 @@ mod test { let filter = ScalarExpression::Binary { op: crate::expression::BinaryOperator::Gt, - left_expr: Box::new(ScalarExpression::ColumnRef(ColumnRef::from( - ColumnCatalog::new("c1".to_owned(), true, desc.clone()), - ))), - right_expr: Box::new(ScalarExpression::ColumnRef(ColumnRef::from( - ColumnCatalog::new("c4".to_owned(), true, desc.clone()), - ))), + left_expr: Box::new(ScalarExpression::ColumnRef( + ColumnRef::from(ColumnCatalog::new("c1".to_owned(), true, desc.clone())), + false, + )), + right_expr: Box::new(ScalarExpression::ColumnRef( + ColumnRef::from(ColumnCatalog::new("c4".to_owned(), true, desc.clone())), + false, + )), evaluator: Some(BinaryEvaluatorBox(Arc::new(Int32GtBinaryEvaluator))), ty: LogicalType::Boolean, }; diff --git a/src/expression/evaluator.rs b/src/expression/evaluator.rs index 069b80dd..b3ccf9c8 100644 --- a/src/expression/evaluator.rs +++ b/src/expression/evaluator.rs @@ -33,7 +33,7 @@ impl ScalarExpression { match self { ScalarExpression::Constant(val) => Ok(val.clone()), - ScalarExpression::ColumnRef(col) => { + ScalarExpression::ColumnRef(col, _) => { let Some((tuple, schema)) = tuple else { return Ok(DataValue::Null); }; diff --git a/src/expression/mod.rs b/src/expression/mod.rs index 5ce161da..e78e058f 100644 --- a/src/expression/mod.rs +++ b/src/expression/mod.rs @@ -39,7 +39,7 @@ pub enum AliasType { #[derive(Debug, PartialEq, Eq, Clone, Hash, ReferenceSerialization)] pub enum ScalarExpression { Constant(DataValue), - ColumnRef(ColumnRef), + ColumnRef(ColumnRef, bool), Alias { expr: Box, alias: AliasType, @@ -289,7 +289,7 @@ impl ScalarExpression { pub fn return_type(&self) -> LogicalType { match self { ScalarExpression::Constant(v) => v.logical_type(), - ScalarExpression::ColumnRef(col) => col.datatype().clone(), + ScalarExpression::ColumnRef(col, _) => col.datatype().clone(), ScalarExpression::Binary { ty: return_type, .. } @@ -353,7 +353,7 @@ impl ScalarExpression { vec.push(expr.output_column()); } match expr { - ScalarExpression::ColumnRef(col) => { + ScalarExpression::ColumnRef(col, _) => { vec.push(col.clone()); } ScalarExpression::Alias { expr, .. } => { @@ -480,7 +480,7 @@ impl ScalarExpression { pub fn has_table_ref_column(&self) -> bool { match self { ScalarExpression::Constant(_) => false, - ScalarExpression::ColumnRef(column) => { + ScalarExpression::ColumnRef(column, _) => { column.table_name().is_some() && column.id().is_some() } ScalarExpression::Alias { expr, .. } => expr.has_table_ref_column(), @@ -600,7 +600,7 @@ impl ScalarExpression { match self { ScalarExpression::AggCall { .. } => true, ScalarExpression::Constant(_) => false, - ScalarExpression::ColumnRef(_) => false, + ScalarExpression::ColumnRef(_, _) => false, ScalarExpression::Alias { expr, .. } => expr.has_agg_call(), ScalarExpression::TypeCast { expr, .. } => expr.has_agg_call(), ScalarExpression::IsNull { expr, .. } => expr.has_agg_call(), @@ -690,7 +690,7 @@ impl ScalarExpression { pub fn output_name(&self) -> String { match self { ScalarExpression::Constant(value) => format!("{}", value), - ScalarExpression::ColumnRef(col) => col.full_name(), + ScalarExpression::ColumnRef(col, _) => col.full_name(), ScalarExpression::Alias { alias, expr } => match alias { AliasType::Name(alias) => alias.to_string(), AliasType::Expr(alias_expr) => { @@ -881,7 +881,7 @@ impl ScalarExpression { pub fn output_column(&self) -> ColumnRef { match self { - ScalarExpression::ColumnRef(col) => col.clone(), + ScalarExpression::ColumnRef(col, _) => col.clone(), ScalarExpression::Alias { alias: AliasType::Expr(expr), .. @@ -1105,33 +1105,39 @@ mod test { )?; fn_assert( &mut cursor, - ScalarExpression::ColumnRef(ColumnRef::from(ColumnCatalog::direct_new( - ColumnSummary { - name: "c3".to_string(), - relation: ColumnRelation::Table { - column_id: c3_column_id, - table_name: Arc::new("t1".to_string()), - is_temp: false, + ScalarExpression::ColumnRef( + ColumnRef::from(ColumnCatalog::direct_new( + ColumnSummary { + name: "c3".to_string(), + relation: ColumnRelation::Table { + column_id: c3_column_id, + table_name: Arc::new("t1".to_string()), + is_temp: false, + }, }, - }, - false, - ColumnDesc::new(LogicalType::Integer, None, false, None)?, + false, + ColumnDesc::new(LogicalType::Integer, None, false, None)?, + false, + )), false, - ))), + ), Some((&transaction, &table_cache)), &mut reference_tables, )?; fn_assert( &mut cursor, - ScalarExpression::ColumnRef(ColumnRef::from(ColumnCatalog::direct_new( - ColumnSummary { - name: "c4".to_string(), - relation: ColumnRelation::None, - }, - false, - ColumnDesc::new(LogicalType::Boolean, None, false, None)?, + ScalarExpression::ColumnRef( + ColumnRef::from(ColumnCatalog::direct_new( + ColumnSummary { + name: "c4".to_string(), + relation: ColumnRelation::None, + }, + false, + ColumnDesc::new(LogicalType::Boolean, None, false, None)?, + false, + )), false, - ))), + ), Some((&transaction, &table_cache)), &mut reference_tables, )?; diff --git a/src/expression/range_detacher.rs b/src/expression/range_detacher.rs index 7284038c..f3928d4f 100644 --- a/src/expression/range_detacher.rs +++ b/src/expression/range_detacher.rs @@ -216,7 +216,7 @@ impl<'a> RangeDetacher<'a> { ScalarExpression::Position { expr, .. } => self.detach(expr)?, ScalarExpression::Trim { expr, .. } => self.detach(expr)?, ScalarExpression::IsNull { expr, negated, .. } => match expr.as_ref() { - ScalarExpression::ColumnRef(column) => { + ScalarExpression::ColumnRef(column, _) => { if let (Some(col_id), Some(col_table)) = (column.id(), column.table_name()) { if &col_id == self.column_id && col_table.as_str() == self.table_name { return if *negated { @@ -253,7 +253,7 @@ impl<'a> RangeDetacher<'a> { | ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(), }, - ScalarExpression::Constant(_) | ScalarExpression::ColumnRef(_) => None, + ScalarExpression::Constant(_) | ScalarExpression::ColumnRef(_, _) => None, // FIXME: support [RangeDetacher::_detach] ScalarExpression::Tuple(_) | ScalarExpression::AggCall { .. } diff --git a/src/expression/simplify.rs b/src/expression/simplify.rs index 820206e4..44b2e05f 100644 --- a/src/expression/simplify.rs +++ b/src/expression/simplify.rs @@ -133,7 +133,7 @@ impl VisitorMut<'_> for Simplify { match (left_expr.unpack_col(false), right_expr.unpack_col(false)) { (Some(col), None) => { self.replaces.push(Replace::Binary(ReplaceBinary { - column_expr: ScalarExpression::ColumnRef(col), + column_expr: ScalarExpression::ColumnRef(col, false), val_expr: mem::replace(right_expr, ScalarExpression::Empty), op: *op, ty: ty.clone(), @@ -142,7 +142,7 @@ impl VisitorMut<'_> for Simplify { } (None, Some(col)) => { self.replaces.push(Replace::Binary(ReplaceBinary { - column_expr: ScalarExpression::ColumnRef(col), + column_expr: ScalarExpression::ColumnRef(col, false), val_expr: mem::replace(left_expr, ScalarExpression::Empty), op: *op, ty: ty.clone(), @@ -157,7 +157,7 @@ impl VisitorMut<'_> for Simplify { match (left_expr.unpack_col(true), right_expr.unpack_col(true)) { (Some(col), None) => { self.replaces.push(Replace::Binary(ReplaceBinary { - column_expr: ScalarExpression::ColumnRef(col), + column_expr: ScalarExpression::ColumnRef(col, false), val_expr: mem::replace(right_expr, ScalarExpression::Empty), op: *op, ty: ty.clone(), @@ -166,7 +166,7 @@ impl VisitorMut<'_> for Simplify { } (None, Some(col)) => { self.replaces.push(Replace::Binary(ReplaceBinary { - column_expr: ScalarExpression::ColumnRef(col), + column_expr: ScalarExpression::ColumnRef(col, false), val_expr: mem::replace(left_expr, ScalarExpression::Empty), op: *op, ty: ty.clone(), @@ -483,7 +483,7 @@ impl ScalarExpression { pub(crate) fn unpack_col(&self, is_deep: bool) -> Option { match self { - ScalarExpression::ColumnRef(col) => Some(col.clone()), + ScalarExpression::ColumnRef(col, _) => Some(col.clone()), ScalarExpression::Alias { expr, .. } => expr.unpack_col(is_deep), ScalarExpression::Unary { expr, .. } => expr.unpack_col(is_deep), ScalarExpression::Binary { diff --git a/src/expression/visitor.rs b/src/expression/visitor.rs index e0484e76..691d45e3 100644 --- a/src/expression/visitor.rs +++ b/src/expression/visitor.rs @@ -254,7 +254,7 @@ pub fn walk_expr<'a, V: Visitor<'a>>( ) -> Result<(), DatabaseError> { match expr { ScalarExpression::Constant(value) => visitor.visit_constant(value), - ScalarExpression::ColumnRef(column_ref) => visitor.visit_column_ref(column_ref), + ScalarExpression::ColumnRef(column_ref, _) => visitor.visit_column_ref(column_ref), ScalarExpression::Alias { expr, alias } => visitor.visit_alias(expr, alias), ScalarExpression::TypeCast { expr, ty } => visitor.visit_type_cast(expr, ty), ScalarExpression::IsNull { negated, expr } => visitor.visit_is_null(*negated, expr), diff --git a/src/expression/visitor_mut.rs b/src/expression/visitor_mut.rs index 87aead51..e3e2317b 100644 --- a/src/expression/visitor_mut.rs +++ b/src/expression/visitor_mut.rs @@ -254,7 +254,7 @@ pub fn walk_mut_expr<'a, V: VisitorMut<'a>>( ) -> Result<(), DatabaseError> { match expr { ScalarExpression::Constant(value) => visitor.visit_constant(value), - ScalarExpression::ColumnRef(column_ref) => visitor.visit_column_ref(column_ref), + ScalarExpression::ColumnRef(column_ref, _) => visitor.visit_column_ref(column_ref), ScalarExpression::Alias { expr, alias } => visitor.visit_alias(expr, alias), ScalarExpression::TypeCast { expr, ty } => visitor.visit_type_cast(expr, ty), ScalarExpression::IsNull { negated, expr } => visitor.visit_is_null(*negated, expr), diff --git a/src/optimizer/core/memo.rs b/src/optimizer/core/memo.rs index 4b50f914..5324b1e5 100644 --- a/src/optimizer/core/memo.rs +++ b/src/optimizer/core/memo.rs @@ -138,7 +138,6 @@ mod tests { &scala_functions, &table_functions, Arc::new(AtomicUsize::new(0)), - vec![], ), &[], None, diff --git a/src/optimizer/rule/normalization/simplification.rs b/src/optimizer/rule/normalization/simplification.rs index a220c065..8426b9c8 100644 --- a/src/optimizer/rule/normalization/simplification.rs +++ b/src/optimizer/rule/normalization/simplification.rs @@ -296,9 +296,10 @@ mod test { op: UnaryOperator::Minus, expr: Box::new(ScalarExpression::Binary { op: BinaryOperator::Plus, - left_expr: Box::new(ScalarExpression::ColumnRef(ColumnRef::from( - c1_col - ))), + left_expr: Box::new(ScalarExpression::ColumnRef( + ColumnRef::from(c1_col), + false + )), right_expr: Box::new(ScalarExpression::Constant(DataValue::Int32(1))), evaluator: None, ty: LogicalType::Integer, @@ -306,7 +307,10 @@ mod test { evaluator: None, ty: LogicalType::Integer, }), - right_expr: Box::new(ScalarExpression::ColumnRef(ColumnRef::from(c2_col))), + right_expr: Box::new(ScalarExpression::ColumnRef( + ColumnRef::from(c2_col), + false + )), evaluator: None, ty: LogicalType::Boolean, } diff --git a/src/planner/operator/mod.rs b/src/planner/operator/mod.rs index 620c7b31..78a25b30 100644 --- a/src/planner/operator/mod.rs +++ b/src/planner/operator/mod.rs @@ -136,7 +136,7 @@ impl Operator { Operator::TableScan(op) => Some( op.columns .values() - .map(|column| ScalarExpression::ColumnRef(column.clone())) + .map(|column| ScalarExpression::ColumnRef(column.clone(), false)) .collect_vec(), ), Operator::Sort(_) | Operator::Limit(_) => None, @@ -148,7 +148,7 @@ impl Operator { schema_ref .iter() .cloned() - .map(ScalarExpression::ColumnRef) + .map(|col| ScalarExpression::ColumnRef(col, false)) .collect_vec(), ), Operator::FunctionScan(op) => Some( @@ -156,7 +156,7 @@ impl Operator { .inner .output_schema() .iter() - .map(|column| ScalarExpression::ColumnRef(column.clone())) + .map(|column| ScalarExpression::ColumnRef(column.clone(), false)) .collect_vec(), ), Operator::ShowTable diff --git a/src/types/index.rs b/src/types/index.rs index df57e070..8762566f 100644 --- a/src/types/index.rs +++ b/src/types/index.rs @@ -46,7 +46,7 @@ impl IndexMeta { for column_id in self.column_ids.iter() { if let Some(column) = table.get_column_by_id(column_id) { - exprs.push(ScalarExpression::ColumnRef(column.clone())); + exprs.push(ScalarExpression::ColumnRef(column.clone(), false)); } else { return Err(DatabaseError::ColumnNotFound(column_id.to_string())); } From 591119b23398aafb28f861c2f8664fbf31f0f2cf Mon Sep 17 00:00:00 2001 From: wszhdshys <1925792291@qq.com> Date: Tue, 22 Jul 2025 18:15:26 +0800 Subject: [PATCH 4/7] Changed the binding method of table in parent, which should support more cases --- src/binder/create_index.rs | 2 +- src/binder/create_view.rs | 3 +- src/binder/expr.rs | 44 +--- src/binder/insert.rs | 2 +- src/binder/mod.rs | 28 ++- src/binder/select.rs | 190 +++--------------- src/binder/update.rs | 2 +- src/execution/ddl/create_index.rs | 2 +- src/execution/dql/aggregate/hash_agg.rs | 4 +- src/execution/dql/join/hash_join.rs | 4 +- src/execution/dql/join/nested_loop_join.rs | 6 +- src/expression/evaluator.rs | 2 +- src/expression/mod.rs | 16 +- src/expression/range_detacher.rs | 4 +- src/expression/simplify.rs | 10 +- src/expression/visitor.rs | 2 +- src/expression/visitor_mut.rs | 2 +- .../rule/normalization/simplification.rs | 2 - src/planner/operator/mod.rs | 6 +- src/types/index.rs | 2 +- 20 files changed, 80 insertions(+), 253 deletions(-) diff --git a/src/binder/create_index.rs b/src/binder/create_index.rs index 2af5fde2..4b5332ba 100644 --- a/src/binder/create_index.rs +++ b/src/binder/create_index.rs @@ -43,7 +43,7 @@ impl> Binder<'_, '_, T, A> for expr in exprs { // TODO: Expression Index match self.bind_expr(&expr.expr)? { - ScalarExpression::ColumnRef(column, false) => columns.push(column), + ScalarExpression::ColumnRef(column) => columns.push(column), expr => { return Err(DatabaseError::UnsupportedStmt(format!( "'CREATE INDEX' by {}", diff --git a/src/binder/create_view.rs b/src/binder/create_view.rs index aba140c7..2d99b0bd 100644 --- a/src/binder/create_view.rs +++ b/src/binder/create_view.rs @@ -39,10 +39,9 @@ impl> Binder<'_, '_, T, A> column.set_ref_table(view_name.clone(), Ulid::new(), true); ScalarExpression::Alias { - expr: Box::new(ScalarExpression::ColumnRef(mapping_column.clone(), false)), + expr: Box::new(ScalarExpression::ColumnRef(mapping_column.clone())), alias: AliasType::Expr(Box::new(ScalarExpression::ColumnRef( ColumnRef::from(column), - false, ))), } }) diff --git a/src/binder/expr.rs b/src/binder/expr.rs index 2ffba817..8e7bcbbf 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -7,11 +7,11 @@ use sqlparser::ast::{ BinaryOperator, CharLengthUnits, DataType, Expr, Function, FunctionArg, FunctionArgExpr, Ident, Query, UnaryOperator, Value, }; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::slice; use std::sync::Arc; -use super::{lower_ident, Binder, BinderContext, QueryBindStep, Source, SubQueryType}; +use super::{lower_ident, Binder, BinderContext, QueryBindStep, SubQueryType}; use crate::expression::function::scala::{ArcScalarFunctionImpl, ScalarFunction}; use crate::expression::function::table::{ArcTableFunctionImpl, TableFunction}; use crate::expression::function::FunctionSummary; @@ -259,10 +259,9 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T let alias_expr = ScalarExpression::Alias { expr: Box::new(expr), - alias: AliasType::Expr(Box::new(ScalarExpression::ColumnRef( - ColumnRef::from(alias_column), - false, - ))), + alias: AliasType::Expr(Box::new(ScalarExpression::ColumnRef(ColumnRef::from( + alias_column, + )))), }; let alias_plan = self.bind_project(sub_query, vec![alias_expr.clone()])?; Ok((alias_expr, alias_plan)) @@ -313,13 +312,13 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T let columns = sub_query_schema .iter() - .map(|column| ScalarExpression::ColumnRef(column.clone(), false)) + .map(|column| ScalarExpression::ColumnRef(column.clone())) .collect::>(); ScalarExpression::Tuple(columns) } else { fn_check(1)?; - ScalarExpression::ColumnRef(sub_query_schema[0].clone(), false) + ScalarExpression::ColumnRef(sub_query_schema[0].clone()) }; Ok((sub_query, expr)) } @@ -370,30 +369,10 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T try_default!(&full_name.0, full_name.1); } if let Some(table) = full_name.0.or(bind_table_name) { - let mut source: &Source; - let mut from_parent: bool; - let (mut parent, mut parent_context) = if let Some(parent) = self.parent { - (Some(parent), Some(&parent.context)) - } else { - (None, None) - }; + let (source,is_parent) = self.context.bind_source::(self.parent, &table, false)?; - loop { - (source, from_parent) = match self.context.bind_source(parent_context, &table) { - (Ok(source), from_parent) => (source, from_parent), - (Err(e), _) => { - if let Some(p) = parent { - (parent, parent_context) = match p.parent { - Some(parent) => (Some(parent), Some(&parent.context)), - None => return Err(e), - } - } else { - return Err(e); - } - continue; - } - }; - break; + if is_parent { + self.parent_table_col.entry(Arc::new(table.clone())).or_insert(HashSet::new()).insert(full_name.1.clone()); } let schema_buf = self.table_schema_buf.entry(Arc::new(table)).or_default(); @@ -402,7 +381,6 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T source .column(&full_name.1, schema_buf) .ok_or_else(|| DatabaseError::ColumnNotFound(full_name.1.to_string()))?, - from_parent, )) } else { let op = @@ -431,7 +409,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T table_schema_buf.entry(table_name.clone()).or_default(); source.column(&full_name.1, schema_buf) } { - *got_column = Some(ScalarExpression::ColumnRef(column, false)); + *got_column = Some(ScalarExpression::ColumnRef(column)); } } }; diff --git a/src/binder/insert.rs b/src/binder/insert.rs index c8f36843..4c8e78c1 100644 --- a/src/binder/insert.rs +++ b/src/binder/insert.rs @@ -51,7 +51,7 @@ impl> Binder<'_, '_, T, A> slice::from_ref(ident), Some(table_name.to_string()), )? { - ScalarExpression::ColumnRef(catalog, _) => columns.push(catalog), + ScalarExpression::ColumnRef(catalog) => columns.push(catalog), _ => return Err(DatabaseError::UnsupportedStmt(ident.to_string())), } } diff --git a/src/binder/mod.rs b/src/binder/mod.rs index b85778b2..3068a7a3 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -25,7 +25,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use crate::catalog::view::View; -use crate::catalog::{ColumnRef, TableCatalog, TableName}; +use crate::catalog::{ColumnCatalog, ColumnRef, TableCatalog, TableName}; use crate::db::{ScalaFunctions, TableFunctions}; use crate::errors::DatabaseError; use crate::expression::ScalarExpression; @@ -275,28 +275,22 @@ impl<'a, T: Transaction> BinderContext<'a, T> { Ok(source) } - pub fn bind_source<'b: 'a>( + pub fn bind_source<'b: 'a, A: AsRef<[(&'static str, DataValue)]> >( &self, - parent: Option<&'a BinderContext<'_, T>>, + parent: Option<&'a Binder<'a,'b,T,A>>, table_name: &str, - ) -> (Result<&'b Source, DatabaseError>, bool) { + is_parent: bool, + ) -> Result<(&'b Source, bool), DatabaseError> { if let Some(source) = self.bind_table.iter().find(|((t, alias, _), _)| { t.as_str() == table_name || matches!(alias.as_ref().map(|a| a.as_str() == table_name), Some(true)) }) { - (Ok(source.1), false) - } else if let Some(context) = parent { - if let Some(source) = context.bind_table.iter().find(|((t, alias, _), _)| { - t.as_str() == table_name - || matches!(alias.as_ref().map(|a| a.as_str() == table_name), Some(true)) - }) { - (Ok(source.1), true) - } else { - (Err(DatabaseError::InvalidTable(table_name.into())), false) - } + Ok((source.1, is_parent)) + } else if let Some(binder) = parent { + binder.context.bind_source(binder.parent, table_name,true) } else { - (Err(DatabaseError::InvalidTable(table_name.into())), false) - } + Err(DatabaseError::InvalidTable(table_name.into())) + } } // Tips: The order of this index is based on Aggregate being bound first. @@ -335,6 +329,7 @@ pub struct Binder<'a, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> args: &'a A, with_pk: Option, pub(crate) parent: Option<&'b Binder<'a, 'b, T, A>>, + pub(crate) parent_table_col: HashMap>, } impl<'a, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, 'b, T, A> { @@ -349,6 +344,7 @@ impl<'a, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, ' args, with_pk: None, parent, + parent_table_col: Default::default(), } } diff --git a/src/binder/select.rs b/src/binder/select.rs index f9a5bd55..fc1114dc 100644 --- a/src/binder/select.rs +++ b/src/binder/select.rs @@ -101,9 +101,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' }; let mut select_list = self.normalize_select_item(&select.projection, &plan)?; - for expr in &select_list { - plan = self.bind_parent(plan, expr)?; - } + plan = self.bind_parent(plan)?; if let Some(predicate) = &select.selection { plan = self.bind_where(plan, predicate)?; @@ -220,7 +218,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' let distinct_exprs = left_schema .iter() .cloned() - .map(|col| ScalarExpression::ColumnRef(col, false)) + .map(ScalarExpression::ColumnRef) .collect_vec(); Ok(self.bind_distinct( @@ -362,11 +360,10 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' ); let alias_column_expr = ScalarExpression::Alias { - expr: Box::new(ScalarExpression::ColumnRef(column, false)), - alias: AliasType::Expr(Box::new(ScalarExpression::ColumnRef( - ColumnRef::from(alias_column), - false, - ))), + expr: Box::new(ScalarExpression::ColumnRef(column)), + alias: AliasType::Expr(Box::new(ScalarExpression::ColumnRef(ColumnRef::from( + alias_column, + )))), }; self.context.add_alias( Some(table_alias.to_string()), @@ -493,7 +490,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' && matches!(join_used.map(|used| used.contains(column_name)), Some(true)) }; for (_, alias_expr) in context.expr_aliases.iter().filter(|(_, expr)| { - if let ScalarExpression::ColumnRef(col, _) = expr.unpack_alias_ref() { + if let ScalarExpression::ColumnRef(col) = expr.unpack_alias_ref() { let column_name = col.name(); if Some(&table_name) == col.table_name() @@ -529,7 +526,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' if fn_used(column_name, context, join_used.as_deref()) { continue; } - let expr = ScalarExpression::ColumnRef(column.clone(), false); + let expr = ScalarExpression::ColumnRef(column.clone()); if let Some(used) = join_used.as_mut() { used.insert(column_name.to_string()); @@ -542,150 +539,17 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' fn bind_parent( &mut self, mut plan: LogicalPlan, - expr: &ScalarExpression, ) -> Result { - match expr { - ScalarExpression::ColumnRef(columnref, true) => { - let table = columnref.0.table_name().unwrap(); - let parent = self.bind_table_ref(&TableWithJoins { - relation: TableFactor::Table { - name: ObjectName(vec![Ident::new(table.as_str().to_string())]), - alias: None, - args: None, - with_hints: vec![], - }, - joins: vec![], - })?; - Ok(LJoinOperator::build( - plan, - parent, - JoinCondition::None, - JoinType::Full, - )) - } - ScalarExpression::Alias { expr, .. } - | ScalarExpression::TypeCast { expr, .. } - | ScalarExpression::IsNull { expr, .. } - | ScalarExpression::Unary { expr, .. } - | ScalarExpression::Reference { expr, .. } => self.bind_parent(plan, expr), - ScalarExpression::Binary { - left_expr, - right_expr, - .. - } => { - plan = self.bind_parent(plan, left_expr)?; - self.bind_parent(plan, right_expr) - } - ScalarExpression::AggCall { args, .. } => { - for expr in args { - plan = self.bind_parent(plan, expr)?; - } - Ok(plan) - } - ScalarExpression::In { expr, args, .. } => { - for expr in args { - plan = self.bind_parent(plan, expr)?; - } - self.bind_parent(plan, expr) - } - ScalarExpression::Between { - expr, - left_expr, - right_expr, - .. - } => { - plan = self.bind_parent(plan, left_expr)?; - plan = self.bind_parent(plan, right_expr)?; - self.bind_parent(plan, expr) - } - ScalarExpression::SubString { - expr, - from_expr, - for_expr, - } => { - if let Some(from_expr) = from_expr { - plan = self.bind_parent(plan, from_expr)?; - } - if let Some(for_expr) = for_expr { - plan = self.bind_parent(plan, for_expr)?; - } - self.bind_parent(plan, expr) - } - ScalarExpression::Position { expr, in_expr } => { - plan = self.bind_parent(plan, in_expr)?; - self.bind_parent(plan, expr) - } - ScalarExpression::Trim { - expr, - trim_what_expr, - .. - } => { - if let Some(trim_what_expr) = trim_what_expr { - plan = self.bind_parent(plan, trim_what_expr)?; - } - self.bind_parent(plan, expr) - } - ScalarExpression::Tuple(exprs) | ScalarExpression::Coalesce { exprs, .. } => { - for expr in exprs { - plan = self.bind_parent(plan, expr)?; - } - Ok(plan) - } - ScalarExpression::ScalaFunction(exprs) => { - for expr in &exprs.args { - plan = self.bind_parent(plan, expr)?; - } - Ok(plan) - } - ScalarExpression::TableFunction(exprs) => { - for expr in &exprs.args { - plan = self.bind_parent(plan, expr)?; - } - Ok(plan) - } - ScalarExpression::If { - condition, - left_expr, - right_expr, - .. - } => { - plan = self.bind_parent(plan, left_expr)?; - plan = self.bind_parent(plan, right_expr)?; - self.bind_parent(plan, condition) - } - ScalarExpression::IfNull { - left_expr, - right_expr, - .. - } - | ScalarExpression::NullIf { - left_expr, - right_expr, - .. - } => { - plan = self.bind_parent(plan, left_expr)?; - self.bind_parent(plan, right_expr) - } - ScalarExpression::CaseWhen { - operand_expr, - expr_pairs, - else_expr, - .. - } => { - if let Some(operand_expr) = operand_expr { - plan = self.bind_parent(plan, operand_expr)?; - } - for (left_expr, right_expr) in expr_pairs { - plan = self.bind_parent(plan, left_expr)?; - plan = self.bind_parent(plan, right_expr)?; - } - if let Some(else_expr) = else_expr { - plan = self.bind_parent(plan, else_expr)?; - } - Ok(plan) - } - _ => Ok(plan), + for (table,columns) in self.parent_table_col.clone().into_iter() { + let parent = self._bind_single_table_ref(None,table.as_str(),None)?; + plan = LJoinOperator::build( + plan, + parent, + JoinCondition::None, + JoinType::Full, + ); } + Ok(plan) } fn bind_join( @@ -749,7 +613,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' let predicate = self.bind_expr(predicate)?; - children = self.bind_parent(children, &predicate)?; + children = self.bind_parent(children)?; if let Some(sub_queries) = self.context.sub_queries_at_now() { for sub_query in sub_queries { @@ -784,7 +648,6 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' }, left_expr: Box::new(ScalarExpression::ColumnRef( agg.output_schema()[0].clone(), - false, )), right_expr: Box::new(ScalarExpression::Constant(DataValue::Int32( 1, @@ -960,7 +823,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' } for column in select_items { - if let ScalarExpression::ColumnRef(col, _) = column { + if let ScalarExpression::ColumnRef(col) = column { let _ = table_force_nullable .iter() .find(|(table_name, source, _)| { @@ -1023,7 +886,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' schema .iter() .find(|column| column.name() == name) - .map(|column| ScalarExpression::ColumnRef(column.clone(), false)) + .map(|column| ScalarExpression::ColumnRef(column.clone())) }; for ident in idents { let name = lower_ident(ident); @@ -1056,8 +919,8 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' right_schema.iter().find(|column| column.name() == *name), ) { on_keys.push(( - ScalarExpression::ColumnRef(left_column.clone(), false), - ScalarExpression::ColumnRef(right_column.clone(), false), + ScalarExpression::ColumnRef(left_column.clone()), + ScalarExpression::ColumnRef(right_column.clone()), )); } } @@ -1107,10 +970,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' BinaryOperator::Eq => { match (left_expr.unpack_alias_ref(), right_expr.unpack_alias_ref()) { // example: foo = bar - ( - ScalarExpression::ColumnRef(l, _), - ScalarExpression::ColumnRef(r, _), - ) => { + (ScalarExpression::ColumnRef(l), ScalarExpression::ColumnRef(r)) => { // reorder left and right joins keys to pattern: (left, right) if fn_contains(left_schema, l.summary()) && fn_contains(right_schema, r.summary()) @@ -1132,8 +992,8 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' }); } } - (ScalarExpression::ColumnRef(column, _), _) - | (_, ScalarExpression::ColumnRef(column, _)) => { + (ScalarExpression::ColumnRef(column), _) + | (_, ScalarExpression::ColumnRef(column)) => { if fn_or_contains(left_schema, right_schema, column.summary()) { accum_filter.push(ScalarExpression::Binary { left_expr, diff --git a/src/binder/update.rs b/src/binder/update.rs index 0670d56c..a160a24a 100644 --- a/src/binder/update.rs +++ b/src/binder/update.rs @@ -41,7 +41,7 @@ impl> Binder<'_, '_, T, A> slice::from_ref(ident), Some(table_name.to_string()), )? { - ScalarExpression::ColumnRef(column, _) => { + ScalarExpression::ColumnRef(column) => { let mut expr = if matches!(expression, ScalarExpression::Empty) { let default_value = column .default_value()? diff --git a/src/execution/ddl/create_index.rs b/src/execution/ddl/create_index.rs index f6d1ec62..51c0dfa0 100644 --- a/src/execution/ddl/create_index.rs +++ b/src/execution/ddl/create_index.rs @@ -48,7 +48,7 @@ impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for CreateIndex { .filter_map(|column| { column .id() - .map(|id| (id, ScalarExpression::ColumnRef(column, false))) + .map(|id| (id, ScalarExpression::ColumnRef(column))) }) .unzip(); let schema = self.input.output_schema().clone(); diff --git a/src/execution/dql/aggregate/hash_agg.rs b/src/execution/dql/aggregate/hash_agg.rs index 547bd3d7..5df8bcc8 100644 --- a/src/execution/dql/aggregate/hash_agg.rs +++ b/src/execution/dql/aggregate/hash_agg.rs @@ -145,11 +145,11 @@ mod test { ]); let operator = AggregateOperator { - groupby_exprs: vec![ScalarExpression::ColumnRef(t1_schema[0].clone(), false)], + groupby_exprs: vec![ScalarExpression::ColumnRef(t1_schema[0].clone())], agg_calls: vec![ScalarExpression::AggCall { distinct: false, kind: AggKind::Sum, - args: vec![ScalarExpression::ColumnRef(t1_schema[1].clone(), false)], + args: vec![ScalarExpression::ColumnRef(t1_schema[1].clone())], ty: LogicalType::Integer, }], is_distinct: false, diff --git a/src/execution/dql/join/hash_join.rs b/src/execution/dql/join/hash_join.rs index bfd023dc..f7cdce8e 100644 --- a/src/execution/dql/join/hash_join.rs +++ b/src/execution/dql/join/hash_join.rs @@ -342,8 +342,8 @@ mod test { ]; let on_keys = vec![( - ScalarExpression::ColumnRef(t1_columns[0].clone(), false), - ScalarExpression::ColumnRef(t2_columns[0].clone(), false), + ScalarExpression::ColumnRef(t1_columns[0].clone()), + ScalarExpression::ColumnRef(t2_columns[0].clone()), )]; let values_t1 = LogicalPlan { diff --git a/src/execution/dql/join/nested_loop_join.rs b/src/execution/dql/join/nested_loop_join.rs index 8c2d4355..67492caa 100644 --- a/src/execution/dql/join/nested_loop_join.rs +++ b/src/execution/dql/join/nested_loop_join.rs @@ -434,8 +434,8 @@ mod test { let on_keys = if eq { vec![( - ScalarExpression::ColumnRef(t1_columns[1].clone(), false), - ScalarExpression::ColumnRef(t2_columns[1].clone(), false), + ScalarExpression::ColumnRef(t1_columns[1].clone()), + ScalarExpression::ColumnRef(t2_columns[1].clone()), )] } else { vec![] @@ -507,11 +507,9 @@ mod test { op: crate::expression::BinaryOperator::Gt, left_expr: Box::new(ScalarExpression::ColumnRef( ColumnRef::from(ColumnCatalog::new("c1".to_owned(), true, desc.clone())), - false, )), right_expr: Box::new(ScalarExpression::ColumnRef( ColumnRef::from(ColumnCatalog::new("c4".to_owned(), true, desc.clone())), - false, )), evaluator: Some(BinaryEvaluatorBox(Arc::new(Int32GtBinaryEvaluator))), ty: LogicalType::Boolean, diff --git a/src/expression/evaluator.rs b/src/expression/evaluator.rs index b3ccf9c8..069b80dd 100644 --- a/src/expression/evaluator.rs +++ b/src/expression/evaluator.rs @@ -33,7 +33,7 @@ impl ScalarExpression { match self { ScalarExpression::Constant(val) => Ok(val.clone()), - ScalarExpression::ColumnRef(col, _) => { + ScalarExpression::ColumnRef(col) => { let Some((tuple, schema)) = tuple else { return Ok(DataValue::Null); }; diff --git a/src/expression/mod.rs b/src/expression/mod.rs index e78e058f..0f9f03aa 100644 --- a/src/expression/mod.rs +++ b/src/expression/mod.rs @@ -39,7 +39,7 @@ pub enum AliasType { #[derive(Debug, PartialEq, Eq, Clone, Hash, ReferenceSerialization)] pub enum ScalarExpression { Constant(DataValue), - ColumnRef(ColumnRef, bool), + ColumnRef(ColumnRef), Alias { expr: Box, alias: AliasType, @@ -289,7 +289,7 @@ impl ScalarExpression { pub fn return_type(&self) -> LogicalType { match self { ScalarExpression::Constant(v) => v.logical_type(), - ScalarExpression::ColumnRef(col, _) => col.datatype().clone(), + ScalarExpression::ColumnRef(col) => col.datatype().clone(), ScalarExpression::Binary { ty: return_type, .. } @@ -353,7 +353,7 @@ impl ScalarExpression { vec.push(expr.output_column()); } match expr { - ScalarExpression::ColumnRef(col, _) => { + ScalarExpression::ColumnRef(col) => { vec.push(col.clone()); } ScalarExpression::Alias { expr, .. } => { @@ -480,7 +480,7 @@ impl ScalarExpression { pub fn has_table_ref_column(&self) -> bool { match self { ScalarExpression::Constant(_) => false, - ScalarExpression::ColumnRef(column, _) => { + ScalarExpression::ColumnRef(column) => { column.table_name().is_some() && column.id().is_some() } ScalarExpression::Alias { expr, .. } => expr.has_table_ref_column(), @@ -600,7 +600,7 @@ impl ScalarExpression { match self { ScalarExpression::AggCall { .. } => true, ScalarExpression::Constant(_) => false, - ScalarExpression::ColumnRef(_, _) => false, + ScalarExpression::ColumnRef(_) => false, ScalarExpression::Alias { expr, .. } => expr.has_agg_call(), ScalarExpression::TypeCast { expr, .. } => expr.has_agg_call(), ScalarExpression::IsNull { expr, .. } => expr.has_agg_call(), @@ -690,7 +690,7 @@ impl ScalarExpression { pub fn output_name(&self) -> String { match self { ScalarExpression::Constant(value) => format!("{}", value), - ScalarExpression::ColumnRef(col, _) => col.full_name(), + ScalarExpression::ColumnRef(col) => col.full_name(), ScalarExpression::Alias { alias, expr } => match alias { AliasType::Name(alias) => alias.to_string(), AliasType::Expr(alias_expr) => { @@ -881,7 +881,7 @@ impl ScalarExpression { pub fn output_column(&self) -> ColumnRef { match self { - ScalarExpression::ColumnRef(col, _) => col.clone(), + ScalarExpression::ColumnRef(col) => col.clone(), ScalarExpression::Alias { alias: AliasType::Expr(expr), .. @@ -1119,7 +1119,6 @@ mod test { ColumnDesc::new(LogicalType::Integer, None, false, None)?, false, )), - false, ), Some((&transaction, &table_cache)), &mut reference_tables, @@ -1136,7 +1135,6 @@ mod test { ColumnDesc::new(LogicalType::Boolean, None, false, None)?, false, )), - false, ), Some((&transaction, &table_cache)), &mut reference_tables, diff --git a/src/expression/range_detacher.rs b/src/expression/range_detacher.rs index f3928d4f..7284038c 100644 --- a/src/expression/range_detacher.rs +++ b/src/expression/range_detacher.rs @@ -216,7 +216,7 @@ impl<'a> RangeDetacher<'a> { ScalarExpression::Position { expr, .. } => self.detach(expr)?, ScalarExpression::Trim { expr, .. } => self.detach(expr)?, ScalarExpression::IsNull { expr, negated, .. } => match expr.as_ref() { - ScalarExpression::ColumnRef(column, _) => { + ScalarExpression::ColumnRef(column) => { if let (Some(col_id), Some(col_table)) = (column.id(), column.table_name()) { if &col_id == self.column_id && col_table.as_str() == self.table_name { return if *negated { @@ -253,7 +253,7 @@ impl<'a> RangeDetacher<'a> { | ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(), }, - ScalarExpression::Constant(_) | ScalarExpression::ColumnRef(_, _) => None, + ScalarExpression::Constant(_) | ScalarExpression::ColumnRef(_) => None, // FIXME: support [RangeDetacher::_detach] ScalarExpression::Tuple(_) | ScalarExpression::AggCall { .. } diff --git a/src/expression/simplify.rs b/src/expression/simplify.rs index 44b2e05f..820206e4 100644 --- a/src/expression/simplify.rs +++ b/src/expression/simplify.rs @@ -133,7 +133,7 @@ impl VisitorMut<'_> for Simplify { match (left_expr.unpack_col(false), right_expr.unpack_col(false)) { (Some(col), None) => { self.replaces.push(Replace::Binary(ReplaceBinary { - column_expr: ScalarExpression::ColumnRef(col, false), + column_expr: ScalarExpression::ColumnRef(col), val_expr: mem::replace(right_expr, ScalarExpression::Empty), op: *op, ty: ty.clone(), @@ -142,7 +142,7 @@ impl VisitorMut<'_> for Simplify { } (None, Some(col)) => { self.replaces.push(Replace::Binary(ReplaceBinary { - column_expr: ScalarExpression::ColumnRef(col, false), + column_expr: ScalarExpression::ColumnRef(col), val_expr: mem::replace(left_expr, ScalarExpression::Empty), op: *op, ty: ty.clone(), @@ -157,7 +157,7 @@ impl VisitorMut<'_> for Simplify { match (left_expr.unpack_col(true), right_expr.unpack_col(true)) { (Some(col), None) => { self.replaces.push(Replace::Binary(ReplaceBinary { - column_expr: ScalarExpression::ColumnRef(col, false), + column_expr: ScalarExpression::ColumnRef(col), val_expr: mem::replace(right_expr, ScalarExpression::Empty), op: *op, ty: ty.clone(), @@ -166,7 +166,7 @@ impl VisitorMut<'_> for Simplify { } (None, Some(col)) => { self.replaces.push(Replace::Binary(ReplaceBinary { - column_expr: ScalarExpression::ColumnRef(col, false), + column_expr: ScalarExpression::ColumnRef(col), val_expr: mem::replace(left_expr, ScalarExpression::Empty), op: *op, ty: ty.clone(), @@ -483,7 +483,7 @@ impl ScalarExpression { pub(crate) fn unpack_col(&self, is_deep: bool) -> Option { match self { - ScalarExpression::ColumnRef(col, _) => Some(col.clone()), + ScalarExpression::ColumnRef(col) => Some(col.clone()), ScalarExpression::Alias { expr, .. } => expr.unpack_col(is_deep), ScalarExpression::Unary { expr, .. } => expr.unpack_col(is_deep), ScalarExpression::Binary { diff --git a/src/expression/visitor.rs b/src/expression/visitor.rs index 691d45e3..e0484e76 100644 --- a/src/expression/visitor.rs +++ b/src/expression/visitor.rs @@ -254,7 +254,7 @@ pub fn walk_expr<'a, V: Visitor<'a>>( ) -> Result<(), DatabaseError> { match expr { ScalarExpression::Constant(value) => visitor.visit_constant(value), - ScalarExpression::ColumnRef(column_ref, _) => visitor.visit_column_ref(column_ref), + ScalarExpression::ColumnRef(column_ref) => visitor.visit_column_ref(column_ref), ScalarExpression::Alias { expr, alias } => visitor.visit_alias(expr, alias), ScalarExpression::TypeCast { expr, ty } => visitor.visit_type_cast(expr, ty), ScalarExpression::IsNull { negated, expr } => visitor.visit_is_null(*negated, expr), diff --git a/src/expression/visitor_mut.rs b/src/expression/visitor_mut.rs index e3e2317b..87aead51 100644 --- a/src/expression/visitor_mut.rs +++ b/src/expression/visitor_mut.rs @@ -254,7 +254,7 @@ pub fn walk_mut_expr<'a, V: VisitorMut<'a>>( ) -> Result<(), DatabaseError> { match expr { ScalarExpression::Constant(value) => visitor.visit_constant(value), - ScalarExpression::ColumnRef(column_ref, _) => visitor.visit_column_ref(column_ref), + ScalarExpression::ColumnRef(column_ref) => visitor.visit_column_ref(column_ref), ScalarExpression::Alias { expr, alias } => visitor.visit_alias(expr, alias), ScalarExpression::TypeCast { expr, ty } => visitor.visit_type_cast(expr, ty), ScalarExpression::IsNull { negated, expr } => visitor.visit_is_null(*negated, expr), diff --git a/src/optimizer/rule/normalization/simplification.rs b/src/optimizer/rule/normalization/simplification.rs index 8426b9c8..5deb4d7f 100644 --- a/src/optimizer/rule/normalization/simplification.rs +++ b/src/optimizer/rule/normalization/simplification.rs @@ -298,7 +298,6 @@ mod test { op: BinaryOperator::Plus, left_expr: Box::new(ScalarExpression::ColumnRef( ColumnRef::from(c1_col), - false )), right_expr: Box::new(ScalarExpression::Constant(DataValue::Int32(1))), evaluator: None, @@ -309,7 +308,6 @@ mod test { }), right_expr: Box::new(ScalarExpression::ColumnRef( ColumnRef::from(c2_col), - false )), evaluator: None, ty: LogicalType::Boolean, diff --git a/src/planner/operator/mod.rs b/src/planner/operator/mod.rs index 78a25b30..61c57837 100644 --- a/src/planner/operator/mod.rs +++ b/src/planner/operator/mod.rs @@ -136,7 +136,7 @@ impl Operator { Operator::TableScan(op) => Some( op.columns .values() - .map(|column| ScalarExpression::ColumnRef(column.clone(), false)) + .map(|column| ScalarExpression::ColumnRef(column.clone())) .collect_vec(), ), Operator::Sort(_) | Operator::Limit(_) => None, @@ -148,7 +148,7 @@ impl Operator { schema_ref .iter() .cloned() - .map(|col| ScalarExpression::ColumnRef(col, false)) + .map(|col| ScalarExpression::ColumnRef(col)) .collect_vec(), ), Operator::FunctionScan(op) => Some( @@ -156,7 +156,7 @@ impl Operator { .inner .output_schema() .iter() - .map(|column| ScalarExpression::ColumnRef(column.clone(), false)) + .map(|column| ScalarExpression::ColumnRef(column.clone())) .collect_vec(), ), Operator::ShowTable diff --git a/src/types/index.rs b/src/types/index.rs index 8762566f..df57e070 100644 --- a/src/types/index.rs +++ b/src/types/index.rs @@ -46,7 +46,7 @@ impl IndexMeta { for column_id in self.column_ids.iter() { if let Some(column) = table.get_column_by_id(column_id) { - exprs.push(ScalarExpression::ColumnRef(column.clone(), false)); + exprs.push(ScalarExpression::ColumnRef(column.clone())); } else { return Err(DatabaseError::ColumnNotFound(column_id.to_string())); } From 818a9c5b0e01710320b37e14f421089ec7fb5ab3 Mon Sep 17 00:00:00 2001 From: wszhdshys <1925792291@qq.com> Date: Wed, 23 Jul 2025 23:07:44 +0800 Subject: [PATCH 5/7] Using the RBO method, the related subquery is implemented --- src/binder/expr.rs | 7 +- src/binder/mod.rs | 10 +- src/binder/select.rs | 22 +- src/db.rs | 7 +- src/execution/dql/filter.rs | 5 +- src/execution/dql/join/nested_loop_join.rs | 12 +- src/expression/mod.rs | 46 ++-- src/optimizer/heuristic/graph.rs | 3 +- .../rule/normalization/correlated_subquery.rs | 244 ++++++++++++++++++ src/optimizer/rule/normalization/mod.rs | 5 + .../rule/normalization/simplification.rs | 10 +- tests/slt/correlated_subquery.slt | 7 - 12 files changed, 303 insertions(+), 75 deletions(-) create mode 100644 src/optimizer/rule/normalization/correlated_subquery.rs diff --git a/src/binder/expr.rs b/src/binder/expr.rs index 8e7bcbbf..a7bd9ebf 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -369,10 +369,13 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T try_default!(&full_name.0, full_name.1); } if let Some(table) = full_name.0.or(bind_table_name) { - let (source,is_parent) = self.context.bind_source::(self.parent, &table, false)?; + let (source, is_parent) = self.context.bind_source::(self.parent, &table, false)?; if is_parent { - self.parent_table_col.entry(Arc::new(table.clone())).or_insert(HashSet::new()).insert(full_name.1.clone()); + self.parent_table_col + .entry(Arc::new(table.clone())) + .or_default() + .insert(full_name.1.clone()); } let schema_buf = self.table_schema_buf.entry(Arc::new(table)).or_default(); diff --git a/src/binder/mod.rs b/src/binder/mod.rs index 3068a7a3..3c09a310 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -25,7 +25,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use crate::catalog::view::View; -use crate::catalog::{ColumnCatalog, ColumnRef, TableCatalog, TableName}; +use crate::catalog::{ColumnRef, TableCatalog, TableName}; use crate::db::{ScalaFunctions, TableFunctions}; use crate::errors::DatabaseError; use crate::expression::ScalarExpression; @@ -275,9 +275,9 @@ impl<'a, T: Transaction> BinderContext<'a, T> { Ok(source) } - pub fn bind_source<'b: 'a, A: AsRef<[(&'static str, DataValue)]> >( + pub fn bind_source<'b: 'a, A: AsRef<[(&'static str, DataValue)]>>( &self, - parent: Option<&'a Binder<'a,'b,T,A>>, + parent: Option<&'a Binder<'a, 'b, T, A>>, table_name: &str, is_parent: bool, ) -> Result<(&'b Source, bool), DatabaseError> { @@ -287,10 +287,10 @@ impl<'a, T: Transaction> BinderContext<'a, T> { }) { Ok((source.1, is_parent)) } else if let Some(binder) = parent { - binder.context.bind_source(binder.parent, table_name,true) + binder.context.bind_source(binder.parent, table_name, true) } else { Err(DatabaseError::InvalidTable(table_name.into())) - } + } } // Tips: The order of this index is based on Aggregate being bound first. diff --git a/src/binder/select.rs b/src/binder/select.rs index fc1114dc..9c437f94 100644 --- a/src/binder/select.rs +++ b/src/binder/select.rs @@ -36,7 +36,7 @@ use crate::types::value::Utf8Type; use crate::types::{ColumnId, LogicalType}; use itertools::Itertools; use sqlparser::ast::{ - CharLengthUnits, Distinct, Expr, Ident, Join, JoinConstraint, JoinOperator, ObjectName, Offset, + CharLengthUnits, Distinct, Expr, Ident, Join, JoinConstraint, JoinOperator, Offset, OrderByExpr, Query, Select, SelectInto, SelectItem, SetExpr, SetOperator, SetQuantifier, TableAlias, TableFactor, TableWithJoins, }; @@ -101,8 +101,6 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' }; let mut select_list = self.normalize_select_item(&select.projection, &plan)?; - plan = self.bind_parent(plan)?; - if let Some(predicate) = &select.selection { plan = self.bind_where(plan, predicate)?; } @@ -536,22 +534,6 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' Ok(()) } - fn bind_parent( - &mut self, - mut plan: LogicalPlan, - ) -> Result { - for (table,columns) in self.parent_table_col.clone().into_iter() { - let parent = self._bind_single_table_ref(None,table.as_str(),None)?; - plan = LJoinOperator::build( - plan, - parent, - JoinCondition::None, - JoinType::Full, - ); - } - Ok(plan) - } - fn bind_join( &mut self, mut left: LogicalPlan, @@ -613,8 +595,6 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' let predicate = self.bind_expr(predicate)?; - children = self.bind_parent(children)?; - if let Some(sub_queries) = self.context.sub_queries_at_now() { for sub_query in sub_queries { let mut on_keys: Vec<(ScalarExpression, ScalarExpression)> = vec![]; diff --git a/src/db.rs b/src/db.rs index ff08a4fa..29aed351 100644 --- a/src/db.rs +++ b/src/db.rs @@ -166,13 +166,18 @@ impl State { let best_plan = Self::default_optimizer(source_plan) .find_best(Some(&transaction.meta_loader(meta_cache)))?; - // println!("best_plan plan: {:#?}", best_plan); + //println!("best_plan plan: {:#?}", best_plan); Ok(best_plan) } pub(crate) fn default_optimizer(source_plan: LogicalPlan) -> HepOptimizer { HepOptimizer::new(source_plan) + .batch( + "Correlated Subquery".to_string(), + HepBatchStrategy::once_topdown(), + vec![NormalizationRuleImpl::CorrelateSubquery], + ) .batch( "Column Pruning".to_string(), HepBatchStrategy::once_topdown(), diff --git a/src/execution/dql/filter.rs b/src/execution/dql/filter.rs index 21ce815a..e55c0f83 100644 --- a/src/execution/dql/filter.rs +++ b/src/execution/dql/filter.rs @@ -35,12 +35,15 @@ impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for Filter { let schema = input.output_schema().clone(); + //println!("{:#?}114514'\n'1919810{:#?}",predicate,schema); + let mut coroutine = build_read(input, cache, transaction); while let CoroutineState::Yielded(tuple) = Pin::new(&mut coroutine).resume(()) { let tuple = throw!(tuple); - + //println!("-> Coroutine returned: {:?}", tuple); if throw!(throw!(predicate.eval(Some((&tuple, &schema)))).is_true()) { + //println!("-> throw!: {:?}", tuple); yield Ok(tuple); } } diff --git a/src/execution/dql/join/nested_loop_join.rs b/src/execution/dql/join/nested_loop_join.rs index 67492caa..e79dca4a 100644 --- a/src/execution/dql/join/nested_loop_join.rs +++ b/src/execution/dql/join/nested_loop_join.rs @@ -505,12 +505,12 @@ mod test { let filter = ScalarExpression::Binary { op: crate::expression::BinaryOperator::Gt, - left_expr: Box::new(ScalarExpression::ColumnRef( - ColumnRef::from(ColumnCatalog::new("c1".to_owned(), true, desc.clone())), - )), - right_expr: Box::new(ScalarExpression::ColumnRef( - ColumnRef::from(ColumnCatalog::new("c4".to_owned(), true, desc.clone())), - )), + left_expr: Box::new(ScalarExpression::ColumnRef(ColumnRef::from( + ColumnCatalog::new("c1".to_owned(), true, desc.clone()), + ))), + right_expr: Box::new(ScalarExpression::ColumnRef(ColumnRef::from( + ColumnCatalog::new("c4".to_owned(), true, desc.clone()), + ))), evaluator: Some(BinaryEvaluatorBox(Arc::new(Int32GtBinaryEvaluator))), ty: LogicalType::Boolean, }; diff --git a/src/expression/mod.rs b/src/expression/mod.rs index 0f9f03aa..5ce161da 100644 --- a/src/expression/mod.rs +++ b/src/expression/mod.rs @@ -1105,37 +1105,33 @@ mod test { )?; fn_assert( &mut cursor, - ScalarExpression::ColumnRef( - ColumnRef::from(ColumnCatalog::direct_new( - ColumnSummary { - name: "c3".to_string(), - relation: ColumnRelation::Table { - column_id: c3_column_id, - table_name: Arc::new("t1".to_string()), - is_temp: false, - }, + ScalarExpression::ColumnRef(ColumnRef::from(ColumnCatalog::direct_new( + ColumnSummary { + name: "c3".to_string(), + relation: ColumnRelation::Table { + column_id: c3_column_id, + table_name: Arc::new("t1".to_string()), + is_temp: false, }, - false, - ColumnDesc::new(LogicalType::Integer, None, false, None)?, - false, - )), - ), + }, + false, + ColumnDesc::new(LogicalType::Integer, None, false, None)?, + false, + ))), Some((&transaction, &table_cache)), &mut reference_tables, )?; fn_assert( &mut cursor, - ScalarExpression::ColumnRef( - ColumnRef::from(ColumnCatalog::direct_new( - ColumnSummary { - name: "c4".to_string(), - relation: ColumnRelation::None, - }, - false, - ColumnDesc::new(LogicalType::Boolean, None, false, None)?, - false, - )), - ), + ScalarExpression::ColumnRef(ColumnRef::from(ColumnCatalog::direct_new( + ColumnSummary { + name: "c4".to_string(), + relation: ColumnRelation::None, + }, + false, + ColumnDesc::new(LogicalType::Boolean, None, false, None)?, + false, + ))), Some((&transaction, &table_cache)), &mut reference_tables, )?; diff --git a/src/optimizer/heuristic/graph.rs b/src/optimizer/heuristic/graph.rs index f6de8c61..6695bfdc 100644 --- a/src/optimizer/heuristic/graph.rs +++ b/src/optimizer/heuristic/graph.rs @@ -79,7 +79,7 @@ impl HepGraph { source_id: HepNodeId, children_option: Option, new_node: Operator, - ) { + ) -> HepNodeId { let new_index = self.graph.add_node(new_node); let mut order = self.graph.edges(source_id).count(); @@ -95,6 +95,7 @@ impl HepGraph { self.graph.add_edge(source_id, new_index, order); self.version += 1; + new_index } pub fn replace_node(&mut self, source_id: HepNodeId, new_node: Operator) { diff --git a/src/optimizer/rule/normalization/correlated_subquery.rs b/src/optimizer/rule/normalization/correlated_subquery.rs new file mode 100644 index 00000000..01d96c26 --- /dev/null +++ b/src/optimizer/rule/normalization/correlated_subquery.rs @@ -0,0 +1,244 @@ +use crate::catalog::{ColumnRef, TableName}; +use crate::errors::DatabaseError; +use crate::expression::visitor::Visitor; +use crate::expression::HasCountStar; +use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; +use crate::optimizer::core::rule::{MatchPattern, NormalizationRule}; +use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; +use crate::planner::operator::join::{JoinCondition, JoinOperator, JoinType}; +use crate::planner::operator::table_scan::TableScanOperator; +use crate::planner::operator::Operator; +use crate::planner::operator::Operator::{Join, TableScan}; +use crate::types::index::IndexInfo; +use crate::types::ColumnId; +use itertools::Itertools; +use std::collections::BTreeMap; +use std::collections::{HashMap, HashSet}; +use std::sync::{Arc, LazyLock}; + +static CORRELATED_SUBQUERY_RULE: LazyLock = LazyLock::new(|| Pattern { + predicate: |op| matches!(op, Join(_)), + children: PatternChildrenPredicate::None, +}); + +#[derive(Clone)] +pub struct CorrelatedSubquery; + +macro_rules! trans_references { + ($columns:expr) => {{ + let mut column_references = HashSet::with_capacity($columns.len()); + for column in $columns { + column_references.insert(column); + } + column_references + }}; +} + +impl CorrelatedSubquery { + fn _apply( + column_references: HashSet<&ColumnRef>, + scan_columns: HashMap, HashMap, Vec)>, + node_id: HepNodeId, + graph: &mut HepGraph, + ) -> Result< + HashMap, HashMap, Vec)>, + DatabaseError, + > { + let operator = &graph.operator(node_id).clone(); + + match operator { + Operator::Aggregate(op) => { + let is_distinct = op.is_distinct; + let referenced_columns = operator.referenced_columns(false); + let mut new_column_references = trans_references!(&referenced_columns); + // on distinct + if is_distinct { + for summary in column_references { + new_column_references.insert(summary); + } + } + + Self::recollect_apply(new_column_references, scan_columns, node_id, graph) + } + Operator::Project(op) => { + let mut has_count_star = HasCountStar::default(); + for expr in &op.exprs { + has_count_star.visit(expr)?; + } + let referenced_columns = operator.referenced_columns(false); + let new_column_references = trans_references!(&referenced_columns); + + Self::recollect_apply(new_column_references, scan_columns, node_id, graph) + } + Operator::TableScan(op) => { + let table_column: HashSet<&ColumnRef> = op.columns.values().collect(); + let mut new_scan_columns = scan_columns.clone(); + new_scan_columns.insert( + op.table_name.clone(), + ( + op.primary_keys.clone(), + op.columns + .iter() + .map(|(num, col)| (col.id().unwrap(), *num)) + .collect(), + op.index_infos.clone(), + ), + ); + let mut parent_col = HashMap::new(); + for col in column_references { + match ( + table_column.contains(col), + scan_columns.get(col.table_name().unwrap_or(&Arc::new("".to_string()))), + ) { + (false, Some(..)) => { + parent_col + .entry(col.table_name().unwrap()) + .or_insert(HashSet::new()) + .insert(col); + } + _ => continue, + } + } + for (table_name, table_columns) in parent_col { + let table_columns = table_columns.into_iter().collect_vec(); + let (primary_keys, columns, index_infos) = + scan_columns.get(table_name).unwrap(); + let map: BTreeMap = table_columns + .into_iter() + .map(|col| (*columns.get(&col.id().unwrap()).unwrap(), col.clone())) + .collect(); + let left_operator = graph.operator(node_id).clone(); + let right_operator = TableScan(TableScanOperator { + table_name: table_name.clone(), + primary_keys: primary_keys.clone(), + columns: map, + limit: (None, None), + index_infos: index_infos.clone(), + with_pk: false, + }); + let join_operator = Join(JoinOperator { + on: JoinCondition::None, + join_type: JoinType::Cross, + }); + + match &left_operator { + TableScan(_) => { + graph.replace_node(node_id, join_operator); + graph.add_node(node_id, None, left_operator); + graph.add_node(node_id, None, right_operator); + } + Join(_) => { + let left_id = graph.eldest_child_at(node_id).unwrap(); + let left_id = graph.add_node(node_id, Some(left_id), join_operator); + graph.add_node(left_id, None, right_operator); + } + _ => unreachable!(), + } + } + Ok(new_scan_columns) + } + Operator::Sort(_) | Operator::Limit(_) | Operator::Filter(_) | Operator::Union(_) => { + let mut new_scan_columns = scan_columns.clone(); + let temp_columns = operator.referenced_columns(false); + // why? + let mut column_references = column_references; + for column in temp_columns.iter() { + column_references.insert(column); + } + for child_id in graph.children_at(node_id).collect_vec() { + let copy_references = column_references.clone(); + let copy_scan = scan_columns.clone(); + if let Ok(scan) = Self::_apply(copy_references, copy_scan, child_id, graph) { + new_scan_columns.extend(scan); + }; + } + Ok(new_scan_columns) + } + Operator::Join(_) => { + let mut new_scan_columns = scan_columns.clone(); + for child_id in graph.children_at(node_id).collect_vec() { + let copy_references = column_references.clone(); + let copy_scan = new_scan_columns.clone(); + if let Ok(scan) = Self::_apply(copy_references, copy_scan, child_id, graph) { + new_scan_columns.extend(scan); + }; + } + Ok(new_scan_columns) + } + // Last Operator + Operator::Dummy | Operator::Values(_) | Operator::FunctionScan(_) => Ok(scan_columns), + Operator::Explain => { + if let Some(child_id) = graph.eldest_child_at(node_id) { + Self::_apply(column_references, scan_columns, child_id, graph) + } else { + unreachable!() + } + } + // DDL Based on Other Plan + Operator::Insert(_) + | Operator::Update(_) + | Operator::Delete(_) + | Operator::Analyze(_) => { + let referenced_columns = operator.referenced_columns(false); + let new_column_references = trans_references!(&referenced_columns); + + if let Some(child_id) = graph.eldest_child_at(node_id) { + Self::recollect_apply(new_column_references, scan_columns, child_id, graph) + } else { + unreachable!(); + } + } + // DDL Single Plan + Operator::CreateTable(_) + | Operator::CreateIndex(_) + | Operator::CreateView(_) + | Operator::DropTable(_) + | Operator::DropView(_) + | Operator::Truncate(_) + | Operator::ShowTable + | Operator::ShowView + | Operator::CopyFromFile(_) + | Operator::CopyToFile(_) + | Operator::AddColumn(_) + | Operator::DropColumn(_) + | Operator::Describe(_) => Ok(scan_columns), + } + } + + fn recollect_apply( + referenced_columns: HashSet<&ColumnRef>, + scan_columns: HashMap, HashMap, Vec)>, + node_id: HepNodeId, + graph: &mut HepGraph, + ) -> Result< + HashMap, HashMap, Vec)>, + DatabaseError, + > { + let mut new_scan_columns = scan_columns.clone(); + for child_id in graph.children_at(node_id).collect_vec() { + let copy_references = referenced_columns.clone(); + let copy_scan = scan_columns.clone(); + + if let Ok(scan) = Self::_apply(copy_references, copy_scan, child_id, graph) { + new_scan_columns.extend(scan); + }; + } + Ok(new_scan_columns) + } +} + +impl MatchPattern for CorrelatedSubquery { + fn pattern(&self) -> &Pattern { + &CORRELATED_SUBQUERY_RULE + } +} + +impl NormalizationRule for CorrelatedSubquery { + fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), DatabaseError> { + Self::_apply(HashSet::new(), HashMap::new(), node_id, graph)?; + // mark changed to skip this rule batch + graph.version += 1; + + Ok(()) + } +} diff --git a/src/optimizer/rule/normalization/mod.rs b/src/optimizer/rule/normalization/mod.rs index 8c30fd4a..b7fc6d4b 100644 --- a/src/optimizer/rule/normalization/mod.rs +++ b/src/optimizer/rule/normalization/mod.rs @@ -10,6 +10,7 @@ use crate::optimizer::rule::normalization::combine_operators::{ use crate::optimizer::rule::normalization::compilation_in_advance::{ EvaluatorBind, ExpressionRemapper, }; +use crate::optimizer::rule::normalization::correlated_subquery::CorrelatedSubquery; use crate::optimizer::rule::normalization::pushdown_limit::{ LimitProjectTranspose, PushLimitIntoScan, PushLimitThroughJoin, }; @@ -21,6 +22,7 @@ use crate::optimizer::rule::normalization::simplification::SimplifyFilter; mod column_pruning; mod combine_operators; mod compilation_in_advance; +mod correlated_subquery; mod pushdown_limit; mod pushdown_predicates; mod simplification; @@ -32,6 +34,7 @@ pub enum NormalizationRuleImpl { CollapseProject, CollapseGroupByAgg, CombineFilter, + CorrelateSubquery, // PushDown limit LimitProjectTranspose, PushLimitThroughJoin, @@ -55,6 +58,7 @@ impl MatchPattern for NormalizationRuleImpl { NormalizationRuleImpl::CollapseProject => CollapseProject.pattern(), NormalizationRuleImpl::CollapseGroupByAgg => CollapseGroupByAgg.pattern(), NormalizationRuleImpl::CombineFilter => CombineFilter.pattern(), + NormalizationRuleImpl::CorrelateSubquery => CorrelatedSubquery.pattern(), NormalizationRuleImpl::LimitProjectTranspose => LimitProjectTranspose.pattern(), NormalizationRuleImpl::PushLimitThroughJoin => PushLimitThroughJoin.pattern(), NormalizationRuleImpl::PushLimitIntoTableScan => PushLimitIntoScan.pattern(), @@ -75,6 +79,7 @@ impl NormalizationRule for NormalizationRuleImpl { NormalizationRuleImpl::CollapseProject => CollapseProject.apply(node_id, graph), NormalizationRuleImpl::CollapseGroupByAgg => CollapseGroupByAgg.apply(node_id, graph), NormalizationRuleImpl::CombineFilter => CombineFilter.apply(node_id, graph), + NormalizationRuleImpl::CorrelateSubquery => CorrelatedSubquery.apply(node_id, graph), NormalizationRuleImpl::LimitProjectTranspose => { LimitProjectTranspose.apply(node_id, graph) } diff --git a/src/optimizer/rule/normalization/simplification.rs b/src/optimizer/rule/normalization/simplification.rs index 5deb4d7f..02d66ef5 100644 --- a/src/optimizer/rule/normalization/simplification.rs +++ b/src/optimizer/rule/normalization/simplification.rs @@ -296,9 +296,9 @@ mod test { op: UnaryOperator::Minus, expr: Box::new(ScalarExpression::Binary { op: BinaryOperator::Plus, - left_expr: Box::new(ScalarExpression::ColumnRef( - ColumnRef::from(c1_col), - )), + left_expr: Box::new(ScalarExpression::ColumnRef(ColumnRef::from( + c1_col + ),)), right_expr: Box::new(ScalarExpression::Constant(DataValue::Int32(1))), evaluator: None, ty: LogicalType::Integer, @@ -306,9 +306,7 @@ mod test { evaluator: None, ty: LogicalType::Integer, }), - right_expr: Box::new(ScalarExpression::ColumnRef( - ColumnRef::from(c2_col), - )), + right_expr: Box::new(ScalarExpression::ColumnRef(ColumnRef::from(c2_col),)), evaluator: None, ty: LogicalType::Boolean, } diff --git a/tests/slt/correlated_subquery.slt b/tests/slt/correlated_subquery.slt index 90331034..1a6c0d62 100644 --- a/tests/slt/correlated_subquery.slt +++ b/tests/slt/correlated_subquery.slt @@ -73,13 +73,6 @@ SELECT id, v1 FROM t1 WHERE EXISTS ( SELECT 1 FROM t2 WHERE t2.id = t1.id AND E 2 b 3 c -# query ITI rowsort -# SELECT id, v1, ( SELECT COUNT(*) FROM t2 WHERE t2.v2 >= 10 ) as cnt FROM t1 -# ---- -# 1 a 2 -# 2 b 3 -# 3 c 2 - query IT rowsort SELECT id, v1 FROM t1 WHERE v2 - 5 > ( SELECT AVG(v1) FROM t3 WHERE t3.id <= t1.id ) ---- From 015af41077fcb9b8963e5a916a7311f7c21365b4 Mon Sep 17 00:00:00 2001 From: Kould <2435992353@qq.com> Date: Tue, 22 Jul 2025 02:01:15 +0800 Subject: [PATCH 6/7] feat: impl `Drop Index` (#276) * feat: impl `Drop Index` * chore: codefmt (cherry picked from commit b710db44398c5ddd63c04f20afee08f4f46645e2) --- docs/features.md | 3 +- src/binder/drop_index.rs | 35 ++++++++ src/binder/mod.rs | 2 + src/execution/ddl/drop_index.rs | 43 ++++++++++ src/execution/ddl/mod.rs | 1 + src/execution/dml/analyze.rs | 15 +++- src/execution/mod.rs | 2 + .../rule/normalization/column_pruning.rs | 1 + .../normalization/compilation_in_advance.rs | 2 + src/planner/mod.rs | 3 + src/planner/operator/drop_index.rs | 43 ++++++++++ src/planner/operator/mod.rs | 6 ++ src/storage/mod.rs | 84 +++++++++++++++++-- src/storage/table_codec.rs | 27 ++++-- tests/slt/create_index.slt | 15 ++++ 15 files changed, 262 insertions(+), 20 deletions(-) create mode 100644 src/binder/drop_index.rs create mode 100644 src/execution/ddl/drop_index.rs create mode 100644 src/planner/operator/drop_index.rs diff --git a/docs/features.md b/docs/features.md index 9f6f88ff..49d859f7 100644 --- a/docs/features.md +++ b/docs/features.md @@ -105,7 +105,8 @@ let kite_sql = DataBaseBuilder::path("./data") - [x] View - Drop - [x] Table - - [ ] Index + - [x] Index + - Tips: `Drop Index table_name.index_name` - [x] View - Alert - [x] Add Column diff --git a/src/binder/drop_index.rs b/src/binder/drop_index.rs new file mode 100644 index 00000000..17bd2c45 --- /dev/null +++ b/src/binder/drop_index.rs @@ -0,0 +1,35 @@ +use crate::binder::{lower_ident, Binder}; +use crate::errors::DatabaseError; +use crate::planner::operator::drop_index::DropIndexOperator; +use crate::planner::operator::Operator; +use crate::planner::{Childrens, LogicalPlan}; +use crate::storage::Transaction; +use crate::types::value::DataValue; +use sqlparser::ast::ObjectName; +use std::sync::Arc; + +impl> Binder<'_, '_, T, A> { + pub(crate) fn bind_drop_index( + &mut self, + name: &ObjectName, + if_exists: &bool, + ) -> Result { + let table_name = name + .0 + .first() + .ok_or(DatabaseError::InvalidTable(name.to_string()))?; + let index_name = name.0.get(1).ok_or(DatabaseError::InvalidIndex)?; + + let table_name = Arc::new(lower_ident(table_name)); + let index_name = lower_ident(index_name); + + Ok(LogicalPlan::new( + Operator::DropIndex(DropIndexOperator { + table_name, + index_name, + if_exists: *if_exists, + }), + Childrens::None, + )) + } +} diff --git a/src/binder/mod.rs b/src/binder/mod.rs index 3c09a310..2f845b9f 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -8,6 +8,7 @@ mod create_view; mod delete; mod describe; mod distinct; +mod drop_index; mod drop_table; mod drop_view; mod explain; @@ -384,6 +385,7 @@ impl<'a, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, ' match object_type { ObjectType::Table => self.bind_drop_table(&names[0], if_exists)?, ObjectType::View => self.bind_drop_view(&names[0], if_exists)?, + ObjectType::Index => self.bind_drop_index(&names[0], if_exists)?, _ => { return Err(DatabaseError::UnsupportedStmt( "only `Table` and `View` are allowed to be Dropped".to_string(), diff --git a/src/execution/ddl/drop_index.rs b/src/execution/ddl/drop_index.rs new file mode 100644 index 00000000..f0d5f316 --- /dev/null +++ b/src/execution/ddl/drop_index.rs @@ -0,0 +1,43 @@ +use crate::execution::{Executor, WriteExecutor}; +use crate::planner::operator::drop_index::DropIndexOperator; +use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; +use crate::throw; +use crate::types::tuple_builder::TupleBuilder; + +pub struct DropIndex { + op: DropIndexOperator, +} + +impl From for DropIndex { + fn from(op: DropIndexOperator) -> Self { + Self { op } + } +} + +impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for DropIndex { + fn execute_mut( + self, + (table_cache, _, _): (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache), + transaction: *mut T, + ) -> Executor<'a> { + Box::new( + #[coroutine] + move || { + let DropIndexOperator { + table_name, + index_name, + if_exists, + } = self.op; + + throw!(unsafe { &mut (*transaction) }.drop_index( + table_cache, + table_name, + &index_name, + if_exists + )); + + yield Ok(TupleBuilder::build_result(index_name.to_string())); + }, + ) + } +} diff --git a/src/execution/ddl/mod.rs b/src/execution/ddl/mod.rs index 23769f7c..294c6c39 100644 --- a/src/execution/ddl/mod.rs +++ b/src/execution/ddl/mod.rs @@ -3,6 +3,7 @@ pub(crate) mod create_index; pub(crate) mod create_table; pub(crate) mod create_view; pub(crate) mod drop_column; +pub(crate) mod drop_index; pub(crate) mod drop_table; pub(crate) mod drop_view; pub(crate) mod truncate; diff --git a/src/execution/dml/analyze.rs b/src/execution/dml/analyze.rs index 69be1397..3dfd1c90 100644 --- a/src/execution/dml/analyze.rs +++ b/src/execution/dml/analyze.rs @@ -19,6 +19,7 @@ use std::fmt::Formatter; use std::fs::DirEntry; use std::ops::Coroutine; use std::ops::CoroutineState; +use std::path::PathBuf; use std::pin::Pin; use std::sync::Arc; use std::{fmt, fs}; @@ -98,10 +99,7 @@ impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for Analyze { } drop(coroutine); let mut values = Vec::with_capacity(builders.len()); - let dir_path = dirs::home_dir() - .expect("Your system does not have a Config directory!") - .join(DEFAULT_STATISTICS_META_PATH) - .join(table_name.as_str()); + let dir_path = Self::build_statistics_meta_path(&table_name); // For DEBUG // println!("Statistics Path: {:#?}", dir_path); throw!(fs::create_dir_all(&dir_path).map_err(DatabaseError::IO)); @@ -149,6 +147,15 @@ impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for Analyze { } } +impl Analyze { + pub fn build_statistics_meta_path(table_name: &TableName) -> PathBuf { + dirs::home_dir() + .expect("Your system does not have a Config directory!") + .join(DEFAULT_STATISTICS_META_PATH) + .join(table_name.as_str()) + } +} + impl fmt::Display for AnalyzeOperator { fn fmt(&self, f: &mut Formatter) -> fmt::Result { let indexes = self.index_metas.iter().map(|index| &index.name).join(", "); diff --git a/src/execution/mod.rs b/src/execution/mod.rs index 9c4ed8c6..1df7b16c 100644 --- a/src/execution/mod.rs +++ b/src/execution/mod.rs @@ -10,6 +10,7 @@ use crate::execution::ddl::create_index::CreateIndex; use crate::execution::ddl::create_table::CreateTable; use crate::execution::ddl::create_view::CreateView; use crate::execution::ddl::drop_column::DropColumn; +use crate::execution::ddl::drop_index::DropIndex; use crate::execution::ddl::drop_table::DropTable; use crate::execution::ddl::drop_view::DropView; use crate::execution::ddl::truncate::Truncate; @@ -194,6 +195,7 @@ pub fn build_write<'a, T: Transaction + 'a>( Operator::CreateView(op) => CreateView::from(op).execute_mut(cache, transaction), Operator::DropTable(op) => DropTable::from(op).execute_mut(cache, transaction), Operator::DropView(op) => DropView::from(op).execute_mut(cache, transaction), + Operator::DropIndex(op) => DropIndex::from(op).execute_mut(cache, transaction), Operator::Truncate(op) => Truncate::from(op).execute_mut(cache, transaction), Operator::CopyFromFile(op) => CopyFromFile::from(op).execute_mut(cache, transaction), Operator::CopyToFile(op) => { diff --git a/src/optimizer/rule/normalization/column_pruning.rs b/src/optimizer/rule/normalization/column_pruning.rs index bd8eeddb..9626c62d 100644 --- a/src/optimizer/rule/normalization/column_pruning.rs +++ b/src/optimizer/rule/normalization/column_pruning.rs @@ -152,6 +152,7 @@ impl ColumnPruning { | Operator::CreateView(_) | Operator::DropTable(_) | Operator::DropView(_) + | Operator::DropIndex(_) | Operator::Truncate(_) | Operator::ShowTable | Operator::ShowView diff --git a/src/optimizer/rule/normalization/compilation_in_advance.rs b/src/optimizer/rule/normalization/compilation_in_advance.rs index e2b2923b..589c1e0e 100644 --- a/src/optimizer/rule/normalization/compilation_in_advance.rs +++ b/src/optimizer/rule/normalization/compilation_in_advance.rs @@ -104,6 +104,7 @@ impl ExpressionRemapper { | Operator::CreateView(_) | Operator::DropTable(_) | Operator::DropView(_) + | Operator::DropIndex(_) | Operator::Truncate(_) | Operator::CopyFromFile(_) | Operator::CopyToFile(_) @@ -212,6 +213,7 @@ impl EvaluatorBind { | Operator::CreateView(_) | Operator::DropTable(_) | Operator::DropView(_) + | Operator::DropIndex(_) | Operator::Truncate(_) | Operator::CopyFromFile(_) | Operator::CopyToFile(_) diff --git a/src/planner/mod.rs b/src/planner/mod.rs index 1220af74..c6f20e90 100644 --- a/src/planner/mod.rs +++ b/src/planner/mod.rs @@ -221,6 +221,9 @@ impl LogicalPlan { Operator::DropView(_) => SchemaOutput::Schema(vec![ColumnRef::from( ColumnCatalog::new_dummy("DROP VIEW SUCCESS".to_string()), )]), + Operator::DropIndex(_) => SchemaOutput::Schema(vec![ColumnRef::from( + ColumnCatalog::new_dummy("DROP INDEX SUCCESS".to_string()), + )]), Operator::Truncate(_) => SchemaOutput::Schema(vec![ColumnRef::from( ColumnCatalog::new_dummy("TRUNCATE TABLE SUCCESS".to_string()), )]), diff --git a/src/planner/operator/drop_index.rs b/src/planner/operator/drop_index.rs new file mode 100644 index 00000000..695cd18f --- /dev/null +++ b/src/planner/operator/drop_index.rs @@ -0,0 +1,43 @@ +use crate::catalog::TableName; +use crate::planner::operator::Operator; +use crate::planner::{Childrens, LogicalPlan}; +use kite_sql_serde_macros::ReferenceSerialization; +use std::fmt; +use std::fmt::Formatter; + +#[derive(Debug, PartialEq, Eq, Clone, Hash, ReferenceSerialization)] +pub struct DropIndexOperator { + pub table_name: TableName, + pub index_name: String, + pub if_exists: bool, +} + +impl DropIndexOperator { + pub fn build( + table_name: TableName, + index_name: String, + if_exists: bool, + childrens: Childrens, + ) -> LogicalPlan { + LogicalPlan::new( + Operator::DropIndex(DropIndexOperator { + table_name, + index_name, + if_exists, + }), + childrens, + ) + } +} + +impl fmt::Display for DropIndexOperator { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!( + f, + "Drop Index {} On {}, If Exists: {}", + self.index_name, self.table_name, self.if_exists + )?; + + Ok(()) + } +} diff --git a/src/planner/operator/mod.rs b/src/planner/operator/mod.rs index 61c57837..ed6d5bd1 100644 --- a/src/planner/operator/mod.rs +++ b/src/planner/operator/mod.rs @@ -8,6 +8,7 @@ pub mod create_table; pub mod create_view; pub mod delete; pub mod describe; +pub mod drop_index; pub mod drop_table; pub mod drop_view; pub mod filter; @@ -39,6 +40,7 @@ use crate::planner::operator::create_table::CreateTableOperator; use crate::planner::operator::create_view::CreateViewOperator; use crate::planner::operator::delete::DeleteOperator; use crate::planner::operator::describe::DescribeOperator; +use crate::planner::operator::drop_index::DropIndexOperator; use crate::planner::operator::drop_table::DropTableOperator; use crate::planner::operator::drop_view::DropViewOperator; use crate::planner::operator::function_scan::FunctionScanOperator; @@ -85,6 +87,7 @@ pub enum Operator { CreateView(CreateViewOperator), DropTable(DropTableOperator), DropView(DropViewOperator), + DropIndex(DropIndexOperator), Truncate(TruncateOperator), // Copy CopyFromFile(CopyFromFileOperator), @@ -174,6 +177,7 @@ impl Operator { | Operator::CreateView(_) | Operator::DropTable(_) | Operator::DropView(_) + | Operator::DropIndex(_) | Operator::Truncate(_) | Operator::CopyFromFile(_) | Operator::CopyToFile(_) => None, @@ -248,6 +252,7 @@ impl Operator { | Operator::CreateView(_) | Operator::DropTable(_) | Operator::DropView(_) + | Operator::DropIndex(_) | Operator::Truncate(_) | Operator::CopyFromFile(_) | Operator::CopyToFile(_) => vec![], @@ -283,6 +288,7 @@ impl fmt::Display for Operator { Operator::CreateView(op) => write!(f, "{}", op), Operator::DropTable(op) => write!(f, "{}", op), Operator::DropView(op) => write!(f, "{}", op), + Operator::DropIndex(op) => write!(f, "{}", op), Operator::Truncate(op) => write!(f, "{}", op), Operator::CopyFromFile(op) => write!(f, "{}", op), Operator::CopyToFile(op) => write!(f, "{}", op), diff --git a/src/storage/mod.rs b/src/storage/mod.rs index 67aad752..f40f43c5 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -4,6 +4,7 @@ pub(crate) mod table_codec; use crate::catalog::view::View; use crate::catalog::{ColumnCatalog, ColumnRef, TableCatalog, TableMeta, TableName}; use crate::errors::DatabaseError; +use crate::execution::dml::analyze::Analyze; use crate::expression::range_detacher::Range; use crate::optimizer::core::statistics_meta::{StatisticMetaLoader, StatisticsMeta}; use crate::serdes::ReferenceTables; @@ -16,10 +17,10 @@ use crate::utils::lru::SharedLruCache; use itertools::Itertools; use std::collections::{BTreeMap, Bound}; use std::io::Cursor; -use std::mem; use std::ops::SubAssign; use std::sync::Arc; use std::vec::IntoIter; +use std::{fs, mem}; use ulid::Generator; pub(crate) type StatisticsMetaCache = SharedLruCache<(TableName, IndexId), StatisticsMeta>; @@ -302,7 +303,7 @@ pub trait Transaction: Sized { self.remove(&index_meta_key)?; let (index_min, index_max) = - unsafe { &*self.table_codec() }.index_bound(table_name, &index_meta.id)?; + unsafe { &*self.table_codec() }.index_bound(table_name, index_meta.id)?; self._drop_data(index_min, index_max)?; self.remove_table_meta(meta_cache, table_name, index_meta.id)?; @@ -410,13 +411,55 @@ pub trait Transaction: Sized { Ok(()) } + fn drop_index( + &mut self, + table_cache: &TableCache, + table_name: TableName, + index_name: &str, + if_exists: bool, + ) -> Result<(), DatabaseError> { + let table = self + .table(table_cache, table_name.clone())? + .ok_or(DatabaseError::TableNotFound)?; + let Some(index_meta) = table.indexes.iter().find(|index| index.name == index_name) else { + if if_exists { + return Ok(()); + } else { + return Err(DatabaseError::TableNotFound); + } + }; + match index_meta.ty { + IndexType::PrimaryKey { .. } | IndexType::Unique => { + return Err(DatabaseError::InvalidIndex) + } + IndexType::Normal | IndexType::Composite => (), + } + + let index_id = index_meta.id; + let index_meta_key = + unsafe { &*self.table_codec() }.encode_index_meta_key(table_name.as_str(), index_id)?; + self.remove(&index_meta_key)?; + + let (index_min, index_max) = + unsafe { &*self.table_codec() }.index_bound(table_name.as_str(), index_id)?; + self._drop_data(index_min, index_max)?; + + let statistics_min_key = unsafe { &*self.table_codec() } + .encode_statistics_path_key(table_name.as_str(), index_id); + self.remove(&statistics_min_key)?; + + table_cache.remove(&table_name); + // When dropping Index, the statistics file corresponding to the Index is not cleaned up and is processed uniformly by the Analyze Table. + + Ok(()) + } + fn drop_table( &mut self, table_cache: &TableCache, table_name: TableName, if_exists: bool, ) -> Result<(), DatabaseError> { - self.drop_name_hash(&table_name)?; if self.table(table_cache, table_name.clone())?.is_none() { if if_exists { return Ok(()); @@ -424,6 +467,7 @@ pub trait Transaction: Sized { return Err(DatabaseError::TableNotFound); } } + self.drop_name_hash(&table_name)?; self.drop_data(table_name.as_str())?; let (column_min, column_max) = @@ -437,6 +481,8 @@ pub trait Transaction: Sized { self.remove(&unsafe { &*self.table_codec() }.encode_root_table_key(table_name.as_str()))?; table_cache.remove(&table_name); + let _ = fs::remove_dir(Analyze::build_statistics_meta_path(&table_name)); + Ok(()) } @@ -1143,7 +1189,7 @@ impl Iter for IndexIter<'_, T> { unsafe { &*self.params.table_codec() }.tuple_bound(table_name) } else { unsafe { &*self.params.table_codec() } - .index_bound(table_name, &index_meta.id)? + .index_bound(table_name, index_meta.id)? }; let mut encode_min = bound_encode(min, false)?; check_bound(&mut encode_min, bound_min); @@ -1548,6 +1594,32 @@ mod test { dbg!(value); assert!(iter.try_next()?.is_none()); } + match transaction.drop_index(&table_cache, Arc::new("t1".to_string()), "pk_index", false) { + Err(DatabaseError::InvalidIndex) => (), + _ => unreachable!(), + } + transaction.drop_index(&table_cache, Arc::new("t1".to_string()), "i1", false)?; + { + let table = transaction + .table(&table_cache, Arc::new("t1".to_string()))? + .unwrap(); + let i2_meta = table.indexes[1].clone(); + assert_eq!(i2_meta.id, 2); + assert_eq!(i2_meta.column_ids, vec![c3_column_id, c2_column_id]); + assert_eq!(i2_meta.table_name, Arc::new("t1".to_string())); + assert_eq!(i2_meta.pk_ty, LogicalType::Integer); + assert_eq!(i2_meta.name, "i2".to_string()); + assert_eq!(i2_meta.ty, IndexType::Composite); + + let (min, max) = table_codec.index_meta_bound("t1"); + let mut iter = transaction.range(Bound::Included(min), Bound::Included(max))?; + + let (_, value) = iter.try_next()?.unwrap(); + dbg!(value); + let (_, value) = iter.try_next()?.unwrap(); + dbg!(value); + assert!(iter.try_next()?.is_none()); + } Ok(()) } @@ -1638,7 +1710,7 @@ mod test { assert_eq!(index_iter.next_tuple()?.unwrap(), tuples[2]); assert_eq!(index_iter.next_tuple()?.unwrap(), tuples[1]); - let (min, max) = table_codec.index_bound("t1", &1)?; + let (min, max) = table_codec.index_bound("t1", 1)?; let mut iter = transaction.range(Bound::Included(min), Bound::Included(max))?; let (_, value) = iter.try_next()?.unwrap(); @@ -1656,7 +1728,7 @@ mod test { assert_eq!(index_iter.next_tuple()?.unwrap(), tuples[2]); assert_eq!(index_iter.next_tuple()?.unwrap(), tuples[1]); - let (min, max) = table_codec.index_bound("t1", &1)?; + let (min, max) = table_codec.index_bound("t1", 1)?; let mut iter = transaction.range(Bound::Included(min), Bound::Included(max))?; let (_, value) = iter.try_next()?.unwrap(); diff --git a/src/storage/table_codec.rs b/src/storage/table_codec.rs index 8df9e511..3d99c0f0 100644 --- a/src/storage/table_codec.rs +++ b/src/storage/table_codec.rs @@ -171,13 +171,13 @@ impl TableCodec { pub fn index_bound( &self, table_name: &str, - index_id: &IndexId, + index_id: IndexId, ) -> Result<(BumpBytes, BumpBytes), DatabaseError> { let op = |bound_id| -> Result { let mut key_prefix = self.key_prefix(CodecType::Index, table_name); key_prefix.write_all(&[BOUND_MIN_TAG])?; - key_prefix.write_all(&index_id.to_be_bytes()[..])?; + key_prefix.write_all(&index_id.to_le_bytes()[..])?; key_prefix.write_all(&[bound_id])?; Ok(key_prefix) }; @@ -293,6 +293,18 @@ impl TableCodec { Tuple::deserialize_from(table_types, pk_indices, projections, schema, bytes, with_pk) } + pub fn encode_index_meta_key( + &self, + table_name: &str, + index_id: IndexId, + ) -> Result { + let mut key_prefix = self.key_prefix(CodecType::IndexMeta, table_name); + + key_prefix.write_all(&[BOUND_MIN_TAG])?; + key_prefix.write_all(&index_id.to_le_bytes()[..])?; + Ok(key_prefix) + } + /// Key: {TableName}{INDEX_META_TAG}{BOUND_MIN_TAG}{IndexID} /// Value: IndexMeta pub fn encode_index_meta( @@ -300,15 +312,12 @@ impl TableCodec { table_name: &str, index_meta: &IndexMeta, ) -> Result<(BumpBytes, BumpBytes), DatabaseError> { - let mut key_prefix = self.key_prefix(CodecType::IndexMeta, table_name); - - key_prefix.write_all(&[BOUND_MIN_TAG])?; - key_prefix.write_all(&index_meta.id.to_be_bytes()[..])?; + let key_bytes = self.encode_index_meta_key(table_name, index_meta.id)?; let mut value_bytes = BumpBytes::new_in(&self.arena); index_meta.encode(&mut value_bytes, true, &mut ReferenceTables::new())?; - Ok((key_prefix, value_bytes)) + Ok((key_bytes, value_bytes)) } pub fn decode_index_meta(bytes: &[u8]) -> Result { @@ -347,7 +356,7 @@ impl TableCodec { ) -> Result { let mut key_prefix = self.key_prefix(CodecType::Index, name); key_prefix.push(BOUND_MIN_TAG); - key_prefix.extend_from_slice(&index.id.to_be_bytes()); + key_prefix.extend_from_slice(&index.id.to_le_bytes()); key_prefix.push(BOUND_MIN_TAG); index.value.memcomparable_encode(&mut key_prefix)?; @@ -900,7 +909,7 @@ mod tests { println!("{:#?}", set); - let (min, max) = table_codec.index_bound(&table_catalog.name, &1).unwrap(); + let (min, max) = table_codec.index_bound(&table_catalog.name, 1).unwrap(); println!("{:?}", min); println!("{:?}", max); diff --git a/tests/slt/create_index.slt b/tests/slt/create_index.slt index 546fe9a2..666032ff 100644 --- a/tests/slt/create_index.slt +++ b/tests/slt/create_index.slt @@ -4,6 +4,9 @@ create table t(id int primary key, v1 int, v2 int, v3 int); statement ok create index index_1 on t (v1); +statement ok +create index if not exists index_1 on t (v1); + statement error create index index_1 on t (v1); @@ -24,5 +27,17 @@ select * from t; ---- 0 0 0 0 +statement ok +drop index t.index_1 + +statement ok +drop index t.index_2 + +statement error +drop index t.pk_index + +statement error +drop index t.index_3 + statement ok drop table t \ No newline at end of file From 12dc772b6ee9e8910f256ad7f9091a4b4c18bc4b Mon Sep 17 00:00:00 2001 From: wszhdshys <1925792291@qq.com> Date: Thu, 24 Jul 2025 10:40:56 +0800 Subject: [PATCH 7/7] update the branch --- src/optimizer/rule/normalization/correlated_subquery.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/optimizer/rule/normalization/correlated_subquery.rs b/src/optimizer/rule/normalization/correlated_subquery.rs index 01d96c26..32eaca39 100644 --- a/src/optimizer/rule/normalization/correlated_subquery.rs +++ b/src/optimizer/rule/normalization/correlated_subquery.rs @@ -194,6 +194,7 @@ impl CorrelatedSubquery { | Operator::CreateView(_) | Operator::DropTable(_) | Operator::DropView(_) + | Operator::DropIndex(_) | Operator::Truncate(_) | Operator::ShowTable | Operator::ShowView