diff --git a/applications/text_summarization/unimo-text/export_model.py b/applications/text_summarization/unimo-text/export_model.py index c9a79c465a14..4dc1a0b58fe5 100644 --- a/applications/text_summarization/unimo-text/export_model.py +++ b/applications/text_summarization/unimo-text/export_model.py @@ -11,16 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os import argparse +import os +from pprint import pprint import paddle -from pprint import pprint - -from paddlenlp.transformers import UNIMOLMHeadModel, UNIMOTokenizer from paddlenlp.ops import FasterUNIMOText - +from paddlenlp.transformers import UNIMOLMHeadModel, UNIMOTokenizer from paddlenlp.utils.log import logger @@ -82,13 +80,13 @@ def do_predict(args): unimo_text, input_spec=[ # input_ids - paddle.static.InputSpec(shape=[None, None], dtype="int32"), + paddle.static.InputSpec(shape=[None, None], dtype="int64"), # token_type_ids - paddle.static.InputSpec(shape=[None, None], dtype="int32"), + paddle.static.InputSpec(shape=[None, None], dtype="int64"), # attention_mask paddle.static.InputSpec(shape=[None, 1, None, None], dtype="float32"), # seq_len - paddle.static.InputSpec(shape=[None], dtype="int32"), + paddle.static.InputSpec(shape=[None], dtype="int64"), args.max_out_len, args.min_out_len, args.topk, diff --git a/examples/question_generation/unimo-text/export_model.py b/examples/question_generation/unimo-text/export_model.py index 44ae51080a0f..ea8d9c1a2e40 100644 --- a/examples/question_generation/unimo-text/export_model.py +++ b/examples/question_generation/unimo-text/export_model.py @@ -12,16 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import argparse +import os +from pprint import pprint import paddle -from pprint import pprint - -from paddlenlp.transformers import UNIMOLMHeadModel, UNIMOTokenizer from paddlenlp.ops import FasterUNIMOText - +from paddlenlp.transformers import UNIMOLMHeadModel, UNIMOTokenizer from paddlenlp.utils.log import logger @@ -70,19 +68,19 @@ def do_predict(args): paddle.static.InputSpec(shape=[None, None], dtype="int64"), # attention_mask paddle.static.InputSpec(shape=[None, 1, None, None], - dtype="float64"), + dtype="float32"), # seq_len paddle.static.InputSpec(shape=[None], dtype="int64"), args.max_dec_len, args.min_dec_len, args.topk, args.topp, - args.num_beams, # num_beams. Used for beam_search. + args.num_beams, # num_beams. Used for beam_search. args.decoding_strategy, tokenizer.cls_token_id, # cls/bos tokenizer.mask_token_id, # mask/eos tokenizer.pad_token_id, # pad - args.diversity_rate, # diversity rate. Used for beam search. + args.diversity_rate, # diversity rate. Used for beam search. args.temperature, args.num_return_sequences, args.length_penalty,