Skip to content

Commit c3d6545

Browse files
linjiecccsijunhe
andauthored
Support set_argument method for UIE (#4163)
* Support set_argument method for UIE * set schema to None by default * revert * update Co-authored-by: Sijun He <sijun.he@hotmail.com>
1 parent f653728 commit c3d6545

File tree

2 files changed

+43
-13
lines changed

2 files changed

+43
-13
lines changed

paddlenlp/taskflow/information_extraction.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from ..transformers import UIE, UIEM, UIEX, AutoModel, AutoTokenizer
2828
from ..utils.doc_parser import DocParser
2929
from ..utils.ie_utils import map_offset, pad_image_data
30+
from ..utils.log import logger
3031
from ..utils.tools import get_bool_ids_greater_than, get_span
3132
from .task import Task
3233
from .utils import DataCollatorGP, SchemaTree, dbc2sbc, get_id_and_prob, gp_decode
@@ -376,7 +377,7 @@ class UIETask(Task):
376377
},
377378
}
378379

379-
def __init__(self, task, model, schema, **kwargs):
380+
def __init__(self, task, model, schema=None, **kwargs):
380381
super().__init__(task=task, model=model, **kwargs)
381382

382383
self._max_seq_len = kwargs.get("max_seq_len", 512)
@@ -385,7 +386,7 @@ def __init__(self, task, model, schema, **kwargs):
385386
self._position_prob = kwargs.get("position_prob", 0.5)
386387
self._lazy_load = kwargs.get("lazy_load", False)
387388
self._num_workers = kwargs.get("num_workers", 0)
388-
self.use_fast = kwargs.get("use_fast", False)
389+
self._use_fast = kwargs.get("use_fast", False)
389390
self._layout_analysis = kwargs.get("layout_analysis", False)
390391
self._ocr_lang = kwargs.get("ocr_lang", "ch")
391392
self._schema_lang = kwargs.get("schema_lang", "ch")
@@ -415,14 +416,31 @@ def __init__(self, task, model, schema, **kwargs):
415416
else:
416417
self._summary_token_num = 3 # [CLS] prompt [SEP] text [SEP]
417418

418-
self._doc_parser = None
419-
self._schema_tree = None
420-
self.set_schema(schema)
419+
self._parser_map = {
420+
"ch": None, # OCR-CH
421+
"en": None, # OCR-EN
422+
"ch-layout": None, # Layout-CH
423+
"en-layout": None, # Layout-EN
424+
}
425+
if not schema:
426+
logger.warning(
427+
"The schema has not been set yet, please set a schema via set_schema(). "
428+
"More details about the setting of schema please refer to https://github.com/PaddlePaddle/PaddleNLP/blob/develop/applications/information_extraction/taskflow_text.md"
429+
)
430+
self._schema_tree = None
431+
else:
432+
self.set_schema(schema)
421433
self._check_predictor_type()
422434
self._get_inference_model()
423435
self._usage = usage
424436
self._construct_tokenizer()
425437

438+
def set_argument(self, argument: dict):
439+
for k, v in argument.items():
440+
if k == "input":
441+
continue
442+
setattr(self, f"_{k}", v)
443+
426444
def set_schema(self, schema):
427445
if isinstance(schema, dict) or isinstance(schema, str):
428446
schema = [schema]
@@ -467,7 +485,7 @@ def _construct_tokenizer(self):
467485
Construct the tokenizer for the predictor.
468486
"""
469487
self._tokenizer = AutoTokenizer.from_pretrained(
470-
self._task_path, use_fast=self.use_fast, from_hf_hub=self.from_hf_hub
488+
self._task_path, use_fast=self._use_fast, from_hf_hub=self.from_hf_hub
471489
)
472490

473491
def _preprocess(self, inputs):
@@ -485,6 +503,7 @@ def _check_input_text(self, inputs):
485503
"""
486504
Check whether the input meet the requirement.
487505
"""
506+
self._ocr_lang_choice = (self._ocr_lang + "-layout") if self._layout_analysis else self._ocr_lang
488507
inputs = inputs[0]
489508
if isinstance(inputs, dict) or isinstance(inputs, str):
490509
inputs = [inputs]
@@ -494,17 +513,17 @@ def _check_input_text(self, inputs):
494513
data = {}
495514
if isinstance(example, dict):
496515
if "doc" in example.keys():
497-
if not self._doc_parser:
498-
self._doc_parser = DocParser(
516+
if not self._parser_map[self._ocr_lang_choice]:
517+
self._parser_map[self._ocr_lang_choice] = DocParser(
499518
ocr_lang=self._ocr_lang, layout_analysis=self._layout_analysis
500519
)
501520
if "layout" in example.keys():
502-
data = self._doc_parser.parse(
521+
data = self._parser_map[self._ocr_lang_choice].parse(
503522
{"doc": example["doc"]}, do_ocr=False, expand_to_a4_size=self._expand_to_a4_size
504523
)
505524
data["layout"] = example["layout"]
506525
else:
507-
data = self._doc_parser.parse(
526+
data = self._parser_map[self._ocr_lang_choice].parse(
508527
{"doc": example["doc"]}, expand_to_a4_size=self._expand_to_a4_size
509528
)
510529
elif "text" in example.keys():
@@ -931,7 +950,7 @@ def _parse_inputs(self, inputs):
931950
org_box[2] + offset_x,
932951
org_box[3] + offset_y,
933952
]
934-
box = self._doc_parser._normalize_box(box, [img_w, img_h], [1000, 1000])
953+
box = self._parser_map[self._ocr_lang_choice]._normalize_box(box, [img_w, img_h], [1000, 1000])
935954
text += segment[1]
936955
bbox.extend([box] * len(segment[1]))
937956
_inputs.append({"text": text, "bbox": bbox, "image": d["image"], "layout": d["layout"]})

paddlenlp/taskflow/taskflow.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,17 @@
510510
"openai/disco-diffusion-clip-rn50",
511511
"openai/disco-diffusion-clip-rn101",
512512
"disco_diffusion_ernie_vil-2.0-base-zh",
513+
"uie-base",
514+
"uie-medium",
515+
"uie-mini",
516+
"uie-micro",
517+
"uie-nano",
518+
"uie-tiny",
519+
"uie-medical-base",
520+
"uie-base-en",
521+
"uie-m-large",
522+
"uie-m-base",
523+
"uie-x-base",
513524
]
514525

515526

@@ -617,11 +628,11 @@ def interactive_mode(self, max_turn):
617628
def set_schema(self, schema):
618629
assert (
619630
self.task_instance.model in support_schema_list
620-
), "This method can only be used by the task with the model of uie or wordtag."
631+
), "This method can only be used by the task based on the model of uie or wordtag."
621632
self.task_instance.set_schema(schema)
622633

623634
def set_argument(self, argument):
624635
assert (
625636
self.task_instance.model in support_argument_list
626-
), "This method can only be used by the task with the model of text_to_image generation."
637+
), "This method can only be used by the task of text-to-image generation or information extraction."
627638
self.task_instance.set_argument(argument)

0 commit comments

Comments
 (0)