Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/lightning/fabric/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,19 @@ def sub_dir(self) -> Optional[str]:
@property
@rank_zero_experiment
def experiment(self) -> "SummaryWriter":
"""Actual tensorboard object. To use TensorBoard features anywhere in your code, do the following.
"""Returns the underlying TensorBoard summary writer object. Allows you to use TensorBoard logging features
directly in your :class:`~lightning.pytorch.core.LightningModule` or anywhere else in your code with:

`logger.experiment.some_tensorboard_function()`

Example::

logger.experiment.some_tensorboard_function()
class LitModel(LightningModule):
def training_step(self, batch, batch_idx):
# log a image
self.logger.experiment.add_image('my_image', batch['image'], self.global_step)
# log a histogram
self.logger.experiment.add_histogram('my_histogram', batch['data'], self.global_step)

"""
if self._experiment is not None:
Expand Down
33 changes: 31 additions & 2 deletions src/lightning/pytorch/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@

import os
from argparse import Namespace
from typing import Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union

from lightning_utilities.core.imports import RequirementCache
from torch import Tensor
from typing_extensions import override

import lightning.pytorch as pl
from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE
from lightning.fabric.loggers.tensorboard import TensorBoardLogger as FabricTensorBoardLogger
from lightning.fabric.utilities.cloud_io import _is_dir
from lightning.fabric.utilities.logger import _convert_params
Expand All @@ -35,6 +35,14 @@
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE
from lightning.pytorch.utilities.rank_zero import rank_zero_only, rank_zero_warn

_TENSORBOARD_AVAILABLE = RequirementCache("tensorboard")
if TYPE_CHECKING:
# assumes at least one will be installed when type checking
if _TENSORBOARD_AVAILABLE:
from torch.utils.tensorboard import SummaryWriter
else:
from tensorboardX import SummaryWriter # type: ignore[no-redef]


class TensorBoardLogger(Logger, FabricTensorBoardLogger):
r"""Log to local or remote file system in `TensorBoard <https://www.tensorflow.org/tensorboard>`_ format.
Expand Down Expand Up @@ -260,3 +268,24 @@ def _get_next_version(self) -> int:
return 0

return max(existing_versions) + 1

@property
@override
@rank_zero_only
def experiment(self) -> "SummaryWriter":
"""Returns the underlying TensorBoard summary writer object. Allows you to use TensorBoard logging features
directly in your :class:`~lightning.pytorch.core.LightningModule` or anywhere else in your code with:

`logger.experiment.some_tensorboard_function()`

Example::

class LitModel(LightningModule):
def training_step(self, batch, batch_idx):
# log a image
self.logger.experiment.add_image('my_image', batch['image'], self.global_step)
# log a histogram
self.logger.experiment.add_histogram('my_histogram', batch['data'], self.global_step)

"""
return super().experiment
Loading