From 4a18d13c8a70dae79e1a30563753e1f8afcdb39d Mon Sep 17 00:00:00 2001 From: Iwan Aucamp Date: Sun, 21 Aug 2022 20:12:11 +0200 Subject: [PATCH] feat: add type hints to `rdflib.plugins.sparql.{algebra,operators}` More or less complete type hints for `rdflib.plugins.sparql.algebra` and `rdflib.plugins.sparql.operators`. This does not change runtime behaviour. Other changes: - Fixed line endings of `test/test_issues/test_issue1043.py` and `test/test_issues/test_issue910.py`. - Removed a type hint comment that was present in rdflib/plugins/sparql/algebra.py Related issues: - Closes . --- CHANGELOG.md | 4 + Taskfile.yml | 11 +- rdflib/plugins/sparql/algebra.py | 98 ++++++---- rdflib/plugins/sparql/operators.py | 277 ++++++++++++++++++----------- test/test_issues/test_issue1043.py | 60 +++---- test/test_issues/test_issue910.py | 132 +++++++------- 6 files changed, 346 insertions(+), 236 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 37fc6dafb..f2775d246 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -170,6 +170,10 @@ and will be removed for release. [PR #2057](https://github.com/RDFLib/rdflib/pull/2057). - `rdflib.graph` have mostly complete type hints. [PR #2080](https://github.com/RDFLib/rdflib/pull/2080). + - `rdflib.plugins.sparql.algebra` amd `rdflib.plugins.sparql.operators` have + mostly complete type hints. + [PR #2094](https://github.com/RDFLib/rdflib/pull/2094). + diff --git a/Taskfile.yml b/Taskfile.yml index d46ca1f8d..c4028a45a 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -132,12 +132,17 @@ tasks: desc: Run tests cmds: - '{{.TEST_HARNESS}}{{print .VENV_BINPREFIX "pytest" | shellQuote}} {{if (mustFromJson .WITH_COVERAGE)}}--cov --cov-report={{end}} {{.CLI_ARGS}}' - flake8: desc: Run flake8 cmds: - - "{{._PYTHON | shellQuote}} -m flakeheaven lint {{.CLI_ARGS}}" - + - | + if {{._PYTHON | shellQuote}} -c 'import importlib; exit(0 if importlib.util.find_spec("flakeheaven") is not None else 1)' + then + 1>&2 echo "running flakeheaven" + {{._PYTHON | shellQuote}} -m flakeheaven lint {{.CLI_ARGS}} + else + 1>&2 echo "skipping flakeheaven as it is not installed, likely because python version is older than 3.8" + fi black: desc: Run black cmds: diff --git a/rdflib/plugins/sparql/algebra.py b/rdflib/plugins/sparql/algebra.py index 01dc17511..cb5d04945 100644 --- a/rdflib/plugins/sparql/algebra.py +++ b/rdflib/plugins/sparql/algebra.py @@ -1,3 +1,5 @@ +from __future__ import annotations + """ Converting the 'parse-tree' output of pyparsing to a SPARQL Algebra expression @@ -48,9 +50,7 @@ def OrderBy(p: CompValue, expr: List[CompValue]) -> CompValue: return CompValue("OrderBy", p=p, expr=expr) -def ToMultiSet( - p: typing.Union[List[Dict[Variable, Identifier]], CompValue] -) -> CompValue: +def ToMultiSet(p: typing.Union[List[Dict[Variable, str]], CompValue]) -> CompValue: return CompValue("ToMultiSet", p=p) @@ -66,11 +66,13 @@ def Minus(p1: CompValue, p2: CompValue) -> CompValue: return CompValue("Minus", p1=p1, p2=p2) -def Graph(term, graph) -> CompValue: +def Graph(term: Identifier, graph: CompValue) -> CompValue: return CompValue("Graph", term=term, p=graph) -def BGP(triples=None) -> CompValue: +def BGP( + triples: Optional[List[Tuple[Identifier, Identifier, Identifier]]] = None +) -> CompValue: return CompValue("BGP", triples=triples or []) @@ -78,19 +80,21 @@ def LeftJoin(p1: CompValue, p2: CompValue, expr) -> CompValue: return CompValue("LeftJoin", p1=p1, p2=p2, expr=expr) -def Filter(expr, p: CompValue) -> CompValue: +def Filter(expr: Expr, p: CompValue) -> CompValue: return CompValue("Filter", expr=expr, p=p) -def Extend(p: CompValue, expr, var) -> CompValue: +def Extend( + p: CompValue, expr: typing.Union[Identifier, Expr], var: Variable +) -> CompValue: return CompValue("Extend", p=p, expr=expr, var=var) -def Values(res) -> CompValue: +def Values(res: List[Dict[Variable, str]]) -> CompValue: return CompValue("values", res=res) -def Project(p: CompValue, PV) -> CompValue: +def Project(p: CompValue, PV: List[Variable]) -> CompValue: return CompValue("Project", p=p, PV=PV) @@ -102,7 +106,7 @@ def _knownTerms( triple: Tuple[Identifier, Identifier, Identifier], varsknown: Set[typing.Union[BNode, Variable]], varscount: Dict[Identifier, int], -): +) -> Tuple[int, int, bool]: return ( len( [ @@ -124,7 +128,7 @@ def reorderTriples( ones with most bindings first """ - def _addvar(term, varsknown): + def _addvar(term: str, varsknown: Set[typing.Union[Variable, BNode]]): if isinstance(term, (Variable, BNode)): varsknown.add(term) @@ -180,20 +184,25 @@ def triples( return reorderTriples((l[x], l[x + 1], l[x + 2]) for x in range(0, len(l), 3)) # type: ignore[misc] -def translatePName(p: typing.Union[CompValue, str], prologue: Prologue): +# type error: Missing return statement +def translatePName( # type: ignore[return] + p: typing.Union[CompValue, str], prologue: Prologue +) -> Optional[Identifier]: """ Expand prefixed/relative URIs """ if isinstance(p, CompValue): if p.name == "pname": - return prologue.absolutize(p) + # type error: Incompatible return value type (got "Union[CompValue, str, None]", expected "Optional[Identifier]") + return prologue.absolutize(p) # type: ignore[return-value] if p.name == "literal": # type error: Argument "datatype" to "Literal" has incompatible type "Union[CompValue, str, None]"; expected "Optional[str]" return Literal( p.string, lang=p.lang, datatype=prologue.absolutize(p.datatype) # type: ignore[arg-type] ) elif isinstance(p, URIRef): - return prologue.absolutize(p) + # type error: Incompatible return value type (got "Union[CompValue, str, None]", expected "Optional[Identifier]") + return prologue.absolutize(p) # type: ignore[return-value] @overload @@ -253,8 +262,8 @@ def translatePath(p: typing.Union[CompValue, URIRef]) -> Optional["Path"]: # ty def translateExists( - e: typing.Union[Expr, Literal, Variable] -) -> typing.Union[Expr, Literal, Variable]: + e: typing.Union[Expr, Literal, Variable, URIRef] +) -> typing.Union[Expr, Literal, Variable, URIRef]: """ Translate the graph pattern used by EXISTS and NOT EXISTS http://www.w3.org/TR/sparql11-query/#sparqlCollectFilters @@ -273,7 +282,7 @@ def _c(n): return e -def collectAndRemoveFilters(parts): +def collectAndRemoveFilters(parts: List[CompValue]) -> Optional[Expr]: """ FILTER expressions apply to the whole group graph pattern in which @@ -294,7 +303,8 @@ def collectAndRemoveFilters(parts): i += 1 if filters: - return and_(*filters) + # type error: Argument 1 to "and_" has incompatible type "*List[Union[Expr, Literal, Variable]]"; expected "Expr" + return and_(*filters) # type: ignore[arg-type] return None @@ -380,7 +390,7 @@ def translateGroupGraphPattern(graphPattern: CompValue) -> CompValue: class StopTraversal(Exception): # noqa: N818 - def __init__(self, rv): + def __init__(self, rv: bool): self.rv = rv @@ -444,7 +454,7 @@ def traverse( visitPre: Callable[[Any], Any] = lambda n: None, visitPost: Callable[[Any], Any] = lambda n: None, complete: Optional[bool] = None, -): +) -> Any: """ Traverse tree, visit each node with visit function visit function may raise StopTraversal to stop traversal @@ -504,7 +514,7 @@ def _findVars(x, res: Set[Variable]) -> Optional[CompValue]: # type: ignore[ret return x -def _addVars(x, children) -> Set[Variable]: +def _addVars(x, children: List[Set[Variable]]) -> Set[Variable]: """ find which variables may be bound by this part of the query """ @@ -549,7 +559,7 @@ def _sample(e: typing.Union[CompValue, List[Expr], Expr, List[str], Variable], v return CompValue("Aggregate_Sample", vars=e) -def _simplifyFilters(e): +def _simplifyFilters(e: Any) -> Any: if isinstance(e, Expr): return simplifyFilters(e) @@ -592,11 +602,11 @@ def translateAggregates( def translateValues( v: CompValue, -) -> typing.Union[List[Dict[Variable, Identifier]], CompValue]: +) -> typing.Union[List[Dict[Variable, str]], CompValue]: # if len(v.var)!=len(v.value): # raise Exception("Unmatched vars and values in ValueClause: "+str(v)) - res: List[Dict[Variable, Identifier]] = [] + res: List[Dict[Variable, str]] = [] if not v.var: return res if not v.value: @@ -722,7 +732,7 @@ def translate(q: CompValue) -> Tuple[CompValue, List[Variable]]: # type error: Missing return statement -def simplify(n) -> Optional[CompValue]: # type: ignore[return] +def simplify(n: Any) -> Optional[CompValue]: # type: ignore[return] """Remove joins to empty BGPs""" if isinstance(n, CompValue): if n.name == "Join": @@ -735,7 +745,7 @@ def simplify(n) -> Optional[CompValue]: # type: ignore[return] return n -def analyse(n, children): +def analyse(n: Any, children: Any) -> bool: """ Some things can be lazily joined. This propegates whether they can up the tree @@ -757,7 +767,7 @@ def analyse(n, children): def translatePrologue( p: ParseResults, base: Optional[str], - initNs: Optional[Mapping[str, str]] = None, + initNs: Optional[Mapping[str, Any]] = None, prologue: Optional[Prologue] = None, ) -> Prologue: @@ -780,7 +790,12 @@ def translatePrologue( return prologue -def translateQuads(quads: CompValue): +def translateQuads( + quads: CompValue, +) -> Tuple[ + List[Tuple[Identifier, Identifier, Identifier]], + DefaultDict[str, List[Tuple[Identifier, Identifier, Identifier]]], +]: if quads.triples: alltriples = triples(quads.triples) else: @@ -825,7 +840,7 @@ def translateUpdate1(u: CompValue, prologue: Prologue) -> CompValue: def translateUpdate( q: CompValue, base: Optional[str] = None, - initNs: Optional[Mapping[str, str]] = None, + initNs: Optional[Mapping[str, Any]] = None, ) -> Update: """ Returns a list of SPARQL Update Algebra expressions @@ -854,7 +869,7 @@ def translateUpdate( def translateQuery( q: ParseResults, base: Optional[str] = None, - initNs: Optional[Mapping[str, str]] = None, + initNs: Optional[Mapping[str, Any]] = None, ) -> Query: """ Translate a query-parsetree to a SPARQL Algebra Expression @@ -901,7 +916,7 @@ def translateAlgebra(query_algebra: Query) -> str: """ import os - def overwrite(text): + def overwrite(text: str): file = open("query.txt", "w+") file.write(text) file.close() @@ -938,19 +953,26 @@ def find_nth(haystack, needle, n): with open("query.txt", "w") as file: file.write(filedata) - aggr_vars = collections.defaultdict(list) # type: dict + aggr_vars: DefaultDict[Identifier, List[Identifier]] = collections.defaultdict(list) - def convert_node_arg(node_arg): + def convert_node_arg( + node_arg: typing.Union[Identifier, CompValue, Expr, str] + ) -> str: if isinstance(node_arg, Identifier): if node_arg in aggr_vars.keys(): - grp_var = aggr_vars[node_arg].pop(0).n3() + # type error: "Identifier" has no attribute "n3" + grp_var = aggr_vars[node_arg].pop(0).n3() # type: ignore[attr-defined] return grp_var else: - return node_arg.n3() + # type error: "Identifier" has no attribute "n3" + return node_arg.n3() # type: ignore[attr-defined] elif isinstance(node_arg, CompValue): return "{" + node_arg.name + "}" - elif isinstance(node_arg, Expr): - return "{" + node_arg.name + "}" + # type error notes: this is because Expr is a subclass of CompValue + # type error: Subclass of "str" and "Expr" cannot exist: would have incompatible method signatures + elif isinstance(node_arg, Expr): # type: ignore[unreachable] + # type error: Statement is unreachable + return "{" + node_arg.name + "}" # type: ignore[unreachable] elif isinstance(node_arg, str): return node_arg else: @@ -1529,7 +1551,7 @@ def sparql_query_text(node): return query_from_algebra -def pprintAlgebra(q): +def pprintAlgebra(q) -> None: def pp(p, ind=" "): # if isinstance(p, list): # print "[ " diff --git a/rdflib/plugins/sparql/operators.py b/rdflib/plugins/sparql/operators.py index c176691cb..fe463579f 100644 --- a/rdflib/plugins/sparql/operators.py +++ b/rdflib/plugins/sparql/operators.py @@ -1,3 +1,5 @@ +from __future__ import annotations + """ This contains evaluation functions for expressions @@ -16,12 +18,13 @@ import warnings from decimal import ROUND_HALF_UP, Decimal, InvalidOperation from functools import reduce +from typing import Any, Callable, Dict, NoReturn, Optional, Tuple, Union, overload from urllib.parse import quote import isodate from pyparsing import ParseResults -from rdflib import RDF, XSD, BNode, Literal, URIRef, Variable +from rdflib.namespace import RDF, XSD from rdflib.plugins.sparql.datatypes import ( XSD_DateTime_DTs, XSD_DTs, @@ -29,11 +32,24 @@ type_promotion, ) from rdflib.plugins.sparql.parserutils import CompValue, Expr -from rdflib.plugins.sparql.sparql import SPARQLError, SPARQLTypeError -from rdflib.term import Node +from rdflib.plugins.sparql.sparql import ( + FrozenBindings, + QueryContext, + SPARQLError, + SPARQLTypeError, +) +from rdflib.term import ( + BNode, + IdentifiedNode, + Identifier, + Literal, + Node, + URIRef, + Variable, +) -def Builtin_IRI(expr, ctx): +def Builtin_IRI(expr: Expr, ctx: FrozenBindings) -> URIRef: """ http://www.w3.org/TR/sparql11-query/#func-iri """ @@ -43,24 +59,26 @@ def Builtin_IRI(expr, ctx): if isinstance(a, URIRef): return a if isinstance(a, Literal): - return ctx.prologue.absolutize(URIRef(a)) + # type error: Item "None" of "Optional[Prologue]" has no attribute "absolutize" + # type error: Incompatible return value type (got "Union[CompValue, str, None, Any]", expected "URIRef") + return ctx.prologue.absolutize(URIRef(a)) # type: ignore[union-attr,return-value] raise SPARQLError("IRI function only accepts URIRefs or Literals/Strings!") -def Builtin_isBLANK(expr, ctx): +def Builtin_isBLANK(expr: Expr, ctx: FrozenBindings) -> Literal: return Literal(isinstance(expr.arg, BNode)) -def Builtin_isLITERAL(expr, ctx): +def Builtin_isLITERAL(expr, ctx) -> Literal: return Literal(isinstance(expr.arg, Literal)) -def Builtin_isIRI(expr, ctx): +def Builtin_isIRI(expr, ctx) -> Literal: return Literal(isinstance(expr.arg, URIRef)) -def Builtin_isNUMERIC(expr, ctx): +def Builtin_isNUMERIC(expr, ctx) -> Literal: try: numeric(expr.arg) return Literal(True) @@ -68,7 +86,7 @@ def Builtin_isNUMERIC(expr, ctx): return Literal(False) -def Builtin_BNODE(expr, ctx): +def Builtin_BNODE(expr, ctx) -> BNode: """ http://www.w3.org/TR/sparql11-query/#func-bnode """ @@ -84,7 +102,7 @@ def Builtin_BNODE(expr, ctx): raise SPARQLError("BNode function only accepts no argument or literal/string") -def Builtin_ABS(expr, ctx): +def Builtin_ABS(expr: Expr, ctx) -> Literal: """ http://www.w3.org/TR/sparql11-query/#func-abs """ @@ -92,7 +110,7 @@ def Builtin_ABS(expr, ctx): return Literal(abs(numeric(expr.arg))) -def Builtin_IF(expr, ctx): +def Builtin_IF(expr: Expr, ctx): """ http://www.w3.org/TR/sparql11-query/#func-if """ @@ -100,7 +118,7 @@ def Builtin_IF(expr, ctx): return expr.arg2 if EBV(expr.arg1) else expr.arg3 -def Builtin_RAND(expr, ctx): +def Builtin_RAND(expr: Expr, ctx) -> Literal: """ http://www.w3.org/TR/sparql11-query/#idp2133952 """ @@ -108,7 +126,7 @@ def Builtin_RAND(expr, ctx): return Literal(random.random()) -def Builtin_UUID(expr, ctx): +def Builtin_UUID(expr: Expr, ctx) -> URIRef: """ http://www.w3.org/TR/sparql11-query/#func-strdt """ @@ -116,7 +134,7 @@ def Builtin_UUID(expr, ctx): return URIRef(uuid.uuid4().urn) -def Builtin_STRUUID(expr, ctx): +def Builtin_STRUUID(expr, ctx) -> Literal: """ http://www.w3.org/TR/sparql11-query/#func-strdt """ @@ -124,32 +142,32 @@ def Builtin_STRUUID(expr, ctx): return Literal(str(uuid.uuid4())) -def Builtin_MD5(expr, ctx): +def Builtin_MD5(expr: Expr, ctx) -> Literal: s = string(expr.arg).encode("utf-8") return Literal(hashlib.md5(s).hexdigest()) -def Builtin_SHA1(expr, ctx): +def Builtin_SHA1(expr: Expr, ctx) -> Literal: s = string(expr.arg).encode("utf-8") return Literal(hashlib.sha1(s).hexdigest()) -def Builtin_SHA256(expr, ctx): +def Builtin_SHA256(expr: Expr, ctx) -> Literal: s = string(expr.arg).encode("utf-8") return Literal(hashlib.sha256(s).hexdigest()) -def Builtin_SHA384(expr, ctx): +def Builtin_SHA384(expr: Expr, ctx) -> Literal: s = string(expr.arg).encode("utf-8") return Literal(hashlib.sha384(s).hexdigest()) -def Builtin_SHA512(expr, ctx): +def Builtin_SHA512(expr: Expr, ctx) -> Literal: s = string(expr.arg).encode("utf-8") return Literal(hashlib.sha512(s).hexdigest()) -def Builtin_COALESCE(expr, ctx): +def Builtin_COALESCE(expr: Expr, ctx): """ http://www.w3.org/TR/sparql11-query/#func-coalesce """ @@ -159,7 +177,7 @@ def Builtin_COALESCE(expr, ctx): raise SPARQLError("COALESCE got no arguments that did not evaluate to an error") -def Builtin_CEIL(expr, ctx): +def Builtin_CEIL(expr: Expr, ctx) -> Literal: """ http://www.w3.org/TR/sparql11-query/#func-ceil """ @@ -168,7 +186,7 @@ def Builtin_CEIL(expr, ctx): return Literal(int(math.ceil(numeric(l_))), datatype=l_.datatype) -def Builtin_FLOOR(expr, ctx): +def Builtin_FLOOR(expr: Expr, ctx) -> Literal: """ http://www.w3.org/TR/sparql11-query/#func-floor """ @@ -176,7 +194,7 @@ def Builtin_FLOOR(expr, ctx): return Literal(int(math.floor(numeric(l_))), datatype=l_.datatype) -def Builtin_ROUND(expr, ctx): +def Builtin_ROUND(expr: Expr, ctx) -> Literal: """ http://www.w3.org/TR/sparql11-query/#func-round """ @@ -191,7 +209,7 @@ def Builtin_ROUND(expr, ctx): return Literal(v, datatype=l_.datatype) -def Builtin_REGEX(expr, ctx): +def Builtin_REGEX(expr: Expr, ctx) -> Literal: """ http://www.w3.org/TR/sparql11-query/#func-regex Invokes the XPath fn:matches function to match text against a regular @@ -214,7 +232,7 @@ def Builtin_REGEX(expr, ctx): return Literal(bool(re.search(str(pattern), text, cFlag))) -def Builtin_REPLACE(expr, ctx): +def Builtin_REPLACE(expr: Expr, ctx) -> Literal: """ http://www.w3.org/TR/sparql11-query/#func-substr """ @@ -224,7 +242,8 @@ def Builtin_REPLACE(expr, ctx): flags = expr.flags # python uses \1, xpath/sparql uses $1 - replacement = re.sub("\\$([0-9]*)", r"\\\1", replacement) + # type error: Incompatible types in assignment (expression has type "str", variable has type "Literal") + replacement = re.sub("\\$([0-9]*)", r"\\\1", replacement) # type: ignore[assignment] cFlag = 0 if flags: @@ -242,7 +261,7 @@ def Builtin_REPLACE(expr, ctx): ) -def Builtin_STRDT(expr, ctx): +def Builtin_STRDT(expr: Expr, ctx) -> Literal: """ http://www.w3.org/TR/sparql11-query/#func-strdt """ @@ -250,7 +269,7 @@ def Builtin_STRDT(expr, ctx): return Literal(str(expr.arg1), datatype=expr.arg2) -def Builtin_STRLANG(expr, ctx): +def Builtin_STRLANG(expr: Expr, ctx) -> Literal: """ http://www.w3.org/TR/sparql11-query/#func-strlang """ @@ -264,7 +283,7 @@ def Builtin_STRLANG(expr, ctx): return Literal(str(s), lang=str(expr.arg2).lower()) -def Builtin_CONCAT(expr, ctx): +def Builtin_CONCAT(expr: Expr, ctx) -> Literal: """ http://www.w3.org/TR/sparql11-query/#func-concat """ @@ -272,15 +291,20 @@ def Builtin_CONCAT(expr, ctx): # dt/lang passed on only if they all match dt = set(x.datatype for x in expr.arg if isinstance(x, Literal)) - dt = dt.pop() if len(dt) == 1 else None + # type error: Incompatible types in assignment (expression has type "Optional[str]", variable has type "Set[Optional[str]]") + dt = dt.pop() if len(dt) == 1 else None # type: ignore[assignment] lang = set(x.language for x in expr.arg if isinstance(x, Literal)) - lang = lang.pop() if len(lang) == 1 else None + # type error: error: Incompatible types in assignment (expression has type "Optional[str]", variable has type "Set[Optional[str]]") + lang = lang.pop() if len(lang) == 1 else None # type: ignore[assignment] - return Literal("".join(string(x) for x in expr.arg), datatype=dt, lang=lang) + # NOTE on type errors: this is because same variable is used for two incompatibel types + # type error: Argument "datatype" to "Literal" has incompatible type "Set[Any]"; expected "Optional[str]" [arg-type] + # type error: Argument "lang" to "Literal" has incompatible type "Set[Any]"; expected "Optional[str]" + return Literal("".join(string(x) for x in expr.arg), datatype=dt, lang=lang) # type: ignore[arg-type] -def _compatibleStrings(a, b): +def _compatibleStrings(a: Literal, b: Literal) -> None: string(a) string(b) @@ -288,7 +312,7 @@ def _compatibleStrings(a, b): raise SPARQLError("incompatible arguments to str functions") -def Builtin_STRSTARTS(expr, ctx): +def Builtin_STRSTARTS(expr: Expr, ctx) -> Literal: """ http://www.w3.org/TR/sparql11-query/#func-strstarts """ @@ -300,7 +324,7 @@ def Builtin_STRSTARTS(expr, ctx): return Literal(a.startswith(b)) -def Builtin_STRENDS(expr, ctx): +def Builtin_STRENDS(expr: Expr, ctx) -> Literal: """ http://www.w3.org/TR/sparql11-query/#func-strends """ @@ -312,7 +336,7 @@ def Builtin_STRENDS(expr, ctx): return Literal(a.endswith(b)) -def Builtin_STRBEFORE(expr, ctx): +def Builtin_STRBEFORE(expr: Expr, ctx) -> Literal: """ http://www.w3.org/TR/sparql11-query/#func-strbefore """ @@ -328,7 +352,7 @@ def Builtin_STRBEFORE(expr, ctx): return Literal(a[:i], lang=a.language, datatype=a.datatype) -def Builtin_STRAFTER(expr, ctx): +def Builtin_STRAFTER(expr: Expr, ctx) -> Literal: """ http://www.w3.org/TR/sparql11-query/#func-strafter """ @@ -344,7 +368,7 @@ def Builtin_STRAFTER(expr, ctx): return Literal(a[i + len(b) :], lang=a.language, datatype=a.datatype) -def Builtin_CONTAINS(expr, ctx): +def Builtin_CONTAINS(expr: Expr, ctx) -> Literal: """ http://www.w3.org/TR/sparql11-query/#func-strcontains """ @@ -356,11 +380,11 @@ def Builtin_CONTAINS(expr, ctx): return Literal(b in a) -def Builtin_ENCODE_FOR_URI(expr, ctx): +def Builtin_ENCODE_FOR_URI(expr: Expr, ctx) -> Literal: return Literal(quote(string(expr.arg).encode("utf-8"))) -def Builtin_SUBSTR(expr, ctx): +def Builtin_SUBSTR(expr: Expr, ctx) -> Literal: """ http://www.w3.org/TR/sparql11-query/#func-substr """ @@ -376,26 +400,26 @@ def Builtin_SUBSTR(expr, ctx): return Literal(a[start:length], lang=a.language, datatype=a.datatype) -def Builtin_STRLEN(e, ctx): +def Builtin_STRLEN(e: Expr, ctx) -> Literal: l_ = string(e.arg) return Literal(len(l_)) -def Builtin_STR(e, ctx): +def Builtin_STR(e: Expr, ctx) -> Literal: arg = e.arg if isinstance(arg, SPARQLError): raise arg return Literal(str(arg)) # plain literal -def Builtin_LCASE(e, ctx): +def Builtin_LCASE(e: Expr, ctx) -> Literal: l_ = string(e.arg) return Literal(l_.lower(), datatype=l_.datatype, lang=l_.language) -def Builtin_LANGMATCHES(e, ctx): +def Builtin_LANGMATCHES(e: Expr, ctx) -> Literal: """ http://www.w3.org/TR/sparql11-query/#func-langMatches @@ -410,39 +434,39 @@ def Builtin_LANGMATCHES(e, ctx): return Literal(_lang_range_check(langRange, langTag)) -def Builtin_NOW(e, ctx): +def Builtin_NOW(e: Expr, ctx) -> Literal: """ http://www.w3.org/TR/sparql11-query/#func-now """ return Literal(ctx.now) -def Builtin_YEAR(e, ctx): +def Builtin_YEAR(e: Expr, ctx) -> Literal: d = date(e.arg) return Literal(d.year) -def Builtin_MONTH(e, ctx): +def Builtin_MONTH(e: Expr, ctx) -> Literal: d = date(e.arg) return Literal(d.month) -def Builtin_DAY(e, ctx): +def Builtin_DAY(e: Expr, ctx) -> Literal: d = date(e.arg) return Literal(d.day) -def Builtin_HOURS(e, ctx): +def Builtin_HOURS(e: Expr, ctx) -> Literal: d = datetime(e.arg) return Literal(d.hour) -def Builtin_MINUTES(e, ctx): +def Builtin_MINUTES(e: Expr, ctx) -> Literal: d = datetime(e.arg) return Literal(d.minute) -def Builtin_SECONDS(e, ctx): +def Builtin_SECONDS(e: Expr, ctx) -> Literal: """ http://www.w3.org/TR/sparql11-query/#func-seconds """ @@ -450,7 +474,7 @@ def Builtin_SECONDS(e, ctx): return Literal(d.second, datatype=XSD.decimal) -def Builtin_TIMEZONE(e, ctx): +def Builtin_TIMEZONE(e: Expr, ctx) -> Literal: """ http://www.w3.org/TR/sparql11-query/#func-timezone @@ -463,8 +487,10 @@ def Builtin_TIMEZONE(e, ctx): delta = dt.utcoffset() - d = delta.days - s = delta.seconds + # type error: Item "None" of "Optional[timedelta]" has no attribute "days" + d = delta.days # type: ignore[union-attr] + # type error: Item "None" of "Optional[timedelta]" has no attribute "seconds" + s = delta.seconds # type: ignore[union-attr] neg = "" if d < 0: @@ -487,7 +513,7 @@ def Builtin_TIMEZONE(e, ctx): return Literal(tzdelta, datatype=XSD.dayTimeDuration) -def Builtin_TZ(e, ctx): +def Builtin_TZ(e: Expr, ctx) -> Literal: d = datetime(e.arg) if not d.tzinfo: return Literal("") @@ -497,13 +523,13 @@ def Builtin_TZ(e, ctx): return Literal(n) -def Builtin_UCASE(e, ctx): +def Builtin_UCASE(e: Expr, ctx) -> Literal: l_ = string(e.arg) return Literal(l_.upper(), datatype=l_.datatype, lang=l_.language) -def Builtin_LANG(e, ctx): +def Builtin_LANG(e: Expr, ctx) -> Literal: """ http://www.w3.org/TR/sparql11-query/#func-lang @@ -516,7 +542,7 @@ def Builtin_LANG(e, ctx): return Literal(l_.language or "") -def Builtin_DATATYPE(e, ctx): +def Builtin_DATATYPE(e: Expr, ctx) -> Optional[str]: l_ = e.arg if not isinstance(l_, Literal): raise SPARQLError("Can only get datatype of literal: %r" % l_) @@ -527,13 +553,13 @@ def Builtin_DATATYPE(e, ctx): return l_.datatype -def Builtin_sameTerm(e, ctx): +def Builtin_sameTerm(e: Expr, ctx) -> Literal: a = e.arg1 b = e.arg2 return Literal(a == b) -def Builtin_BOUND(e, ctx): +def Builtin_BOUND(e: Expr, ctx) -> Literal: """ http://www.w3.org/TR/sparql11-query/#func-bound """ @@ -542,22 +568,28 @@ def Builtin_BOUND(e, ctx): return Literal(not isinstance(n, Variable)) -def Builtin_EXISTS(e, ctx): +def Builtin_EXISTS(e: Expr, ctx: FrozenBindings) -> Literal: # damn... from rdflib.plugins.sparql.evaluate import evalPart exists = e.name == "Builtin_EXISTS" - ctx = ctx.ctx.thaw(ctx) # hmm - for x in evalPart(ctx, e.graph): + # type error: Incompatible types in assignment (expression has type "QueryContext", variable has type "FrozenBindings") + ctx = ctx.ctx.thaw(ctx) # type: ignore[assignment] # hmm + # type error: Argument 1 to "evalPart" has incompatible type "FrozenBindings"; expected "QueryContext" + for x in evalPart(ctx, e.graph): # type: ignore[arg-type] return Literal(exists) return Literal(not exists) -_CUSTOM_FUNCTIONS = {} +_CustomFunction = Callable[[Expr, FrozenBindings], Node] +_CUSTOM_FUNCTIONS: Dict[URIRef, Tuple[_CustomFunction, bool]] = {} -def register_custom_function(uri, func, override=False, raw=False): + +def register_custom_function( + uri: URIRef, func: _CustomFunction, override: bool = False, raw: bool = False +) -> None: """ Register a custom SPARQL function. @@ -571,19 +603,23 @@ def register_custom_function(uri, func, override=False, raw=False): _CUSTOM_FUNCTIONS[uri] = (func, raw) -def custom_function(uri, override=False, raw=False): +def custom_function( + uri: URIRef, override: bool = False, raw: bool = False +) -> Callable[[_CustomFunction], _CustomFunction]: """ Decorator version of :func:`register_custom_function`. """ - def decorator(func): + def decorator(func: _CustomFunction) -> _CustomFunction: register_custom_function(uri, func, override=override, raw=raw) return func return decorator -def unregister_custom_function(uri, func=None): +def unregister_custom_function( + uri: URIRef, func: Optional[Callable[..., Any]] = None +) -> None: """ The 'func' argument is included for compatibility with existing code. A previous implementation checked that the function associated with @@ -596,7 +632,7 @@ def unregister_custom_function(uri, func=None): warnings.warn("This function is not registered as %s" % uri.n3()) -def Function(e, ctx): +def Function(e: Expr, ctx: FrozenBindings) -> Node: """ Custom functions and casts """ @@ -624,7 +660,7 @@ def Function(e, ctx): @custom_function(XSD.decimal, raw=True) @custom_function(XSD.integer, raw=True) @custom_function(XSD.boolean, raw=True) -def default_cast(e, ctx): +def default_cast(e: Expr, ctx: FrozenBindings) -> Literal: # type: ignore[return] if not e.expr: raise SPARQLError("Nothing given to cast.") if len(e.expr) > 1: @@ -688,19 +724,21 @@ def default_cast(e, ctx): raise SPARQLError("Cannot interpret '%r' as bool" % x) -def UnaryNot(expr, ctx): +def UnaryNot(expr: Expr, ctx: FrozenBindings) -> Literal: return Literal(not EBV(expr.expr)) -def UnaryMinus(expr, ctx): +def UnaryMinus(expr: Expr, ctx: FrozenBindings) -> Literal: return Literal(-numeric(expr.expr)) -def UnaryPlus(expr, ctx): +def UnaryPlus(expr: Expr, ctx: FrozenBindings) -> Literal: return Literal(+numeric(expr.expr)) -def MultiplicativeExpression(e, ctx): +def MultiplicativeExpression( + e: Expr, ctx: Union[QueryContext, FrozenBindings] +) -> Literal: expr = e.expr other = e.other @@ -710,6 +748,7 @@ def MultiplicativeExpression(e, ctx): if other is None: return expr try: + res: Union[Decimal, float] res = Decimal(numeric(expr)) for op, f in zip(e.op, other): f = numeric(f) @@ -727,7 +766,8 @@ def MultiplicativeExpression(e, ctx): return Literal(res) -def AdditiveExpression(e, ctx): +# type error: Missing return statement +def AdditiveExpression(e: Expr, ctx: Union[QueryContext, FrozenBindings]) -> Literal: # type: ignore[return] expr = e.expr other = e.other @@ -756,7 +796,8 @@ def AdditiveExpression(e, ctx): # ( dateTime1 - dateTime2 - dateTime3 ) is an invalid operation if len(other) > 1: error_message = "Can't evaluate multiple %r arguments" - raise SPARQLError(error_message, dt.datatype) + # type error: Too many arguments for "SPARQLError" + raise SPARQLError(error_message, dt.datatype) # type: ignore[call-arg] else: n = dateTimeObjects(term) res = calculateDuration(res, n) @@ -802,7 +843,7 @@ def AdditiveExpression(e, ctx): return Literal(res, datatype=dt) -def RelationalExpression(e, ctx): +def RelationalExpression(e: Expr, ctx: Union[QueryContext, FrozenBindings]) -> Literal: expr = e.expr other = e.other @@ -830,7 +871,7 @@ def RelationalExpression(e, ctx): res = op == "NOT IN" - error = False + error: Union[bool, SPARQLError] = False if other == RDF.nil: other = [] @@ -844,7 +885,9 @@ def RelationalExpression(e, ctx): if not error: return Literal(False ^ res) else: - raise error + # Note on type error: this is because variable is Union[bool, SPARQLError] + # type error: Exception must be derived from BaseException + raise error # type: ignore[misc] if op not in ("=", "!=", "IN", "NOT IN"): if not isinstance(expr, Literal): @@ -882,7 +925,9 @@ def RelationalExpression(e, ctx): return Literal(r) -def ConditionalAndExpression(e, ctx): +def ConditionalAndExpression( + e: Expr, ctx: Union[QueryContext, FrozenBindings] +) -> Literal: # TODO: handle returned errors @@ -897,7 +942,9 @@ def ConditionalAndExpression(e, ctx): return Literal(all(EBV(x) for x in [expr] + other)) -def ConditionalOrExpression(e, ctx): +def ConditionalOrExpression( + e: Expr, ctx: Union[QueryContext, FrozenBindings] +) -> Literal: # TODO: handle errors @@ -923,11 +970,11 @@ def ConditionalOrExpression(e, ctx): return Literal(False) -def not_(arg): +def not_(arg) -> Expr: return Expr("UnaryNot", UnaryNot, expr=arg) -def and_(*args): +def and_(*args: Expr) -> Expr: if len(args) == 1: return args[0] @@ -942,13 +989,14 @@ def and_(*args): TrueFilter = Expr("TrueFilter", lambda _1, _2: Literal(True)) -def simplify(expr): +def simplify(expr: Any) -> Any: if isinstance(expr, ParseResults) and len(expr) == 1: return simplify(expr[0]) if isinstance(expr, (list, ParseResults)): return list(map(simplify, expr)) - if not isinstance(expr, CompValue): + # type error: Statement is unreachable + if not isinstance(expr, CompValue): # type: ignore[unreachable] return expr if expr.name.endswith("Expression"): if expr.other is None: @@ -962,13 +1010,13 @@ def simplify(expr): return expr -def literal(s): +def literal(s: Literal) -> Literal: if not isinstance(s, Literal): raise SPARQLError("Non-literal passed as string: %r" % s) return s -def datetime(e): +def datetime(e: Literal) -> py_datetime.datetime: if not isinstance(e, Literal): raise SPARQLError("Non-literal passed as datetime: %r" % e) if not e.datatype == XSD.dateTime: @@ -976,7 +1024,7 @@ def datetime(e): return e.toPython() -def date(e) -> py_datetime.date: +def date(e: Literal) -> py_datetime.date: if not isinstance(e, Literal): raise SPARQLError("Non-literal passed as date: %r" % e) if e.datatype not in (XSD.date, XSD.dateTime): @@ -987,7 +1035,7 @@ def date(e) -> py_datetime.date: return result -def string(s): +def string(s: Literal) -> Literal: """ Make sure the passed thing is a string literal i.e. plain literal, xsd:string literal or lang-tagged literal @@ -999,7 +1047,7 @@ def string(s): return s -def numeric(expr): +def numeric(expr: Literal) -> Any: """ return a number from a literal http://www.w3.org/TR/xpath20/#promotion @@ -1033,7 +1081,7 @@ def numeric(expr): return expr.toPython() -def dateTimeObjects(expr): +def dateTimeObjects(expr: Literal) -> Any: """ return a dataTime/date/time/duration/dayTimeDuration/yearMonthDuration python objects from a literal @@ -1041,7 +1089,13 @@ def dateTimeObjects(expr): return expr.toPython() -def isCompatibleDateTimeDatatype(obj1, dt1, obj2, dt2): +# type error: Missing return statement +def isCompatibleDateTimeDatatype( # type: ignore[return] + obj1: Union[py_datetime.date, py_datetime.datetime], + dt1: URIRef, + obj2: Union[isodate.Duration, py_datetime.timedelta], + dt2: URIRef, +) -> bool: """ Returns a boolean indicating if first object is compatible with operation(+/-) over second object. @@ -1075,18 +1129,28 @@ def isCompatibleDateTimeDatatype(obj1, dt1, obj2, dt2): return True -def calculateDuration(obj1, obj2): +def calculateDuration( + obj1: Union[py_datetime.date, py_datetime.datetime], + obj2: Union[py_datetime.date, py_datetime.datetime], +) -> Literal: """ returns the duration Literal between two datetime """ date1 = obj1 date2 = obj2 - difference = date1 - date2 + # type error: No overload variant of "__sub__" of "datetime" matches argument type "date" + difference = date1 - date2 # type: ignore[operator] return Literal(difference, datatype=XSD.duration) -def calculateFinalDateTime(obj1, dt1, obj2, dt2, operation): +def calculateFinalDateTime( + obj1: Union[py_datetime.date, py_datetime.datetime], + dt1: URIRef, + obj2: Union[isodate.Duration, py_datetime.timedelta], + dt2: URIRef, + operation: str, +) -> Literal: """ Calculates the final dateTime/date/time resultant after addition/ subtraction of duration/dayTimeDuration/yearMonthDuration @@ -1106,7 +1170,22 @@ def calculateFinalDateTime(obj1, dt1, obj2, dt2, operation): raise SPARQLError("Incompatible Data types to DateTime Operations") -def EBV(rt): +@overload +def EBV(rt: Literal) -> bool: + ... + + +@overload +def EBV(rt: Union[Variable, IdentifiedNode, SPARQLError, Expr]) -> NoReturn: + ... + + +@overload +def EBV(rt: Union[Identifier, SPARQLError, Expr]) -> Union[bool, NoReturn]: + ... + + +def EBV(rt: Union[Identifier, SPARQLError, Expr]) -> bool: """ Effective Boolean Value (EBV) @@ -1151,7 +1230,7 @@ def EBV(rt): ) -def _lang_range_check(range, lang): +def _lang_range_check(range: Literal, lang: Literal) -> bool: """ Implementation of the extended filtering algorithm, as defined in point 3.3.2, of U{RFC 4647}, on @@ -1169,7 +1248,7 @@ def _lang_range_check(range, lang): """ - def _match(r, l_): + def _match(r: str, l_: str) -> bool: """ Matching of a range and language item: either range is a wildcard or the two are equal diff --git a/test/test_issues/test_issue1043.py b/test/test_issues/test_issue1043.py index bd6e9a34c..896529e5d 100644 --- a/test/test_issues/test_issue1043.py +++ b/test/test_issues/test_issue1043.py @@ -1,30 +1,30 @@ -import io -import sys -import unittest - -from rdflib import RDFS, XSD, Graph, Literal, Namespace - - -class TestIssue1043(unittest.TestCase): - def test_issue_1043(self): - expected = """@prefix rdfs: . -@prefix xsd: . - - rdfs:label 4e-08 . - - -""" - capturedOutput = io.StringIO() - sys.stdout = capturedOutput - g = Graph() - g.bind("xsd", XSD) - g.bind("rdfs", RDFS) - n = Namespace("http://example.org/") - g.add((n.number, RDFS.label, Literal(0.00000004, datatype=XSD.decimal))) - g.print() - sys.stdout = sys.__stdout__ - self.assertEqual(capturedOutput.getvalue(), expected) - - -if __name__ == "__main__": - unittest.main() +import io +import sys +import unittest + +from rdflib import RDFS, XSD, Graph, Literal, Namespace + + +class TestIssue1043(unittest.TestCase): + def test_issue_1043(self): + expected = """@prefix rdfs: . +@prefix xsd: . + + rdfs:label 4e-08 . + + +""" + capturedOutput = io.StringIO() + sys.stdout = capturedOutput + g = Graph() + g.bind("xsd", XSD) + g.bind("rdfs", RDFS) + n = Namespace("http://example.org/") + g.add((n.number, RDFS.label, Literal(0.00000004, datatype=XSD.decimal))) + g.print() + sys.stdout = sys.__stdout__ + self.assertEqual(capturedOutput.getvalue(), expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_issues/test_issue910.py b/test/test_issues/test_issue910.py index c3640d3c2..2d6082fa9 100644 --- a/test/test_issues/test_issue910.py +++ b/test/test_issues/test_issue910.py @@ -1,66 +1,66 @@ -import unittest - -from rdflib import Graph - - -class TestIssue910(unittest.TestCase): - def testA(self): - g = Graph() - q = g.query( - """ - SELECT * { - { BIND ("a" AS ?a) } - UNION - { BIND ("a" AS ?a) } - } - """ - ) - self.assertEqual(len(q) == 2, True) - - def testB(self): - g = Graph() - q = g.query( - """ - SELECT * { - { BIND ("a" AS ?a) } - UNION - { VALUES ?a { "a" } } - UNION - { SELECT ("a" AS ?a) {} } - } - """ - ) - self.assertEqual(len(q) == 3, True) - - def testC(self): - g = Graph() - q = g.query( - """ - SELECT * { - { BIND ("a" AS ?a) } - UNION - { VALUES ?a { "a" } } - UNION - { SELECT ("b" AS ?a) {} } - } - """ - ) - self.assertEqual(len(q) == 3, True) - - def testD(self): - g = Graph() - q = g.query( - """SELECT * { - { BIND ("a" AS ?a) } - UNION - { VALUES ?a { "b" } } - UNION - { SELECT ("c" AS ?a) {} } - } - """ - ) - self.assertEqual(len(q) == 3, True) - - -if __name__ == "__main__": - unittest.main() +import unittest + +from rdflib import Graph + + +class TestIssue910(unittest.TestCase): + def testA(self): + g = Graph() + q = g.query( + """ + SELECT * { + { BIND ("a" AS ?a) } + UNION + { BIND ("a" AS ?a) } + } + """ + ) + self.assertEqual(len(q) == 2, True) + + def testB(self): + g = Graph() + q = g.query( + """ + SELECT * { + { BIND ("a" AS ?a) } + UNION + { VALUES ?a { "a" } } + UNION + { SELECT ("a" AS ?a) {} } + } + """ + ) + self.assertEqual(len(q) == 3, True) + + def testC(self): + g = Graph() + q = g.query( + """ + SELECT * { + { BIND ("a" AS ?a) } + UNION + { VALUES ?a { "a" } } + UNION + { SELECT ("b" AS ?a) {} } + } + """ + ) + self.assertEqual(len(q) == 3, True) + + def testD(self): + g = Graph() + q = g.query( + """SELECT * { + { BIND ("a" AS ?a) } + UNION + { VALUES ?a { "b" } } + UNION + { SELECT ("c" AS ?a) {} } + } + """ + ) + self.assertEqual(len(q) == 3, True) + + +if __name__ == "__main__": + unittest.main()