Skip to content

[Trainer] Fix fp16 for paddle #4283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 29, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 42 additions & 6 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,28 @@ def is_datasets_available():
if is_datasets_available():
import datasets


@contextlib.contextmanager
def device_guard(device="cpu", dev_id=0):
origin_device = paddle.device.get_device()
if device == "cpu":
paddle.set_device(device)
elif device in ["gpu", "xpu", "npu"]:
paddle.set_device("{}:{}".format(device, dev_id))
try:
yield
finally:
paddle.set_device(origin_device)


def paddlenlp_load(path, return_numpy=False):
if return_numpy:
with device_guard():
return paddle.load(path)
else:
return paddle.load(path, return_numpy=return_numpy)


__all__ = ["Trainer"]


Expand Down Expand Up @@ -267,7 +289,11 @@ def __init__(
self.amp_dtype = "float16" if args.fp16 else "bfloat16"
# fix for load saved fp16 or bf16 ckpt, decorate model first.
if self.args.fp16_opt_level == "O2":
paddle.amp.decorate(models=model, level=self.args.fp16_opt_level, dtype=self.amp_dtype)
if self.amp_dtype == "bfloat16":
# fix for paddlepaddle < 2.4.1, not support for bf16
paddle.amp.decorate(models=model, level=self.args.fp16_opt_level, dtype=self.amp_dtype)
else:
paddle.amp.decorate(models=model, level=self.args.fp16_opt_level)

if self.sharding is not None:
self.scaler = paddle.amp.GradScaler(init_loss_scaling=self.args.scale_loss)
Expand Down Expand Up @@ -1130,9 +1156,16 @@ def _wrap_model(self, model, training=True):
# Mixed precision training
if training and self.do_grad_scaling: # self.args.fp16_opt_level=="O2":
# model, self.optimizer
decorated = paddle.amp.decorate(
models=model, optimizers=self.optimizer, level=self.args.fp16_opt_level, dtype=self.amp_dtype
)
if self.amp_dtype == "bfloat16":
# fix for paddlepaddle < 2.4.1, not support for bf16
decorated = paddle.amp.decorate(
models=model, optimizers=self.optimizer, level=self.args.fp16_opt_level, dtype=self.amp_dtype
)
else:
decorated = paddle.amp.decorate(
models=model, optimizers=self.optimizer, level=self.args.fp16_opt_level
)

if self.optimizer is None:
model = decorated
else:
Expand Down Expand Up @@ -1459,15 +1492,18 @@ def _load_optimizer_and_scheduler(self, checkpoint):
# Load in optimizer and scheduler states
if self.sharding is not None:
self.optimizer.set_state_dict(
paddle.load(
paddlenlp_load(
os.path.join(checkpoint, OPTIMIZER_NAME + f"_shard{self.sharding_group.rank}"),
return_numpy=True,
)
)
empty_dict = paddle.load(os.path.join(checkpoint, OPTIMIZER_NAME), return_numpy=True)
assert len(empty_dict) == 0, "Optimizer file of sharding, should be empty!"
else:
self.optimizer.set_state_dict(paddle.load(os.path.join(checkpoint, OPTIMIZER_NAME), return_numpy=True))
self.optimizer.set_state_dict(
paddlenlp_load(os.path.join(checkpoint, OPTIMIZER_NAME), return_numpy=True)
)

self.lr_scheduler.set_state_dict(paddle.load(os.path.join(checkpoint, SCHEDULER_NAME)))
if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)):
self.scaler.load_state_dict(paddle.load(os.path.join(checkpoint, SCALER_NAME), return_numpy=True))
Expand Down