Skip to content

Commit 3dad856

Browse files
shihaoyinBordaawaelchlipre-commit-ci[bot]
authored andcommitted
Fix TensorBoardLogger.log_graph not recording the graph (#17926)
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: awaelchli <aedu.waelchli@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit c31ef77)
1 parent ba6e19d commit 3dad856

File tree

3 files changed

+21
-3
lines changed

3 files changed

+21
-3
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1717

1818

1919

20+
- Fixed `TensorBoardLogger.log_graph` not unwrapping the `_FabricModule` ([#17844](https://github.com/Lightning-AI/lightning/pull/17844))
21+
22+
2023
## [2.0.5] - 2023-07-07
2124

2225
### Added

src/lightning/fabric/loggers/tensorboard.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from lightning.fabric.utilities.logger import _sanitize_params as _utils_sanitize_params
2929
from lightning.fabric.utilities.rank_zero import rank_zero_only, rank_zero_warn
3030
from lightning.fabric.utilities.types import _PATH
31+
from lightning.fabric.wrappers import _unwrap_objects
3132

3233
log = logging.getLogger(__name__)
3334

@@ -247,6 +248,7 @@ def log_hyperparams( # type: ignore[override]
247248
def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None:
248249
model_example_input = getattr(model, "example_input_array", None)
249250
input_array = model_example_input if input_array is None else input_array
251+
model = _unwrap_objects(model)
250252

251253
if input_array is None:
252254
rank_zero_warn(
@@ -263,8 +265,10 @@ def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None
263265
getattr(model, "_apply_batch_transfer_handler", None)
264266
):
265267
# this is probably is a LightningModule
266-
input_array = model._on_before_batch_transfer(input_array) # type: ignore[operator]
267-
input_array = model._apply_batch_transfer_handler(input_array) # type: ignore[operator]
268+
input_array = model._on_before_batch_transfer(input_array)
269+
input_array = model._apply_batch_transfer_handler(input_array)
270+
self.experiment.add_graph(model, input_array)
271+
else:
268272
self.experiment.add_graph(model, input_array)
269273

270274
@rank_zero_only

tests/tests_fabric/loggers/test_tensorboard.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from lightning.fabric.loggers import TensorBoardLogger
2525
from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE
26+
from lightning.fabric.wrappers import _FabricModule
2627
from tests_fabric.test_fabric import BoringModel
2728

2829

@@ -153,8 +154,18 @@ def test_tensorboard_log_graph(tmpdir, example_input_array):
153154
if example_input_array is not None:
154155
model.example_input_array = None
155156

156-
logger = TensorBoardLogger(tmpdir, log_graph=True)
157+
logger = TensorBoardLogger(tmpdir)
158+
logger._experiment = Mock()
157159
logger.log_graph(model, example_input_array)
160+
if example_input_array is not None:
161+
logger.experiment.add_graph.assert_called_with(model, example_input_array)
162+
logger._experiment.reset_mock()
163+
164+
# model wrapped in `FabricModule`
165+
wrapped = _FabricModule(model, precision=Mock())
166+
logger.log_graph(wrapped, example_input_array)
167+
if example_input_array is not None:
168+
logger.experiment.add_graph.assert_called_with(model, example_input_array)
158169

159170

160171
@pytest.mark.skipif(not _TENSORBOARD_AVAILABLE, reason=str(_TENSORBOARD_AVAILABLE))

0 commit comments

Comments
 (0)