Skip to content
22 changes: 21 additions & 1 deletion src/lightning/fabric/loggers/csv_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ class CSVLogger(Logger):
overwritten.
prefix: A string to put at the beginning of metric keys.
flush_logs_every_n_steps: How often to flush logs to disk (defaults to every 100 steps).
sub_dir: Sub-directory to group CSV logs. If a ``sub_dir`` argument is passed
then logs are saved in ``/root_dir/name/version/sub_dir/``. Defaults to ``None`` in which case
logs are saved in ``/root_dir/name/version/``.

Example::

Expand All @@ -65,6 +68,7 @@ def __init__(
version: Optional[Union[int, str]] = None,
prefix: str = "",
flush_logs_every_n_steps: int = 100,
sub_dir: Optional[_PATH] = None,
):
super().__init__()
root_dir = os.fspath(root_dir)
Expand All @@ -75,6 +79,7 @@ def __init__(
self._fs = get_filesystem(root_dir)
self._experiment: Optional[_ExperimentWriter] = None
self._flush_logs_every_n_steps = flush_logs_every_n_steps
self._sub_dir = None if sub_dir is None else os.fspath(sub_dir)

@property
@override
Expand Down Expand Up @@ -117,7 +122,22 @@ def log_dir(self) -> str:
"""
# create a pseudo standard path
version = self.version if isinstance(self.version, str) else f"version_{self.version}"
return os.path.join(self._root_dir, self.name, version)
log_dir = os.path.join(self.root_dir, self.name, version)
if isinstance(self.sub_dir, str):
log_dir = os.path.join(log_dir, self.sub_dir)
log_dir = os.path.expandvars(log_dir)
log_dir = os.path.expanduser(log_dir)
return log_dir

@property
def sub_dir(self) -> Optional[str]:
"""Gets the sub directory where the CSV experiments are saved.

Returns:
The local path to the sub directory where the CSV experiments are saved.

"""
return self._sub_dir

@property
@rank_zero_experiment
Expand Down
30 changes: 30 additions & 0 deletions tests/tests_fabric/loggers/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,36 @@ def test_no_name(tmp_path, name):
assert os.listdir(tmp_path / "version_0")


def test_csv_log_sub_dir(tmp_path):
# no sub_dir specified
root_dir = tmp_path / "logs"
logger = CSVLogger(root_dir, name="name", version="version")
assert logger.log_dir == os.path.join(root_dir, "name", "version")

# sub_dir specified
logger = CSVLogger(root_dir, name="name", version="version", sub_dir="sub_dir")
assert logger.log_dir == os.path.join(root_dir, "name", "version", "sub_dir")


def test_csv_expand_home():
"""Test that the home dir (`~`) gets expanded properly."""
root_dir = "~/tmp"
explicit_root_dir = os.path.expanduser(root_dir)
logger = CSVLogger(root_dir, name="name", version="version", sub_dir="sub_dir")
assert logger.root_dir == root_dir
assert logger.log_dir == os.path.join(explicit_root_dir, "name", "version", "sub_dir")


@mock.patch.dict(os.environ, {"TEST_ENV_DIR": "some_directory"})
def test_tensorboard_expand_env_vars():
"""Test that the env vars in path names (`$`) get handled properly."""
test_env_dir = os.environ["TEST_ENV_DIR"]
root_dir = "$TEST_ENV_DIR/tmp"
explicit_root_dir = f"{test_env_dir}/tmp"
logger = CSVLogger(root_dir, name="name", version="version", sub_dir="sub_dir")
assert logger.log_dir == os.path.join(explicit_root_dir, "name", "version", "sub_dir")


@pytest.mark.parametrize("step_idx", [10, None])
def test_log_metrics(tmp_path, step_idx):
logger = CSVLogger(tmp_path)
Expand Down
Loading