Skip to content

Commit 8bf99c7

Browse files
committed
fix: validate default answers
1 parent 0d5bf8d commit 8bf99c7

File tree

3 files changed

+98
-98
lines changed

3 files changed

+98
-98
lines changed

copier/_main.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -555,11 +555,9 @@ def _ask(self) -> None: # noqa: C901
555555
if var_name in self.answers.last:
556556
try:
557557
answer = question.parse_answer(self.answers.last[var_name])
558+
question.validate_answer(answer)
558559
except Exception:
559560
del self.answers.last[var_name]
560-
else:
561-
if question.validate_answer(answer):
562-
del self.answers.last[var_name]
563561
# Skip a question when the skip condition is met.
564562
if not question.get_when():
565563
# Omit its answer from the answers file.
@@ -573,14 +571,10 @@ def _ask(self) -> None: # noqa: C901
573571
if question.default is MISSING:
574572
continue
575573
if var_name in self.answers.init:
576-
# Try to parse the answer value.
574+
# Try to parse and validate (if the question has a validator)
575+
# the answer value.
577576
answer = question.parse_answer(self.answers.init[var_name])
578-
# Try to validate the answer value if the question has a
579-
# validator.
580-
if err_msg := question.validate_answer(answer):
581-
raise ValueError(
582-
f"Validation error for question '{var_name}': {err_msg}"
583-
)
577+
question.validate_answer(answer)
584578
# At this point, the answer value is valid. Do not ask the
585579
# question again, but set answer as the user's answer instead.
586580
self.answers.user[var_name] = answer

copier/_user_data.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import warnings
77
from collections import ChainMap
88
from collections.abc import Mapping, Sequence
9+
from copy import deepcopy
910
from dataclasses import field
1011
from datetime import datetime
1112
from functools import cached_property
@@ -275,7 +276,13 @@ def get_default(self) -> Any:
275276
result = self.render_value(
276277
self.settings.defaults.get(self.var_name, self.default)
277278
)
278-
result = self.cast_answer(result)
279+
result = self.parse_answer(result)
280+
# Computed values (i.e., `when: false`) are intentionally not validated
281+
# at the moment.
282+
# https://github.com/copier-org/copier/issues/1779#issuecomment-2365006990
283+
# https://github.com/copier-org/copier/pull/1785
284+
if self.get_when():
285+
self.validate_answer(result)
279286
return result
280287

281288
def get_default_rendered(self) -> bool | str | Choice | None | MissingType:
@@ -319,7 +326,6 @@ def _formatted_choices(self) -> Sequence[Choice]:
319326
"""Obtain choices rendered and properly formatted."""
320327
result = []
321328
choices = self.choices
322-
default = self.get_default()
323329
if isinstance(choices, str):
324330
choices = parse_yaml_string(self.render_value(self.choices))
325331
if isinstance(choices, dict):
@@ -344,15 +350,7 @@ def _formatted_choices(self) -> Sequence[Choice]:
344350

345351
disabled = self.render_value(value.get("validator", ""))
346352
value = value["value"]
347-
# The value can be templated
348-
value = self.render_value(value)
349-
checked = (
350-
self.multiselect
351-
and isinstance(default, list)
352-
and self.cast_answer(value) in default
353-
or None
354-
)
355-
c = Choice(name, value, disabled=disabled, checked=checked)
353+
c = Choice(name, self.render_value(value), disabled=disabled)
356354
# Try to cast the value according to the question's type to raise
357355
# an error in case the value is incompatible.
358356
self.cast_answer(c.value)
@@ -382,7 +380,11 @@ def _validate(answer: str) -> str | Literal[True]:
382380
ans = self.parse_answer(answer)
383381
except Exception:
384382
return "Invalid input"
385-
return self.validate_answer(ans) or True
383+
try:
384+
self.validate_answer(ans)
385+
except Exception as exc:
386+
return str(exc)
387+
return True
386388

387389
lexer = None
388390
result: AnyByStrDict = {
@@ -405,7 +407,14 @@ def _validate(answer: str) -> str | Literal[True]:
405407
result["default"] = False
406408
if self.choices:
407409
questionary_type = "checkbox" if self.multiselect else "select"
408-
result["choices"] = self._formatted_choices
410+
choices = self._formatted_choices
411+
# Select default choices for a multiselect question.
412+
if self.multiselect and isinstance(
413+
default_choices := self.get_default(), list
414+
):
415+
for choice in (choices := deepcopy(choices)):
416+
choice.checked = self.cast_answer(choice.value) in default_choices
417+
result["choices"] = choices
409418
if questionary_type == "input":
410419
if self.secret:
411420
questionary_type = "password"
@@ -436,15 +445,16 @@ def get_multiline(self) -> bool:
436445
"""Get the value for multiline."""
437446
return cast_to_bool(self.render_value(self.multiline))
438447

439-
def validate_answer(self, answer: Any) -> str:
448+
def validate_answer(self, answer: Any) -> None:
440449
"""Validate user answer."""
441450
try:
442451
err_msg = self.render_value(self.validator, {self.var_name: answer}).strip()
443452
except Exception as error:
444-
return str(error)
453+
err_msg = str(error)
445454
if err_msg:
446-
return err_msg
447-
return ""
455+
raise ValueError(
456+
f"Validation error for question '{self.var_name}': {err_msg}"
457+
)
448458

449459
def get_when(self) -> bool:
450460
"""Get skip condition for question."""

tests/test_copy.py

Lines changed: 67 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -852,93 +852,89 @@ def test_required_choice_question_without_data(
852852

853853

854854
@pytest.mark.parametrize(
855-
"type_name, default, expected",
855+
"spec, expected",
856856
[
857-
("str", "string", does_not_raise()),
858-
("str", "1.0", does_not_raise()),
859-
("str", 1.0, does_not_raise()),
860-
("str", None, pytest.raises(TypeError)),
857+
({"type": "str", "default": "string"}, does_not_raise()),
858+
({"type": "str", "default": "1.0"}, does_not_raise()),
859+
({"type": "str", "default": 1.0}, does_not_raise()),
860+
({"type": "str", "default": None}, pytest.raises(TypeError)),
861861
(
862862
{
863863
"type": "str",
864+
"default": "",
865+
"validator": "[% if q|length < 3 %]too short[% endif %]",
866+
},
867+
pytest.raises(ValueError),
868+
),
869+
(
870+
{
871+
"type": "str",
872+
"default": "",
864873
"secret": True,
865874
"validator": "[% if q|length < 3 %]too short[% endif %]",
866875
},
867-
"",
868876
pytest.raises(ValueError),
869877
),
870-
("int", 1, does_not_raise()),
871-
("int", 1.0, does_not_raise()),
872-
("int", "1", does_not_raise()),
873-
("int", "1.0", pytest.raises(ValueError)),
874-
("int", "no-int", pytest.raises(ValueError)),
875-
("int", None, pytest.raises(TypeError)),
876-
("int", {}, pytest.raises(TypeError)),
877-
("int", [], pytest.raises(TypeError)),
878-
("float", 1.1, does_not_raise()),
879-
("float", 1, does_not_raise()),
880-
("float", "1.1", does_not_raise()),
881-
("float", "no-float", pytest.raises(ValueError)),
882-
("float", None, pytest.raises(TypeError)),
883-
("float", {}, pytest.raises(TypeError)),
884-
("float", [], pytest.raises(TypeError)),
885-
("bool", True, does_not_raise()),
886-
("bool", False, does_not_raise()),
887-
("bool", "y", does_not_raise()),
888-
("bool", "n", does_not_raise()),
889-
("bool", None, pytest.raises(TypeError)),
890-
("json", '"string"', does_not_raise()),
891-
("json", "1", does_not_raise()),
892-
("json", 1, does_not_raise()),
893-
("json", "1.1", does_not_raise()),
894-
("json", 1.1, does_not_raise()),
895-
("json", "true", does_not_raise()),
896-
("json", True, does_not_raise()),
897-
("json", "false", does_not_raise()),
898-
("json", False, does_not_raise()),
899-
("json", "{}", does_not_raise()),
900-
("json", {}, does_not_raise()),
901-
("json", "[]", does_not_raise()),
902-
("json", [], does_not_raise()),
903-
("json", "null", does_not_raise()),
904-
("json", None, does_not_raise()),
905-
("yaml", '"string"', does_not_raise()),
906-
("yaml", "string", does_not_raise()),
907-
("yaml", "1", does_not_raise()),
908-
("yaml", 1, does_not_raise()),
909-
("yaml", "1.1", does_not_raise()),
910-
("yaml", 1.1, does_not_raise()),
911-
("yaml", "true", does_not_raise()),
912-
("yaml", True, does_not_raise()),
913-
("yaml", "false", does_not_raise()),
914-
("yaml", False, does_not_raise()),
915-
("yaml", "{}", does_not_raise()),
916-
("yaml", {}, does_not_raise()),
917-
("yaml", "[]", does_not_raise()),
918-
("yaml", [], does_not_raise()),
919-
("yaml", "null", does_not_raise()),
920-
("yaml", None, does_not_raise()),
878+
({"type": "int", "default": 1}, does_not_raise()),
879+
({"type": "int", "default": 1.0}, does_not_raise()),
880+
({"type": "int", "default": "1"}, does_not_raise()),
881+
({"type": "int", "default": "1.0"}, pytest.raises(ValueError)),
882+
({"type": "int", "default": "no-int"}, pytest.raises(ValueError)),
883+
({"type": "int", "default": None}, pytest.raises(TypeError)),
884+
({"type": "int", "default": {}}, pytest.raises(TypeError)),
885+
({"type": "int", "default": []}, pytest.raises(TypeError)),
886+
({"type": "float", "default": 1.1}, does_not_raise()),
887+
({"type": "float", "default": 1}, does_not_raise()),
888+
({"type": "float", "default": "1.1"}, does_not_raise()),
889+
({"type": "float", "default": "no-float"}, pytest.raises(ValueError)),
890+
({"type": "float", "default": None}, pytest.raises(TypeError)),
891+
({"type": "float", "default": {}}, pytest.raises(TypeError)),
892+
({"type": "float", "default": []}, pytest.raises(TypeError)),
893+
({"type": "bool", "default": True}, does_not_raise()),
894+
({"type": "bool", "default": False}, does_not_raise()),
895+
({"type": "bool", "default": "y"}, does_not_raise()),
896+
({"type": "bool", "default": "n"}, does_not_raise()),
897+
({"type": "bool", "default": None}, pytest.raises(TypeError)),
898+
({"type": "json", "default": '"string"'}, does_not_raise()),
899+
({"type": "json", "default": "1"}, does_not_raise()),
900+
({"type": "json", "default": 1}, does_not_raise()),
901+
({"type": "json", "default": "1.1"}, does_not_raise()),
902+
({"type": "json", "default": 1.1}, does_not_raise()),
903+
({"type": "json", "default": "true"}, does_not_raise()),
904+
({"type": "json", "default": True}, does_not_raise()),
905+
({"type": "json", "default": "false"}, does_not_raise()),
906+
({"type": "json", "default": False}, does_not_raise()),
907+
({"type": "json", "default": "{}"}, does_not_raise()),
908+
({"type": "json", "default": {}}, does_not_raise()),
909+
({"type": "json", "default": "[]"}, does_not_raise()),
910+
({"type": "json", "default": []}, does_not_raise()),
911+
({"type": "json", "default": "null"}, does_not_raise()),
912+
({"type": "json", "default": None}, does_not_raise()),
913+
({"type": "yaml", "default": '"string"'}, does_not_raise()),
914+
({"type": "yaml", "default": "string"}, does_not_raise()),
915+
({"type": "yaml", "default": "1"}, does_not_raise()),
916+
({"type": "yaml", "default": 1}, does_not_raise()),
917+
({"type": "yaml", "default": "1.1"}, does_not_raise()),
918+
({"type": "yaml", "default": 1.1}, does_not_raise()),
919+
({"type": "yaml", "default": "true"}, does_not_raise()),
920+
({"type": "yaml", "default": True}, does_not_raise()),
921+
({"type": "yaml", "default": "false"}, does_not_raise()),
922+
({"type": "yaml", "default": False}, does_not_raise()),
923+
({"type": "yaml", "default": "{}"}, does_not_raise()),
924+
({"type": "yaml", "default": {}}, does_not_raise()),
925+
({"type": "yaml", "default": "[]"}, does_not_raise()),
926+
({"type": "yaml", "default": []}, does_not_raise()),
927+
({"type": "yaml", "default": "null"}, does_not_raise()),
928+
({"type": "yaml", "default": None}, does_not_raise()),
921929
],
922930
)
923931
def test_validate_default_value(
924932
tmp_path_factory: pytest.TempPathFactory,
925-
type_name: str,
926-
default: Any,
933+
spec: AnyByStrDict,
927934
expected: AbstractContextManager[None],
928935
) -> None:
929936
src, dst = map(tmp_path_factory.mktemp, ("src", "dst"))
930-
build_file_tree(
931-
{
932-
(src / "copier.yml"): yaml.dump(
933-
{
934-
"q": {
935-
"type": type_name,
936-
"default": default,
937-
}
938-
}
939-
)
940-
}
941-
)
937+
build_file_tree({(src / "copier.yml"): yaml.dump({"q": spec})})
942938
with expected:
943939
copier.run_copy(str(src), dst, defaults=True)
944940

0 commit comments

Comments
 (0)