27
27
from ..transformers import UIE , UIEM , UIEX , AutoModel , AutoTokenizer
28
28
from ..utils .doc_parser import DocParser
29
29
from ..utils .ie_utils import map_offset , pad_image_data
30
+ from ..utils .log import logger
30
31
from ..utils .tools import get_bool_ids_greater_than , get_span
31
32
from .task import Task
32
33
from .utils import DataCollatorGP , SchemaTree , dbc2sbc , get_id_and_prob , gp_decode
@@ -376,7 +377,7 @@ class UIETask(Task):
376
377
},
377
378
}
378
379
379
- def __init__ (self , task , model , schema , ** kwargs ):
380
+ def __init__ (self , task , model , schema = None , ** kwargs ):
380
381
super ().__init__ (task = task , model = model , ** kwargs )
381
382
382
383
self ._max_seq_len = kwargs .get ("max_seq_len" , 512 )
@@ -385,7 +386,7 @@ def __init__(self, task, model, schema, **kwargs):
385
386
self ._position_prob = kwargs .get ("position_prob" , 0.5 )
386
387
self ._lazy_load = kwargs .get ("lazy_load" , False )
387
388
self ._num_workers = kwargs .get ("num_workers" , 0 )
388
- self .use_fast = kwargs .get ("use_fast" , False )
389
+ self ._use_fast = kwargs .get ("use_fast" , False )
389
390
self ._layout_analysis = kwargs .get ("layout_analysis" , False )
390
391
self ._ocr_lang = kwargs .get ("ocr_lang" , "ch" )
391
392
self ._schema_lang = kwargs .get ("schema_lang" , "ch" )
@@ -415,14 +416,31 @@ def __init__(self, task, model, schema, **kwargs):
415
416
else :
416
417
self ._summary_token_num = 3 # [CLS] prompt [SEP] text [SEP]
417
418
418
- self ._doc_parser = None
419
- self ._schema_tree = None
420
- self .set_schema (schema )
419
+ self ._parser_map = {
420
+ "ch" : None , # OCR-CH
421
+ "en" : None , # OCR-EN
422
+ "ch-layout" : None , # Layout-CH
423
+ "en-layout" : None , # Layout-EN
424
+ }
425
+ if not schema :
426
+ logger .warning (
427
+ "The schema has not been set yet, please set a schema via set_schema(). "
428
+ "More details about the setting of schema please refer to https://github.com/PaddlePaddle/PaddleNLP/blob/develop/applications/information_extraction/taskflow_text.md"
429
+ )
430
+ self ._schema_tree = None
431
+ else :
432
+ self .set_schema (schema )
421
433
self ._check_predictor_type ()
422
434
self ._get_inference_model ()
423
435
self ._usage = usage
424
436
self ._construct_tokenizer ()
425
437
438
+ def set_argument (self , argument : dict ):
439
+ for k , v in argument .items ():
440
+ if k == "input" :
441
+ continue
442
+ setattr (self , f"_{ k } " , v )
443
+
426
444
def set_schema (self , schema ):
427
445
if isinstance (schema , dict ) or isinstance (schema , str ):
428
446
schema = [schema ]
@@ -467,7 +485,7 @@ def _construct_tokenizer(self):
467
485
Construct the tokenizer for the predictor.
468
486
"""
469
487
self ._tokenizer = AutoTokenizer .from_pretrained (
470
- self ._task_path , use_fast = self .use_fast , from_hf_hub = self .from_hf_hub
488
+ self ._task_path , use_fast = self ._use_fast , from_hf_hub = self .from_hf_hub
471
489
)
472
490
473
491
def _preprocess (self , inputs ):
@@ -485,6 +503,7 @@ def _check_input_text(self, inputs):
485
503
"""
486
504
Check whether the input meet the requirement.
487
505
"""
506
+ self ._ocr_lang_choice = (self ._ocr_lang + "-layout" ) if self ._layout_analysis else self ._ocr_lang
488
507
inputs = inputs [0 ]
489
508
if isinstance (inputs , dict ) or isinstance (inputs , str ):
490
509
inputs = [inputs ]
@@ -494,17 +513,17 @@ def _check_input_text(self, inputs):
494
513
data = {}
495
514
if isinstance (example , dict ):
496
515
if "doc" in example .keys ():
497
- if not self ._doc_parser :
498
- self ._doc_parser = DocParser (
516
+ if not self ._parser_map [ self . _ocr_lang_choice ] :
517
+ self ._parser_map [ self . _ocr_lang_choice ] = DocParser (
499
518
ocr_lang = self ._ocr_lang , layout_analysis = self ._layout_analysis
500
519
)
501
520
if "layout" in example .keys ():
502
- data = self ._doc_parser .parse (
521
+ data = self ._parser_map [ self . _ocr_lang_choice ] .parse (
503
522
{"doc" : example ["doc" ]}, do_ocr = False , expand_to_a4_size = self ._expand_to_a4_size
504
523
)
505
524
data ["layout" ] = example ["layout" ]
506
525
else :
507
- data = self ._doc_parser .parse (
526
+ data = self ._parser_map [ self . _ocr_lang_choice ] .parse (
508
527
{"doc" : example ["doc" ]}, expand_to_a4_size = self ._expand_to_a4_size
509
528
)
510
529
elif "text" in example .keys ():
@@ -931,7 +950,7 @@ def _parse_inputs(self, inputs):
931
950
org_box [2 ] + offset_x ,
932
951
org_box [3 ] + offset_y ,
933
952
]
934
- box = self ._doc_parser ._normalize_box (box , [img_w , img_h ], [1000 , 1000 ])
953
+ box = self ._parser_map [ self . _ocr_lang_choice ] ._normalize_box (box , [img_w , img_h ], [1000 , 1000 ])
935
954
text += segment [1 ]
936
955
bbox .extend ([box ] * len (segment [1 ]))
937
956
_inputs .append ({"text" : text , "bbox" : bbox , "image" : d ["image" ], "layout" : d ["layout" ]})
0 commit comments