12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- import argparse
16
- import math
17
15
import os
18
16
import random
19
- import time
20
17
import sys
18
+ import time
21
19
22
20
import numpy as np
23
21
import paddle
22
+ from paddle .distributed import fleet
23
+ from paddle .distributed .fleet .meta_optimizers .dygraph_optimizer import (
24
+ DygraphShardingOptimizer ,
25
+ )
26
+ from paddle .distributed .fleet .meta_parallel import get_rng_state_tracker
27
+ from paddle .distributed .fleet .utils .hybrid_parallel_util import (
28
+ fused_allreduce_gradients ,
29
+ )
30
+ from paddle .distributed .sharding import group_sharded_parallel
24
31
from visualdl import LogWriter
25
- from modeling import GPTModel , GPTForPretraining , GPTPretrainingCriterion , GPTForPretrainingPipe
26
- from paddlenlp .transformers import GPTTokenizer , GPTChineseTokenizer
27
- from paddlenlp .utils .log import logger
32
+
33
+ from paddlenlp .transformers import GPTChineseTokenizer , GPTTokenizer
28
34
from paddlenlp .utils import profiler
35
+ from paddlenlp .utils .log import logger
29
36
30
37
# to import data_tools
31
38
filepath = os .path .abspath (os .path .dirname (__file__ ))
32
39
sys .path .insert (0 , os .path .join (filepath , "../" ))
33
-
34
- from dataset import create_pretrained_dataset
35
- from args import parse_args
36
- import lr
37
- from paddle .distributed import fleet
38
- from paddle .distributed .fleet .meta_parallel import get_rng_state_tracker
39
- from paddle .distributed .fleet .meta_optimizers .dygraph_optimizer import DygraphShardingOptimizer
40
- from paddle .fluid .dygraph .parallel import sync_params_buffers
41
- from paddle .distributed .fleet .utils .hybrid_parallel_util import fused_allreduce_gradients
42
-
43
- # add sharding stage2/3
44
- from paddle .distributed .sharding import group_sharded_parallel
40
+ import lr # noqa e402
41
+ from args import parse_args # noqa e402
42
+ from dataset import create_pretrained_dataset # noqa e402
43
+ from modeling import ( # noqa e402
44
+ GPTForPretraining ,
45
+ GPTForPretrainingPipe ,
46
+ GPTModel ,
47
+ GPTPretrainingCriterion ,
48
+ )
45
49
46
50
MODEL_CLASSES = {
47
51
"gpt" : (GPTForPretraining , GPTTokenizer ),
@@ -268,6 +272,11 @@ def do_train(args):
268
272
# TODO(Baibaifan): combine ShardingStage1/2/3 and fleet.distributed_model in feature
269
273
if args .sharding_stage in [2 , 3 ]:
270
274
if args .dp_degree > 1 :
275
+ try :
276
+ from paddle .fluid .dygraph .parallel import sync_params_buffers
277
+ except ImportError :
278
+ from paddle .distributed .parallel import sync_params_buffers
279
+
271
280
sync_params_buffers (model , comm_group = dp_group , src_rank = dp_group .ranks [0 ])
272
281
273
282
scaler = scaler if args .use_pure_fp16 else None
@@ -287,7 +296,7 @@ def do_train(args):
287
296
logger .warning ("No optimizer checkpoint file found in %s." % opt_path )
288
297
289
298
global_step = 0
290
- tic_train = time .time ()
299
+ # tic_train = time.time()
291
300
for epoch in range (args .num_train_epochs ):
292
301
files = get_train_data_file (args )
293
302
files .sort ()
@@ -414,7 +423,7 @@ def do_train(args):
414
423
log_writer .add_scalar ("loss" , float (loss ), global_step )
415
424
log_writer .add_scalar ("learning_rate" , optimizer .get_lr (), global_step )
416
425
417
- tic_train = time .time ()
426
+ # tic_train = time.time()
418
427
train_reader_cost = 0.0
419
428
train_run_cost = 0.0
420
429
0 commit comments