Skip to content

Commit 62c30c2

Browse files
authored
Add support for type parameter defaults + string repr improvements (#2797)
1 parent 1b3e7f1 commit 62c30c2

File tree

6 files changed

+183
-28
lines changed

6 files changed

+183
-28
lines changed

ChangeLog

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ Release date: TBA
6060

6161
refs #2789
6262

63+
* Add support for type parameter defaults added in Python 3.13.
64+
65+
* Improve ``as_string()`` representation for ``TypeVar``, ``ParamSpec`` and ``TypeVarTuple`` nodes, as well as
66+
type parameter in ``ClassDef``, ``FuncDef`` and ``TypeAlias`` nodes (PEP 695).
67+
6368

6469
What's New in astroid 3.3.11?
6570
=============================

astroid/nodes/as_string.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -176,18 +176,27 @@ def visit_call(self, node: nodes.Call) -> str:
176176
args.extend(keywords)
177177
return f"{expr_str}({', '.join(args)})"
178178

179+
def _handle_type_params(
180+
self, type_params: list[nodes.TypeVar | nodes.ParamSpec | nodes.TypeVarTuple]
181+
) -> str:
182+
return (
183+
f"[{', '.join(tp.accept(self) for tp in type_params)}]"
184+
if type_params
185+
else ""
186+
)
187+
179188
def visit_classdef(self, node: nodes.ClassDef) -> str:
180189
"""return an astroid.ClassDef node as string"""
181190
decorate = node.decorators.accept(self) if node.decorators else ""
191+
type_params = self._handle_type_params(node.type_params)
182192
args = [n.accept(self) for n in node.bases]
183193
if node._metaclass and not node.has_metaclass_hack():
184194
args.append("metaclass=" + node._metaclass.accept(self))
185195
args += [n.accept(self) for n in node.keywords]
186196
args_str = f"({', '.join(args)})" if args else ""
187197
docs = self._docs_dedent(node.doc_node)
188-
# TODO: handle type_params
189-
return "\n\n{}class {}{}:{}\n{}\n".format(
190-
decorate, node.name, args_str, docs, self._stmt_list(node.body)
198+
return "\n\n{}class {}{}{}:{}\n{}\n".format(
199+
decorate, node.name, type_params, args_str, docs, self._stmt_list(node.body)
191200
)
192201

193202
def visit_compare(self, node: nodes.Compare) -> str:
@@ -336,17 +345,18 @@ def visit_formattedvalue(self, node: nodes.FormattedValue) -> str:
336345
def handle_functiondef(self, node: nodes.FunctionDef, keyword: str) -> str:
337346
"""return a (possibly async) function definition node as string"""
338347
decorate = node.decorators.accept(self) if node.decorators else ""
348+
type_params = self._handle_type_params(node.type_params)
339349
docs = self._docs_dedent(node.doc_node)
340350
trailer = ":"
341351
if node.returns:
342352
return_annotation = " -> " + node.returns.as_string()
343353
trailer = return_annotation + ":"
344-
# TODO: handle type_params
345-
def_format = "\n%s%s %s(%s)%s%s\n%s"
354+
def_format = "\n%s%s %s%s(%s)%s%s\n%s"
346355
return def_format % (
347356
decorate,
348357
keyword,
349358
node.name,
359+
type_params,
350360
node.args.accept(self),
351361
trailer,
352362
docs,
@@ -455,7 +465,10 @@ def visit_nonlocal(self, node: nodes.Nonlocal) -> str:
455465

456466
def visit_paramspec(self, node: nodes.ParamSpec) -> str:
457467
"""return an astroid.ParamSpec node as string"""
458-
return node.name.accept(self)
468+
default_value_str = (
469+
f" = {node.default_value.accept(self)}" if node.default_value else ""
470+
)
471+
return f"**{node.name.accept(self)}{default_value_str}"
459472

460473
def visit_pass(self, node: nodes.Pass) -> str:
461474
"""return an astroid.Pass node as string"""
@@ -545,15 +558,23 @@ def visit_tuple(self, node: nodes.Tuple) -> str:
545558

546559
def visit_typealias(self, node: nodes.TypeAlias) -> str:
547560
"""return an astroid.TypeAlias node as string"""
548-
return node.name.accept(self) if node.name else "_"
561+
type_params = self._handle_type_params(node.type_params)
562+
return f"type {node.name.accept(self)}{type_params} = {node.value.accept(self)}"
549563

550564
def visit_typevar(self, node: nodes.TypeVar) -> str:
551565
"""return an astroid.TypeVar node as string"""
552-
return node.name.accept(self) if node.name else "_"
566+
bound_str = f": {node.bound.accept(self)}" if node.bound else ""
567+
default_value_str = (
568+
f" = {node.default_value.accept(self)}" if node.default_value else ""
569+
)
570+
return f"{node.name.accept(self)}{bound_str}{default_value_str}"
553571

554572
def visit_typevartuple(self, node: nodes.TypeVarTuple) -> str:
555573
"""return an astroid.TypeVarTuple node as string"""
556-
return "*" + node.name.accept(self) if node.name else ""
574+
default_value_str = (
575+
f" = {node.default_value.accept(self)}" if node.default_value else ""
576+
)
577+
return f"*{node.name.accept(self)}{default_value_str}"
557578

558579
def visit_unaryop(self, node: nodes.UnaryOp) -> str:
559580
"""return an astroid.UnaryOp node as string"""

astroid/nodes/node_classes.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3383,9 +3383,9 @@ class ParamSpec(_base_nodes.AssignTypeNode):
33833383
<ParamSpec l.1 at 0x7f23b2e4e198>
33843384
"""
33853385

3386-
_astroid_fields = ("name",)
3387-
3386+
_astroid_fields = ("name", "default_value")
33883387
name: AssignName
3388+
default_value: NodeNG | None
33893389

33903390
def __init__(
33913391
self,
@@ -3404,8 +3404,9 @@ def __init__(
34043404
parent=parent,
34053405
)
34063406

3407-
def postinit(self, *, name: AssignName) -> None:
3407+
def postinit(self, *, name: AssignName, default_value: NodeNG | None) -> None:
34083408
self.name = name
3409+
self.default_value = default_value
34093410

34103411
def _infer(
34113412
self, context: InferenceContext | None = None, **kwargs: Any
@@ -4141,10 +4142,10 @@ class TypeVar(_base_nodes.AssignTypeNode):
41414142
<TypeVar l.1 at 0x7f23b2e4e198>
41424143
"""
41434144

4144-
_astroid_fields = ("name", "bound")
4145-
4145+
_astroid_fields = ("name", "bound", "default_value")
41464146
name: AssignName
41474147
bound: NodeNG | None
4148+
default_value: NodeNG | None
41484149

41494150
def __init__(
41504151
self,
@@ -4163,9 +4164,16 @@ def __init__(
41634164
parent=parent,
41644165
)
41654166

4166-
def postinit(self, *, name: AssignName, bound: NodeNG | None) -> None:
4167+
def postinit(
4168+
self,
4169+
*,
4170+
name: AssignName,
4171+
bound: NodeNG | None,
4172+
default_value: NodeNG | None = None,
4173+
) -> None:
41674174
self.name = name
41684175
self.bound = bound
4176+
self.default_value = default_value
41694177

41704178
def _infer(
41714179
self, context: InferenceContext | None = None, **kwargs: Any
@@ -4187,9 +4195,9 @@ class TypeVarTuple(_base_nodes.AssignTypeNode):
41874195
<TypeVarTuple l.1 at 0x7f23b2e4e198>
41884196
"""
41894197

4190-
_astroid_fields = ("name",)
4191-
4198+
_astroid_fields = ("name", "default_value")
41924199
name: AssignName
4200+
default_value: NodeNG | None
41934201

41944202
def __init__(
41954203
self,
@@ -4208,8 +4216,11 @@ def __init__(
42084216
parent=parent,
42094217
)
42104218

4211-
def postinit(self, *, name: AssignName) -> None:
4219+
def postinit(
4220+
self, *, name: AssignName, default_value: NodeNG | None = None
4221+
) -> None:
42124222
self.name = name
4223+
self.default_value = default_value
42134224

42144225
def _infer(
42154226
self, context: InferenceContext | None = None, **kwargs: Any

astroid/rebuilder.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from astroid import nodes
2020
from astroid._ast import ParserModule, get_parser_module, parse_function_type_comment
21-
from astroid.const import PY312_PLUS, Context
21+
from astroid.const import PY312_PLUS, PY313_PLUS, Context
2222
from astroid.nodes.utils import Position
2323
from astroid.typing import InferenceResult
2424

@@ -1483,7 +1483,12 @@ def visit_paramspec(
14831483
)
14841484
# Add AssignName node for 'node.name'
14851485
# https://bugs.python.org/issue43994
1486-
newnode.postinit(name=self.visit_assignname(node, newnode, node.name))
1486+
newnode.postinit(
1487+
name=self.visit_assignname(node, newnode, node.name),
1488+
default_value=(
1489+
self.visit(node.default_value, newnode) if PY313_PLUS else None
1490+
),
1491+
)
14871492
return newnode
14881493

14891494
def visit_pass(self, node: ast.Pass, parent: nodes.NodeNG) -> nodes.Pass:
@@ -1679,6 +1684,9 @@ def visit_typevar(self, node: ast.TypeVar, parent: nodes.NodeNG) -> nodes.TypeVa
16791684
newnode.postinit(
16801685
name=self.visit_assignname(node, newnode, node.name),
16811686
bound=self.visit(node.bound, newnode),
1687+
default_value=(
1688+
self.visit(node.default_value, newnode) if PY313_PLUS else None
1689+
),
16821690
)
16831691
return newnode
16841692

@@ -1695,7 +1703,12 @@ def visit_typevartuple(
16951703
)
16961704
# Add AssignName node for 'node.name'
16971705
# https://bugs.python.org/issue43994
1698-
newnode.postinit(name=self.visit_assignname(node, newnode, node.name))
1706+
newnode.postinit(
1707+
name=self.visit_assignname(node, newnode, node.name),
1708+
default_value=(
1709+
self.visit(node.default_value, newnode) if PY313_PLUS else None
1710+
),
1711+
)
16991712
return newnode
17001713

17011714
def visit_unaryop(self, node: ast.UnaryOp, parent: nodes.NodeNG) -> nodes.UnaryOp:

tests/test_nodes.py

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
IS_PYPY,
3333
PY311_PLUS,
3434
PY312_PLUS,
35+
PY313_PLUS,
3536
PY314_PLUS,
3637
Context,
3738
)
@@ -332,27 +333,76 @@ def test_recursion_error_trapped() -> None:
332333
class AsStringTypeParamNodes(unittest.TestCase):
333334
@staticmethod
334335
def test_as_string_type_alias() -> None:
335-
ast = abuilder.string_build("type Point = tuple[float, float]")
336-
type_alias = ast.body[0]
337-
assert type_alias.as_string().strip() == "Point"
336+
ast1 = abuilder.string_build("type Point = tuple[float, float]")
337+
type_alias1 = ast1.body[0]
338+
assert type_alias1.as_string().strip() == "type Point = tuple[float, float]"
339+
ast2 = abuilder.string_build(
340+
"type Point[T, **P] = tuple[float, T, Callable[P, None]]"
341+
)
342+
type_alias2 = ast2.body[0]
343+
assert (
344+
type_alias2.as_string().strip()
345+
== "type Point[T, **P] = tuple[float, T, Callable[P, None]]"
346+
)
338347

339348
@staticmethod
340349
def test_as_string_type_var() -> None:
341-
ast = abuilder.string_build("type Point[T] = tuple[float, float]")
350+
ast = abuilder.string_build("type Point[T: int | str] = tuple[float, float]")
351+
type_var = ast.body[0].type_params[0]
352+
assert type_var.as_string().strip() == "T: int | str"
353+
354+
@staticmethod
355+
@pytest.mark.skipif(
356+
not PY313_PLUS, reason="Type parameter defaults were added in 313"
357+
)
358+
def test_as_string_type_var_default() -> None:
359+
ast = abuilder.string_build(
360+
"type Point[T: int | str = int] = tuple[float, float]"
361+
)
342362
type_var = ast.body[0].type_params[0]
343-
assert type_var.as_string().strip() == "T"
363+
assert type_var.as_string().strip() == "T: int | str = int"
344364

345365
@staticmethod
346366
def test_as_string_type_var_tuple() -> None:
347367
ast = abuilder.string_build("type Alias[*Ts] = tuple[*Ts]")
348368
type_var_tuple = ast.body[0].type_params[0]
349369
assert type_var_tuple.as_string().strip() == "*Ts"
350370

371+
@staticmethod
372+
@pytest.mark.skipif(
373+
not PY313_PLUS, reason="Type parameter defaults were added in 313"
374+
)
375+
def test_as_string_type_var_tuple_defaults() -> None:
376+
ast = abuilder.string_build("type Alias[*Ts = tuple[int, str]] = tuple[*Ts]")
377+
type_var_tuple = ast.body[0].type_params[0]
378+
assert type_var_tuple.as_string().strip() == "*Ts = tuple[int, str]"
379+
351380
@staticmethod
352381
def test_as_string_param_spec() -> None:
353382
ast = abuilder.string_build("type Alias[**P] = Callable[P, int]")
354383
param_spec = ast.body[0].type_params[0]
355-
assert param_spec.as_string().strip() == "P"
384+
assert param_spec.as_string().strip() == "**P"
385+
386+
@staticmethod
387+
@pytest.mark.skipif(
388+
not PY313_PLUS, reason="Type parameter defaults were added in 313"
389+
)
390+
def test_as_string_param_spec_defaults() -> None:
391+
ast = abuilder.string_build("type Alias[**P = [str, int]] = Callable[P, int]")
392+
param_spec = ast.body[0].type_params[0]
393+
assert param_spec.as_string().strip() == "**P = [str, int]"
394+
395+
@staticmethod
396+
def test_as_string_class_type_params() -> None:
397+
code = abuilder.string_build("class A[T, **P]: ...")
398+
cls_node = code.body[0]
399+
assert cls_node.as_string().strip() == "class A[T, **P]:\n ..."
400+
401+
@staticmethod
402+
def test_as_string_function_type_params() -> None:
403+
code = abuilder.string_build("def func[T, **P](): ...")
404+
func_node = code.body[0]
405+
assert func_node.as_string().strip() == "def func[T, **P]():\n ..."
356406

357407

358408
class _NodeTest(unittest.TestCase):

0 commit comments

Comments
 (0)