Skip to content

[FasterGeneration] MBart supports dy2sta #3356

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions faster_generation/samples/mbart_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@
import paddle
from paddlenlp.transformers import MBartForConditionalGeneration, MBartTokenizer

model_name = "mbart-large-50-one-to-many-mmt"
model_name = "mbart-large-50-many-to-many-mmt"

tokenizer = MBartTokenizer.from_pretrained(model_name)
model = MBartForConditionalGeneration.from_pretrained(model_name,
src_lang="en_XX")
tokenizer = MBartTokenizer.from_pretrained(model_name, src_lang="en_XX")
model = MBartForConditionalGeneration.from_pretrained(model_name)
model.eval()


Expand All @@ -41,7 +40,7 @@ def postprocess_response(seq, bos_idx, eos_idx):

inputs = "PaddleNLP is a powerful NLP library with Awesome pre-trained models and easy-to-use interface, supporting wide-range of NLP tasks from research to industrial applications."
input_ids = tokenizer(inputs)["input_ids"]
input_ids = paddle.to_tensor(input_ids, dtype='int64').unsqueeze(0)
input_ids = paddle.to_tensor(input_ids, dtype='int32').unsqueeze(0)

outputs, _ = model.generate(input_ids=input_ids,
forced_bos_token_id=bos_id,
Expand All @@ -53,5 +52,6 @@ def postprocess_response(seq, bos_idx, eos_idx):
result = postprocess_response(outputs[0].numpy().tolist(), bos_id, eos_id)

print("Model input:", inputs)

print("Result:", result)
# PaddleNLP是一个强大的NLP库,具有超乎寻常的预训练模型和易于使用的接口,支持从研究到工业应用的广泛的NLP任务。
147 changes: 147 additions & 0 deletions paddlenlp/ops/faster_transformer/sample/mbart_export_model_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import argparse
import paddle
from pprint import pprint
from paddlenlp.transformers import MBartForConditionalGeneration, MBartTokenizer
from paddlenlp.ops import FasterMBART
from paddlenlp.utils.log import logger


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_name_or_path",
default="mbart-large-50-many-to-many-mmt",
type=str,
help="The model name to specify the bart to use. ")
parser.add_argument("--inference_model_dir",
default="./infer_model/",
type=str,
help="Path to save inference model of bart. ")
parser.add_argument(
"--topk",
default=4,
type=int,
help="The number of candidate to procedure top_k sampling. ")
parser.add_argument(
"--topp",
default=1.0,
type=float,
help="The probability threshold to procedure top_p sampling. ")
parser.add_argument("--max_out_len",
default=64,
type=int,
help="Maximum output length. ")
parser.add_argument("--temperature",
default=1.0,
type=float,
help="The temperature to set. ")
parser.add_argument("--num_return_sequences",
default=1,
type=int,
help="The number of returned sequences. ")
parser.add_argument("--use_fp16_decoding",
action="store_true",
help="Whether to use fp16 decoding to predict. ")
parser.add_argument("--decoding_strategy",
default="beam_search",
choices=["sampling", "beam_search"],
type=str,
help="The main strategy to decode. ")
parser.add_argument(
"--num_beams",
default=5,
type=int,
help="The number of candidate to procedure beam search. ")
parser.add_argument("--diversity_rate",
default=0.0,
type=float,
help="The diversity rate to procedure beam search. ")
parser.add_argument("--repetition_penalty",
default=1.0,
type=float,
help="The repetition_penalty to set. ")
parser.add_argument("--length_penalty",
default=0.0,
type=float,
help="The length penalty to decode. ")
parser.add_argument("--early_stopping",
action="store_true",
help="Whether to do early stopping. ")

args = parser.parse_args()
return args


def do_predict(args):
place = "gpu"
place = paddle.set_device(place)

model = MBartForConditionalGeneration.from_pretrained(
args.model_name_or_path, src_lang="en_XX")
tokenizer = MBartTokenizer.from_pretrained(args.model_name_or_path)

bos_id = tokenizer.lang_code_to_id["zh_CN"]
eos_id = model.mbart.config["eos_token_id"]

# For opening faster_encoder
model.eval()

faster_mbart = FasterMBART(model=model,
use_fp16_decoding=args.use_fp16_decoding)
# Set evaluate mode
faster_mbart.eval()

# Convert dygraph model to static graph model
faster_mbart = paddle.jit.to_static(
faster_mbart,
input_spec=[
# input_ids
paddle.static.InputSpec(shape=[None, None], dtype="int32"),
# encoder_output
None,
# seq_len
None,
bos_id, # forced_bos_token_id
args.num_beams, # num_beams.
args.topk, # top_k
args.topp, # top_p
args.decoding_strategy, # decode_strategy
tokenizer.bos_token_id, # bos_token_id
tokenizer.eos_token_id, # eos_token_id
tokenizer.pad_token_id, # pad_token_id
model.mbart.
config["decoder_start_token_id"], # decoder_start_token_id
args.max_out_len, # max_length
args.diversity_rate, # diversity_rate
args.length_penalty, # length_penalty
args.temperature, # temperature
args.num_return_sequences, # num_return_sequences
args.early_stopping, # early_stopping
tokenizer.eos_token_id, #forced_eos_token_id
])

# Save converted static graph model
paddle.jit.save(faster_mbart, os.path.join(args.inference_model_dir,
"mbart"))
logger.info("MBART has been saved to {}.".format(args.inference_model_dir))


if __name__ == "__main__":
args = parse_args()
pprint(args)

do_predict(args)
97 changes: 97 additions & 0 deletions paddlenlp/ops/faster_transformer/sample/mbart_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
import numpy as np
from pprint import pprint

import paddle
import paddle.inference as paddle_infer

from paddlenlp.transformers import MBartTokenizer
from paddlenlp.ops.ext_utils import load


def setup_args():
"""Setup arguments."""
parser = argparse.ArgumentParser()
parser.add_argument("--inference_model_dir",
default="./infer_model/",
type=str,
help="Path to save inference model of BART. ")

args = parser.parse_args()

return args


def postprocess_response(tokenizer, seq, bos_idx, eos_idx):
"""Post-process the decoded sequence."""
eos_pos = len(seq) - 1
for i, idx in enumerate(seq):
if idx == eos_idx:
eos_pos = i
break
seq = [
idx for idx in seq[:eos_pos + 1] if idx != bos_idx and idx != eos_idx
]
res = tokenizer.convert_ids_to_string(seq)
return res


def infer(args):
model_name = "mbart-large-50-many-to-many-mmt"
tokenizer = MBartTokenizer.from_pretrained(model_name)

bos_id = tokenizer.lang_code_to_id["zh_CN"]
eos_id = tokenizer.eos_token_id

inputs = "PaddleNLP is a powerful NLP library with Awesome pre-trained models and easy-to-use interface, supporting wide-range of NLP tasks from research to industrial applications."
input_ids = tokenizer(inputs)["input_ids"]
input_ids = np.asarray(input_ids, dtype="int32").reshape(1, -1)

# Load FasterTransformer lib.
load("FasterTransformer", verbose=True)

config = paddle_infer.Config(
os.path.join(args.inference_model_dir, "mbart.pdmodel"),
os.path.join(args.inference_model_dir, "mbart.pdiparams"))

config.enable_use_gpu(100, 0)
config.disable_glog_info()
predictor = paddle_infer.create_predictor(config)

input_names = predictor.get_input_names()
input_handle = predictor.get_input_handle(input_names[0])
input_handle.copy_from_cpu(input_ids.astype("int32"))

predictor.run()

output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])
output_data = output_handle.copy_to_cpu()

result = postprocess_response(
tokenizer,
output_data.transpose([1, 2, 0]).tolist()[0][0], bos_id, eos_id)
print("Model input:", inputs)
print("Result:", result)


if __name__ == "__main__":
args = setup_args()
pprint(args)

infer(args)
3 changes: 2 additions & 1 deletion paddlenlp/ops/faster_transformer/transformer/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2515,7 +2515,8 @@ def __init__(self,
self.pos_emb = [model.decoder.decoder_embed_positions.weight]
self.word_emb = [model.decoder.embed_tokens.weight]

self.linear_weight = [model.lm_head_weight.t()]
setattr(self, "lm_head_weight_", model.lm_head_weight.t())
self.linear_weight = [getattr(self, "lm_head_weight_")]
self.linear_bias = [model.final_logits_bias]

def forward(self,
Expand Down
14 changes: 11 additions & 3 deletions paddlenlp/ops/faster_transformer/transformer/faster_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1379,8 +1379,13 @@ def forward(self,


class FasterMBART(MBartPretrainedModel):
enable_faster_encoder_func = enable_faster_encoder

def __init__(self, model, decoding_lib=None, use_fp16_decoding=False):
def __init__(self,
model,
decoding_lib=None,
use_fp16_decoding=False,
enable_faster_encoder=False):
super(FasterMBART, self).__init__()
self.use_fp16_decoding = use_fp16_decoding
self._model = model
Expand All @@ -1393,13 +1398,18 @@ def __init__(self, model, decoding_lib=None, use_fp16_decoding=False):
self.encoder = model.mbart.get_encoder()
self.decoder = model.mbart.get_decoder()
self.pad_token_id = model.mbart.config['pad_token_id']
self.enable_faster_encoder = enable_faster_encoder

self.decoding = InferMBartDecoding(
model=self._model,
decoding_lib=decoding_lib,
use_fp16_decoding=use_fp16_decoding,
hidden_act=model.mbart.config['activation_function'])

if self.enable_faster_encoder:
# Must use `enable_faster_encoder` in `__init__` when dygraph to static graph.
self.encoder = FasterMBART.enable_faster_encoder_func(self.encoder)

def get_encoder(self):
return self.encoder

Expand Down Expand Up @@ -1439,11 +1449,9 @@ def forward(self,

#(gongenlei) Not enable_faster_encoder temporarily
if encoder_output is None:
self.encoder = enable_faster_encoder(self.encoder)
assert input_ids is not None, "You have to specify either input_ids or encoder_output."
encoder_output = self.prepare_encoder_decoder_kwargs_for_generation(
input_ids, model_kwargs)["encoder_output"]
self.encoder = disable_faster_encoder(self.encoder)
batch_size = paddle.shape(encoder_output)[0]
if seq_len is None:
assert input_ids is not None, "You have to specify either input_ids when generating seq_len."
Expand Down
4 changes: 2 additions & 2 deletions paddlenlp/transformers/mbart/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def forward(self, input_ids_shape, past_key_values_length=0):
positions = paddle.arange(past_key_values_length,
past_key_values_length + seq_len,
dtype="int64")
return super().forward(positions + self.offset)
return Embedding.forward(self, positions + self.offset)


class MBartEncoder(MBartPretrainedModel):
Expand Down Expand Up @@ -270,7 +270,7 @@ def forward(self, input_ids=None, attention_mask=None, **kwargs):
if input_ids is None:
raise ValueError("Input_ids cannot be None.")
inputs_embeds = self.d_model**0.5 * self.embed_tokens(input_ids)
inputs_embed_pos = self.encoder_embed_positions(input_ids.shape)
inputs_embed_pos = self.encoder_embed_positions(paddle.shape(input_ids))
hidden_states = inputs_embeds + inputs_embed_pos
hidden_states = self.encoder_layernorm_embedding(hidden_states)
encoder_input = self.encoder_dropout(hidden_states)
Expand Down