From e15cc3d839463fc34b7480eb39543614714154e7 Mon Sep 17 00:00:00 2001 From: wuhuachaocoding Date: Tue, 6 Sep 2022 11:49:23 +0000 Subject: [PATCH] update recompute interface. --- examples/language_model/gpt-3/dygraph/modeling.py | 8 +++++++- examples/language_model/moe/dygraph/modeling.py | 8 ++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/examples/language_model/gpt-3/dygraph/modeling.py b/examples/language_model/gpt-3/dygraph/modeling.py index 112b0b63b058..28fadcefe812 100644 --- a/examples/language_model/gpt-3/dygraph/modeling.py +++ b/examples/language_model/gpt-3/dygraph/modeling.py @@ -1178,4 +1178,10 @@ def _logits_helper(embedding, output): loss_fn=GPTPretrainingCriterionPipe(), topology=topology, seg_method="layer:TransformerDecoderLayer", - recompute_interval=1 if use_recompute else 0) + recompute_interval=1 if use_recompute else 0, + recompute_ctx={ + "mp_group": + fleet.fleet._hcg.get_model_parallel_group(), + "offload": False, + "partition": False + }) diff --git a/examples/language_model/moe/dygraph/modeling.py b/examples/language_model/moe/dygraph/modeling.py index 66d9743e328c..64c1f220ca1d 100644 --- a/examples/language_model/moe/dygraph/modeling.py +++ b/examples/language_model/moe/dygraph/modeling.py @@ -1165,5 +1165,9 @@ def _logits_helper(embedding, output): topology=topology, seg_method="layer:TransformerDecoderLayer", recompute_interval=recompute_interval, - recompute_partition=False, - recompute_offload=False) + recompute_ctx={ + "mp_group": + fleet.fleet._hcg.get_model_parallel_group(), + "offload": False, + "partition": False + })