Skip to content

Commit 9a00c15

Browse files
authored
[Fluid] Fix fluid sync params buffers. (#4878)
1 parent 75f3130 commit 9a00c15

File tree

2 files changed

+35
-21
lines changed

2 files changed

+35
-21
lines changed

examples/language_model/gpt-3/dygraph/run_pretrain.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,36 +12,40 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import argparse
16-
import math
1715
import os
1816
import random
19-
import time
2017
import sys
18+
import time
2119

2220
import numpy as np
2321
import paddle
22+
from paddle.distributed import fleet
23+
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer import (
24+
DygraphShardingOptimizer,
25+
)
26+
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
27+
from paddle.distributed.fleet.utils.hybrid_parallel_util import (
28+
fused_allreduce_gradients,
29+
)
30+
from paddle.distributed.sharding import group_sharded_parallel
2431
from visualdl import LogWriter
25-
from modeling import GPTModel, GPTForPretraining, GPTPretrainingCriterion, GPTForPretrainingPipe
26-
from paddlenlp.transformers import GPTTokenizer, GPTChineseTokenizer
27-
from paddlenlp.utils.log import logger
32+
33+
from paddlenlp.transformers import GPTChineseTokenizer, GPTTokenizer
2834
from paddlenlp.utils import profiler
35+
from paddlenlp.utils.log import logger
2936

3037
# to import data_tools
3138
filepath = os.path.abspath(os.path.dirname(__file__))
3239
sys.path.insert(0, os.path.join(filepath, "../"))
33-
34-
from dataset import create_pretrained_dataset
35-
from args import parse_args
36-
import lr
37-
from paddle.distributed import fleet
38-
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
39-
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer import DygraphShardingOptimizer
40-
from paddle.fluid.dygraph.parallel import sync_params_buffers
41-
from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients
42-
43-
# add sharding stage2/3
44-
from paddle.distributed.sharding import group_sharded_parallel
40+
import lr # noqa e402
41+
from args import parse_args # noqa e402
42+
from dataset import create_pretrained_dataset # noqa e402
43+
from modeling import ( # noqa e402
44+
GPTForPretraining,
45+
GPTForPretrainingPipe,
46+
GPTModel,
47+
GPTPretrainingCriterion,
48+
)
4549

4650
MODEL_CLASSES = {
4751
"gpt": (GPTForPretraining, GPTTokenizer),
@@ -268,6 +272,11 @@ def do_train(args):
268272
# TODO(Baibaifan): combine ShardingStage1/2/3 and fleet.distributed_model in feature
269273
if args.sharding_stage in [2, 3]:
270274
if args.dp_degree > 1:
275+
try:
276+
from paddle.fluid.dygraph.parallel import sync_params_buffers
277+
except ImportError:
278+
from paddle.distributed.parallel import sync_params_buffers
279+
271280
sync_params_buffers(model, comm_group=dp_group, src_rank=dp_group.ranks[0])
272281

273282
scaler = scaler if args.use_pure_fp16 else None
@@ -287,7 +296,7 @@ def do_train(args):
287296
logger.warning("No optimizer checkpoint file found in %s." % opt_path)
288297

289298
global_step = 0
290-
tic_train = time.time()
299+
# tic_train = time.time()
291300
for epoch in range(args.num_train_epochs):
292301
files = get_train_data_file(args)
293302
files.sort()
@@ -414,7 +423,7 @@ def do_train(args):
414423
log_writer.add_scalar("loss", float(loss), global_step)
415424
log_writer.add_scalar("learning_rate", optimizer.get_lr(), global_step)
416425

417-
tic_train = time.time()
426+
# tic_train = time.time()
418427
train_reader_cost = 0.0
419428
train_run_cost = 0.0
420429

paddlenlp/trainer/trainer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
from paddle.distributed.fleet.utils.hybrid_parallel_util import (
4242
fused_allreduce_gradients,
4343
)
44-
from paddle.fluid.dygraph.parallel import sync_params_buffers
4544
from paddle.io import DataLoader, Dataset, DistributedBatchSampler
4645
from tqdm.auto import tqdm
4746

@@ -1201,6 +1200,12 @@ def _wrap_model(self, model, training=True):
12011200
else:
12021201
# sync params (broadcast) buffers in dp group
12031202
if self.args.dp_degree > 1:
1203+
try:
1204+
from paddle.fluid.dygraph.parallel import sync_params_buffers
1205+
except ImportError:
1206+
# fix for new api in paddlepaddle v2.5
1207+
from paddle.distributed.parallel import sync_params_buffers
1208+
12041209
hcg = fleet.get_hybrid_communicate_group()
12051210
dp_group = hcg.get_data_parallel_group()
12061211
sync_params_buffers(model, comm_group=dp_group, src_rank=dp_group.ranks[0])

0 commit comments

Comments
 (0)