diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index bfd8706643ea..f11cabe2b2ef 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -111,6 +111,28 @@ def is_datasets_available(): if is_datasets_available(): import datasets + +@contextlib.contextmanager +def device_guard(device="cpu", dev_id=0): + origin_device = paddle.device.get_device() + if device == "cpu": + paddle.set_device(device) + elif device in ["gpu", "xpu", "npu"]: + paddle.set_device("{}:{}".format(device, dev_id)) + try: + yield + finally: + paddle.set_device(origin_device) + + +def paddlenlp_load(path, return_numpy=False): + if return_numpy: + with device_guard(): + return paddle.load(path) + else: + return paddle.load(path, return_numpy=return_numpy) + + __all__ = ["Trainer"] @@ -267,7 +289,11 @@ def __init__( self.amp_dtype = "float16" if args.fp16 else "bfloat16" # fix for load saved fp16 or bf16 ckpt, decorate model first. if self.args.fp16_opt_level == "O2": - paddle.amp.decorate(models=model, level=self.args.fp16_opt_level, dtype=self.amp_dtype) + if self.amp_dtype == "bfloat16": + # fix for paddlepaddle < 2.4.1, not support for bf16 + paddle.amp.decorate(models=model, level=self.args.fp16_opt_level, dtype=self.amp_dtype) + else: + paddle.amp.decorate(models=model, level=self.args.fp16_opt_level) if self.sharding is not None: self.scaler = paddle.amp.GradScaler(init_loss_scaling=self.args.scale_loss) @@ -1130,9 +1156,16 @@ def _wrap_model(self, model, training=True): # Mixed precision training if training and self.do_grad_scaling: # self.args.fp16_opt_level=="O2": # model, self.optimizer - decorated = paddle.amp.decorate( - models=model, optimizers=self.optimizer, level=self.args.fp16_opt_level, dtype=self.amp_dtype - ) + if self.amp_dtype == "bfloat16": + # fix for paddlepaddle < 2.4.1, not support for bf16 + decorated = paddle.amp.decorate( + models=model, optimizers=self.optimizer, level=self.args.fp16_opt_level, dtype=self.amp_dtype + ) + else: + decorated = paddle.amp.decorate( + models=model, optimizers=self.optimizer, level=self.args.fp16_opt_level + ) + if self.optimizer is None: model = decorated else: @@ -1459,7 +1492,7 @@ def _load_optimizer_and_scheduler(self, checkpoint): # Load in optimizer and scheduler states if self.sharding is not None: self.optimizer.set_state_dict( - paddle.load( + paddlenlp_load( os.path.join(checkpoint, OPTIMIZER_NAME + f"_shard{self.sharding_group.rank}"), return_numpy=True, ) @@ -1467,7 +1500,10 @@ def _load_optimizer_and_scheduler(self, checkpoint): empty_dict = paddle.load(os.path.join(checkpoint, OPTIMIZER_NAME), return_numpy=True) assert len(empty_dict) == 0, "Optimizer file of sharding, should be empty!" else: - self.optimizer.set_state_dict(paddle.load(os.path.join(checkpoint, OPTIMIZER_NAME), return_numpy=True)) + self.optimizer.set_state_dict( + paddlenlp_load(os.path.join(checkpoint, OPTIMIZER_NAME), return_numpy=True) + ) + self.lr_scheduler.set_state_dict(paddle.load(os.path.join(checkpoint, SCHEDULER_NAME))) if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)): self.scaler.load_state_dict(paddle.load(os.path.join(checkpoint, SCALER_NAME), return_numpy=True))