Skip to content

Commit f8d36ed

Browse files
authored
fix bug unimo fp16 infer error (#4166)
* fix bug * update fix * update unimo infer with summarization
1 parent c3d6545 commit f8d36ed

File tree

2 files changed

+12
-16
lines changed

2 files changed

+12
-16
lines changed

applications/text_summarization/unimo-text/export_model.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,14 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import os
1514
import argparse
15+
import os
16+
from pprint import pprint
1617

1718
import paddle
1819

19-
from pprint import pprint
20-
21-
from paddlenlp.transformers import UNIMOLMHeadModel, UNIMOTokenizer
2220
from paddlenlp.ops import FasterUNIMOText
23-
21+
from paddlenlp.transformers import UNIMOLMHeadModel, UNIMOTokenizer
2422
from paddlenlp.utils.log import logger
2523

2624

@@ -82,13 +80,13 @@ def do_predict(args):
8280
unimo_text,
8381
input_spec=[
8482
# input_ids
85-
paddle.static.InputSpec(shape=[None, None], dtype="int32"),
83+
paddle.static.InputSpec(shape=[None, None], dtype="int64"),
8684
# token_type_ids
87-
paddle.static.InputSpec(shape=[None, None], dtype="int32"),
85+
paddle.static.InputSpec(shape=[None, None], dtype="int64"),
8886
# attention_mask
8987
paddle.static.InputSpec(shape=[None, 1, None, None], dtype="float32"),
9088
# seq_len
91-
paddle.static.InputSpec(shape=[None], dtype="int32"),
89+
paddle.static.InputSpec(shape=[None], dtype="int64"),
9290
args.max_out_len,
9391
args.min_out_len,
9492
args.topk,

examples/question_generation/unimo-text/export_model.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import os
1615
import argparse
16+
import os
17+
from pprint import pprint
1718

1819
import paddle
1920

20-
from pprint import pprint
21-
22-
from paddlenlp.transformers import UNIMOLMHeadModel, UNIMOTokenizer
2321
from paddlenlp.ops import FasterUNIMOText
24-
22+
from paddlenlp.transformers import UNIMOLMHeadModel, UNIMOTokenizer
2523
from paddlenlp.utils.log import logger
2624

2725

@@ -70,19 +68,19 @@ def do_predict(args):
7068
paddle.static.InputSpec(shape=[None, None], dtype="int64"),
7169
# attention_mask
7270
paddle.static.InputSpec(shape=[None, 1, None, None],
73-
dtype="float64"),
71+
dtype="float32"),
7472
# seq_len
7573
paddle.static.InputSpec(shape=[None], dtype="int64"),
7674
args.max_dec_len,
7775
args.min_dec_len,
7876
args.topk,
7977
args.topp,
80-
args.num_beams, # num_beams. Used for beam_search.
78+
args.num_beams, # num_beams. Used for beam_search.
8179
args.decoding_strategy,
8280
tokenizer.cls_token_id, # cls/bos
8381
tokenizer.mask_token_id, # mask/eos
8482
tokenizer.pad_token_id, # pad
85-
args.diversity_rate, # diversity rate. Used for beam search.
83+
args.diversity_rate, # diversity rate. Used for beam search.
8684
args.temperature,
8785
args.num_return_sequences,
8886
args.length_penalty,

0 commit comments

Comments
 (0)