Skip to content

Commit c31ef77

Browse files
shihaoyinBordaawaelchlipre-commit-ci[bot]
authored
Fix TensorBoardLogger.log_graph not recording the graph (#17926)
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: awaelchli <aedu.waelchli@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent a9269ae commit c31ef77

File tree

3 files changed

+21
-3
lines changed

3 files changed

+21
-3
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
150150
- Fixed FSDP re-applying activation checkpointing when the user had manually applied it already ([#18006](https://github.com/Lightning-AI/lightning/pull/18006))
151151

152152

153+
- Fixed `TensorBoardLogger.log_graph` not unwrapping the `_FabricModule` ([#17844](https://github.com/Lightning-AI/lightning/pull/17844))
154+
155+
153156
## [2.0.5] - 2023-07-07
154157

155158
### Added

src/lightning/fabric/loggers/tensorboard.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from lightning.fabric.utilities.logger import _sanitize_params as _utils_sanitize_params
2828
from lightning.fabric.utilities.rank_zero import rank_zero_only, rank_zero_warn
2929
from lightning.fabric.utilities.types import _PATH
30+
from lightning.fabric.wrappers import _unwrap_objects
3031

3132
log = logging.getLogger(__name__)
3233

@@ -246,6 +247,7 @@ def log_hyperparams( # type: ignore[override]
246247
def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None:
247248
model_example_input = getattr(model, "example_input_array", None)
248249
input_array = model_example_input if input_array is None else input_array
250+
model = _unwrap_objects(model)
249251

250252
if input_array is None:
251253
rank_zero_warn(
@@ -262,8 +264,10 @@ def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None
262264
getattr(model, "_apply_batch_transfer_handler", None)
263265
):
264266
# this is probably is a LightningModule
265-
input_array = model._on_before_batch_transfer(input_array) # type: ignore[operator]
266-
input_array = model._apply_batch_transfer_handler(input_array) # type: ignore[operator]
267+
input_array = model._on_before_batch_transfer(input_array)
268+
input_array = model._apply_batch_transfer_handler(input_array)
269+
self.experiment.add_graph(model, input_array)
270+
else:
267271
self.experiment.add_graph(model, input_array)
268272

269273
@rank_zero_only

tests/tests_fabric/loggers/test_tensorboard.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from lightning.fabric.loggers import TensorBoardLogger
2525
from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE
26+
from lightning.fabric.wrappers import _FabricModule
2627
from tests_fabric.test_fabric import BoringModel
2728

2829

@@ -153,8 +154,18 @@ def test_tensorboard_log_graph(tmpdir, example_input_array):
153154
if example_input_array is not None:
154155
model.example_input_array = None
155156

156-
logger = TensorBoardLogger(tmpdir, log_graph=True)
157+
logger = TensorBoardLogger(tmpdir)
158+
logger._experiment = Mock()
157159
logger.log_graph(model, example_input_array)
160+
if example_input_array is not None:
161+
logger.experiment.add_graph.assert_called_with(model, example_input_array)
162+
logger._experiment.reset_mock()
163+
164+
# model wrapped in `FabricModule`
165+
wrapped = _FabricModule(model, precision=Mock())
166+
logger.log_graph(wrapped, example_input_array)
167+
if example_input_array is not None:
168+
logger.experiment.add_graph.assert_called_with(model, example_input_array)
158169

159170

160171
@pytest.mark.skipif(not _TENSORBOARD_AVAILABLE, reason=str(_TENSORBOARD_AVAILABLE))

0 commit comments

Comments
 (0)