diff --git a/CHANGES.md b/CHANGES.md index 876f3531..d5e47a56 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -8,6 +8,8 @@ Note: Minor version `0.X.0` update might break the API, It's recommended to pin ## [unreleased] +* switch from pygeofilter to cql2 + ## [1.2.1] - 2025-08-26 * update `starlette-cramjam` requirement to `>=0.4,<0.6` diff --git a/pyproject.toml b/pyproject.toml index dde7057a..77644e9f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ classifiers = [ dynamic = ["version"] dependencies = [ "orjson", + "cql2", "asyncpg>=0.23.0", "buildpg>=0.3", "fastapi>=0.100.0", @@ -31,7 +32,6 @@ dependencies = [ "pydantic>=2.4,<3.0", "pydantic-settings~=2.0", "geojson-pydantic>=1.0,<3.0", - "pygeofilter>=0.2.0,<0.3.0", "ciso8601~=2.3", "starlette-cramjam>=0.4,<0.6", ] diff --git a/tipg/collections.py b/tipg/collections.py index 82d61169..8edc00d9 100644 --- a/tipg/collections.py +++ b/tipg/collections.py @@ -3,7 +3,7 @@ import abc import datetime import re -from functools import lru_cache +from functools import lru_cache, reduce from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union from buildpg import RawDangerous as raw @@ -11,9 +11,10 @@ from buildpg import funcs as pg_funcs from buildpg import logic, render from ciso8601 import parse_rfc3339 +from cql2 import Expr +from geojson_pydantic.geometries import Polygon from morecantile import Tile, TileMatrixSet from pydantic import BaseModel, Field, model_validator -from pygeofilter.ast import AstType from pyproj import Transformer from tipg.errors import ( @@ -24,8 +25,6 @@ InvalidPropertyName, MissingDatetimeColumn, ) -from tipg.filter.evaluate import to_filter -from tipg.filter.filters import bbox_to_wkt from tipg.logger import logger from tipg.model import Extent from tipg.settings import ( @@ -44,6 +43,12 @@ TransformerFromCRS = lru_cache(Transformer.from_crs) +def bbox_to_wkt(bbox: List[float], srid: int = 4326) -> str: + """Return WKT representation of a BBOX.""" + poly = Polygon.from_bounds(*bbox) # type:ignore + return f"SRID={srid};{poly.wkt}" + + def debug_query(q, *p): """Utility to print raw statement to use for debugging.""" @@ -323,6 +328,139 @@ def queryables(self) -> Dict: return {**geoms, **props} + def cql_where( # noqa: C901 + self, + ids: Optional[List[str]] = None, + datetime: Optional[List[str]] = None, + bbox: Optional[List[float]] = None, + properties: Optional[List[Tuple[str, Any]]] = None, + cql: Optional[Expr] = None, + geom: Optional[str] = None, + dt: Optional[str] = None, + tile: Optional[Tile] = None, + tms: Optional[TileMatrixSet] = None, + ) -> Expr: + """Construct WHERE query.""" + exprs = [] + + if cql: + exprs.append(cql) + + # `ids` filter + if ids: + # REF: https://github.com/developmentseed/cql2-rs/issues/91 + if len(ids) == 1: + exprs.append(Expr(f"{self.id_column.name} = {ids[0]}")) + else: + id_list = ", ".join(f"'{id_}'" for id_ in ids) + exprs.append(Expr(f"{self.id_column.name} IN {id_list}")) + + # `properties` filter + if properties is not None: + for prop, val in properties: + col = self.get_column(prop) + if not col: + raise InvalidPropertyName(f"Invalid property name: {prop}") + exprs.append(Expr(f"{prop}='{val}'")) + + # `bbox` filter + geometry_column = self.get_geometry_column(geom) + if bbox is not None and geometry_column is not None: + # TODO: should we use bbox_to_wkt(bbox) + exprs.append( + Expr( + f"S_INTERSECTS({geometry_column.name}, {', '.join(map(str, bbox))})" + ) + ) + print(exprs[0].reduce().to_sql()) + + # `datetime` filter + if datetime: + if not self.datetime_columns: + raise MissingDatetimeColumn( + "Must have timestamp/timestamptz/date typed column to filter with datetime." + ) + + datetime_column = self.get_datetime_column(dt) + if not datetime_column: + raise InvalidDatetimeColumnName(f"Invalid Datetime Column: {dt}.") + + if len(datetime) == 1: + # NOTE: should we do parse_rfc3339(datetime[0]) + exprs.append(Expr(f"{datetime_column.name}=TIMESTAMP('{datetime[0]}')")) + + else: + start = ( + parse_rfc3339(datetime[0]) + if datetime[0] not in ["..", ""] + else None + ) + end = ( + parse_rfc3339(datetime[1]) + if datetime[1] not in ["..", ""] + else None + ) + if start is None and end is None: + raise InvalidDatetime( + "Double open-ended datetime intervals are not allowed." + ) + + if start is not None and end is not None and start > end: + raise InvalidDatetime( + "Start datetime cannot be before end datetime." + ) + + if (start and end) and start > end: + raise ValueError("Invalid datetime range: start must be <= end") + + startstr, endstr = datetime[:2] + if startstr not in ["..", ""]: + exprs.append( + Expr(f"{datetime_column.name}>=TIMESTAMP('{startstr}')") + ) + if endstr: + exprs.append(Expr(f"{datetime_column.name}<=TIMESTAMP('{endstr}')")) + + # if tile and tms and geometry_column: + # # Get tile bounds in the TMS coordinate system + # bbox = tms.xy_bounds(tile) + # left, bottom, right, top = bbox + + # # If the geometry column’s SRID does not match the TMS CRS, transform the bounds: + # # Use a fallback of 4326 if tms.crs.to_epsg() returns a falsey value. + # tms_epsg = tms.crs.to_epsg() or 4326 + # if geometry_column.srid != tms_epsg: + # transformer = TransformerFromCRS( + # tms_epsg, geometry_column.srid, always_xy=True + # ) + + # left, bottom, right, top = transformer.transform_bounds( + # left, bottom, right, top + # ) + + # wheres.append( + # logic.Func( + # "ST_Intersects", + # logic.Func( + # "ST_Segmentize", + # logic.Func( + # "ST_MakeEnvelope", + # left, + # bottom, + # right, + # top, + # geometry_column.srid, + # ), + # right - left, + # ), + # logic.V(geometry_column.name), + # ) + # ) + if exprs: + return reduce(lambda x, y: x + y, exprs).reduce() + + return None + @abc.abstractmethod async def features(self, *args, **kwargs) -> ItemList: """Get Items.""" @@ -516,152 +654,6 @@ def _geom( return g - def _where( # noqa: C901 - self, - ids: Optional[List[str]] = None, - datetime: Optional[List[str]] = None, - bbox: Optional[List[float]] = None, - properties: Optional[List[Tuple[str, Any]]] = None, - cql: Optional[AstType] = None, - geom: Optional[str] = None, - dt: Optional[str] = None, - tile: Optional[Tile] = None, - tms: Optional[TileMatrixSet] = None, - ): - """Construct WHERE query.""" - wheres = [logic.S(True)] - - # `ids` filter - if ids is not None: - if len(ids) == 1: - wheres.append( - logic.V(self.id_column.name) - == pg_funcs.cast(pg_funcs.cast(ids[0], "text"), self.id_column.type) - ) - else: - w = [ - logic.V(self.id_column.name) - == logic.S( - pg_funcs.cast(pg_funcs.cast(i, "text"), self.id_column.type) - ) - for i in ids - ] - wheres.append(pg_funcs.OR(*w)) - - # `properties filter - if properties is not None: - w = [] - for prop, val in properties: - col = self.get_column(prop) - if not col: - raise InvalidPropertyName(f"Invalid property name: {prop}") - - w.append( - logic.V(col.name) - == logic.S(pg_funcs.cast(pg_funcs.cast(val, "text"), col.type)) - ) - - if w: - wheres.append(pg_funcs.AND(*w)) - - # `bbox` filter - geometry_column = self.get_geometry_column(geom) - if bbox is not None and geometry_column is not None: - wheres.append( - logic.Func( - "ST_Intersects", - logic.S(bbox_to_wkt(bbox)), - logic.V(geometry_column.name), - ) - ) - - # `datetime` filter - if datetime: - if not self.datetime_columns: - raise MissingDatetimeColumn( - "Must have timestamp/timestamptz/date typed column to filter with datetime." - ) - - datetime_column = self.get_datetime_column(dt) - if not datetime_column: - raise InvalidDatetimeColumnName(f"Invalid Datetime Column: {dt}.") - - wheres.append(self._datetime_filter_to_sql(datetime, datetime_column.name)) - - # `CQL` filter - if cql is not None: - wheres.append(to_filter(cql, [p.name for p in self.properties])) - - if tile and tms and geometry_column: - # Get tile bounds in the TMS coordinate system - bbox = tms.xy_bounds(tile) - left, bottom, right, top = bbox - - # If the geometry column’s SRID does not match the TMS CRS, transform the bounds: - # Use a fallback of 4326 if tms.crs.to_epsg() returns a falsey value. - tms_epsg = tms.crs.to_epsg() or 4326 - if geometry_column.srid != tms_epsg: - transformer = TransformerFromCRS( - tms_epsg, geometry_column.srid, always_xy=True - ) - - left, bottom, right, top = transformer.transform_bounds( - left, bottom, right, top - ) - - wheres.append( - logic.Func( - "ST_Intersects", - logic.Func( - "ST_Segmentize", - logic.Func( - "ST_MakeEnvelope", - left, - bottom, - right, - top, - geometry_column.srid, - ), - right - left, - ), - logic.V(geometry_column.name), - ) - ) - - return clauses.Where(pg_funcs.AND(*wheres)) - - def _datetime_filter_to_sql(self, interval: List[str], dt_name: str): - if len(interval) == 1: - return logic.V(dt_name) == logic.S( - pg_funcs.cast(parse_rfc3339(interval[0]), "timestamptz") - ) - - else: - start = ( - parse_rfc3339(interval[0]) if interval[0] not in ["..", ""] else None - ) - end = parse_rfc3339(interval[1]) if interval[1] not in ["..", ""] else None - - if start is None and end is None: - raise InvalidDatetime( - "Double open-ended datetime intervals are not allowed." - ) - - if start is not None and end is not None and start > end: - raise InvalidDatetime("Start datetime cannot be before end datetime.") - - if not start: - return logic.V(dt_name) <= logic.S(pg_funcs.cast(end, "timestamptz")) - - elif not end: - return logic.V(dt_name) >= logic.S(pg_funcs.cast(start, "timestamptz")) - - else: - return pg_funcs.AND( - logic.V(dt_name) >= logic.S(pg_funcs.cast(start, "timestamptz")), - logic.V(dt_name) < logic.S(pg_funcs.cast(end, "timestamptz")), - ) - def _sortby(self, sortby: Optional[str]): sorts = [] if sortby: @@ -690,15 +682,10 @@ async def _features_query( self, conn: asyncpg.Connection, *, - ids_filter: Optional[List[str]] = None, - bbox_filter: Optional[List[float]] = None, - datetime_filter: Optional[List[str]] = None, - properties_filter: Optional[List[Tuple[str, str]]] = None, - cql_filter: Optional[AstType] = None, + where: Optional[str] = None, sortby: Optional[str] = None, properties: Optional[List[str]] = None, geom: Optional[str] = None, - dt: Optional[str] = None, limit: Optional[int] = None, offset: Optional[int] = None, bbox_only: Optional[bool] = None, @@ -719,15 +706,7 @@ async def _features_query( geom_as_wkt=geom_as_wkt, ), self._from(function_parameters), - self._where( - ids=ids_filter, - datetime=datetime_filter, - bbox=bbox_filter, - properties=properties_filter, - cql=cql_filter, - geom=geom, - dt=dt, - ), + clauses.Where(where or logic.S(True)), self._sortby(sortby), clauses.Limit(limit), clauses.Offset(offset), @@ -745,28 +724,14 @@ async def _features_count_query( self, conn: asyncpg.Connection, *, - ids_filter: Optional[List[str]] = None, - bbox_filter: Optional[List[float]] = None, - datetime_filter: Optional[List[str]] = None, - properties_filter: Optional[List[Tuple[str, str]]] = None, - cql_filter: Optional[AstType] = None, - geom: Optional[str] = None, - dt: Optional[str] = None, + where: Optional[str], function_parameters: Optional[Dict[str, str]], ) -> int: """Build features COUNT query.""" c = clauses.Clauses( self._select_count(), self._from(function_parameters), - self._where( - ids=ids_filter, - datetime=datetime_filter, - bbox=bbox_filter, - properties=properties_filter, - cql=cql_filter, - geom=geom, - dt=dt, - ), + clauses.Where(where or logic.S(True)), ) q, p = render(":c", c=c) @@ -781,7 +746,7 @@ async def features( bbox_filter: Optional[List[float]] = None, datetime_filter: Optional[List[str]] = None, properties_filter: Optional[List[Tuple[str, str]]] = None, - cql_filter: Optional[AstType] = None, + cql_filter: Optional[Expr] = None, sortby: Optional[str] = None, properties: Optional[List[str]] = None, geom: Optional[str] = None, @@ -807,31 +772,30 @@ async def features( f"Limit can not be set higher than the `tipg_max_features_per_query` setting of {features_settings.max_features_per_query}" ) + where_filter = self.cql_where( + ids=ids_filter, + datetime=datetime_filter, + bbox=bbox_filter, + properties=properties_filter, + cql=cql_filter, + geom=geom, + dt=dt, + ) + matched = await self._features_count_query( conn, - ids_filter=ids_filter, - datetime_filter=datetime_filter, - bbox_filter=bbox_filter, - properties_filter=properties_filter, + where=where_filter.to_sql() if where_filter else None, function_parameters=function_parameters, - cql_filter=cql_filter, - geom=geom, - dt=dt, ) features = [ f async for f in self._features_query( conn, - ids_filter=ids_filter, - datetime_filter=datetime_filter, - bbox_filter=bbox_filter, - properties_filter=properties_filter, - cql_filter=cql_filter, + where=where_filter.to_sql() if where_filter else None, sortby=sortby, properties=properties, geom=geom, - dt=dt, limit=limit, offset=offset, bbox_only=bbox_only, @@ -860,7 +824,7 @@ async def get_tile( datetime_filter: Optional[List[str]] = None, properties_filter: Optional[List[Tuple[str, str]]] = None, function_parameters: Optional[Dict[str, str]] = None, - cql_filter: Optional[AstType] = None, + cql_filter: Optional[Expr] = None, sortby: Optional[str] = None, properties: Optional[List[str]] = None, geom: Optional[str] = None, @@ -879,6 +843,18 @@ async def get_tile( f"Limit can not be set higher than the `tipg_max_features_per_tile` setting of {mvt_settings.max_features_per_tile}" ) + where_filter = self.cql_where( + ids=ids_filter, + datetime=datetime_filter, + bbox=bbox_filter, + properties=properties_filter, + cql=cql_filter, + geom=geom, + dt=dt, + tms=tms, + tile=tile, + ) + c = clauses.Clauses( self._select_mvt( properties=properties, @@ -887,17 +863,7 @@ async def get_tile( tile=tile, ), self._from(function_parameters), - self._where( - ids=ids_filter, - datetime=datetime_filter, - bbox=bbox_filter, - properties=properties_filter, - cql=cql_filter, - geom=geom, - dt=dt, - tms=tms, - tile=tile, - ), + clauses.Where(where_filter.to_sql() if where_filter else logic.S(True)), clauses.Limit(limit), ) diff --git a/tipg/dependencies.py b/tipg/dependencies.py index a82feeb7..70fbfe48 100644 --- a/tipg/dependencies.py +++ b/tipg/dependencies.py @@ -4,11 +4,9 @@ from typing import Annotated, Dict, List, Literal, Optional, Tuple, get_args from ciso8601 import parse_rfc3339 +from cql2 import Expr from morecantile import Tile from morecantile import tms as default_tms -from pygeofilter.ast import AstType -from pygeofilter.parsers.cql2_json import parse as cql2_json_parser -from pygeofilter.parsers.cql2_text import parse as cql2_text_parser from tipg.collections import Catalog, Collection, CollectionList from tipg.errors import InvalidBBox, MissingCollectionCatalog, MissingFunctionParameter @@ -289,14 +287,10 @@ def filter_query( alias="filter-lang", ), ] = None, -) -> Optional[AstType]: +) -> Optional[Expr]: """Parse Filter Query.""" if query is not None: - if filter_lang == "cql2-json": - return cql2_json_parser(query) - - # default to cql2-text - return cql2_text_parser(query) + return Expr(query) return None diff --git a/tipg/factory.py b/tipg/factory.py index cf072811..1732ad33 100644 --- a/tipg/factory.py +++ b/tipg/factory.py @@ -19,10 +19,10 @@ import jinja2 import orjson +from cql2 import Expr from morecantile import Tile, TileMatrixSet from morecantile import tms as default_tms from morecantile.defaults import TileMatrixSets -from pygeofilter.ast import AstType from tipg import model from tipg.collections import Collection, CollectionList @@ -777,7 +777,7 @@ async def items( # noqa: C901 bbox_filter: Annotated[Optional[List[float]], Depends(bbox_query)], datetime_filter: Annotated[Optional[List[str]], Depends(datetime_query)], properties: Annotated[Optional[List[str]], Depends(properties_query)], - cql_filter: Annotated[Optional[AstType], Depends(filter_query)], + cql_filter: Annotated[Optional[Expr], Depends(filter_query)], sortby: Annotated[Optional[str], Depends(sortby_query)], geom_column: Annotated[ Optional[str], @@ -1086,9 +1086,9 @@ async def item( async with request.app.state.pool.acquire() as conn: item_list = await collection.features( conn, + cql_filter=Expr(f"{collection.id_column.name} = {itemId}"), bbox_only=bbox_only, simplify=simplify, - ids_filter=[itemId], properties=properties, function_parameters=function_parameters_query(request, collection), geom=geom_column, @@ -1629,7 +1629,7 @@ async def collection_get_tile( properties: Annotated[ Optional[List[str]], Depends(properties_query) ] = None, - cql_filter: Annotated[Optional[AstType], Depends(filter_query)] = None, + cql_filter: Annotated[Optional[Expr], Depends(filter_query)] = None, sortby: Annotated[Optional[str], Depends(sortby_query)] = None, geom_column: Annotated[ Optional[str], diff --git a/tipg/filter/__init__.py b/tipg/filter/__init__.py deleted file mode 100644 index 98f1bebd..00000000 --- a/tipg/filter/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""tipg.filter""" diff --git a/tipg/filter/evaluate.py b/tipg/filter/evaluate.py deleted file mode 100644 index 5852ec25..00000000 --- a/tipg/filter/evaluate.py +++ /dev/null @@ -1,151 +0,0 @@ -"""tipg.filter.evaluate.""" - -from datetime import date, datetime, time, timedelta - -from pygeofilter import ast, values -from pygeofilter.backends.evaluator import Evaluator, handle - -from tipg.filter import filters - -LITERALS = (str, float, int, bool, datetime, date, time, timedelta) - - -class BuildPGEvaluator(Evaluator): # noqa: D101 - def __init__(self, field_mapping): # noqa: D107 - self.field_mapping = field_mapping - - @handle(ast.Not) - def not_(self, node, sub): # noqa: D102 - return filters.negate(sub) - - @handle(ast.And, ast.Or) - def combination(self, node, lhs, rhs): # noqa: D102 - return filters.combine((lhs, rhs), node.op.value) - - @handle(ast.Comparison, subclasses=True) - def comparison(self, node, lhs, rhs): # noqa: D102 - return filters.runop( - lhs, - rhs, - node.op.value, - ) - - @handle(ast.Between) - def between(self, node, lhs, low, high): # noqa: D102 - return filters.between(lhs, low, high, node.not_) - - @handle(ast.Like) - def like(self, node, lhs): # noqa: D102 - return filters.like( - lhs, - node.pattern, - not node.nocase, - node.not_, - ) - - @handle(ast.In) - def in_(self, node, lhs, *options): # noqa: D102 - return filters.runop( - lhs, - options, - "in", - node.not_, - ) - - @handle(ast.IsNull) - def null(self, node, lhs): # noqa: D102 - if isinstance(lhs, list): - lhs = filters.attribute(lhs[0].name, self.field_mapping) - return filters.isnull(lhs) - - # @handle(ast.ExistsPredicateNode) - # def exists(self, node, lhs): - # if self.use_getattr: - # result = hasattr(self.obj, node.lhs.name) - # else: - # result = lhs in self.obj - - # if node.not_: - # result = not result - # return result - - @handle(ast.TemporalPredicate, subclasses=True) - def temporal(self, node, lhs, rhs): # noqa: D102 - return filters.temporal( - lhs, - rhs, - node.op.value, - ) - - @handle(ast.SpatialComparisonPredicate, subclasses=True) - def spatial_operation(self, node, lhs, rhs): # noqa: D102 - return filters.spatial( - lhs, - rhs, - node.op.name, - ) - - @handle(ast.Relate) - def spatial_pattern(self, node, lhs, rhs): # noqa: D102 - return filters.spatial( - lhs, - rhs, - "RELATE", - pattern=node.pattern, - ) - - @handle(ast.SpatialDistancePredicate, subclasses=True) - def spatial_distance(self, node, lhs, rhs): # noqa: D102 - return filters.spatial( - lhs, - rhs, - node.op.value, - distance=node.distance, - units=node.units, - ) - - @handle(ast.BBox) - def bbox(self, node, lhs): # noqa: D102 - return filters.bbox(lhs, node.minx, node.miny, node.maxx, node.maxy, node.crs) - - @handle(ast.Attribute) - def attribute(self, node): # noqa: D102 - return filters.attribute(node.name, self.field_mapping) - - @handle(ast.Arithmetic, subclasses=True) - def arithmetic(self, node, lhs, rhs): # noqa: D102 - return filters.runop(lhs, rhs, node.op.value) - - @handle(ast.Function) - def function(self, node, *arguments): # noqa: D102 - return filters.func(node.name, *arguments) - - @handle(*values.LITERALS) - def literal(self, node): # noqa: D102 - return filters.literal(node) - - @handle(values.Interval) - def interval(self, node, start, end): # noqa: D102 - return filters.literal((start, end)) - - @handle(values.Geometry) - def geometry(self, node): # noqa: D102 - return filters.parse_geometry(node.__geo_interface__) - - @handle(values.Envelope) - def envelope(self, node): # noqa: D102 - return filters.parse_bbox([node.x1, node.y1, node.x2, node.y2]) - - -def to_filter(ast, field_mapping=None): # noqa: D102 - """Helper function to translate ECQL AST to Django Query expressions. - - :param ast: the abstract syntax tree - :param field_mapping: a dict mapping from the filter name to the Django field lookup. - :param mapping_choices: a dict mapping field lookups to choices. - :type ast: :class:`Node` - :returns: a Django query object - :rtype: :class:`django.db.models.Q` - - """ - return BuildPGEvaluator(field_mapping).evaluate(ast) diff --git a/tipg/filter/filters.py b/tipg/filter/filters.py deleted file mode 100644 index 2e306e8e..00000000 --- a/tipg/filter/filters.py +++ /dev/null @@ -1,325 +0,0 @@ -"""tipg.filter.filters""" - -import re -from datetime import timedelta -from functools import reduce -from inspect import signature -from typing import Any, Callable, Dict, List - -from buildpg import V -from buildpg.funcs import AND as and_ -from buildpg.funcs import NOT as not_ -from buildpg.funcs import OR as or_ -from buildpg.funcs import any -from buildpg.logic import Func -from geojson_pydantic.geometries import Polygon, parse_geometry_obj - - -def bbox_to_wkt(bbox: List[float], srid: int = 4326) -> str: - """Return WKT representation of a BBOX.""" - poly = Polygon.from_bounds(*bbox) # type:ignore - return f"SRID={srid};{poly.wkt}" - - -def parse_geometry(geom: Dict[str, Any]) -> str: - """Parse geometry object and return WKT.""" - wkt = parse_geometry_obj(geom).wkt # type:ignore - sridtxt = "" if wkt.startswith("SRID=") else "SRID=4326;" - return f"{sridtxt}{wkt}" - - -# ------------------------------------------------------------------------------ -# Filters -# ------------------------------------------------------------------------------ -class Operator: - """Filter Operators.""" - - OPERATORS: Dict[str, Callable] = { - "==": lambda f, a: f == a, - "=": lambda f, a: f == a, - "eq": lambda f, a: f == a, - "!=": lambda f, a: f != a, - "<>": lambda f, a: f != a, - "ne": lambda f, a: f != a, - ">": lambda f, a: f > a, - "gt": lambda f, a: f > a, - "<": lambda f, a: f < a, - "lt": lambda f, a: f < a, - ">=": lambda f, a: f >= a, - "ge": lambda f, a: f >= a, - "<=": lambda f, a: f <= a, - "le": lambda f, a: f <= a, - "like": lambda f, a: f.like(a), - "ilike": lambda f, a: f.ilike(a), - "not_ilike": lambda f, a: ~f.ilike(a), - "in": lambda f, a: f == any(a), - "not_in": lambda f, a: ~f == any(a), - "any": lambda f, a: f.any(a), - "not_any": lambda f, a: f.not_(f.any(a)), - "INTERSECTS": lambda f, a: Func( - "st_intersects", - f, - Func("st_transform", a, Func("st_srid", f)), - ), - "DISJOINT": lambda f, a: Func( - "st_disjoint", f, Func("st_transform", a, Func("st_srid", f)) - ), - "CONTAINS": lambda f, a: Func( - "st_contains", f, Func("st_transform", a, Func("st_srid", f)) - ), - "WITHIN": lambda f, a: Func( - "st_within", f, Func("st_transform", a, Func("st_srid", f)) - ), - "TOUCHES": lambda f, a: Func( - "st_touches", f, Func("st_transform", a, Func("st_srid", f)) - ), - "CROSSES": lambda f, a: Func( - "st_crosses", - f, - Func("st_transform", a, Func("st_srid", f)), - ), - "OVERLAPS": lambda f, a: Func( - "st_overlaps", - f, - Func("st_transform", a, Func("st_srid", f)), - ), - "EQUALS": lambda f, a: Func( - "st_equals", - f, - Func("st_transform", a, Func("st_srid", f)), - ), - "RELATE": lambda f, a, pattern: Func( - "st_relate", f, Func("st_transform", a, Func("st_srid", f)), pattern - ), - "DWITHIN": lambda f, a, distance: Func( - "st_dwithin", f, Func("st_transform", a, Func("st_srid", f)), distance - ), - "BEYOND": lambda f, a, distance: ~Func( - "st_dwithin", f, Func("st_transform", a, Func("st_srid", f)), distance - ), - "+": lambda f, a: f + a, - "-": lambda f, a: f - a, - "*": lambda f, a: f * a, - "/": lambda f, a: f / a, - } - - def __init__(self, operator: str = None): - """Init.""" - if not operator: - operator = "==" - - if operator not in self.OPERATORS: - raise Exception("Operator `{}` not valid.".format(operator)) - - self.operator = operator - self.function = self.OPERATORS[operator] - self.arity = len(signature(self.function).parameters) - - -def func(name, *args): - """Return results of running SQL function with arguments.""" - return Func(name, *args) - - -def combine(sub_filters, combinator: str = "AND"): - """Combine filters using a logical combinator - - :param sub_filters: the filters to combine - :param combinator: a string: "AND" / "OR" - :return: the combined filter - - """ - assert combinator in ("AND", "OR") - _op = and_ if combinator == "AND" else or_ - - def test(acc, q): - return _op(acc, q) - - return reduce(test, sub_filters) - - -def negate(sub_filter): - """Negate a filter, opposing its meaning. - - :param sub_filter: the filter to negate - :return: the negated filter - - """ - return not_(sub_filter) - - -def runop(lhs, rhs=None, op: str = "=", negate: bool = False): - """Compare a filter with an expression using a comparison operation. - - :param lhs: the field to compare - :param rhs: the filter expression - :param op: a string denoting the operation. - :return: a comparison expression object - - """ - _op = Operator(op) - - if negate: - return not_(_op.function(lhs, rhs)) - return _op.function(lhs, rhs) - - -def between(lhs, low, high, negate=False): - """Create a filter to match elements that have a value within a certain range. - - :param lhs: the field to compare - :param low: the lower value of the range - :param high: the upper value of the range - :param not_: whether the range shall be inclusive (the default) or exclusive - :return: a comparison expression object - - """ - l_op = Operator("<=") - g_op = Operator(">=") - if negate: - return not_(and_(g_op.function(lhs, low), l_op.function(lhs, high))) - - return and_(g_op.function(lhs, low), l_op.function(lhs, high)) - - -def like(lhs, rhs, case=False, negate=False): - """Create a filter to filter elements according to a string attribute using wildcard expressions. - - :param lhs: the field to compare - :param rhs: the wildcard pattern: a string containing any number of '%' characters as wildcards. - :param case: whether the lookup shall be done case sensitively or not - :param not_: whether the range shall be inclusive (the default) or exclusive - :return: a comparison expression object - - """ - if case: - _op = Operator("like") - else: - _op = Operator("ilike") - - if negate: - return not_(_op.function(lhs, rhs)) - - return _op.function(lhs, rhs) - - -def temporal(lhs, time_or_period, op): - """Create a temporal filter for the given temporal attribute. - - :param lhs: the field to compare - :type lhs: :class:`django.db.models.F` - :param time_or_period: the time instant or time span to use as a filter - :type time_or_period: :class:`datetime.datetime` or a tuple of two datetimes or a tuple of one datetime and one :class:`datetime.timedelta` - :param op: the comparison operation. one of ``"BEFORE"``, ``"BEFORE OR DURING"``, ``"DURING"``, ``"DURING OR AFTER"``, ``"AFTER"``. - :type op: str - :return: a comparison expression object - :rtype: :class:`django.db.models.Q` - - """ - low = None - high = None - equal = None - if op in ("BEFORE", "AFTER"): - if op == "BEFORE": - high = time_or_period - else: - low = time_or_period - elif op == "TEQUALS": - equal = time_or_period - else: - low, high = time_or_period - - if isinstance(low, timedelta): - low = high - low - if isinstance(high, timedelta): - high = low + high - if low is not None or high is not None: - if low is not None and high is not None: - return between(lhs, low, high) - elif low is not None: - return runop(lhs, low, ">=") - else: - return runop(lhs, high, "<=") - elif equal is not None: - return runop(lhs, equal, "==") - - -UNITS_LOOKUP = {"kilometers": "km", "meters": "m"} - - -def spatial(lhs, rhs, op, pattern=None, distance=None, units=None): - """Create a spatial filter for the given spatial attribute. - - :param lhs: the field to compare - :param rhs: the time instant or time span to use as a filter - :param op: the comparison operation. one of ``"INTERSECTS"``, ``"DISJOINT"``, `"CONTAINS"``, ``"WITHIN"``, ``"TOUCHES"``, ``"CROSSES"``, ``"OVERLAPS"``, ``"EQUALS"``, ``"RELATE"``, ``"DWITHIN"``, ``"BEYOND"`` - :param pattern: the spatial relation pattern - :param distance: the distance value for distance based lookups: ``"DWITHIN"`` and ``"BEYOND"`` - :param units: the units the distance is expressed in - :return: a comparison expression object - - """ - - _op = Operator(op) - if op == "RELATE": - return _op.function(lhs, rhs, pattern) - elif op in ("DWITHIN", "BEYOND"): - if units == "kilometers": - distance = distance / 1000 - elif units == "miles": - distance = distance / 1609 - return _op.function(lhs, rhs, distance) - else: - return _op.function(lhs, rhs) - - -def bbox(lhs, minx, miny, maxx, maxy, crs: int = 4326): - """Create a bounding box filter for the given spatial attribute. - - :param lhs: the field to compare - :param minx: the lower x part of the bbox - :param miny: the lower y part of the bbox - :param maxx: the upper x part of the bbox - :param maxy: the upper y part of the bbox - :param crs: the CRS the bbox is expressed in - :return: a comparison expression object - - """ - - return Func("st_intersects", lhs, bbox_to_wkt([minx, miny, maxx, maxy], crs)) - - -def quote_ident(s: str) -> str: - """quote.""" - if re.match(r"^[a-z]+$", s): - return s - if re.match(r"^[a-zA-Z][\w\d_]*$", s): - return f'"{s}"' - raise TypeError(f"{s} is not a valid identifier") - - -def attribute(name: str, fields: List[str]): - """Create an attribute lookup expression using a field mapping dictionary. - - :param name: the field filter name - :param field_mapping: the dictionary to use as a lookup. - - """ - if name in fields: - return V(name) - elif name.lower() == "true": - return True - elif name.lower() == "false": - return False - else: - raise TypeError(f"Field {name} not in table.") - - -def isnull(lhs): - """null value.""" - return lhs.is_(V("NULL")) - - -def literal(value): - """literal value.""" - return value