From 3dd510880fed3d44a5712d5a3583b6c542c4ed19 Mon Sep 17 00:00:00 2001 From: Frank Colson Date: Fri, 17 May 2024 19:11:25 -0600 Subject: [PATCH] Add AND and OR support --- src/query.rs | 20 ++++++++++++++++++++ tests/tantivy_test.py | 26 ++++++++++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/src/query.rs b/src/query.rs index d38a7475..a6e5b565 100644 --- a/src/query.rs +++ b/src/query.rs @@ -65,6 +65,26 @@ impl Query { Ok(format!("Query({:?})", self.get())) } + pub(crate) fn __and__(&self, other: Query) -> Query { + let inner = tv::query::BooleanQuery::from(vec![ + (tv::query::Occur::Must, self.inner.box_clone()), + (tv::query::Occur::Must, other.inner.box_clone()), + ]); + Query { + inner: Box::new(inner), + } + } + + pub(crate) fn __or__(&self, other: Query) -> Query { + let inner = tv::query::BooleanQuery::from(vec![ + (tv::query::Occur::Should, self.inner.box_clone()), + (tv::query::Occur::Should, other.inner.box_clone()), + ]); + Query { + inner: Box::new(inner), + } + } + /// Construct a Tantivy's TermQuery #[staticmethod] #[pyo3(signature = (schema, field_name, field_value, index_option = "position"))] diff --git a/tests/tantivy_test.py b/tests/tantivy_test.py index e2a77eb5..68ab7b30 100644 --- a/tests/tantivy_test.py +++ b/tests/tantivy_test.py @@ -64,6 +64,32 @@ def test_and_query(self, ram_index): assert len(result.hits) == 1 + def test_combine_queries(self, ram_index): + index = ram_index + + query1 = ram_index.parse_query("title:men", ["title"]) + query2 = ram_index.parse_query("body:summer", ["body"]) + + combined_and = query1 & query2 + + searcher = index.searcher() + result = searcher.search(combined_and, 10) + + # This is an AND query, so it should return 0 results since summer isn't present + assert len(result.hits) == 0 + + combined_or = query1 | query2 + + result = searcher.search(combined_or, 10) + + assert len(result.hits) == 1 + + double_combined = (query1 & query2) | query1 + + result = searcher.search(double_combined, 10) + + assert len(result.hits) == 1 + def test_and_query_numeric_fields(self, ram_index_numeric_fields): index = ram_index_numeric_fields searcher = index.searcher()