Skip to content

Commit e619b6d

Browse files
authored
exp run: fix failure when dependencies are from inside submodules (#10831)
1 parent 396e71f commit e619b6d

File tree

6 files changed

+62
-9
lines changed

6 files changed

+62
-9
lines changed

dvc/repo/experiments/executor/base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,9 @@ def save(
298298
stages = dvc.commit([], recursive=recursive, force=True, relink=False)
299299
exp_hash = cls.hash_exp(stages)
300300
if include_untracked:
301-
dvc.scm.add(include_untracked, force=True) # type: ignore[call-arg]
301+
from dvc.scm import add_no_submodules
302+
303+
add_no_submodules(dvc.scm, include_untracked, force=True) # type: ignore[call-arg]
302304

303305
with cls.auto_push(dvc):
304306
cls.commit(
@@ -513,7 +515,9 @@ def reproduce(
513515
stages = dvc.reproduce(*args, **kwargs)
514516
if paths := cls._get_top_level_paths(dvc):
515517
logger.debug("Staging top-level files: %s", paths)
516-
dvc.scm_context.add(paths)
518+
from dvc.scm import add_no_submodules
519+
520+
add_no_submodules(dvc.scm, paths)
517521

518522
exp_hash = cls.hash_exp(stages)
519523
if not repro_dry:

dvc/repo/scm_context.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,10 @@ def _make_git_add_cmd(paths: Union[str, Iterable[str]]) -> str:
4141
def add(self, paths: Union[str, Iterable[str]]) -> None:
4242
from scmrepo.exceptions import UnsupportedIndexFormat
4343

44+
from dvc.scm import add_no_submodules
45+
4446
try:
45-
return self.scm.add(paths)
47+
add_no_submodules(self.scm, paths)
4648
except UnsupportedIndexFormat:
4749
link = "https://github.com/iterative/dvc/issues/610"
4850
add_cmd = self._make_git_add_cmd([relpath(path) for path in paths])

dvc/scm.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,27 @@
11
"""Manages source control systems (e.g. Git)."""
22

33
import os
4-
from collections.abc import Iterator, Mapping
4+
from collections.abc import Iterable, Iterator, Mapping
55
from contextlib import contextmanager
66
from functools import partial
77
from typing import TYPE_CHECKING, Literal, Optional, Union, overload
88

99
from funcy import group_by
10-
from scmrepo.base import Base # noqa: F401
10+
from scmrepo.base import Base # noqa: TC002
1111
from scmrepo.git import Git
1212
from scmrepo.noscm import NoSCM
1313

1414
from dvc.exceptions import DvcException
15+
from dvc.log import logger
1516
from dvc.progress import Tqdm
1617

1718
if TYPE_CHECKING:
1819
from scmrepo.progress import GitProgressEvent
1920

2021
from dvc.fs import FileSystem
2122

23+
logger = logger.getChild(__name__)
24+
2225

2326
class SCMError(DvcException):
2427
"""Base class for source control management errors."""
@@ -283,3 +286,32 @@ def lfs_prefetch(fs: "FileSystem", paths: list[str]):
283286
include=[(path if path.startswith("/") else f"/{path}") for path in paths],
284287
progress=pbar.update_git,
285288
)
289+
290+
291+
def add_no_submodules(
292+
scm: "Base",
293+
paths: Union[str, Iterable[str]],
294+
**kwargs,
295+
) -> None:
296+
"""Stage paths to Git, excluding those inside submodules."""
297+
298+
if isinstance(paths, str):
299+
paths = [paths]
300+
301+
submodule_roots = {os.path.join(scm.root_dir, sub) for sub in scm.list_submodules()}
302+
303+
repo_paths: list[str] = []
304+
skipped_paths: list[str] = []
305+
306+
for p in paths:
307+
abs_path = os.path.abspath(p)
308+
if abs_path in submodule_roots or abs_path.startswith(tuple(submodule_roots)):
309+
skipped_paths.append(p)
310+
else:
311+
repo_paths.append(p)
312+
313+
if skipped_paths:
314+
msg = "Skipping staging for path(s) inside submodules: %s"
315+
logger.debug(msg, ", ".join(skipped_paths))
316+
317+
scm.add(repo_paths, **kwargs)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ dependencies = [
6767
"requests>=2.22",
6868
"rich>=12",
6969
"ruamel.yaml>=0.17.11",
70-
"scmrepo>=3.3.8,<4",
70+
"scmrepo>=3.5.2,<4",
7171
"shortuuid>=0.5",
7272
"shtab<2,>=1.3.4",
7373
"tabulate>=0.8.7",

tests/func/experiments/test_experiments.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -793,3 +793,18 @@ def test_custom_commit_message(tmp_dir, scm, dvc, tmp):
793793
)
794794
)
795795
assert scm.resolve_commit(exp).message == "custom commit message"
796+
797+
798+
@pytest.mark.parametrize("dep", ["submodule", "submodule/file"])
799+
def test_experiments_run_with_submodule_dependencies(dvc, scm, make_tmp_dir, dep):
800+
external_repo = make_tmp_dir("external_repo", scm=True)
801+
external_repo.scm_gen("file", "content", commit="add file")
802+
803+
submodules = scm.pygit2.repo.submodules
804+
submodules.add(os.fspath(external_repo), "submodule")
805+
submodules.update(init=True)
806+
scm.add_commit([".gitmodules"], message="add submodule")
807+
808+
dvc.stage.add(cmd="echo foo", deps=[dep], name="foo")
809+
810+
assert dvc.experiments.run()

tests/unit/repo/test_scm_context.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def test_scm_track_changed_files(scm_context):
3939

4040
scm_context.track_file("foo")
4141
scm_context.track_changed_files()
42-
scm_context.scm.add.assert_called_once_with({"foo"})
42+
scm_context.scm.add.assert_called_once_with(["foo"])
4343

4444

4545
def test_ignore(scm_context):
@@ -73,7 +73,7 @@ def test_scm_context_autostage_changed_files(scm_context):
7373

7474
assert not scm_context.files_to_track
7575
assert not scm_context.ignored_paths
76-
scm_context.scm.add.assert_called_once_with({"foo"})
76+
scm_context.scm.add.assert_called_once_with(["foo"])
7777

7878

7979
def test_scm_context_clears_ignores_on_error(scm_context):
@@ -141,4 +141,4 @@ def test_method(repo, *args, **kwargs):
141141
method = mocker.MagicMock(wraps=test_method)
142142
decorator(method, autostage=True)(repo, "arg", kw=1)
143143
method.assert_called_once_with(repo, "arg", kw=1)
144-
scm_context.scm.add.assert_called_once_with({"foo"})
144+
scm_context.scm.add.assert_called_once_with(["foo"])

0 commit comments

Comments
 (0)