-
Notifications
You must be signed in to change notification settings - Fork 50
Improve type checking in query pagination and AST manipulation code. #898
base: main
Are you sure you want to change the base?
Changes from 3 commits
4cec241
acc53da
14fc32c
4c9909c
ee73216
e7c7022
21f3742
2cf340a
259d287
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,41 +1,53 @@ | ||
| # Copyright 2019-present Kensho Technologies, LLC. | ||
| from typing import Any, Optional, Type, Union | ||
|
|
||
| from graphql.error import GraphQLSyntaxError | ||
| from graphql.language.ast import ( | ||
| DocumentNode, | ||
| FieldNode, | ||
| InlineFragmentNode, | ||
| ListTypeNode, | ||
| NamedTypeNode, | ||
| Node, | ||
| NonNullTypeNode, | ||
| OperationDefinitionNode, | ||
| OperationType, | ||
| SelectionNode, | ||
| TypeNode, | ||
| ) | ||
| from graphql.language.parser import parse | ||
| from graphql.pyutils import FrozenList | ||
|
|
||
| from .exceptions import GraphQLParsingError | ||
|
|
||
|
|
||
| def get_ast_field_name(ast): | ||
| def get_ast_field_name(ast: FieldNode) -> str: | ||
| """Return the field name for the given AST node.""" | ||
| return ast.name.value | ||
|
|
||
|
|
||
| def get_ast_field_name_or_none(ast): | ||
| def get_ast_field_name_or_none(ast: Union[FieldNode, InlineFragmentNode]) -> Optional[str]: | ||
| """Return the field name for the AST node, or None if the AST is an InlineFragment.""" | ||
| if isinstance(ast, InlineFragmentNode): | ||
| return None | ||
| return get_ast_field_name(ast) | ||
|
|
||
|
|
||
| def get_human_friendly_ast_field_name(ast): | ||
| def get_human_friendly_ast_field_name(ast: Node) -> str: | ||
| """Return a human-friendly name for the AST node, suitable for error messages.""" | ||
| if isinstance(ast, InlineFragmentNode): | ||
| return "type coercion to {}".format(ast.type_condition) | ||
| elif isinstance(ast, OperationDefinitionNode): | ||
| return "{} operation definition".format(ast.operation) | ||
|
|
||
| return get_ast_field_name(ast) | ||
| elif isinstance(ast, FieldNode): | ||
| return get_ast_field_name(ast) | ||
| else: | ||
| # Fall back to Node's __repr__() method. | ||
| # If we need more information for a specific type, we can add another branch in the if-elif. | ||
| return repr(ast) | ||
|
|
||
|
|
||
| def safe_parse_graphql(graphql_string): | ||
| def safe_parse_graphql(graphql_string: str) -> DocumentNode: | ||
| """Return an AST representation of the given GraphQL input, reraising GraphQL library errors.""" | ||
| try: | ||
| ast = parse(graphql_string) | ||
|
|
@@ -45,7 +57,9 @@ def safe_parse_graphql(graphql_string): | |
| return ast | ||
|
|
||
|
|
||
| def get_only_query_definition(document_ast, desired_error_type): | ||
| def get_only_query_definition( | ||
| document_ast: DocumentNode, desired_error_type: Type[Exception] | ||
| ) -> OperationDefinitionNode: | ||
| """Assert that the Document AST contains only a single definition for a query, and return it.""" | ||
| if not isinstance(document_ast, DocumentNode) or not document_ast.definitions: | ||
| raise AssertionError( | ||
|
|
@@ -59,6 +73,12 @@ def get_only_query_definition(document_ast, desired_error_type): | |
| ) | ||
|
|
||
| definition_ast = document_ast.definitions[0] | ||
| if not isinstance(definition_ast, OperationDefinitionNode): | ||
| raise desired_error_type( | ||
| f"Expected a query definition at the start of the GraphQL input, but found an " | ||
| f"unsupported and unrecognized definition instead: {definition_ast}" | ||
| ) | ||
|
|
||
| if definition_ast.operation != OperationType.QUERY: | ||
| raise desired_error_type( | ||
| "Expected a GraphQL document with a single query definition, but instead found a " | ||
|
|
@@ -70,9 +90,12 @@ def get_only_query_definition(document_ast, desired_error_type): | |
| return definition_ast | ||
|
|
||
|
|
||
| def get_only_selection_from_ast(ast, desired_error_type): | ||
| def get_only_selection_from_ast( | ||
| ast: Union[FieldNode, InlineFragmentNode, OperationDefinitionNode], | ||
| desired_error_type: Type[Exception], | ||
| ) -> SelectionNode: | ||
| """Return the selected sub-ast, ensuring that there is precisely one.""" | ||
| selections = [] if ast.selection_set is None else ast.selection_set.selections | ||
| selections = FrozenList([]) if ast.selection_set is None else ast.selection_set.selections | ||
|
|
||
| if len(selections) != 1: | ||
| ast_name = get_human_friendly_ast_field_name(ast) | ||
|
|
@@ -96,22 +119,68 @@ def get_only_selection_from_ast(ast, desired_error_type): | |
| return selections[0] | ||
|
|
||
|
|
||
| def get_ast_with_non_null_stripped(ast): | ||
| def get_ast_with_non_null_stripped(ast: TypeNode) -> Union[ListTypeNode, NamedTypeNode]: | ||
| """Strip a NonNullType layer around the AST if there is one, return the underlying AST.""" | ||
| result: TypeNode | ||
|
|
||
| if isinstance(ast, NonNullTypeNode): | ||
| stripped_ast = ast.type | ||
| if isinstance(stripped_ast, NonNullTypeNode): | ||
| raise AssertionError( | ||
| "NonNullType is unexpectedly found to wrap around another NonNullType in AST " | ||
| "{}, which is not allowed.".format(ast) | ||
| ) | ||
| return stripped_ast | ||
| result = stripped_ast | ||
| else: | ||
| return ast | ||
| result = ast | ||
|
|
||
| if not isinstance(result, (ListTypeNode, NamedTypeNode)): | ||
| raise AssertionError( | ||
| f"Expected the result to be either a ListTypeNode or a NamedTypeNode, but instead " | ||
| f"found: {result}" | ||
| ) | ||
|
|
||
| return result | ||
|
|
||
| def get_ast_with_non_null_and_list_stripped(ast): | ||
|
|
||
| def get_ast_with_non_null_and_list_stripped(ast: TypeNode) -> NamedTypeNode: | ||
| """Strip any NonNullType or List layers around the AST, return the underlying AST.""" | ||
| while isinstance(ast, (NonNullTypeNode, ListTypeNode)): | ||
| ast = ast.type | ||
|
|
||
| if not isinstance(ast, NamedTypeNode): | ||
| raise AssertionError( | ||
| f"Expected the AST value to be a NamedTypeNode, but unexpectedly instead found: {ast}" | ||
| ) | ||
|
|
||
| return ast | ||
|
|
||
|
|
||
| def assert_selection_is_a_field_node(ast: Any) -> FieldNode: | ||
|
||
| """Return the input value, if it is indeed a FieldNode, or raise AssertionError otherwise.""" | ||
| # N.B.: Using Any as the input type, since the inputs here are a variety of generics, unions, | ||
| # and a number of AST types. Since this function asserts the type of the input anyway, | ||
| # this is not a concern. | ||
| if not isinstance(ast, FieldNode): | ||
| raise AssertionError( | ||
| f"Expected AST to be a FieldNode, but instead found {type(ast).__name__}. " | ||
| f"This is a bug. Node value: {ast}" | ||
| ) | ||
|
|
||
| return ast | ||
|
|
||
|
|
||
| def assert_selection_is_a_field_or_inline_fragment_node( | ||
| ast: Any, | ||
| ) -> Union[FieldNode, InlineFragmentNode]: | ||
| """Return the input FieldNode or InlineFragmentNode, or raise AssertionError otherwise.""" | ||
| # N.B.: Using Any as the input type, since the inputs here are a variety of generics, unions, | ||
| # and a number of AST types. Since this function asserts the type of the input anyway, | ||
| # this is not a concern. | ||
| if not isinstance(ast, (FieldNode, InlineFragmentNode)): | ||
| raise AssertionError( | ||
| f"Expected AST to be a FieldNode or InlineFragmentNode, but instead " | ||
| f"found {type(ast).__name__}. This is a bug. Node value: {ast}" | ||
| ) | ||
|
|
||
| return ast | ||
Uh oh!
There was an error while loading. Please reload this page.