Skip to content

Commit 1b2e89d

Browse files
committed
Change to is_static_model
1 parent d07f8c4 commit 1b2e89d

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

paddlenlp/taskflow/feature_extraction.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ class MultimodalFeatureExtractionTask(Task):
182182
},
183183
}
184184

185-
def __init__(self, task, model, batch_size=1, _static_mode=True, return_tensors=True, **kwargs):
185+
def __init__(self, task, model, batch_size=1, is_static_model=True, return_tensors=True, **kwargs):
186186
super().__init__(task=task, model=model, **kwargs)
187187
self._seed = None
188188
# we do not use batch
@@ -191,14 +191,14 @@ def __init__(self, task, model, batch_size=1, _static_mode=True, return_tensors=
191191
self.return_tensors = return_tensors
192192
self._check_task_files()
193193
self._construct_tokenizer()
194-
self._static_mode = _static_mode
194+
self.is_static_model = is_static_model
195195
self._config_map = {}
196196
self.predictor_map = {}
197197
self.input_names_map = {}
198198
self.input_handles_map = {}
199199
self.output_handle_map = {}
200200
self._check_predictor_type()
201-
if self._static_mode:
201+
if self.is_static_model:
202202
self._get_inference_model()
203203
else:
204204
self._construct_model(model)
@@ -228,7 +228,7 @@ def _parse_batch(batch_examples):
228228
else:
229229
batch_texts = None
230230
batch_images = batch_examples
231-
if self._static_mode:
231+
if self.is_static_model:
232232
tokenized_inputs = self._processor(
233233
text=batch_texts, images=batch_images, return_tensors="np", padding="max_length", truncation=True
234234
)
@@ -287,7 +287,7 @@ def _run_model(self, inputs):
287287
Run the task model from the outputs of the `_preprocess` function.
288288
"""
289289
all_feats = []
290-
if self._static_mode:
290+
if self.is_static_model:
291291
with static_mode_guard():
292292
for batch_inputs in inputs["batches"]:
293293
if self._predictor_type == "paddle-inference":

tests/taskflow/test_feature_extraction.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def test_feature_extraction_task(self):
8484
dygraph_taskflow = MultimodalFeatureExtractionTask(
8585
model="PaddlePaddle/ernie_vil-2.0-base-zh",
8686
task="feature_extraction",
87-
_static_mode=False,
87+
is_static_model=False,
8888
return_tensors=False,
8989
)
9090
dygraph_results = dygraph_taskflow(input_text)
@@ -94,7 +94,7 @@ def test_feature_extraction_task(self):
9494
static_taskflow = MultimodalFeatureExtractionTask(
9595
model="PaddlePaddle/ernie_vil-2.0-base-zh",
9696
task="feature_extraction",
97-
_static_mode=True,
97+
is_static_model=True,
9898
return_tensors=False,
9999
device_id=0,
100100
)
@@ -123,7 +123,7 @@ def test_taskflow_task(self):
123123
# dygraph test
124124
dygraph_taskflow = Taskflow(
125125
task="feature_extraction",
126-
_static_mode=False,
126+
is_static_model=False,
127127
return_tensors=False,
128128
)
129129
dygraph_results = dygraph_taskflow(input_text)
@@ -133,7 +133,7 @@ def test_taskflow_task(self):
133133
# static test
134134
static_taskflow = Taskflow(
135135
task="feature_extraction",
136-
_static_mode=True,
136+
is_static_model=True,
137137
return_tensors=False,
138138
)
139139
static_results = static_taskflow(input_text)

0 commit comments

Comments
 (0)