Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ Release date: TBA

refs #2789

* Add support for type parameter defaults added in Python 3.13.

* Improve ``as_string()`` representation for ``TypeVar``, ``ParamSpec`` and ``TypeVarTuple`` nodes, as well as
type parameter in ``ClassDef``, ``FuncDef`` and ``TypeAlias`` nodes (PEP 695).


What's New in astroid 3.3.11?
=============================
Expand Down
39 changes: 30 additions & 9 deletions astroid/nodes/as_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,18 +176,27 @@ def visit_call(self, node: nodes.Call) -> str:
args.extend(keywords)
return f"{expr_str}({', '.join(args)})"

def _handle_type_params(
self, type_params: list[nodes.TypeVar | nodes.ParamSpec | nodes.TypeVarTuple]
) -> str:
return (
f"[{', '.join(tp.accept(self) for tp in type_params)}]"
if type_params
else ""
)

def visit_classdef(self, node: nodes.ClassDef) -> str:
"""return an astroid.ClassDef node as string"""
decorate = node.decorators.accept(self) if node.decorators else ""
type_params = self._handle_type_params(node.type_params)
args = [n.accept(self) for n in node.bases]
if node._metaclass and not node.has_metaclass_hack():
args.append("metaclass=" + node._metaclass.accept(self))
args += [n.accept(self) for n in node.keywords]
args_str = f"({', '.join(args)})" if args else ""
docs = self._docs_dedent(node.doc_node)
# TODO: handle type_params
return "\n\n{}class {}{}:{}\n{}\n".format(
decorate, node.name, args_str, docs, self._stmt_list(node.body)
return "\n\n{}class {}{}{}:{}\n{}\n".format(
decorate, node.name, type_params, args_str, docs, self._stmt_list(node.body)
)

def visit_compare(self, node: nodes.Compare) -> str:
Expand Down Expand Up @@ -336,17 +345,18 @@ def visit_formattedvalue(self, node: nodes.FormattedValue) -> str:
def handle_functiondef(self, node: nodes.FunctionDef, keyword: str) -> str:
"""return a (possibly async) function definition node as string"""
decorate = node.decorators.accept(self) if node.decorators else ""
type_params = self._handle_type_params(node.type_params)
docs = self._docs_dedent(node.doc_node)
trailer = ":"
if node.returns:
return_annotation = " -> " + node.returns.as_string()
trailer = return_annotation + ":"
# TODO: handle type_params
def_format = "\n%s%s %s(%s)%s%s\n%s"
def_format = "\n%s%s %s%s(%s)%s%s\n%s"
return def_format % (
decorate,
keyword,
node.name,
type_params,
node.args.accept(self),
trailer,
docs,
Expand Down Expand Up @@ -455,7 +465,10 @@ def visit_nonlocal(self, node: nodes.Nonlocal) -> str:

def visit_paramspec(self, node: nodes.ParamSpec) -> str:
"""return an astroid.ParamSpec node as string"""
return node.name.accept(self)
default_value_str = (
f" = {node.default_value.accept(self)}" if node.default_value else ""
)
return f"**{node.name.accept(self)}{default_value_str}"

def visit_pass(self, node: nodes.Pass) -> str:
"""return an astroid.Pass node as string"""
Expand Down Expand Up @@ -545,15 +558,23 @@ def visit_tuple(self, node: nodes.Tuple) -> str:

def visit_typealias(self, node: nodes.TypeAlias) -> str:
"""return an astroid.TypeAlias node as string"""
return node.name.accept(self) if node.name else "_"
type_params = self._handle_type_params(node.type_params)
return f"type {node.name.accept(self)}{type_params} = {node.value.accept(self)}"

def visit_typevar(self, node: nodes.TypeVar) -> str:
"""return an astroid.TypeVar node as string"""
return node.name.accept(self) if node.name else "_"
bound_str = f": {node.bound.accept(self)}" if node.bound else ""
default_value_str = (
f" = {node.default_value.accept(self)}" if node.default_value else ""
)
return f"{node.name.accept(self)}{bound_str}{default_value_str}"

def visit_typevartuple(self, node: nodes.TypeVarTuple) -> str:
"""return an astroid.TypeVarTuple node as string"""
return "*" + node.name.accept(self) if node.name else ""
default_value_str = (
f" = {node.default_value.accept(self)}" if node.default_value else ""
)
return f"*{node.name.accept(self)}{default_value_str}"

def visit_unaryop(self, node: nodes.UnaryOp) -> str:
"""return an astroid.UnaryOp node as string"""
Expand Down
29 changes: 20 additions & 9 deletions astroid/nodes/node_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3383,9 +3383,9 @@ class ParamSpec(_base_nodes.AssignTypeNode):
<ParamSpec l.1 at 0x7f23b2e4e198>
"""

_astroid_fields = ("name",)

_astroid_fields = ("name", "default_value")
name: AssignName
default_value: NodeNG | None

def __init__(
self,
Expand All @@ -3404,8 +3404,9 @@ def __init__(
parent=parent,
)

def postinit(self, *, name: AssignName) -> None:
def postinit(self, *, name: AssignName, default_value: NodeNG | None) -> None:
self.name = name
self.default_value = default_value

def _infer(
self, context: InferenceContext | None = None, **kwargs: Any
Expand Down Expand Up @@ -4141,10 +4142,10 @@ class TypeVar(_base_nodes.AssignTypeNode):
<TypeVar l.1 at 0x7f23b2e4e198>
"""

_astroid_fields = ("name", "bound")

_astroid_fields = ("name", "bound", "default_value")
name: AssignName
bound: NodeNG | None
default_value: NodeNG | None

def __init__(
self,
Expand All @@ -4163,9 +4164,16 @@ def __init__(
parent=parent,
)

def postinit(self, *, name: AssignName, bound: NodeNG | None) -> None:
def postinit(
self,
*,
name: AssignName,
bound: NodeNG | None,
default_value: NodeNG | None = None,
) -> None:
self.name = name
self.bound = bound
self.default_value = default_value

def _infer(
self, context: InferenceContext | None = None, **kwargs: Any
Expand All @@ -4187,9 +4195,9 @@ class TypeVarTuple(_base_nodes.AssignTypeNode):
<TypeVarTuple l.1 at 0x7f23b2e4e198>
"""

_astroid_fields = ("name",)

_astroid_fields = ("name", "default_value")
name: AssignName
default_value: NodeNG | None

def __init__(
self,
Expand All @@ -4208,8 +4216,11 @@ def __init__(
parent=parent,
)

def postinit(self, *, name: AssignName) -> None:
def postinit(
self, *, name: AssignName, default_value: NodeNG | None = None
) -> None:
self.name = name
self.default_value = default_value

def _infer(
self, context: InferenceContext | None = None, **kwargs: Any
Expand Down
19 changes: 16 additions & 3 deletions astroid/rebuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from astroid import nodes
from astroid._ast import ParserModule, get_parser_module, parse_function_type_comment
from astroid.const import PY312_PLUS, Context
from astroid.const import PY312_PLUS, PY313_PLUS, Context
from astroid.nodes.utils import Position
from astroid.typing import InferenceResult

Expand Down Expand Up @@ -1483,7 +1483,12 @@ def visit_paramspec(
)
# Add AssignName node for 'node.name'
# https://bugs.python.org/issue43994
newnode.postinit(name=self.visit_assignname(node, newnode, node.name))
newnode.postinit(
name=self.visit_assignname(node, newnode, node.name),
default_value=(
self.visit(node.default_value, newnode) if PY313_PLUS else None
),
)
return newnode

def visit_pass(self, node: ast.Pass, parent: nodes.NodeNG) -> nodes.Pass:
Expand Down Expand Up @@ -1679,6 +1684,9 @@ def visit_typevar(self, node: ast.TypeVar, parent: nodes.NodeNG) -> nodes.TypeVa
newnode.postinit(
name=self.visit_assignname(node, newnode, node.name),
bound=self.visit(node.bound, newnode),
default_value=(
self.visit(node.default_value, newnode) if PY313_PLUS else None
),
)
return newnode

Expand All @@ -1695,7 +1703,12 @@ def visit_typevartuple(
)
# Add AssignName node for 'node.name'
# https://bugs.python.org/issue43994
newnode.postinit(name=self.visit_assignname(node, newnode, node.name))
newnode.postinit(
name=self.visit_assignname(node, newnode, node.name),
default_value=(
self.visit(node.default_value, newnode) if PY313_PLUS else None
),
)
return newnode

def visit_unaryop(self, node: ast.UnaryOp, parent: nodes.NodeNG) -> nodes.UnaryOp:
Expand Down
62 changes: 56 additions & 6 deletions tests/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
IS_PYPY,
PY311_PLUS,
PY312_PLUS,
PY313_PLUS,
PY314_PLUS,
Context,
)
Expand Down Expand Up @@ -332,27 +333,76 @@ def test_recursion_error_trapped() -> None:
class AsStringTypeParamNodes(unittest.TestCase):
@staticmethod
def test_as_string_type_alias() -> None:
ast = abuilder.string_build("type Point = tuple[float, float]")
type_alias = ast.body[0]
assert type_alias.as_string().strip() == "Point"
ast1 = abuilder.string_build("type Point = tuple[float, float]")
type_alias1 = ast1.body[0]
assert type_alias1.as_string().strip() == "type Point = tuple[float, float]"
ast2 = abuilder.string_build(
"type Point[T, **P] = tuple[float, T, Callable[P, None]]"
)
type_alias2 = ast2.body[0]
assert (
type_alias2.as_string().strip()
== "type Point[T, **P] = tuple[float, T, Callable[P, None]]"
)

@staticmethod
def test_as_string_type_var() -> None:
ast = abuilder.string_build("type Point[T] = tuple[float, float]")
ast = abuilder.string_build("type Point[T: int | str] = tuple[float, float]")
type_var = ast.body[0].type_params[0]
assert type_var.as_string().strip() == "T: int | str"

@staticmethod
@pytest.mark.skipif(
not PY313_PLUS, reason="Type parameter defaults were added in 313"
)
def test_as_string_type_var_default() -> None:
ast = abuilder.string_build(
"type Point[T: int | str = int] = tuple[float, float]"
)
type_var = ast.body[0].type_params[0]
assert type_var.as_string().strip() == "T"
assert type_var.as_string().strip() == "T: int | str = int"

@staticmethod
def test_as_string_type_var_tuple() -> None:
ast = abuilder.string_build("type Alias[*Ts] = tuple[*Ts]")
type_var_tuple = ast.body[0].type_params[0]
assert type_var_tuple.as_string().strip() == "*Ts"

@staticmethod
@pytest.mark.skipif(
not PY313_PLUS, reason="Type parameter defaults were added in 313"
)
def test_as_string_type_var_tuple_defaults() -> None:
ast = abuilder.string_build("type Alias[*Ts = tuple[int, str]] = tuple[*Ts]")
type_var_tuple = ast.body[0].type_params[0]
assert type_var_tuple.as_string().strip() == "*Ts = tuple[int, str]"

@staticmethod
def test_as_string_param_spec() -> None:
ast = abuilder.string_build("type Alias[**P] = Callable[P, int]")
param_spec = ast.body[0].type_params[0]
assert param_spec.as_string().strip() == "P"
assert param_spec.as_string().strip() == "**P"

@staticmethod
@pytest.mark.skipif(
not PY313_PLUS, reason="Type parameter defaults were added in 313"
)
def test_as_string_param_spec_defaults() -> None:
ast = abuilder.string_build("type Alias[**P = [str, int]] = Callable[P, int]")
param_spec = ast.body[0].type_params[0]
assert param_spec.as_string().strip() == "**P = [str, int]"

@staticmethod
def test_as_string_class_type_params() -> None:
code = abuilder.string_build("class A[T, **P]: ...")
cls_node = code.body[0]
assert cls_node.as_string().strip() == "class A[T, **P]:\n ..."

@staticmethod
def test_as_string_function_type_params() -> None:
code = abuilder.string_build("def func[T, **P](): ...")
func_node = code.body[0]
assert func_node.as_string().strip() == "def func[T, **P]():\n ..."


class _NodeTest(unittest.TestCase):
Expand Down
Loading
Loading