diff --git a/paddlenlp/taskflow/feature_extraction.py b/paddlenlp/taskflow/feature_extraction.py
new file mode 100644
index 000000000000..beac02ee9495
--- /dev/null
+++ b/paddlenlp/taskflow/feature_extraction.py
@@ -0,0 +1,424 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+
+import numpy as np
+import paddle
+from PIL import Image
+
+from ..transformers import AutoModel, AutoProcessor
+from ..utils.log import logger
+from .task import Task
+from .utils import dygraph_mode_guard, static_mode_guard
+
+usage = r"""
+ from paddlenlp import Taskflow
+ from PIL import Image
+
+ # multi modal feature_extraction with ernie_vil-2.0-base-zh
+ vision_language = Taskflow("feature_extraction")
+ image_embeds = vision_language([Image.open("demo/000000039769.jpg")])
+ print(image_embeds)
+ '''
+ Tensor(shape=[1, 768], dtype=float32, place=Place(gpu:0), stop_gradient=True,
+ [[-0.59475428, -0.69795364, 0.22144008, 0.88066685, -0.58184201,
+ -0.73454666, 0.95557910, -0.61410815, 0.23474170, 0.13301648,
+ 0.86196446, 0.12281934, 0.69097638, 1.47614217, 0.07238606,
+ ...
+ '''
+ text_embeds = vision_language(["猫的照片","狗的照片"])
+ text_features = text_embeds["features"]
+ print(text_features)
+ '''
+ Tensor(shape=[2, 768], dtype=float32, place=Place(gpu:0), stop_gradient=True,
+ [[ 0.04250504, -0.41429776, 0.26163983, ..., 0.26221892,
+ 0.34387422, 0.18779707],
+ '''
+ image_features /= image_features.norm(axis=-1, keepdim=True)
+ text_features /= text_features.norm(axis=-1, keepdim=True)
+ logits_per_image = 100 * image_features @ text_features.t()
+ probs = F.softmax(logits_per_image, axis=-1)
+ print(probs)
+ '''
+ Tensor(shape=[1, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True,
+ [[0.99833173, 0.00166824]])
+ '''
+ """
+
+
+class MultimodalFeatureExtractionTask(Task):
+ """
+ Feature extraction task using no model head. This task extracts the hidden states from the base
+ model, which can be used as features in retrieval and clustering tasks.
+ Args:
+ task(string): The name of task.
+ model(string): The model name in the task.
+ kwargs (dict, optional): Additional keyword arguments passed along to the specific task.
+ """
+
+ resource_files_names = {
+ "model_state": "model_state.pdparams",
+ "config": "config.json",
+ "vocab_file": "vocab.txt",
+ "preprocessor_config": "preprocessor_config.json",
+ "special_tokens_map": "special_tokens_map.json",
+ "tokenizer_config": "tokenizer_config.json",
+ }
+ resource_files_urls = {
+ "PaddlePaddle/ernie_vil-2.0-base-zh": {
+ "model_state": [
+ "https://paddlenlp.bj.bcebos.com/models/community/PaddlePaddle/ernie_vil-2.0-base-zh/model_state.pdparams",
+ "38d8c8e01f74ba881e87d9a3f669e5ae",
+ ],
+ "config": [
+ "https://paddlenlp.bj.bcebos.com/models/community/PaddlePaddle/ernie_vil-2.0-base-zh/config.json",
+ "caf929b450d5638e8df2a95c936519e7",
+ ],
+ "vocab_file": [
+ "https://paddlenlp.bj.bcebos.com/models/community/PaddlePaddle/ernie_vil-2.0-base-zh/vocab.txt",
+ "1c1c1f4fd93c5bed3b4eebec4de976a8",
+ ],
+ "preprocessor_config": [
+ "https://paddlenlp.bj.bcebos.com/models/community/PaddlePaddle/ernie_vil-2.0-base-zh/preprocessor_config.json",
+ "9a2e8da9f41896fedb86756b79355ee2",
+ ],
+ "special_tokens_map": [
+ "https://paddlenlp.bj.bcebos.com/models/community/PaddlePaddle/ernie_vil-2.0-base-zh/special_tokens_map.json",
+ "8b3fb1023167bb4ab9d70708eb05f6ec",
+ ],
+ "tokenizer_config": [
+ "https://paddlenlp.bj.bcebos.com/models/community/PaddlePaddle/ernie_vil-2.0-base-zh/tokenizer_config.json",
+ "da5385c23c8f522d33fc3aac829e4375",
+ ],
+ },
+ "OFA-Sys/chinese-clip-vit-base-patch16": {
+ "model_state": [
+ "https://paddlenlp.bj.bcebos.com/models/community/OFA-Sys/chinese-clip-vit-base-patch16/model_state.pdparams",
+ "d594c94833b8cfeffc4f986712b3ef79",
+ ],
+ "config": [
+ "https://paddlenlp.bj.bcebos.com/models/community/OFA-Sys/chinese-clip-vit-base-patch16/config.json",
+ "3611b5c34ad69dcf91e3c1d03b01a93a",
+ ],
+ "vocab_file": [
+ "https://paddlenlp.bj.bcebos.com/models/community/OFA-Sys/chinese-clip-vit-base-patch16/vocab.txt",
+ "3b5b76c4aef48ecf8cb3abaafe960f09",
+ ],
+ "preprocessor_config": [
+ "https://paddlenlp.bj.bcebos.com/models/community/OFA-Sys/chinese-clip-vit-base-patch16/preprocessor_config.json",
+ "ba1fb66c75b18b3c9580ea5120e01ced",
+ ],
+ "special_tokens_map": [
+ "https://paddlenlp.bj.bcebos.com/models/community/OFA-Sys/chinese-clip-vit-base-patch16/special_tokens_map.json",
+ "8b3fb1023167bb4ab9d70708eb05f6ec",
+ ],
+ "tokenizer_config": [
+ "https://paddlenlp.bj.bcebos.com/models/community/OFA-Sys/chinese-clip-vit-base-patch16/tokenizer_config.json",
+ "573ba0466e15cdb5bd423ff7010735ce",
+ ],
+ },
+ "OFA-Sys/chinese-clip-vit-large-patch14": {
+ "model_state": [
+ "https://paddlenlp.bj.bcebos.com/models/community/OFA-Sys/chinese-clip-vit-large-patch14/model_state.pdparams",
+ "5c0dde02d68179a9cc566173e53966c0",
+ ],
+ "config": [
+ "https://paddlenlp.bj.bcebos.com/models/community/OFA-Sys/chinese-clip-vit-large-patch14/config.json",
+ "a5e35843aa87ab1106e9f60f1e16b96d",
+ ],
+ "vocab_file": [
+ "https://paddlenlp.bj.bcebos.com/models/community/OFA-Sys/chinese-clip-vit-large-patch14/vocab.txt",
+ "3b5b76c4aef48ecf8cb3abaafe960f09",
+ ],
+ "preprocessor_config": [
+ "https://paddlenlp.bj.bcebos.com/models/community/OFA-Sys/chinese-clip-vit-large-patch14/preprocessor_config.json",
+ "ba1fb66c75b18b3c9580ea5120e01ced",
+ ],
+ "special_tokens_map": [
+ "https://paddlenlp.bj.bcebos.com/models/community/OFA-Sys/chinese-clip-vit-large-patch14/special_tokens_map.json",
+ "8b3fb1023167bb4ab9d70708eb05f6ec",
+ ],
+ "tokenizer_config": [
+ "https://paddlenlp.bj.bcebos.com/models/community/OFA-Sys/chinese-clip-vit-large-patch14/tokenizer_config.json",
+ "573ba0466e15cdb5bd423ff7010735ce",
+ ],
+ },
+ "OFA-Sys/chinese-clip-vit-large-patch14-336px": {
+ "model_state": [
+ "https://paddlenlp.bj.bcebos.com/models/community/OFA-Sys/chinese-clip-vit-large-patch14-336px/model_state.pdparams",
+ "ee3eb7f9667cfb06338bea5757c5e0d7",
+ ],
+ "config": [
+ "https://paddlenlp.bj.bcebos.com/models/community/OFA-Sys/chinese-clip-vit-large-patch14-336px/config.json",
+ "cb2794d99bea8c8f45901d177e663e1e",
+ ],
+ "vocab_file": [
+ "https://paddlenlp.bj.bcebos.com/models/community/OFA-Sys/chinese-clip-vit-large-patch14-336px/vocab.txt",
+ "3b5b76c4aef48ecf8cb3abaafe960f09",
+ ],
+ "preprocessor_config": [
+ "https://paddlenlp.bj.bcebos.com/models/community/OFA-Sys/chinese-clip-vit-large-patch14-336px/preprocessor_config.json",
+ "c52a0b3abe9bdd1c3c5a3d56797f4a03",
+ ],
+ "special_tokens_map": [
+ "https://paddlenlp.bj.bcebos.com/models/community/OFA-Sys/chinese-clip-vit-large-patch14-336px/special_tokens_map.json",
+ "8b3fb1023167bb4ab9d70708eb05f6ec",
+ ],
+ "tokenizer_config": [
+ "https://paddlenlp.bj.bcebos.com/models/community/OFA-Sys/chinese-clip-vit-large-patch14-336px/tokenizer_config.json",
+ "573ba0466e15cdb5bd423ff7010735ce",
+ ],
+ },
+ }
+
+ def __init__(self, task, model, batch_size=1, is_static_model=True, return_tensors=True, **kwargs):
+ super().__init__(task=task, model=model, **kwargs)
+ self._seed = None
+ # we do not use batch
+ self.export_type = "text"
+ self._batch_size = batch_size
+ self.return_tensors = return_tensors
+ self._check_task_files()
+ self._construct_tokenizer()
+ self.is_static_model = is_static_model
+ self._config_map = {}
+ self.predictor_map = {}
+ self.input_names_map = {}
+ self.input_handles_map = {}
+ self.output_handle_map = {}
+ self._check_predictor_type()
+ if self.is_static_model:
+ self._get_inference_model()
+ else:
+ self._construct_model(model)
+
+ def _construct_model(self, model):
+ """
+ Construct the inference model for the predictor.
+ """
+ self._model = AutoModel.from_pretrained(self._task_path)
+ self._model.eval()
+
+ def _construct_tokenizer(self):
+ """
+ Construct the tokenizer for the predictor.
+ """
+ self._processor = AutoProcessor.from_pretrained(self.model)
+
+ def _batchify(self, data, batch_size):
+ """
+ Generate input batches.
+ """
+
+ def _parse_batch(batch_examples):
+ if isinstance(batch_examples[0], str):
+ batch_texts = batch_examples
+ batch_images = None
+ else:
+ batch_texts = None
+ batch_images = batch_examples
+ if self.is_static_model:
+ tokenized_inputs = self._processor(
+ text=batch_texts, images=batch_images, return_tensors="np", padding="max_length", truncation=True
+ )
+ else:
+ tokenized_inputs = self._processor(
+ text=batch_texts, images=batch_images, return_tensors="pd", padding="max_length", truncation=True
+ )
+ return tokenized_inputs
+
+ # Seperates data into some batches.
+ one_batch = []
+ for example in data:
+ one_batch.append(example)
+ if len(one_batch) == batch_size:
+ yield _parse_batch(one_batch)
+ one_batch = []
+ if one_batch:
+ yield _parse_batch(one_batch)
+
+ def _check_input_text(self, inputs):
+ """
+ Check whether the input text meet the requirement.
+ """
+ inputs = inputs[0]
+ if isinstance(inputs, str):
+ if len(inputs) == 0:
+ raise ValueError("Invalid inputs, input text should not be empty, please check your input.")
+ inputs = [inputs]
+ elif isinstance(inputs, Image.Image):
+ inputs = [inputs]
+ elif isinstance(inputs, list):
+ # and len(inputs[0].strip()) > 0
+ if not (isinstance(inputs[0], (str, Image.Image))):
+ raise TypeError(
+ "Invalid inputs, input text/image should be list of str/PIL.image, and first element of list should not be empty."
+ )
+ else:
+ raise TypeError(
+ "Invalid inputs, input text should be str or list of str, but type of {} found!".format(type(inputs))
+ )
+ return inputs
+
+ def _preprocess(self, inputs):
+ """
+ Transform the raw inputs to the model inputs, two steps involved:
+ 1) Transform the raw text/image to token ids/pixel_values.
+ 2) Generate the other model inputs from the raw text/image and token ids/pixel_values.
+ """
+ inputs = self._check_input_text(inputs)
+ batches = self._batchify(inputs, self._batch_size)
+ outputs = {"batches": batches, "inputs": inputs}
+ return outputs
+
+ def _run_model(self, inputs):
+ """
+ Run the task model from the outputs of the `_preprocess` function.
+ """
+ all_feats = []
+ if self.is_static_model:
+ with static_mode_guard():
+ for batch_inputs in inputs["batches"]:
+ if self._predictor_type == "paddle-inference":
+ if "input_ids" in batch_inputs:
+ self.input_handles_map["text"][0].copy_from_cpu(batch_inputs["input_ids"])
+ self.predictor_map["text"].run()
+ text_features = self.output_handle_map["text"][0].copy_to_cpu()
+ all_feats.append(text_features)
+ elif "pixel_values" in batch_inputs:
+ self.input_handles_map["image"][0].copy_from_cpu(batch_inputs["pixel_values"])
+ self.predictor_map["image"].run()
+ image_features = self.output_handle_map["image"][0].copy_to_cpu()
+ all_feats.append(image_features)
+ else:
+ # onnx mode
+ if "input_ids" in batch_inputs:
+ input_dict = {}
+ input_dict["input_ids"] = batch_inputs["input_ids"]
+ text_features = self.predictor_map["text"].run(None, input_dict)[0].tolist()
+ all_feats.append(text_features)
+ elif "pixel_values" in batch_inputs:
+ input_dict = {}
+ input_dict["pixel_values"] = batch_inputs["pixel_values"]
+ image_features = self.predictor_map["image"].run(None, input_dict)[0].tolist()
+ all_feats.append(image_features)
+ else:
+ for batch_inputs in inputs["batches"]:
+ if "input_ids" in batch_inputs:
+ text_features = self._model.get_text_features(input_ids=batch_inputs["input_ids"])
+ all_feats.append(text_features.numpy())
+ if "pixel_values" in batch_inputs:
+ image_features = self._model.get_image_features(pixel_values=batch_inputs["pixel_values"])
+ all_feats.append(image_features.numpy())
+ inputs.update({"features": all_feats})
+ return inputs
+
+ def _postprocess(self, inputs):
+ inputs["features"] = np.concatenate(inputs["features"], axis=0)
+ if self.return_tensors:
+ inputs["features"] = paddle.to_tensor(inputs["features"])
+ return inputs
+
+ def _construct_input_spec(self):
+ """
+ Construct the input spec for the convert dygraph model to static model.
+ """
+ self._input_text_spec = [
+ paddle.static.InputSpec(shape=[None, None], dtype="int64", name="input_ids"),
+ ]
+
+ self._input_image_spec = [
+ paddle.static.InputSpec(shape=[None, 3, 224, 224], dtype="float32", name="pixel_values"),
+ ]
+
+ def _convert_dygraph_to_static(self):
+ """
+ Convert the dygraph model to static model.
+ """
+ assert (
+ self._model is not None
+ ), "The dygraph model must be created before converting the dygraph model to static model."
+ assert (
+ self._input_image_spec is not None or self._input_text_spec is not None
+ ), "The input spec must be created before converting the dygraph model to static model."
+ logger.info("Converting to the inference model cost a little time.")
+
+ static_model = paddle.jit.to_static(self._model.get_text_features, input_spec=self._input_text_spec)
+ self.inference_model_path = self.inference_text_model_path
+ paddle.jit.save(static_model, self.inference_model_path)
+ logger.info("The inference model save in the path:{}".format(self.inference_model_path))
+
+ static_model = paddle.jit.to_static(self._model.get_image_features, input_spec=self._input_image_spec)
+ self.inference_model_path = self.inference_image_model_path
+ paddle.jit.save(static_model, self.inference_model_path)
+ logger.info("The inference model save in the path:{}".format(self.inference_model_path))
+
+ def _get_inference_model(self):
+ """
+ Return the inference program, inputs and outputs in static mode.
+ """
+ _base_path = os.path.join(self._home_path, "taskflow", self.task, self.model)
+ self.inference_image_model_path = os.path.join(_base_path, "static", "get_image_features")
+ self.inference_text_model_path = os.path.join(_base_path, "static", "get_text_features")
+ if (
+ not os.path.exists(self.inference_image_model_path + ".pdiparams")
+ or self._param_updated
+ or not os.path.exists(self.inference_text_model_path + ".pdiparams")
+ ):
+ with dygraph_mode_guard():
+ self._construct_model(self.model)
+ self._construct_input_spec()
+ self._convert_dygraph_to_static()
+ if self._predictor_type == "paddle-inference":
+ # Get text inference model
+ self.inference_model_path = self.inference_text_model_path
+ self._static_model_file = self.inference_model_path + ".pdmodel"
+ self._static_params_file = self.inference_model_path + ".pdiparams"
+ self._config = paddle.inference.Config(self._static_model_file, self._static_params_file)
+ self._prepare_static_mode()
+
+ self.predictor_map["text"] = self.predictor
+ self.input_names_map["text"] = self.input_names
+ self.input_handles_map["text"] = self.input_handles
+ self.output_handle_map["text"] = self.output_handle
+ self._config_map["text"] = self._config
+
+ # Get image inference model
+ self.inference_model_path = self.inference_image_model_path
+ self._static_model_file = self.inference_model_path + ".pdmodel"
+ self._static_params_file = self.inference_model_path + ".pdiparams"
+ self._config = paddle.inference.Config(self._static_model_file, self._static_params_file)
+ self._prepare_static_mode()
+
+ self.predictor_map["image"] = self.predictor
+ self.input_names_map["image"] = self.input_names
+ self.input_handles_map["image"] = self.input_handles
+ self.output_handle_map["image"] = self.output_handle
+ self._config_map["image"] = self._config
+ else:
+ # Get text onnx model
+ self.export_type = "text"
+ self.inference_model_path = self.inference_text_model_path
+ self._static_model_file = self.inference_model_path + ".pdmodel"
+ self._static_params_file = self.inference_model_path + ".pdiparams"
+ self._prepare_onnx_mode()
+ self.predictor_map["text"] = self.predictor
+
+ # Get image onnx model
+ self.export_type = "image"
+ self.inference_model_path = self.inference_image_model_path
+ self._static_model_file = self.inference_model_path + ".pdmodel"
+ self._static_params_file = self.inference_model_path + ".pdiparams"
+ self._prepare_onnx_mode()
+ self.predictor_map["image"] = self.predictor
diff --git a/paddlenlp/taskflow/task.py b/paddlenlp/taskflow/task.py
index a27410af525b..0e8a33c4b1a3 100644
--- a/paddlenlp/taskflow/task.py
+++ b/paddlenlp/taskflow/task.py
@@ -61,6 +61,8 @@ def __init__(self, model, task, priority_path=None, **kwargs):
self._home_path = self.kwargs["home_path"] if "home_path" in self.kwargs else PPNLP_HOME
self._task_flag = self.kwargs["task_flag"] if "task_flag" in self.kwargs else self.model
self.from_hf_hub = kwargs.pop("from_hf_hub", False)
+ # Add mode flag for onnx output path redirection
+ self.export_type = None
if "task_path" in self.kwargs:
self._task_path = self.kwargs["task_path"]
@@ -221,8 +223,12 @@ def _prepare_onnx_mode(self):
logger.warning(
"The inference precision is change to 'fp32', please install the dependencies that required for 'fp16' inference, pip install onnxruntime-gpu onnx onnxconverter-common"
)
+ if self.export_type is None:
+ onnx_dir = os.path.join(self._task_path, "onnx")
+ else:
+ # Compatible multimodal model for saving image and text path
+ onnx_dir = os.path.join(self._task_path, "onnx", self.export_type)
- onnx_dir = os.path.join(self._task_path, "onnx")
if not os.path.exists(onnx_dir):
os.mkdir(onnx_dir)
float_onnx_file = os.path.join(onnx_dir, "model.onnx")
diff --git a/paddlenlp/taskflow/taskflow.py b/paddlenlp/taskflow/taskflow.py
index b5d4530e8164..da0e1da6f502 100644
--- a/paddlenlp/taskflow/taskflow.py
+++ b/paddlenlp/taskflow/taskflow.py
@@ -23,6 +23,7 @@
from .dependency_parsing import DDParserTask
from .dialogue import DialogueTask
from .document_intelligence import DocPromptTask
+from .feature_extraction import MultimodalFeatureExtractionTask
from .fill_mask import FillMaskTask
from .information_extraction import GPTask, UIETask
from .knowledge_mining import NPTagTask, WordTagTask
@@ -514,6 +515,76 @@
},
"default": {"model": "utc-large"},
},
+ "feature_extraction": {
+ "models": {
+ "PaddlePaddle/ernie_vil-2.0-base-zh": {
+ "task_class": MultimodalFeatureExtractionTask,
+ "task_flag": "feature_extraction-PaddlePaddle/ernie_vil-2.0-base-zh",
+ "task_priority_path": "PaddlePaddle/ernie_vil-2.0-base-zh",
+ },
+ "OFA-Sys/chinese-clip-vit-base-patch16": {
+ "task_class": MultimodalFeatureExtractionTask,
+ "task_flag": "feature_extraction-OFA-Sys/chinese-clip-vit-base-patch16",
+ "task_priority_path": "OFA-Sys/chinese-clip-vit-base-patch16",
+ },
+ "OFA-Sys/chinese-clip-vit-huge-patch14": {
+ "task_class": MultimodalFeatureExtractionTask,
+ "task_flag": "feature_extraction-OFA-Sys/chinese-clip-vit-huge-patch14",
+ "task_priority_path": "OFA-Sys/chinese-clip-vit-huge-patch14",
+ },
+ "OFA-Sys/chinese-clip-vit-large-patch14": {
+ "task_class": MultimodalFeatureExtractionTask,
+ "task_flag": "feature_extraction-OFA-Sys/chinese-clip-vit-large-patch14",
+ "task_priority_path": "OFA-Sys/chinese-clip-vit-large-patch14",
+ },
+ "OFA-Sys/chinese-clip-vit-large-patch14-336px": {
+ "task_class": MultimodalFeatureExtractionTask,
+ "task_flag": "feature_extraction-OFA-Sys/chinese-clip-vit-large-patch14-336px",
+ "task_priority_path": "OFA-Sys/chinese-clip-vit-large-patch14-336px",
+ },
+ "openai/clip-vit-base-patch32": {
+ "task_class": MultimodalFeatureExtractionTask,
+ "task_flag": "feature_extraction-openai/clip-vit-base-patch32",
+ "task_priority_path": "openai/clip-vit-base-patch32",
+ },
+ "openai/clip-vit-base-patch16": {
+ "task_class": MultimodalFeatureExtractionTask,
+ "task_flag": "feature_extraction-openai/clip-vit-base-patch16",
+ "task_priority_path": "openai/clip-vit-base-patch16",
+ },
+ "openai/clip-vit-large-patch14": {
+ "task_class": MultimodalFeatureExtractionTask,
+ "task_flag": "feature_extraction-openai/clip-vit-large-patch14",
+ "task_priority_path": "openai/clip-vit-large-patch14",
+ },
+ "laion/CLIP-ViT-H-14-laion2B-s32B-b79K": {
+ "task_class": MultimodalFeatureExtractionTask,
+ "task_flag": "feature_extraction-laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
+ "task_priority_path": "laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
+ },
+ "laion/CLIP-ViT-B-32-laion2B-s34B-b79K": {
+ "task_class": MultimodalFeatureExtractionTask,
+ "task_flag": "feature_extraction-laion/CLIP-ViT-B-32-laion2B-s34B-b79K",
+ "task_priority_path": "laion/CLIP-ViT-B-32-laion2B-s34B-b79K",
+ },
+ "openai/clip-rn50": {
+ "task_class": MultimodalFeatureExtractionTask,
+ "task_flag": "feature_extraction-openai/clip-rn50",
+ "task_priority_path": "openai/clip-rn50",
+ },
+ "openai/clip-rn101": {
+ "task_class": MultimodalFeatureExtractionTask,
+ "task_flag": "feature_extraction-openai/clip-rn101",
+ "task_priority_path": "openai/clip-rn101",
+ },
+ "openai/clip-rn50x4": {
+ "task_class": MultimodalFeatureExtractionTask,
+ "task_flag": "feature_extraction-openai/clip-rn50x4",
+ "task_priority_path": "openai/clip-rn50x4",
+ },
+ },
+ "default": {"model": "PaddlePaddle/ernie_vil-2.0-base-zh"},
+ },
}
support_schema_list = [
diff --git a/paddlenlp/transformers/__init__.py b/paddlenlp/transformers/__init__.py
index 3c111eb68537..33d74c1709a5 100644
--- a/paddlenlp/transformers/__init__.py
+++ b/paddlenlp/transformers/__init__.py
@@ -163,6 +163,7 @@
from .opt.modeling import *
from .auto.modeling import *
from .auto.tokenizer import *
+from .auto.processing import *
from .codegen.modeling import *
from .codegen.tokenizer import *
from .codegen.configuration import *
diff --git a/paddlenlp/transformers/auto/processing.py b/paddlenlp/transformers/auto/processing.py
new file mode 100644
index 000000000000..34e896b8ebab
--- /dev/null
+++ b/paddlenlp/transformers/auto/processing.py
@@ -0,0 +1,176 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+# Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib
+import io
+import json
+import os
+from collections import OrderedDict
+
+from paddlenlp.utils.downloader import COMMUNITY_MODEL_PREFIX, get_path_from_url
+from paddlenlp.utils.env import MODEL_HOME
+from paddlenlp.utils.import_utils import import_module
+from paddlenlp.utils.log import logger
+
+__all__ = [
+ "AutoProcessor",
+]
+
+PROCESSOR_MAPPING_NAMES = OrderedDict(
+ [
+ ("ChineseCLIPProcessor", "chineseclip"),
+ ("CLIPProcessor", "clip"),
+ ("ErnieViLProcessor", "ernie_vil"),
+ ]
+)
+
+
+def get_configurations():
+ MAPPING_NAMES = OrderedDict()
+ for key, class_name in PROCESSOR_MAPPING_NAMES.items():
+ import_class = importlib.import_module(f"paddlenlp.transformers.{class_name}.procesing")
+ processor_name = getattr(import_class, key)
+ name = tuple(processor_name.pretrained_init_configuration.keys())
+ if MAPPING_NAMES.get(name, None) is None:
+ MAPPING_NAMES[name] = []
+ MAPPING_NAMES[name].append(processor_name)
+ return MAPPING_NAMES
+
+
+class AutoProcessor:
+ """
+ AutoClass can help you automatically retrieve the relevant model given the provided
+ pretrained weights/vocabulary.
+ Autoprocessor is a generic processor class that will be instantiated as one of the
+ base processor classes when created with the Autoprocessor.from_pretrained() classmethod.
+ """
+
+ MAPPING_NAMES = get_configurations()
+ _processor_mapping = MAPPING_NAMES
+ _name_mapping = PROCESSOR_MAPPING_NAMES
+ processor_config_file = "preprocessor_config.json"
+
+ def __init__(self, *args, **kwargs):
+ raise EnvironmentError(
+ f"{self.__class__.__name__} is designed to be instantiated "
+ f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path).`"
+ )
+
+ @classmethod
+ def _get_processor_class_from_config(cls, pretrained_model_name_or_path, config_file_path):
+ with io.open(config_file_path, encoding="utf-8") as f:
+ init_kwargs = json.load(f)
+ # class name corresponds to this configuration
+ init_class = init_kwargs.pop("init_class", None)
+ if init_class is None:
+ init_class = init_kwargs.pop("processor_class", None)
+
+ if init_class:
+ class_name = cls._name_mapping[init_class]
+ import_class = import_module(f"paddlenlp.transformers.{class_name}.procesing")
+ processor_class = getattr(import_class, init_class)
+ return processor_class
+ # If no `init_class`, we use pattern recognition to recognize the processor class.
+ else:
+ logger.info("We use pattern recognition to recognize the processor class.")
+ for key, pattern in cls._name_mapping.items():
+ if pattern in pretrained_model_name_or_path.lower():
+ init_class = key
+ class_name = cls._name_mapping[init_class]
+ import_class = import_module(f"paddlenlp.transformers.{class_name}.processor")
+ processor_class = getattr(import_class, init_class)
+ break
+ return processor_class
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
+ """
+ Creates an instance of `Autoprocessor`. Related resources are loaded by
+ specifying name of a built-in pretrained model, or a community-contributed
+ pretrained model, or a local file directory path.
+
+ Args:
+ pretrained_model_name_or_path (str): Name of pretrained model or dir path
+ to load from. The string can be:
+
+ - Name of built-in pretrained model
+ - Name of a community-contributed pretrained model.
+ - Local directory path which contains processor related resources
+ and processor config file ("processor_config.json").
+ *args (tuple): position arguments for model `__init__`. If provided,
+ use these as position argument values for processor initialization.
+ **kwargs (dict): keyword arguments for model `__init__`. If provided,
+ use these to update pre-defined keyword argument values for processor
+ initialization.
+
+ Returns:
+ Pretrainedprocessor: An instance of `Pretrainedprocessor`.
+
+
+ Example:
+ .. code-block::
+ from paddlenlp.transformers import AutoProcessor
+ processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
+ processor.save_pretrained('clip_processor')
+ """
+
+ all_processor_names = []
+ for names, processor_class in cls._processor_mapping.items():
+ for name in names:
+ all_processor_names.append(name)
+ # From built-in pretrained models
+ if pretrained_model_name_or_path in all_processor_names:
+ for names, processor_classes in cls._processor_mapping.items():
+ for pattern in names:
+ if pattern == pretrained_model_name_or_path:
+ actual_processor_class = processor_classes[0]
+ logger.info(
+ "We are using %s to load '%s'." % (actual_processor_class, pretrained_model_name_or_path)
+ )
+ return actual_processor_class.from_pretrained(
+ pretrained_model_name_or_path, *model_args, **kwargs
+ )
+ # From local dir path
+ elif os.path.isdir(pretrained_model_name_or_path):
+ config_file = os.path.join(pretrained_model_name_or_path, cls.processor_config_file)
+ if os.path.exists(config_file):
+ processor_class = cls._get_processor_class_from_config(pretrained_model_name_or_path, config_file)
+ logger.info("We are using %s to load '%s'." % (processor_class, pretrained_model_name_or_path))
+ return processor_class.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
+ # Assuming from community-contributed pretrained models
+ else:
+ community_config_path = "/".join(
+ [COMMUNITY_MODEL_PREFIX, pretrained_model_name_or_path, cls.processor_config_file]
+ )
+
+ default_root = os.path.join(MODEL_HOME, pretrained_model_name_or_path)
+ try:
+ resolved_vocab_file = get_path_from_url(community_config_path, default_root)
+ except RuntimeError as err:
+ logger.error(err)
+ raise RuntimeError(
+ f"Can't load processor for '{pretrained_model_name_or_path}'.\n"
+ f"Please make sure that '{pretrained_model_name_or_path}' is:\n"
+ "- a correct model-identifier of built-in pretrained models,\n"
+ "- or a correct model-identifier of community-contributed pretrained models,\n"
+ "- or the correct path to a directory containing relevant processor files.\n"
+ )
+
+ if os.path.exists(resolved_vocab_file):
+ processor_class = cls._get_processor_class_from_config(
+ pretrained_model_name_or_path, resolved_vocab_file
+ )
+ logger.info("We are using %s to load '%s'." % (processor_class, pretrained_model_name_or_path))
+ return processor_class.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
diff --git a/paddlenlp/transformers/chineseclip/procesing.py b/paddlenlp/transformers/chineseclip/procesing.py
index e7e0034ae155..701dfda83efd 100644
--- a/paddlenlp/transformers/chineseclip/procesing.py
+++ b/paddlenlp/transformers/chineseclip/procesing.py
@@ -41,6 +41,13 @@ class ChineseCLIPProcessor(ProcessorMixin):
image_processor_class = "ChineseCLIPImageProcessor"
tokenizer_class = "ChineseCLIPTokenizer"
+ pretrained_init_configuration = {
+ "OFA-Sys/chinese-clip-vit-base-patch16": {"do_lower_case": True},
+ "OFA-Sys/chinese-clip-vit-huge-patch14": {"do_lower_case": True},
+ "OFA-Sys/chinese-clip-vit-large-patch14": {"do_lower_case": True},
+ "OFA-Sys/chinese-clip-vit-large-patch14-336px": {"do_lower_case": True},
+ }
+
def __init__(self, image_processor=None, tokenizer=None, **kwargs):
if "feature_extractor" in kwargs:
warnings.warn(
diff --git a/paddlenlp/transformers/clip/procesing.py b/paddlenlp/transformers/clip/procesing.py
index ffaf9ad661ec..3424f643e193 100644
--- a/paddlenlp/transformers/clip/procesing.py
+++ b/paddlenlp/transformers/clip/procesing.py
@@ -40,6 +40,17 @@ class CLIPProcessor(ProcessorMixin):
image_processor_class = "CLIPImageProcessor"
tokenizer_class = "CLIPTokenizer"
+ pretrained_init_configuration = {
+ "openai/clip-vit-base-patch32": {"do_lower_case": True},
+ "openai/clip-vit-base-patch16": {"do_lower_case": True},
+ "openai/clip-vit-large-patch14": {"do_lower_case": True},
+ "laion/CLIP-ViT-H-14-laion2B-s32B-b79K": {"do_lower_case": True},
+ "laion/CLIP-ViT-B-32-laion2B-s34B-b79K": {"do_lower_case": True},
+ "openai/clip-rn50": {"do_lower_case": True},
+ "openai/clip-rn101": {"do_lower_case": True},
+ "openai/clip-rn50x4": {"do_lower_case": True},
+ }
+
def __init__(self, image_processor=None, tokenizer=None, **kwargs):
if "feature_extractor" in kwargs:
warnings.warn(
diff --git a/paddlenlp/transformers/ernie_vil/procesing.py b/paddlenlp/transformers/ernie_vil/procesing.py
index 952cfe901477..e89ab381f4bb 100644
--- a/paddlenlp/transformers/ernie_vil/procesing.py
+++ b/paddlenlp/transformers/ernie_vil/procesing.py
@@ -40,6 +40,10 @@ class ErnieViLProcessor(ProcessorMixin):
image_processor_class = "ErnieViLImageProcessor"
tokenizer_class = "ErnieViLTokenizer"
+ pretrained_init_configuration = {
+ "PaddlePaddle/ernie_vil-2.0-base-zh": {"do_lower_case": True},
+ }
+
def __init__(self, image_processor=None, tokenizer=None, **kwargs):
if "feature_extractor" in kwargs:
warnings.warn(
diff --git a/pipelines/examples/image_text_retrieval/README.md b/pipelines/examples/image_text_retrieval/README.md
new file mode 100644
index 000000000000..313cddfaf71b
--- /dev/null
+++ b/pipelines/examples/image_text_retrieval/README.md
@@ -0,0 +1,191 @@
+# 端到端文图跨模态检索系统
+
+## 1. 场景概述
+
+文图跨模态检索系统目的是通过文字找到最符合描述的图片。传统的方案是用标签和图片的关键字进行匹配,而跨模态检索真正的实现了文本语义和图片语义内容的匹配,这种检索方式更符合人类的逻辑判断,是一种真正意义上的端到端人工智能。文图应用目前可以广泛应用于电商搜索,安防视频,图像检索,抖音等小视频,旅游app应用搜索。有助于提升效率和搜索体验。另外还有一些潜在的领域,比如司法的互联网调查取证,侵权检测,数据增强,文案匹配,各种互联网logo,肖像,风景,海报等图片网站的检索,医药等专业领域的文图搜索等。
+
+## 2. 产品功能介绍
+
+本项目提供了低成本搭建端到端文图跨模态检索系统的能力。用户只需要处理好自己的业务数据,就可以使用本项目预置的文图跨模态检索系统模型快速搭建一个针对自己业务数据的跨模态检索系统,并可以提供 Web 化产品服务。
+
+
+

+
+
+
+### 2.1 系统特色
+
++ 端到端
+ + 提供包括数据建库、模型服务部署、WebUI 可视化一整套端到端文图跨模态检索系统能力
+ + 依托百度领先的NLP技术,包括[ERNIE](https://github.com/PaddlePaddle/ERNIE)语义理解技术,[ERNIE-ViL 2.0](https://arxiv.org/abs/2209.15270)跨模态检索能力
+ + 预置领先的深度学习模型
+
+## 3. 快速开始: 快速搭建文图跨模态检索系统
+
+
+### 3.1 运行环境和安装说明
+
+本实验采用了以下的运行环境进行,详细说明如下,用户也可以在自己 GPU 硬件环境进行:
+
+a. 软件环境:
+- python >= 3.7.0
+- paddlenlp >= 2.5.0
+- paddlepaddle-gpu >=2.4.1
+- CUDA Version: 11.2
+- NVIDIA Driver Version: 440.64.00
+- Ubuntu 16.04.6 LTS (Docker)
+
+b. 硬件环境:
+
+- NVIDIA Tesla V100 16GB x4卡
+- Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz
+
+c. 依赖安装:
+首先需要安装PaddlePaddle,PaddlePaddle的安装请参考文档[官方安装文档](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html),然后安装下面的依赖:
+```bash
+# pip 一键安装
+pip install --upgrade paddle-pipelines -i https://pypi.tuna.tsinghua.edu.cn/simple
+# 或者源码进行安装最新版本
+cd ${HOME}/PaddleNLP/pipelines/
+pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
+python setup.py install
+```
+
+```
+# 下载pipelines源代码
+git clone https://github.com/PaddlePaddle/PaddleNLP.git
+cd PaddleNLP/pipelines
+```
+【注意】以下的所有的流程都只需要在`pipelines`根目录下进行,不需要跳转目录
+
+### 3.2 数据说明
+文图跨模态检索数据库的数据来自于[Noah-Wukong数据集](https://wukong-dataset.github.io/wukong-dataset/index.html),并选取了测试集中3056张图片来搭建文图跨模态检索系统。
+
+### 3.3 一键体验文图跨模态检索系统
+
+#### 3.3.1 快速一键启动
+
+我们预置了基于[Noah-Wukong数据集](https://wukong-dataset.github.io/wukong-dataset/index.html)搭建文图跨模态检索系统的代码示例,您可以通过如下命令快速体验文图跨模态检索系统的效果
+```bash
+# 我们建议在 GPU 环境下运行本示例,运行速度较快
+# 设置 1 个空闲的 GPU 卡,此处假设 0 卡为空闲 GPU
+export CUDA_VISIBLE_DEVICES=0
+python examples/image_text_retrieval/image_text_retrieval_example.py --device gpu \
+ --search_engine faiss
+# 如果只有 CPU 机器,可以通过 --device 参数指定 cpu 即可, 运行耗时较长
+unset CUDA_VISIBLE_DEVICES
+python examples/image_text_retrieval/image_text_retrieval_example.py --device cpu \
+ --search_engine faiss
+```
+
+
+### 3.4 构建 Web 可视化文图跨模态检索系统
+
+整个 Web 可视化文图跨模态检索系统主要包含 3 大组件: 1. 基于 ElasticSearch 的 ANN 服务 2. 基于 RestAPI 构建模型服务 3. 基于 Streamlit 构建 WebUI,接下来我们依次搭建这 3 个服务并最终形成可视化的文图跨模态检索系统。
+
+#### 3.4.1 启动 ANN 服务
+1. 参考官方文档下载安装 [elasticsearch-8.3.2](https://www.elastic.co/cn/downloads/elasticsearch) 并解压。
+2. 启动 ES 服务
+首先修改`config/elasticsearch.yml`的配置:
+```
+xpack.security.enabled: false
+```
+然后启动:
+```bash
+./bin/elasticsearch
+```
+3. 检查确保 ES 服务启动成功
+```bash
+curl http://localhost:9200/_aliases?pretty=true
+```
+备注:ES 服务默认开启端口为 9200
+
+#### 3.4.2 文档数据写入 ANN 索引库
+```
+# 以DuReader-Robust 数据集为例建立 ANN 索引库
+python utils/offline_ann_mm.py --index_name wukong_test \
+ --doc_dir data/wukong_test \
+ --search_engine elastic \
+ --delete_index
+```
+可以使用下面的命令来查看数据:
+
+```
+# 打印几条数据
+curl http://localhost:9200/wukong_test/_search
+```
+
+参数含义说明
+* `index_name`: 索引的名称
+* `doc_dir`: txt文本数据的路径
+* `host`: ANN索引引擎的IP地址
+* `port`: ANN索引引擎的端口号
+* `search_engine`: 选择的近似索引引擎elastic,milvus,默认elastic
+* `delete_index`: 是否删除现有的索引和数据,用于清空es的数据,默认为false
+
+删除索引也可以使用下面的命令:
+
+```
+curl -XDELETE http://localhost:9200/wukong_test
+```
+
+#### 3.4.3 启动 RestAPI 模型服务
+```bash
+# 指定文图跨模态检索系统的Yaml配置文件
+export PIPELINE_YAML_PATH=rest_api/pipeline/image_text_retrieval.yaml
+# 使用端口号 8891 启动模型服务
+python rest_api/application.py 8891
+```
+Linux 用户推荐采用 Shell 脚本来启动服务::
+
+```bash
+sh examples/image_text_retrieval/run_search_server.sh
+```
+启动后可以使用curl命令验证是否成功运行:
+
+```
+curl -X POST -k http://localhost:8891/query -H 'Content-Type: application/json' -d '{"query": "云南普者黑现纯白色⒌蒂莲","params": {"Retriever": {"top_k": 5}}}'
+```
+
+更多API接口文档及其调用方式请参考链接[http://127.0.0.1:8891/docs](http://127.0.0.1:8891/docs)
+
+#### 3.4.4 启动 WebUI
+```bash
+# 配置模型服务地址
+export API_ENDPOINT=http://127.0.0.1:8891
+# 在指定端口 8502 启动 WebUI
+python ui/webapp_image_text_retrieval.py --server.port 8502
+```
+Linux 用户推荐采用 Shell 脚本来启动服务::
+
+```bash
+sh examples/image_text_retrieval/run_search_web.sh
+```
+
+到这里您就可以打开浏览器访问 http://127.0.0.1:8502 地址体验文图跨模态检索系统服务了。
+
+#### 3.4.5 数据更新
+
+数据更新使用前面的 `utils/offline_ann_mm.py`进行数据更新,把图片放在特定目录,然后传入该目录即可:
+
+```
+python utils/offline_ann_mm.py --index_name wukong_test \
+ --doc_dir data/wukong_test \
+ --port 9200 \
+ --search_engine elastic \
+ --delete_index
+```
+
+
+如果安装遇见问题可以查看[FAQ文档](../../FAQ.md)
+
+## Reference
+[1]Y. Sun et al., “[ERNIE 3.0: Large-scale Knowledge Enhanced Pre-training for Language Understanding and Generation](https://arxiv.org/pdf/2107.02137.pdf),” arXiv:2107.02137 [cs], Jul. 2021, Accessed: Jan. 17, 2022. [Online]. Available: http://arxiv.org/abs/2107.02137
+
+[2]Shan, Bin, et al. "[ERNIE-ViL 2.0: Multi-view Contrastive Learning for Image-Text Pre-training](https://arxiv.org/abs/2209.15270)." arXiv preprint arXiv:2209.15270 (2022).
+
+## Acknowledge
+
+我们借鉴了 Deepset.ai [Haystack](https://github.com/deepset-ai/haystack) 优秀的框架设计,在此对[Haystack](https://github.com/deepset-ai/haystack)作者及其开源社区表示感谢。
+
+We learn form the excellent framework design of Deepset.ai [Haystack](https://github.com/deepset-ai/haystack), and we would like to express our thanks to the authors of Haystack and their open source community.
diff --git a/pipelines/examples/image_text_retrieval/image_text_retrieval_example.py b/pipelines/examples/image_text_retrieval/image_text_retrieval_example.py
new file mode 100644
index 000000000000..b92117c1d165
--- /dev/null
+++ b/pipelines/examples/image_text_retrieval/image_text_retrieval_example.py
@@ -0,0 +1,75 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import os
+
+from pipelines.document_stores import FAISSDocumentStore
+from pipelines.nodes import MultiModalRetriever
+from pipelines.pipelines import Pipeline
+from pipelines.schema import Document
+from pipelines.utils import fetch_archive_from_http
+
+# yapf: disable
+parser = argparse.ArgumentParser()
+parser.add_argument('--device', choices=['cpu', 'gpu'], default="gpu", help="Select which device to run dense_qa system, defaults to gpu.")
+parser.add_argument("--index_name", default='wukong_test', type=str, help="The ann index name of ANN.")
+parser.add_argument("--embedding_dim", default=768, type=int, help="The embedding_dim of index")
+parser.add_argument("--query_embedding_model", default="PaddlePaddle/ernie_vil-2.0-base-zh", type=str, help="The query_embedding_model path")
+parser.add_argument("--document_embedding_model", default="PaddlePaddle/ernie_vil-2.0-base-zh", type=str, help="The document_embedding_model path")
+args = parser.parse_args()
+# yapf: enable
+
+
+def image_text_retrieval_tutorial():
+ faiss_document_store = "faiss_document_store.db"
+ if os.path.exists(args.index_name) and os.path.exists(faiss_document_store):
+ # Connect to existed FAISS Index
+ document_store = FAISSDocumentStore.load(args.index_name)
+ retriever_mm = MultiModalRetriever(
+ document_store=document_store,
+ query_embedding_model=args.query_embedding_model,
+ query_type="text",
+ document_embedding_models={"image": args.document_embedding_model},
+ )
+ else:
+ doc_dir = "data/wukong_test"
+ wukong_data = "https://paddlenlp.bj.bcebos.com/applications/wukong_test_demo.zip"
+ fetch_archive_from_http(url=wukong_data, output_dir=doc_dir)
+ if os.path.exists(args.index_name):
+ os.remove(args.index_name)
+ if os.path.exists(faiss_document_store):
+ os.remove(faiss_document_store)
+ document_store = FAISSDocumentStore(embedding_dim=args.embedding_dim, faiss_index_factory_str="Flat")
+ docs = [Document(content=f"./{doc_dir}/{filename}", content_type="image") for filename in os.listdir(doc_dir)]
+ retriever_mm = MultiModalRetriever(
+ document_store=document_store,
+ query_embedding_model=args.query_embedding_model,
+ query_type="text",
+ document_embedding_models={"image": args.document_embedding_model},
+ )
+ # Update metadata
+ document_store.write_documents(docs)
+ # Update Embedding
+ document_store.update_embeddings(retriever_mm)
+ # Save index
+ document_store.save(args.index_name)
+ pipe = Pipeline()
+ pipe.add_node(component=retriever_mm, name="Retriever", inputs=["Query"])
+ result = pipe.run(query="云南普者黑现纯白色⒌蒂莲", params={"Retriever": {"top_k": 5}})
+ print(result)
+
+
+if __name__ == "__main__":
+ image_text_retrieval_tutorial()
diff --git a/pipelines/examples/image_text_retrieval/run_search_server.sh b/pipelines/examples/image_text_retrieval/run_search_server.sh
new file mode 100644
index 000000000000..791d7ab53466
--- /dev/null
+++ b/pipelines/examples/image_text_retrieval/run_search_server.sh
@@ -0,0 +1,20 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+unset http_proxy && unset https_proxy
+# 指定语义检索系统的Yaml配置文件
+export CUDA_VISIBLE_DEVICES=0
+export PIPELINE_YAML_PATH=rest_api/pipeline/image_text_retrieval.yaml
+# 使用端口号 8891 启动模型服务
+python rest_api/application.py 8891
\ No newline at end of file
diff --git a/pipelines/examples/image_text_retrieval/run_search_web.sh b/pipelines/examples/image_text_retrieval/run_search_web.sh
new file mode 100644
index 000000000000..fec15e00d197
--- /dev/null
+++ b/pipelines/examples/image_text_retrieval/run_search_web.sh
@@ -0,0 +1,18 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# 配置模型服务地址
+export API_ENDPOINT=http://127.0.0.1:8891
+# 在指定端口 8502 启动 WebUI
+python ui/webapp_image_text_retrieval.py --serving_port 8502
\ No newline at end of file
diff --git a/pipelines/pipelines/document_stores/base.py b/pipelines/pipelines/document_stores/base.py
index 9dddb21dde1d..ce8537eb4b4b 100644
--- a/pipelines/pipelines/document_stores/base.py
+++ b/pipelines/pipelines/document_stores/base.py
@@ -13,25 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Generator, Optional, Dict, List, Set, Union
-
-import logging
import collections
-import numpy as np
-from itertools import islice
+import logging
from abc import abstractmethod
+from itertools import islice
from pathlib import Path
+from typing import Dict, Generator, List, Optional, Set, Union
-try:
- from typing import Literal
-except ImportError:
- from typing_extensions import Literal # type: ignore
+import numpy as np
-from pipelines.schema import Document, Label
-from pipelines.nodes.base import BaseComponent
+from pipelines.document_stores.utils import (
+ eval_data_from_json,
+ eval_data_from_jsonl,
+ squad_json_to_jsonl,
+)
from pipelines.errors import DuplicateDocumentError
+from pipelines.nodes.base import BaseComponent
from pipelines.nodes.preprocessor import PreProcessor
-from pipelines.document_stores.utils import eval_data_from_json, eval_data_from_jsonl, squad_json_to_jsonl
+from pipelines.schema import Document, FilterType, Label
logger = logging.getLogger(__name__)
@@ -290,6 +289,37 @@ def query_by_embedding(
) -> List[Document]:
pass
+ def query_by_embedding_batch(
+ self,
+ query_embs: Union[List[np.ndarray], np.ndarray],
+ filters: Optional[Union[FilterType, List[Optional[FilterType]]]] = None,
+ top_k: int = 10,
+ index: Optional[str] = None,
+ return_embedding: Optional[bool] = None,
+ headers: Optional[Dict[str, str]] = None,
+ ) -> List[List[Document]]:
+ if isinstance(filters, list):
+ if len(filters) != len(query_embs):
+ raise Exception(
+ "Number of filters does not match number of query_embs. Please provide as many filters"
+ " as query_embs or a single filter that will be applied to each query_emb."
+ )
+ else:
+ filters = [filters] * len(query_embs)
+ results = []
+ for query_emb, filter in zip(query_embs, filters):
+ results.append(
+ self.query_by_embedding(
+ query_emb=query_emb,
+ filters=filter,
+ top_k=top_k,
+ index=index,
+ return_embedding=return_embedding,
+ headers=headers,
+ )
+ )
+ return results
+
@abstractmethod
def get_label_count(self, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None) -> int:
pass
@@ -338,28 +368,28 @@ def add_eval_data(
# TODO improve support for PreProcessor when adding eval data
if preprocessor is not None:
assert preprocessor.split_by != "sentence", (
- f"Split by sentence not supported.\n"
- f"Please set 'split_by' to either 'word' or 'passage' in the supplied PreProcessor."
+ "Split by sentence not supported.\n"
+ "Please set 'split_by' to either 'word' or 'passage' in the supplied PreProcessor."
)
- assert preprocessor.split_respect_sentence_boundary == False, (
- f"split_respect_sentence_boundary not supported yet.\n"
- f"Please set 'split_respect_sentence_boundary' to False in the supplied PreProcessor."
+ assert preprocessor.split_respect_sentence_boundary is False, (
+ "split_respect_sentence_boundary not supported yet.\n"
+ "Please set 'split_respect_sentence_boundary' to False in the supplied PreProcessor."
)
assert preprocessor.split_overlap == 0, (
- f"Overlapping documents are currently not supported when adding eval data.\n"
- f"Please set 'split_overlap=0' in the supplied PreProcessor."
+ "Overlapping documents are currently not supported when adding eval data.\n"
+ "Please set 'split_overlap=0' in the supplied PreProcessor."
)
- assert preprocessor.clean_empty_lines == False, (
- f"clean_empty_lines currently not supported when adding eval data.\n"
- f"Please set 'clean_empty_lines=False' in the supplied PreProcessor."
+ assert preprocessor.clean_empty_lines is False, (
+ "clean_empty_lines currently not supported when adding eval data.\n"
+ "Please set 'clean_empty_lines=False' in the supplied PreProcessor."
)
- assert preprocessor.clean_whitespace == False, (
- f"clean_whitespace is currently not supported when adding eval data.\n"
- f"Please set 'clean_whitespace=False' in the supplied PreProcessor."
+ assert preprocessor.clean_whitespace is False, (
+ "clean_whitespace is currently not supported when adding eval data.\n"
+ "Please set 'clean_whitespace=False' in the supplied PreProcessor."
)
- assert preprocessor.clean_header_footer == False, (
- f"clean_header_footer is currently not supported when adding eval data.\n"
- f"Please set 'clean_header_footer=False' in the supplied PreProcessor."
+ assert preprocessor.clean_header_footer is False, (
+ "clean_header_footer is currently not supported when adding eval data.\n"
+ "Please set 'clean_header_footer=False' in the supplied PreProcessor."
)
file_path = Path(filename)
diff --git a/pipelines/pipelines/nodes/__init__.py b/pipelines/pipelines/nodes/__init__.py
index 505aef8b7b3d..9834232943a1 100644
--- a/pipelines/pipelines/nodes/__init__.py
+++ b/pipelines/pipelines/nodes/__init__.py
@@ -35,7 +35,11 @@
from pipelines.nodes.question_generator import QuestionGenerator
from pipelines.nodes.ranker import BaseRanker, ErnieRanker
from pipelines.nodes.reader import BaseReader, ErnieReader
-from pipelines.nodes.retriever import BaseRetriever, DensePassageRetriever
+from pipelines.nodes.retriever import (
+ BaseRetriever,
+ DensePassageRetriever,
+ MultiModalRetriever,
+)
from pipelines.nodes.sentiment_analysis import (
SentaProcessor,
SentaVisualization,
diff --git a/pipelines/pipelines/nodes/retriever/__init__.py b/pipelines/pipelines/nodes/retriever/__init__.py
index 93dcc59c195f..a8be48ca86f3 100644
--- a/pipelines/pipelines/nodes/retriever/__init__.py
+++ b/pipelines/pipelines/nodes/retriever/__init__.py
@@ -14,4 +14,5 @@
# flake8: noqa
from pipelines.nodes.retriever.base import BaseRetriever
from pipelines.nodes.retriever.dense import DensePassageRetriever
+from pipelines.nodes.retriever.multimodal_retriever import MultiModalRetriever
from pipelines.nodes.retriever.sparse import BM25Retriever
diff --git a/pipelines/pipelines/nodes/retriever/embedder.py b/pipelines/pipelines/nodes/retriever/embedder.py
new file mode 100644
index 000000000000..4cba0f78b0a9
--- /dev/null
+++ b/pipelines/pipelines/nodes/retriever/embedder.py
@@ -0,0 +1,185 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Union
+
+import numpy as np
+import paddle
+from PIL import Image
+from tqdm.auto import tqdm
+
+from paddlenlp import Taskflow
+from pipelines.schema import Document
+
+logger = logging.getLogger(__name__)
+FilterType = Dict[str, Union[Dict[str, Any], List[Any], str, int, float, bool]]
+
+
+DOCUMENT_CONVERTERS = {
+ # NOTE: Keep this '?' cleaning step, it needs to be double-checked for impact on the inference results.
+ "text": lambda doc: doc.content[:-1] if doc.content[-1] == "?" else doc.content,
+ "image": lambda doc: Image.open(doc.content),
+}
+
+CAN_EMBED_META = ["text"]
+
+
+class MultiModalEmbedder:
+ def __init__(
+ self,
+ embedding_models: Dict[str, Union[Path, str]], # replace str with ContentTypes starting from Python3.8
+ feature_extractors_params: Optional[Dict[str, Dict[str, Any]]] = None,
+ batch_size: int = 16,
+ embed_meta_fields: List[str] = ["name"],
+ progress_bar: bool = True,
+ ):
+ """
+ Init the Retriever and all its models from a local or remote model checkpoint.
+ :param embedding_models: A dictionary matching a local path or remote name of encoder checkpoint with
+ the content type it should handle ("text", "image", etc...).
+ Expected input format: `{'text': 'name_or_path_to_text_model', 'image': 'name_or_path_to_image_model', ... }`
+ Keep in mind that the models should output in the same embedding space for this retriever to work.
+ :param feature_extractors_params: A dictionary matching a content type ("text", "image" and so on) with the
+ parameters of its own feature extractor if the model requires one.
+ Expected input format: `{'text': {'param_name': 'param_value', ...}, 'image': {'param_name': 'param_value', ...}, ...}`
+ :param batch_size: Number of questions or passages to encode at once. In case of multiple GPUs, this will be the total batch size.
+ :param embed_meta_fields: Concatenate the provided meta fields and text passage / image to a text pair that is
+ then used to create the embedding.
+ This is the approach used in the original paper and is likely to improve
+ performance if your titles contain meaningful information for retrieval
+ (topic, entities etc.).
+ :param progress_bar: Whether to show a tqdm progress bar or not.
+ Can be helpful to disable in production deployments to keep the logs clean.
+ """
+ super().__init__()
+
+ self.batch_size = batch_size
+ self.progress_bar = progress_bar
+ self.embed_meta_fields = embed_meta_fields
+
+ feature_extractors_params = {
+ content_type: {"max_length": 256, **(feature_extractors_params or {}).get(content_type, {})}
+ for content_type in ["text", "image"] # FIXME get_args(ContentTypes) from Python3.8 on
+ }
+
+ self.models = {} # replace str with ContentTypes starting from Python3.8
+ for content_type, embedding_model in embedding_models.items():
+ if content_type in ["text", "image"]:
+ self.models[content_type] = Taskflow("feature_extraction", model=embedding_model)
+ else:
+ raise ValueError(f"{content_type} is not a supported content.")
+
+ # Check embedding sizes for models: they must all match
+ if len(self.models) > 1:
+ sizes = {model.embedding_dim for model in self.models.values()}
+ if None in sizes:
+ logger.warning(
+ "Pipelines could not find the output embedding dimensions for '%s'. "
+ "Dimensions won't be checked before computing the embeddings.",
+ ", ".join(
+ {
+ str(model.model_name_or_path)
+ for model in self.models.values()
+ if model.embedding_dim is None
+ }
+ ),
+ )
+ elif len(sizes) > 1:
+ embedding_sizes: Dict[int, List[str]] = {}
+ for model in self.models.values():
+ embedding_sizes[model.embedding_dim] = embedding_sizes.get(model.embedding_dim, []) + [
+ str(model.model_name_or_path)
+ ]
+ raise ValueError(f"Not all models have the same embedding size: {embedding_sizes}")
+
+ def embed(self, documents: List[Document], batch_size: Optional[int] = None) -> np.ndarray:
+ """
+ Create embeddings for a list of documents using the relevant encoder for their content type.
+ :param documents: Documents to embed.
+ :return: Embeddings, one per document, in the form of a np.array
+ """
+ batch_size = batch_size if batch_size is not None else self.batch_size
+
+ all_embeddings = []
+ for batch_index in tqdm(
+ iterable=range(0, len(documents), batch_size),
+ unit=" Docs",
+ desc="Create embeddings",
+ position=1,
+ leave=False,
+ disable=not self.progress_bar,
+ ):
+ docs_batch = documents[batch_index : batch_index + batch_size]
+ data_by_type = self._docs_to_data(documents=docs_batch)
+
+ # Get output for each model
+ outputs_by_type: Dict[str, paddle.Tensor] = {} # replace str with ContentTypes starting Python3.8
+ for data_type, data in data_by_type.items():
+
+ model = self.models.get(data_type)
+ if not model:
+ raise Exception(
+ f"Some data of type {data_type} was passed, but no model capable of handling such data was "
+ f"initialized. Initialized models: {', '.join(self.models.keys())}"
+ )
+ outputs_by_type[data_type] = model(data)["features"]
+ # Check the output sizes
+ embedding_sizes = [output.shape[-1] for output in outputs_by_type.values()]
+
+ if not all(embedding_size == embedding_sizes[0] for embedding_size in embedding_sizes):
+ raise Exception(
+ "Some of the models are using a different embedding size. They should all match. "
+ f"Embedding sizes by model: "
+ f"{ {name: output.shape[-1] for name, output in outputs_by_type.items()} }"
+ )
+
+ # Combine the outputs in a single matrix
+ outputs = paddle.stack(list(outputs_by_type.values()))
+ embeddings = outputs.reshape([-1, embedding_sizes[0]])
+ embeddings = embeddings.cpu()
+ all_embeddings.append(embeddings)
+ return np.concatenate(all_embeddings)
+
+ def _docs_to_data(
+ self, documents: List[Document]
+ ) -> Dict[str, List[Any]]: # FIXME replace str to ContentTypes from Python3.8
+ """
+ Extract the data to embed from each document and return them classified by content type.
+ :param documents: The documents to prepare fur multimodal embedding.
+ :return: A dictionary containing one key for each content type, and a list of data extracted
+ from each document, ready to be passed to the feature extractor (for example the content
+ of a text document, a linearized table, a PIL image object, and so on)
+ """
+ docs_data: Dict[str, List[Any]] = { # FIXME replace str to ContentTypes from Python3.8
+ key: [] for key in ["text", "image"]
+ } # FIXME get_args(ContentTypes) from Python3.8 on
+ for doc in documents:
+ try:
+ document_converter = DOCUMENT_CONVERTERS[doc.content_type]
+ except KeyError:
+ raise Exception(
+ f"Unknown content type '{doc.content_type}'. Known types: 'text', 'table', 'image'." # FIXME {', '.join(get_args(ContentTypes))}" from Python3.8 on
+ )
+
+ data = document_converter(doc)
+
+ if doc.content_type in CAN_EMBED_META:
+ meta = [v for k, v in (doc.meta or {}).items() if k in self.embed_meta_fields]
+ data = f"{' '.join(meta)} {data}" if meta else data
+
+ docs_data[doc.content_type].append(data)
+
+ return {key: values for key, values in docs_data.items() if values}
diff --git a/pipelines/pipelines/nodes/retriever/multimodal_retriever.py b/pipelines/pipelines/nodes/retriever/multimodal_retriever.py
new file mode 100644
index 000000000000..f25cef2cff8a
--- /dev/null
+++ b/pipelines/pipelines/nodes/retriever/multimodal_retriever.py
@@ -0,0 +1,197 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Union
+
+import numpy as np
+
+from pipelines.document_stores import BaseDocumentStore
+from pipelines.nodes.retriever.base import BaseRetriever
+from pipelines.nodes.retriever.embedder import MultiModalEmbedder
+from pipelines.schema import ContentTypes, Document, FilterType
+
+logger = logging.getLogger(__name__)
+
+
+class MultiModalRetriever(BaseRetriever):
+ def __init__(
+ self,
+ document_store: BaseDocumentStore,
+ query_embedding_model: Union[Path, str],
+ document_embedding_models: Dict[str, Union[Path, str]], # Replace str with ContentTypes starting Python3.8
+ query_type: str = "text", # Replace str with ContentTypes starting Python3.8
+ query_feature_extractor_params: Dict[str, Any] = {"max_length": 64},
+ document_feature_extractors_params: Dict[str, Dict[str, Any]] = {"text": {"max_length": 256}},
+ top_k: int = 10,
+ batch_size: int = 16,
+ embed_meta_fields: List[str] = ["name"],
+ similarity_function: str = "dot_product",
+ progress_bar: bool = True,
+ scale_score: bool = True,
+ ):
+ """
+ Retriever that uses a multiple encoder to jointly retrieve among a database consisting of different
+ data types.
+ :param document_store: An instance of DocumentStore from which to retrieve documents.
+ :param query_embedding_model: Local path or remote name of question encoder checkpoint. The format equals the
+ one used by Hugging Face transformers' modelhub models.
+ :param document_embedding_models: Dictionary matching a local path or remote name of document encoder
+ checkpoint with the content type it should handle ("text", "table", "image", and so on).
+ The format equals the one used by Hugging Face transformers' modelhub models.
+ :param query_type: The content type of the query ("text", "image" and so on).
+ :param query_feature_extraction_params: The parameters to pass to the feature extractor of the query.
+ :param document_feature_extraction_params: The parameters to pass to the feature extractor of the documents.
+ :param top_k: How many documents to return per query.
+ :param batch_size: Number of questions or documents to encode at once. For multiple GPUs, this is
+ the total batch size.
+ :param embed_meta_fields: Concatenate the provided meta fields to a (text) pair that is then used to create
+ the embedding. This is likely to improve performance if your titles contain meaningful information
+ for retrieval (topic, entities, and so on). Note that only text and table documents support this feature.
+ :param similarity_function: Which function to apply for calculating the similarity of query and document
+ embeddings during training. Options: `dot_product` (default) or `cosine`.
+ :param progress_bar: Whether to show a tqdm progress bar or not.
+ Can be helpful to disable in production deployments to keep the logs clean.
+ :param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]).
+ If true (default) similarity scores (e.g. cosine or dot_product) which naturally have a different value
+ range are scaled to a range of [0,1], where 1 means extremely relevant.
+ Otherwise raw similarity scores (for example, cosine or dot_product) are used.
+ """
+ super().__init__()
+
+ self.similarity_function = similarity_function
+ self.progress_bar = progress_bar
+ self.top_k = top_k
+ self.scale_score = scale_score
+
+ self.document_embedder = MultiModalEmbedder(
+ embedding_models=document_embedding_models,
+ feature_extractors_params=document_feature_extractors_params,
+ batch_size=batch_size,
+ embed_meta_fields=embed_meta_fields,
+ progress_bar=progress_bar,
+ )
+
+ # # Try to reuse the same embedder for queries if there is overlap
+ if document_embedding_models.get(query_type, None) == query_embedding_model:
+ self.query_embedder = self.document_embedder
+ else:
+ self.query_embedder = MultiModalEmbedder(
+ embedding_models={query_type: query_embedding_model},
+ feature_extractors_params={query_type: query_feature_extractor_params},
+ batch_size=batch_size,
+ embed_meta_fields=embed_meta_fields,
+ progress_bar=progress_bar,
+ )
+
+ self.document_store = document_store
+
+ def retrieve( # type: ignore
+ self,
+ query: Any,
+ query_type: ContentTypes = "text",
+ filters: Optional[FilterType] = None,
+ top_k: Optional[int] = None,
+ index: Optional[str] = None,
+ headers: Optional[Dict[str, str]] = None,
+ scale_score: Optional[bool] = None,
+ document_store: Optional[BaseDocumentStore] = None,
+ ) -> List[Document]:
+ """
+ Scan through documents in DocumentStore and return a small number of documents that are most relevant to the
+ supplied query. Returns a list of Documents.
+ :param query: Query value. It might be text, a path, a table, and so on.
+ :param query_type: Type of the query ("text", "table", "image" and so on).
+ :param filters: Optional filters to narrow down the search space to documents whose metadata fulfill certain
+ conditions. It can be a single filter applied to each query or a list of filters
+ (one filter per query).
+ :param top_k: How many documents to return per query. Must be > 0.
+ :param index: The name of the index in the DocumentStore from which to retrieve documents.
+ :param batch_size: Number of queries to embed at a time. Must be > 0.
+ :param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]).
+ If true, similarity scores (for example, cosine or dot_product) which naturally have a different
+ value range is scaled to a range of [0,1], where 1 means extremely relevant.
+ Otherwise raw similarity scores (for example, cosine or dot_product) are used.
+ """
+ return self.retrieve_batch(
+ queries=[query],
+ queries_type=query_type,
+ filters=[filters],
+ top_k=top_k,
+ index=index,
+ headers=headers,
+ batch_size=1,
+ scale_score=scale_score,
+ document_store=document_store,
+ )[0]
+
+ def retrieve_batch( # type: ignore
+ self,
+ queries: List[Any],
+ queries_type: ContentTypes = "text",
+ filters: Optional[Union[FilterType, List[Optional[FilterType]]]] = None,
+ top_k: Optional[int] = None,
+ index: Optional[str] = None,
+ headers: Optional[Dict[str, str]] = None,
+ batch_size: Optional[int] = None,
+ scale_score: Optional[bool] = None,
+ document_store: Optional[BaseDocumentStore] = None,
+ ) -> List[List[Document]]:
+ """
+ Scan through documents in DocumentStore and return a small number of documents that are most relevant to the
+ supplied queries. Returns a list of lists of Documents (one list per query).
+ This method assumes all queries are of the same data type. Mixed-type query batches (for example one image and one text)
+ are currently not supported. Group the queries by type and call `retrieve()` on uniform batches only.
+ :param queries: List of query values. They might be text, paths, tables, and so on.
+ :param queries_type: Type of the query ("text", "table", "image" and so on)
+ :param filters: Optional filters to narrow down the search space to documents whose metadata fulfill certain
+ conditions. It can be a single filter that will be applied to each query or a list of filters
+ (one filter per query).
+ :param top_k: How many documents to return per query. Must be > 0.
+ :param index: The name of the index in the DocumentStore from which to retrieve documents.
+ :param batch_size: Number of queries to embed at a time. Must be > 0.
+ :param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]).
+ If True, similarity scores (for example, cosine or dot_product) which naturally have a different
+ value range are scaled to a range of [0,1], where 1 means extremely relevant.
+ Otherwise raw similarity scores (for example, cosine or dot_product) are used.
+ """
+ top_k = top_k or self.top_k
+ document_store = document_store or self.document_store
+ if not document_store:
+ raise ValueError(
+ "This Retriever was not initialized with a Document Store. Provide one to the retrieve() or retrieve_batch() method."
+ )
+ index = index or document_store.index
+ scale_score = scale_score or self.scale_score
+
+ # Embed the queries - we need them into Document format to leverage MultiModalEmbedder.embed()
+ query_docs = [Document(content=query, content_type=queries_type) for query in queries]
+ query_embeddings = self.query_embedder.embed(documents=query_docs, batch_size=batch_size)
+ # Query documents by embedding (the actual retrieval step)
+ documents = document_store.query_by_embedding_batch(
+ query_embs=query_embeddings,
+ top_k=top_k,
+ filters=filters,
+ index=index,
+ headers=headers,
+ )
+ return documents
+
+ def embed_documents(self, docs: List[Document]) -> np.ndarray:
+ return self.document_embedder.embed(documents=docs)
+
+ def embed_queries(self, queries: List[str]) -> np.ndarray:
+ query_documents = [Document(content=query, content_type="text") for query in queries]
+ return self.query_embedder.embed(documents=query_documents)
diff --git a/pipelines/pipelines/schema.py b/pipelines/pipelines/schema.py
index 4cf33be8312e..08c24e6400f4 100644
--- a/pipelines/pipelines/schema.py
+++ b/pipelines/pipelines/schema.py
@@ -14,39 +14,45 @@
# limitations under the License.
from __future__ import annotations
+
import typing
-from typing import Any, Optional, Dict, List, Union, Optional
from dataclasses import asdict
+from typing import Any, Dict, List, Optional, Union
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal # type: ignore
-# We are using Pydantic dataclasses instead of vanilla Python's
-# See #1598 for the reasons behind this choice & performance considerations
-from pydantic.dataclasses import dataclass
if typing.TYPE_CHECKING:
from dataclasses import dataclass # type: ignore
+else:
+ # We are using Pydantic dataclasses instead of vanilla Python's
+ # See #1598 for the reasons behind this choice & performance considerations
+ from pydantic.dataclasses import dataclass
-from pydantic.json import pydantic_encoder
+import ast
+import json
+import logging
+import time
from pathlib import Path
from uuid import uuid4
+
import mmh3
import numpy as np
-import logging
-import time
-import json
import pandas as pd
-import ast
+from pydantic import BaseConfig
+from pydantic.json import pydantic_encoder
logger = logging.getLogger(__name__)
-from pydantic import BaseConfig
-
BaseConfig.arbitrary_types_allowed = True
+#: Types of content_types supported
+ContentTypes = Literal["text", "image"]
+FilterType = Dict[str, Union[Dict[str, Any], List[Any], str, int, float, bool]]
+
@dataclass
class Document:
@@ -101,7 +107,7 @@ def __init__(
"""
if content is None:
- raise ValueError(f"Can't create 'Document': Mandatory 'content' field is None")
+ raise ValueError("Can't create 'Document': Mandatory 'content' field is None")
self.content = content
self.content_type = content_type
@@ -143,7 +149,7 @@ def _get_id(self, id_hash_keys: Optional[List[str]] = None):
if final_hash_key == "":
raise ValueError(
- f"Cant't create 'Document': 'id_hash_keys' must contain at least one of ['content', 'meta']"
+ "Cant't create 'Document': 'id_hash_keys' must contain at least one of ['content', 'meta']"
)
return "{:02x}".format(mmh3.hash128(final_hash_key, signed=False))
@@ -253,10 +259,10 @@ class Span:
start: int
end: int
"""
- Defining a sequence of characters (Text span) or cells (Table span) via start and end index.
- For extractive QA: Character where answer starts/ends
+ Defining a sequence of characters (Text span) or cells (Table span) via start and end index.
+ For extractive QA: Character where answer starts/ends
For TableQA: Cell where the answer starts/ends (counted from top left to bottom right of table)
-
+
:param start: Position where the span starts
:param end: Position where the spand ends
"""
@@ -277,24 +283,24 @@ class Answer:
For example, it's used within some Nodes like the Reader, but also in the REST API.
:param answer: The answer string. If there's no possible answer (aka "no_answer" or "is_impossible) this will be an empty string.
- :param type: One of ("generative", "extractive", "other"): Whether this answer comes from an extractive model
- (i.e. we can locate an exact answer string in one of the documents) or from a generative model
- (i.e. no pointer to a specific document, no offsets ...).
+ :param type: One of ("generative", "extractive", "other"): Whether this answer comes from an extractive model
+ (i.e. we can locate an exact answer string in one of the documents) or from a generative model
+ (i.e. no pointer to a specific document, no offsets ...).
:param score: The relevance score of the Answer determined by a model (e.g. Reader or Generator).
In the range of [0,1], where 1 means extremely relevant.
:param context: The related content that was used to create the answer (i.e. a text passage, part of a table, image ...)
:param offsets_in_document: List of `Span` objects with start and end positions of the answer **in the
document** (as stored in the document store).
- For extractive QA: Character where answer starts => `Answer.offsets_in_document[0].start
+ For extractive QA: Character where answer starts => `Answer.offsets_in_document[0].start
For TableQA: Cell where the answer starts (counted from top left to bottom right of table) => `Answer.offsets_in_document[0].start
- (Note that in TableQA there can be multiple cell ranges that are relevant for the answer, thus there can be multiple `Spans` here)
+ (Note that in TableQA there can be multiple cell ranges that are relevant for the answer, thus there can be multiple `Spans` here)
:param offsets_in_context: List of `Span` objects with start and end positions of the answer **in the
context** (i.e. the surrounding text/table of a certain window size).
- For extractive QA: Character where answer starts => `Answer.offsets_in_document[0].start
+ For extractive QA: Character where answer starts => `Answer.offsets_in_document[0].start
For TableQA: Cell where the answer starts (counted from top left to bottom right of table) => `Answer.offsets_in_document[0].start
- (Note that in TableQA there can be multiple cell ranges that are relevant for the answer, thus there can be multiple `Spans` here)
+ (Note that in TableQA there can be multiple cell ranges that are relevant for the answer, thus there can be multiple `Spans` here)
:param document_id: ID of the document that the answer was located it (if any)
- :param meta: Dict that can be used to associate any kind of custom meta data with the answer.
+ :param meta: Dict that can be used to associate any kind of custom meta data with the answer.
In extractive QA, this will carry the meta data of the document where the answer was found.
"""
@@ -423,12 +429,12 @@ def __init__(
# If an Answer is provided we need to make sure that it's consistent with the `no_answer` value
# TODO: reassess if we want to enforce Span.start=0 and Span.end=0 for no_answer=True
if self.answer is not None:
- if no_answer == True:
+ if no_answer is True:
if self.answer.answer != "" or self.answer.context:
raise ValueError(
f"Got no_answer == True while there seems to be an possible Answer: {self.answer}"
)
- elif no_answer == False:
+ elif no_answer is False:
if self.answer.answer == "":
raise ValueError(
f"Got no_answer == False while there seems to be no possible Answer: {self.answer}"
@@ -524,13 +530,13 @@ def __init__(self, labels: List[Label], drop_negative_labels=False, drop_no_answ
# drop duplicate labels and remove negative labels if needed.
labels = list(set(labels))
if drop_negative_labels:
- is_positive_label = lambda l: (l.is_correct_answer and l.is_correct_document) or (
+ is_positive_label = lambda l: (l.is_correct_answer and l.is_correct_document) or ( # noqa: E731
l.answer is None and l.is_correct_document
)
labels = [l for l in labels if is_positive_label(l)]
if drop_no_answers:
- labels = [l for l in labels if l.no_answer == False]
+ labels = [l for l in labels if l.no_answer is False]
self.labels = labels
diff --git a/pipelines/rest_api/pipeline/image_text_retrieval.yaml b/pipelines/rest_api/pipeline/image_text_retrieval.yaml
new file mode 100644
index 000000000000..29fb716d2129
--- /dev/null
+++ b/pipelines/rest_api/pipeline/image_text_retrieval.yaml
@@ -0,0 +1,25 @@
+version: '1.1.0'
+
+components: # define all the building-blocks for Pipeline
+ - name: DocumentStore
+ type: ElasticsearchDocumentStore # consider using Milvus2DocumentStore or WeaviateDocumentStore for scaling to large number of documents
+ params:
+ host: localhost
+ port: 9200
+ index: wukong
+ embedding_dim: 768
+ - name: Retriever
+ type: MultiModalRetriever
+ params:
+ document_store: DocumentStore # params can reference other components defined in the YAML
+ top_k: 10
+ query_embedding_model: PaddlePaddle/ernie_vil-2.0-base-zh
+ document_embedding_models:
+ image: PaddlePaddle/ernie_vil-2.0-base-zh
+
+pipelines:
+ - name: query
+ type: Query
+ nodes:
+ - name: Retriever
+ inputs: [Query]
\ No newline at end of file
diff --git a/pipelines/ui/utils.py b/pipelines/ui/utils.py
index 2a2b37237dfe..53bdfabd9f8a 100644
--- a/pipelines/ui/utils.py
+++ b/pipelines/ui/utils.py
@@ -230,6 +230,37 @@ def text_to_image_search(
return results, response
+def image_text_search(query, filters={}, top_k_retriever=5) -> Tuple[List[Dict[str, Any]], Dict[str, str]]:
+ """
+ Send a query to the REST API and parse the answer.
+ Returns both a ready-to-use representation of the results and the raw JSON.
+ """
+
+ url = f"{API_ENDPOINT}/{DOC_REQUEST}"
+ params = {"filters": filters, "Retriever": {"top_k": top_k_retriever}}
+ req = {"query": query, "params": params}
+ response_raw = requests.post(url, json=req)
+
+ if response_raw.status_code >= 400 and response_raw.status_code != 503:
+ raise Exception(f"{vars(response_raw)}")
+
+ response = response_raw.json()
+ if "errors" in response:
+ raise Exception(", ".join(response["errors"]))
+
+ # Format response
+ results = []
+ answers = response["documents"]
+ for answer in answers:
+ results.append(
+ {
+ "context": answer["content"],
+ "relevance": round(answer["meta"]["es_ann_score"] * 100, 2),
+ }
+ )
+ return results, response
+
+
def text_to_qa_pair_search(query, is_filter=True) -> Tuple[List[Dict[str, Any]], Dict[str, str]]:
"""
Send a prompt text and corresponding parameters to the REST API
diff --git a/pipelines/ui/webapp_image_text_retrieval.py b/pipelines/ui/webapp_image_text_retrieval.py
new file mode 100644
index 000000000000..995bb4ff7149
--- /dev/null
+++ b/pipelines/ui/webapp_image_text_retrieval.py
@@ -0,0 +1,58 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+
+import gradio as gr
+from utils import image_text_search
+
+# yapf: disable
+parser = argparse.ArgumentParser()
+parser.add_argument("--serving_port", default=8502, type=int, help="Port for the serving.")
+args = parser.parse_args()
+# yapf: enable
+
+
+def infer(query, top_k_retriever):
+ results, response = image_text_search(query, top_k_retriever=top_k_retriever)
+ images = [item["context"] for item in results]
+ return images
+
+
+def main():
+ block = gr.Blocks()
+ title = "
文图跨模态搜索应用
"
+ description = "本项目为ERNIE-ViL 2.0等CLIP中文版模型的DEMO,可用于图文检索和图像、文本的表征提取,应用于文图搜索、文图推荐、零样本分类、视频检索等应用场景。"
+
+ with block:
+ gr.Markdown(title)
+ gr.Markdown(description)
+ with gr.Row():
+ with gr.Column(scale=1):
+ with gr.Column(scale=2):
+ text = gr.Textbox(value="云南普者黑现纯白色⒌蒂莲", label="请填写文本", elem_id=0, interactive=True)
+ top_k = gr.components.Slider(minimum=0, maximum=50, step=1, value=8, label="返回图片数", elem_id=2)
+ btn = gr.Button(
+ "搜索",
+ )
+ with gr.Column(scale=100):
+ out = gr.Gallery(label="检索结果为:").style(grid=4, height=200)
+ inputs = [text, top_k]
+ btn.click(fn=infer, inputs=inputs, outputs=out)
+ return block
+
+
+if __name__ == "__main__":
+ block = main()
+ block.launch(server_name="0.0.0.0", server_port=args.serving_port, share=False)
diff --git a/pipelines/utils/offline_ann_mm.py b/pipelines/utils/offline_ann_mm.py
new file mode 100644
index 000000000000..01a41f92cce6
--- /dev/null
+++ b/pipelines/utils/offline_ann_mm.py
@@ -0,0 +1,112 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import os
+
+from pipelines.document_stores import ElasticsearchDocumentStore, MilvusDocumentStore
+from pipelines.nodes import MultiModalRetriever
+from pipelines.schema import Document
+from pipelines.utils import fetch_archive_from_http, launch_es
+
+data_dict = {
+ "data/wukong_test": "https://paddlenlp.bj.bcebos.com/applications/wukong_test_demo.zip",
+}
+
+# yapf: disable
+parser = argparse.ArgumentParser()
+parser.add_argument("--index_name", default="wukong_test", type=str, help="The index name of the ANN search engine")
+parser.add_argument("--doc_dir", default="data/wukong_test", type=str, help="The doc path of the corpus")
+parser.add_argument("--search_engine", choices=["elastic", "milvus"], default="elastic", help="The type of ANN search engine.")
+parser.add_argument("--host", type=str, default="127.0.0.1", help="host ip of ANN search engine")
+parser.add_argument("--port", type=str, default="9200", help="port of ANN search engine")
+parser.add_argument("--embedding_dim", default=768, type=int, help="The embedding_dim of index")
+parser.add_argument("--query_embedding_model", default="PaddlePaddle/ernie_vil-2.0-base-zh", type=str, help="The query_embedding_model path")
+parser.add_argument("--document_embedding_model", default="PaddlePaddle/ernie_vil-2.0-base-zh", type=str, help="The document_embedding_model path")
+parser.add_argument("--delete_index", action="store_true", help="Whether to delete existing index while updating index")
+args = parser.parse_args()
+# yapf: enable
+
+
+def offline_ann(index_name, doc_dir):
+
+ if args.search_engine == "milvus":
+ document_store = MilvusDocumentStore(
+ embedding_dim=args.embedding_dim,
+ host=args.host,
+ index=args.index_name,
+ port=args.port,
+ index_param={"M": 16, "efConstruction": 50},
+ index_type="HNSW",
+ )
+ else:
+ launch_es()
+ document_store = ElasticsearchDocumentStore(
+ host=args.host,
+ port=args.port,
+ username="",
+ password="",
+ embedding_dim=args.embedding_dim,
+ index=index_name,
+ )
+ docs = [
+ Document(content=f"./{args.doc_dir}/{filename}", content_type="image") for filename in os.listdir(args.doc_dir)
+ ]
+
+ print(docs[:3])
+
+ # 文档数据写入数据库
+ document_store.write_documents(docs)
+
+ # 语义索引模型
+ retriever_mm = MultiModalRetriever(
+ document_store=document_store,
+ query_embedding_model=args.query_embedding_model,
+ query_type="text",
+ document_embedding_models={"image": args.document_embedding_model},
+ )
+
+ # 建立索引库
+ document_store.update_embeddings(retriever_mm)
+
+
+def delete_data(index_name):
+ if args.search_engine == "milvus":
+ document_store = MilvusDocumentStore(
+ embedding_dim=args.embedding_dim,
+ host=args.host,
+ index=args.index_name,
+ port=args.port,
+ index_param={"M": 16, "efConstruction": 50},
+ index_type="HNSW",
+ )
+ else:
+ document_store = ElasticsearchDocumentStore(
+ host=args.host,
+ port=args.port,
+ username="",
+ password="",
+ embedding_dim=args.embedding_dim,
+ index=index_name,
+ )
+ document_store.delete_index(index_name)
+ print("Delete an existing elasticsearch index {} Done.".format(index_name))
+
+
+if __name__ == "__main__":
+ if args.doc_dir in data_dict:
+ fetch_archive_from_http(url=data_dict[args.doc_dir], output_dir=args.doc_dir)
+ if args.delete_index:
+ delete_data(args.index_name)
+ offline_ann(args.index_name, args.doc_dir)
diff --git a/tests/taskflow/test_feature_extraction.py b/tests/taskflow/test_feature_extraction.py
new file mode 100644
index 000000000000..bb15d7fb4ff2
--- /dev/null
+++ b/tests/taskflow/test_feature_extraction.py
@@ -0,0 +1,156 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+from tempfile import TemporaryDirectory
+
+import numpy as np
+import paddle
+from PIL import Image
+
+from paddlenlp.taskflow import Taskflow
+from paddlenlp.taskflow.feature_extraction import MultimodalFeatureExtractionTask
+
+
+class TestMultimodalFeatureExtractionTask(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls.temp_dir = TemporaryDirectory()
+ cls.batch_size = 2
+ cls.max_resolution = 40
+ cls.min_resolution = 30
+ cls.num_channels = 3
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.temp_dir.cleanup()
+
+ def test_small_model_pd(self):
+ feature_extractor = Taskflow(task="feature_extraction")
+ outputs = feature_extractor("This is a test")
+ self.assertEqual(outputs["features"].shape, [1, 768])
+
+ def test_return_tensors_pd(self):
+ feature_extractor = Taskflow(task="feature_extraction", return_tensors=True)
+ outputs = feature_extractor(
+ "This is a test",
+ )
+ self.assertTrue(paddle.is_tensor(outputs["features"]))
+
+ def prepare_inputs(self, equal_resolution=False, numpify=False, paddleify=False):
+ """This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
+ or a list of PaddlePaddle tensors if one specifies paddleify=True.
+ """
+
+ assert not (numpify and paddleify), "You cannot specify both numpy and PaddlePaddle tensors at the same time"
+
+ if equal_resolution:
+ image_inputs = []
+ for i in range(self.batch_size):
+ image_inputs.append(
+ np.random.randint(
+ 255, size=(self.num_channels, self.max_resolution, self.max_resolution), dtype=np.uint8
+ )
+ )
+ else:
+ image_inputs = []
+ for i in range(self.batch_size):
+ width, height = np.random.choice(np.arange(self.min_resolution, self.max_resolution), 2)
+ image_inputs.append(np.random.randint(255, size=(self.num_channels, width, height), dtype=np.uint8))
+
+ if not numpify and not paddleify:
+ # PIL expects the channel dimension as last dimension
+ image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
+
+ if paddleify:
+ image_inputs = [paddle.to_tensor(x) for x in image_inputs]
+
+ return image_inputs
+
+ def test_feature_extraction_task(self):
+ input_text = (["这是一只猫", "这是一只狗"],)
+ # dygraph text test
+ dygraph_taskflow = MultimodalFeatureExtractionTask(
+ model="PaddlePaddle/ernie_vil-2.0-base-zh",
+ task="feature_extraction",
+ is_static_model=False,
+ return_tensors=False,
+ )
+ dygraph_results = dygraph_taskflow(input_text)
+ shape = dygraph_results["features"].shape
+ self.assertEqual(shape[0], 2)
+ # static text test
+ static_taskflow = MultimodalFeatureExtractionTask(
+ model="PaddlePaddle/ernie_vil-2.0-base-zh",
+ task="feature_extraction",
+ is_static_model=True,
+ return_tensors=False,
+ device_id=0,
+ )
+ static_results = static_taskflow(input_text)
+ self.assertEqual(static_results["features"].shape[0], 2)
+
+ for dygraph_result, static_result in zip(dygraph_results["features"], static_results["features"]):
+ for dygraph_pred, static_pred in zip(dygraph_result.tolist(), static_result.tolist()):
+ self.assertAlmostEqual(dygraph_pred, static_pred, delta=1e-5)
+
+ input_image = (self.prepare_inputs(equal_resolution=True, paddleify=False),)
+ # dygraph image test
+ dygraph_results = dygraph_taskflow(input_image)
+ self.assertEqual(dygraph_results["features"].shape[0], 2)
+
+ # static image test
+ static_results = static_taskflow(input_image)
+ self.assertEqual(static_results["features"].shape[0], 2)
+
+ for dygraph_result, static_result in zip(dygraph_results["features"], static_results["features"]):
+ for dygraph_pred, static_pred in zip(dygraph_result.tolist(), static_result.tolist()):
+ self.assertAlmostEqual(dygraph_pred, static_pred, delta=1e-5)
+
+ def test_taskflow_task(self):
+ input_text = ["这是一只猫", "这是一只狗"]
+ # dygraph test
+ dygraph_taskflow = Taskflow(
+ task="feature_extraction",
+ is_static_model=False,
+ return_tensors=False,
+ )
+ dygraph_results = dygraph_taskflow(input_text)
+ shape = dygraph_results["features"].shape
+
+ self.assertEqual(shape[0], 2)
+ # static test
+ static_taskflow = Taskflow(
+ task="feature_extraction",
+ is_static_model=True,
+ return_tensors=False,
+ )
+ static_results = static_taskflow(input_text)
+ self.assertEqual(static_results["features"].shape[0], 2)
+ for dygraph_result, static_result in zip(dygraph_results["features"], static_results["features"]):
+ for dygraph_pred, static_pred in zip(dygraph_result.tolist(), static_result.tolist()):
+ self.assertAlmostEqual(dygraph_pred, static_pred, delta=1e-5)
+
+ input_image = self.prepare_inputs(equal_resolution=True, paddleify=False)
+ # dygraph image test
+ dygraph_results = dygraph_taskflow(input_image)
+ self.assertEqual(dygraph_results["features"].shape[0], 2)
+
+ # static image test
+ static_results = static_taskflow(input_image)
+ self.assertEqual(static_results["features"].shape[0], 2)
+
+ for dygraph_result, static_result in zip(dygraph_results["features"], static_results["features"]):
+ for dygraph_pred, static_pred in zip(dygraph_result.tolist(), static_result.tolist()):
+ self.assertAlmostEqual(dygraph_pred, static_pred, delta=1e-5)
diff --git a/tests/taskflow/test_text_classification.py b/tests/taskflow/test_text_classification.py
index 2fa193138998..9afa1c0c4627 100644
--- a/tests/taskflow/test_text_classification.py
+++ b/tests/taskflow/test_text_classification.py
@@ -140,7 +140,7 @@ def test_classification_task(self, batch_size, problem_type, model):
for dygraph_result, static_result in zip(dygraph_results, static_results):
for dygraph_pred, static_pred in zip(dygraph_result["predictions"], static_result["predictions"]):
self.assertEqual(dygraph_pred["label"], static_pred["label"])
- self.assertAlmostEqual(dygraph_pred["score"], static_pred["score"], delta=1e-6)
+ self.assertAlmostEqual(dygraph_pred["score"], static_pred["score"], delta=1e-5)
# if multi_label, all predictions should be greater than the threshold
if model == "multi_label":
self.assertGreater(dygraph_pred["score"], dygraph_taskflow.multilabel_threshold)
@@ -197,7 +197,7 @@ def test_taskflow_task(self, batch_size, problem_type, model):
for dygraph_result, static_result in zip(dygraph_results, static_results):
for dygraph_pred, static_pred in zip(dygraph_result["predictions"], static_result["predictions"]):
self.assertEqual(dygraph_pred["label"], static_pred["label"])
- self.assertAlmostEqual(dygraph_pred["score"], static_pred["score"], delta=1e-6)
+ self.assertAlmostEqual(dygraph_pred["score"], static_pred["score"], delta=1e-5)
# if multi_label, all predictions should be greater than the threshold
if model == "multi_label":
self.assertGreater(dygraph_pred["score"], dygraph_taskflow.task_instance.multilabel_threshold)