From e29e4a506a4e440b56efffdbf28a2e8ae9873c0c Mon Sep 17 00:00:00 2001 From: w5688414 Date: Wed, 18 Jan 2023 20:59:43 +0800 Subject: [PATCH 01/18] Add vision language taskflow API --- paddlenlp/taskflow/taskflow.py | 10 ++ .../taskflow/vision_language_embedding.py | 116 ++++++++++++++++++ 2 files changed, 126 insertions(+) create mode 100644 paddlenlp/taskflow/vision_language_embedding.py diff --git a/paddlenlp/taskflow/taskflow.py b/paddlenlp/taskflow/taskflow.py index 10c83444d744..e4d2c410116a 100644 --- a/paddlenlp/taskflow/taskflow.py +++ b/paddlenlp/taskflow/taskflow.py @@ -42,6 +42,7 @@ TextToImageGenerationTask, TextToImageStableDiffusionTask, ) +from .vision_language_embedding import VisionLanguageTask from .word_segmentation import SegJiebaTask, SegLACTask, SegWordTagTask from .zero_shot_text_classification import ZeroShotTextClassificationTask @@ -486,6 +487,15 @@ }, "default": {"model": "utc-large"}, }, + "vision_language": { + "models": { + "PaddlePaddle/ernie_vil-2.0-base-zh": { + "task_class": VisionLanguageTask, + "task_flag": "vision_language_embeddings-2.0-base-zh", + }, + }, + "default": {"model": "PaddlePaddle/ernie_vil-2.0-base-zh"}, + }, } support_schema_list = [ diff --git a/paddlenlp/taskflow/vision_language_embedding.py b/paddlenlp/taskflow/vision_language_embedding.py new file mode 100644 index 000000000000..97f74a352ec4 --- /dev/null +++ b/paddlenlp/taskflow/vision_language_embedding.py @@ -0,0 +1,116 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from PIL import Image + +from ..transformers import ErnieViLModel, ErnieViLProcessor +from .task import Task + + +class VisionLanguageTask(Task): + """ + The text_to_image generation model to generate the image. + Args: + task(string): The name of task. + model(string): The model name in the task. + kwargs (dict, optional): Additional keyword arguments passed along to the specific task. + """ + + def __init__(self, task, model, **kwargs): + super().__init__(task=task, model=model, **kwargs) + self._seed = None + # we do not use batch + self._batch_size = 1 + self._construct_tokenizer(image_model=model, text_model="ernie_vil-2.0-base-zh") + self._construct_model(model) + + def _construct_model(self, model): + """ + Construct the inference model for the predictor. + """ + self._model = ErnieViLModel.from_pretrained(model) + self._model.eval() + + def _construct_tokenizer(self, image_model, text_model): + """ + Construct the tokenizer for the predictor. + """ + self._processor = ErnieViLProcessor.from_pretrained(image_model) + + def _batchify(self, data, batch_size): + """ + Generate input batches. + """ + + def _parse_batch(batch_examples): + batch_texts = batch_examples["texts"] + batch_images = [Image.open(item) for item in batch_examples["images"]] + + tokenizerd_inputs = self._processor( + text=batch_texts, images=batch_images, return_tensors="pd", padding="max_length", truncation=True + ) + + return tokenizerd_inputs + + # Seperates data into some batches. + # breakpoint() + yield _parse_batch(data[0]) + # one_batch = [] + # for example in data: + # one_batch.append(example) + # if len(one_batch) == batch_size: + # yield _parse_batch(one_batch) + # one_batch = [] + # if one_batch: + # yield _parse_batch(one_batch) + + def _preprocess(self, inputs): + """ + Transform the raw text to the model inputs, two steps involved: + 1) Transform the raw text to token ids. + 2) Generate the other model inputs from the raw text and token ids. + """ + # inputs = self._check_input_text(inputs) + batches = self._batchify(inputs, self._batch_size) + outputs = {"batches": batches, "text": inputs} + return outputs + + def _run_model(self, inputs): + """ + Run the task model from the outputs of the `_preprocess` function. + """ + all_texts = [] + all_images = [] + for batch_inputs in inputs["batches"]: + if len(batch_inputs["input_ids"]) > 0: + text_features = self._model.get_text_features(input_ids=batch_inputs["input_ids"]) + all_texts.append(text_features) + if len(batch_inputs["pixel_values"]) > 0: + image_features = self._model.get_image_features(pixel_values=batch_inputs["pixel_values"]) + all_images.append(image_features) + inputs.update({"text_features": all_texts}) + inputs.update({"image_features": all_images}) + return inputs + + def _postprocess(self, inputs): + return inputs + + def _construct_input_spec(self): + """ + Construct the input spec for the convert dygraph model to static model. + """ + self._input_spec = [ + paddle.static.InputSpec(shape=[None, None], dtype="int64", name="input_ids"), + ] From 88ea9fb278fb7e7c46d3eab9cde9dced738646f9 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Tue, 31 Jan 2023 20:07:07 +0800 Subject: [PATCH 02/18] Update image text retrieval taskflow api --- ...e_embedding.py => image_text_retrieval.py} | 83 +++++---- paddlenlp/taskflow/taskflow.py | 68 ++++++- paddlenlp/transformers/__init__.py | 1 + paddlenlp/transformers/auto/processing.py | 175 ++++++++++++++++++ .../transformers/chineseclip/procesing.py | 7 + paddlenlp/transformers/clip/procesing.py | 11 ++ paddlenlp/transformers/ernie_vil/procesing.py | 4 + 7 files changed, 314 insertions(+), 35 deletions(-) rename paddlenlp/taskflow/{vision_language_embedding.py => image_text_retrieval.py} (56%) create mode 100644 paddlenlp/transformers/auto/processing.py diff --git a/paddlenlp/taskflow/vision_language_embedding.py b/paddlenlp/taskflow/image_text_retrieval.py similarity index 56% rename from paddlenlp/taskflow/vision_language_embedding.py rename to paddlenlp/taskflow/image_text_retrieval.py index 97f74a352ec4..8bfec444b63c 100644 --- a/paddlenlp/taskflow/vision_language_embedding.py +++ b/paddlenlp/taskflow/image_text_retrieval.py @@ -15,11 +15,11 @@ import paddle from PIL import Image -from ..transformers import ErnieViLModel, ErnieViLProcessor +from ..transformers import AutoModel, AutoProcessor from .task import Task -class VisionLanguageTask(Task): +class ImageTextRetrievalTask(Task): """ The text_to_image generation model to generate the image. Args: @@ -33,21 +33,21 @@ def __init__(self, task, model, **kwargs): self._seed = None # we do not use batch self._batch_size = 1 - self._construct_tokenizer(image_model=model, text_model="ernie_vil-2.0-base-zh") + self._construct_tokenizer(model_name=model) self._construct_model(model) def _construct_model(self, model): """ Construct the inference model for the predictor. """ - self._model = ErnieViLModel.from_pretrained(model) + self._model = AutoModel.from_pretrained(model) self._model.eval() - def _construct_tokenizer(self, image_model, text_model): + def _construct_tokenizer(self, model_name): """ Construct the tokenizer for the predictor. """ - self._processor = ErnieViLProcessor.from_pretrained(image_model) + self._processor = AutoProcessor.from_pretrained(model_name) def _batchify(self, data, batch_size): """ @@ -55,26 +55,48 @@ def _batchify(self, data, batch_size): """ def _parse_batch(batch_examples): - batch_texts = batch_examples["texts"] - batch_images = [Image.open(item) for item in batch_examples["images"]] - - tokenizerd_inputs = self._processor( + if isinstance(batch_examples[0], str): + batch_texts = batch_examples + batch_images = None + else: + batch_texts = None + batch_images = batch_examples + + tokenized_inputs = self._processor( text=batch_texts, images=batch_images, return_tensors="pd", padding="max_length", truncation=True ) - - return tokenizerd_inputs + return tokenized_inputs # Seperates data into some batches. - # breakpoint() - yield _parse_batch(data[0]) - # one_batch = [] - # for example in data: - # one_batch.append(example) - # if len(one_batch) == batch_size: - # yield _parse_batch(one_batch) - # one_batch = [] - # if one_batch: - # yield _parse_batch(one_batch) + one_batch = [] + for example in data: + one_batch.append(example) + if len(one_batch) == batch_size: + yield _parse_batch(one_batch) + one_batch = [] + if one_batch: + yield _parse_batch(one_batch) + + def _check_input_text(self, inputs): + """ + Check whether the input text meet the requirement. + """ + inputs = inputs[0] + if isinstance(inputs, (str, Image.Image)): + if len(inputs) == 0: + raise ValueError("Invalid inputs, input text/image should not be empty, please check your input.") + inputs = [inputs] + elif isinstance(inputs, list): + # and len(inputs[0].strip()) > 0 + if not (isinstance(inputs[0], (str, Image.Image))): + raise TypeError( + "Invalid inputs, input text/image should be list of str/PIL.image, and first element of list should not be empty." + ) + else: + raise TypeError( + "Invalid inputs, input text should be str or list of str, but type of {} found!".format(type(inputs)) + ) + return inputs def _preprocess(self, inputs): """ @@ -82,7 +104,7 @@ def _preprocess(self, inputs): 1) Transform the raw text to token ids. 2) Generate the other model inputs from the raw text and token ids. """ - # inputs = self._check_input_text(inputs) + inputs = self._check_input_text(inputs) batches = self._batchify(inputs, self._batch_size) outputs = {"batches": batches, "text": inputs} return outputs @@ -91,20 +113,19 @@ def _run_model(self, inputs): """ Run the task model from the outputs of the `_preprocess` function. """ - all_texts = [] - all_images = [] + all_feats = [] for batch_inputs in inputs["batches"]: - if len(batch_inputs["input_ids"]) > 0: + if "input_ids" in batch_inputs: text_features = self._model.get_text_features(input_ids=batch_inputs["input_ids"]) - all_texts.append(text_features) - if len(batch_inputs["pixel_values"]) > 0: + all_feats.append(text_features) + if "pixel_values" in batch_inputs: image_features = self._model.get_image_features(pixel_values=batch_inputs["pixel_values"]) - all_images.append(image_features) - inputs.update({"text_features": all_texts}) - inputs.update({"image_features": all_images}) + all_feats.append(image_features) + inputs.update({"features": all_feats}) return inputs def _postprocess(self, inputs): + inputs["features"] = paddle.concat(inputs["features"], axis=0) return inputs def _construct_input_spec(self): diff --git a/paddlenlp/taskflow/taskflow.py b/paddlenlp/taskflow/taskflow.py index e4d2c410116a..3ef4bacbb633 100644 --- a/paddlenlp/taskflow/taskflow.py +++ b/paddlenlp/taskflow/taskflow.py @@ -24,6 +24,7 @@ from .dialogue import DialogueTask from .document_intelligence import DocPromptTask from .fill_mask import FillMaskTask +from .image_text_retrieval import ImageTextRetrievalTask from .information_extraction import GPTask, UIETask from .knowledge_mining import NPTagTask, WordTagTask from .lexical_analysis import LacTask @@ -42,7 +43,6 @@ TextToImageGenerationTask, TextToImageStableDiffusionTask, ) -from .vision_language_embedding import VisionLanguageTask from .word_segmentation import SegJiebaTask, SegLACTask, SegWordTagTask from .zero_shot_text_classification import ZeroShotTextClassificationTask @@ -487,11 +487,71 @@ }, "default": {"model": "utc-large"}, }, - "vision_language": { + "image_text_retrieval": { "models": { "PaddlePaddle/ernie_vil-2.0-base-zh": { - "task_class": VisionLanguageTask, - "task_flag": "vision_language_embeddings-2.0-base-zh", + "task_class": ImageTextRetrievalTask, + "task_flag": "image_text_retrieval-2.0-base-zh", + }, + "OFA-Sys/chinese-clip-vit-base-patch16": { + "task_class": ImageTextRetrievalTask, + "task_flag": "image_text_retrieval-OFA-Sys/chinese-clip-vit-base-patch16", + "task_priority_path": "OFA-Sys/chinese-clip-vit-base-patch16", + }, + "OFA-Sys/chinese-clip-vit-huge-patch14": { + "task_class": ImageTextRetrievalTask, + "task_flag": "image_text_retrieval-OFA-Sys/chinese-clip-vit-huge-patch14", + "task_priority_path": "OFA-Sys/chinese-clip-vit-huge-patch14", + }, + "OFA-Sys/chinese-clip-vit-large-patch14": { + "task_class": ImageTextRetrievalTask, + "task_flag": "image_text_retrieval-OFA-Sys/chinese-clip-vit-large-patch14", + "task_priority_path": "OFA-Sys/chinese-clip-vit-large-patch14", + }, + "OFA-Sys/chinese-clip-vit-large-patch14-336px": { + "task_class": ImageTextRetrievalTask, + "task_flag": "image_text_retrieval-OFA-Sys/chinese-clip-vit-large-patch14-336px", + "task_priority_path": "OFA-Sys/chinese-clip-vit-large-patch14-336px", + }, + "openai/clip-vit-base-patch32": { + "task_class": ImageTextRetrievalTask, + "task_flag": "image_text_retrieval-openai/clip-vit-base-patch32", + "task_priority_path": "openai/clip-vit-base-patch32", + }, + "openai/clip-vit-base-patch16": { + "task_class": ImageTextRetrievalTask, + "task_flag": "image_text_retrieval-openai/clip-vit-base-patch16", + "task_priority_path": "openai/clip-vit-base-patch16", + }, + "openai/clip-vit-large-patch14": { + "task_class": ImageTextRetrievalTask, + "task_flag": "image_text_retrieval-openai/clip-vit-large-patch14", + "task_priority_path": "openai/clip-vit-large-patch14", + }, + "laion/CLIP-ViT-H-14-laion2B-s32B-b79K": { + "task_class": ImageTextRetrievalTask, + "task_flag": "image_text_retrieval-laion/CLIP-ViT-H-14-laion2B-s32B-b79K", + "task_priority_path": "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", + }, + "laion/CLIP-ViT-B-32-laion2B-s34B-b79K": { + "task_class": ImageTextRetrievalTask, + "task_flag": "image_text_retrieval-laion/CLIP-ViT-B-32-laion2B-s34B-b79K", + "task_priority_path": "laion/CLIP-ViT-B-32-laion2B-s34B-b79K", + }, + "openai/clip-rn50": { + "task_class": ImageTextRetrievalTask, + "task_flag": "image_text_retrieval-openai/clip-rn50", + "task_priority_path": "openai/clip-rn50", + }, + "openai/clip-rn101": { + "task_class": ImageTextRetrievalTask, + "task_flag": "image_text_retrieval-openai/clip-rn101", + "task_priority_path": "openai/clip-rn101", + }, + "openai/clip-rn50x4": { + "task_class": ImageTextRetrievalTask, + "task_flag": "image_text_retrieval-openai/clip-rn50x4", + "task_priority_path": "openai/clip-rn50x4", }, }, "default": {"model": "PaddlePaddle/ernie_vil-2.0-base-zh"}, diff --git a/paddlenlp/transformers/__init__.py b/paddlenlp/transformers/__init__.py index e5df9cfcc7b9..6b41bb7c47db 100644 --- a/paddlenlp/transformers/__init__.py +++ b/paddlenlp/transformers/__init__.py @@ -156,6 +156,7 @@ from .opt.modeling import * from .auto.modeling import * from .auto.tokenizer import * +from .auto.processing import * from .codegen.modeling import * from .codegen.tokenizer import * from .codegen.configuration import * diff --git a/paddlenlp/transformers/auto/processing.py b/paddlenlp/transformers/auto/processing.py new file mode 100644 index 000000000000..026ef51be767 --- /dev/null +++ b/paddlenlp/transformers/auto/processing.py @@ -0,0 +1,175 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import io +import json +import os +from collections import OrderedDict + +from paddlenlp.utils.downloader import COMMUNITY_MODEL_PREFIX, get_path_from_url +from paddlenlp.utils.env import MODEL_HOME +from paddlenlp.utils.import_utils import import_module +from paddlenlp.utils.log import logger + +__all__ = [ + "AutoProcessor", +] + +PROCESSOR_MAPPING_NAMES = OrderedDict( + [ + ("ChineseCLIPProcessor", "chineseclip"), + ("CLIPProcessor", "clip"), + ("ErnieViLProcessor", "ernie_vil"), + ] +) + + +def get_configurations(): + MAPPING_NAMES = OrderedDict() + for key, class_name in PROCESSOR_MAPPING_NAMES.items(): + import_class = importlib.import_module(f"paddlenlp.transformers.{class_name}.procesing") + processor_name = getattr(import_class, key) + name = tuple(processor_name.pretrained_init_configuration.keys()) + if MAPPING_NAMES.get(name, None) is None: + MAPPING_NAMES[name] = [] + MAPPING_NAMES[name].append(processor_name) + return MAPPING_NAMES + + +class AutoProcessor: + """ + AutoClass can help you automatically retrieve the relevant model given the provided + pretrained weights/vocabulary. + Autoprocessor is a generic processor class that will be instantiated as one of the + base processor classes when created with the Autoprocessor.from_pretrained() classmethod. + """ + + MAPPING_NAMES = get_configurations() + _processor_mapping = MAPPING_NAMES + _name_mapping = PROCESSOR_MAPPING_NAMES + processor_config_file = "preprocessor_config.json" + + def __init__(self, *args, **kwargs): + raise EnvironmentError( + f"{self.__class__.__name__} is designed to be instantiated " + f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path).`" + ) + + @classmethod + def _get_processor_class_from_config(cls, pretrained_model_name_or_path, config_file_path): + with io.open(config_file_path, encoding="utf-8") as f: + init_kwargs = json.load(f) + # class name corresponds to this configuration + init_class = init_kwargs.pop("init_class", None) + if init_class is None: + init_class = init_kwargs.pop("processor_class", None) + + if init_class: + class_name = cls._name_mapping[init_class] + import_class = import_module(f"paddlenlp.transformers.{class_name}.procesing") + processor_class = getattr(import_class, init_class) + return processor_class + # If no `init_class`, we use pattern recognition to recognize the processor class. + else: + logger.info("We use pattern recognition to recognize the processor class.") + for key, pattern in cls._name_mapping.items(): + if pattern in pretrained_model_name_or_path.lower(): + init_class = key + class_name = cls._name_mapping[init_class] + import_class = import_module(f"paddlenlp.transformers.{class_name}.processor") + processor_class = getattr(import_class, init_class) + break + return processor_class + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + """ + Creates an instance of `Autoprocessor`. Related resources are loaded by + specifying name of a built-in pretrained model, or a community-contributed + pretrained model, or a local file directory path. + + Args: + pretrained_model_name_or_path (str): Name of pretrained model or dir path + to load from. The string can be: + + - Name of built-in pretrained model + - Name of a community-contributed pretrained model. + - Local directory path which contains processor related resources + and processor config file ("processor_config.json"). + *args (tuple): position arguments for model `__init__`. If provided, + use these as position argument values for processor initialization. + **kwargs (dict): keyword arguments for model `__init__`. If provided, + use these to update pre-defined keyword argument values for processor + initialization. + + Returns: + Pretrainedprocessor: An instance of `Pretrainedprocessor`. + + + Example: + .. code-block:: + from paddlenlp.transformers import AutoProcessor + processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") + processor.save_pretrained('clip_processor') + """ + + all_processor_names = [] + for names, processor_class in cls._processor_mapping.items(): + for name in names: + all_processor_names.append(name) + # From built-in pretrained models + if pretrained_model_name_or_path in all_processor_names: + for names, processor_classes in cls._processor_mapping.items(): + for pattern in names: + if pattern == pretrained_model_name_or_path: + actual_processor_class = processor_classes[0] + logger.info( + "We are using %s to load '%s'." % (actual_processor_class, pretrained_model_name_or_path) + ) + return actual_processor_class.from_pretrained( + pretrained_model_name_or_path, *model_args, **kwargs + ) + # From local dir path + elif os.path.isdir(pretrained_model_name_or_path): + config_file = os.path.join(pretrained_model_name_or_path, cls.processor_config_file) + if os.path.exists(config_file): + processor_class = cls._get_processor_class_from_config(pretrained_model_name_or_path, config_file) + logger.info("We are using %s to load '%s'." % (processor_class, pretrained_model_name_or_path)) + return processor_class.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + # Assuming from community-contributed pretrained models + else: + community_config_path = "/".join( + [COMMUNITY_MODEL_PREFIX, pretrained_model_name_or_path, cls.processor_config_file] + ) + + default_root = os.path.join(MODEL_HOME, pretrained_model_name_or_path) + try: + resolved_vocab_file = get_path_from_url(community_config_path, default_root) + except RuntimeError as err: + logger.error(err) + raise RuntimeError( + f"Can't load processor for '{pretrained_model_name_or_path}'.\n" + f"Please make sure that '{pretrained_model_name_or_path}' is:\n" + "- a correct model-identifier of built-in pretrained models,\n" + "- or a correct model-identifier of community-contributed pretrained models,\n" + "- or the correct path to a directory containing relevant processor files.\n" + ) + + if os.path.exists(resolved_vocab_file): + processor_class = cls._get_processor_class_from_config( + pretrained_model_name_or_path, resolved_vocab_file + ) + logger.info("We are using %s to load '%s'." % (processor_class, pretrained_model_name_or_path)) + return processor_class.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) diff --git a/paddlenlp/transformers/chineseclip/procesing.py b/paddlenlp/transformers/chineseclip/procesing.py index e7e0034ae155..701dfda83efd 100644 --- a/paddlenlp/transformers/chineseclip/procesing.py +++ b/paddlenlp/transformers/chineseclip/procesing.py @@ -41,6 +41,13 @@ class ChineseCLIPProcessor(ProcessorMixin): image_processor_class = "ChineseCLIPImageProcessor" tokenizer_class = "ChineseCLIPTokenizer" + pretrained_init_configuration = { + "OFA-Sys/chinese-clip-vit-base-patch16": {"do_lower_case": True}, + "OFA-Sys/chinese-clip-vit-huge-patch14": {"do_lower_case": True}, + "OFA-Sys/chinese-clip-vit-large-patch14": {"do_lower_case": True}, + "OFA-Sys/chinese-clip-vit-large-patch14-336px": {"do_lower_case": True}, + } + def __init__(self, image_processor=None, tokenizer=None, **kwargs): if "feature_extractor" in kwargs: warnings.warn( diff --git a/paddlenlp/transformers/clip/procesing.py b/paddlenlp/transformers/clip/procesing.py index ffaf9ad661ec..3424f643e193 100644 --- a/paddlenlp/transformers/clip/procesing.py +++ b/paddlenlp/transformers/clip/procesing.py @@ -40,6 +40,17 @@ class CLIPProcessor(ProcessorMixin): image_processor_class = "CLIPImageProcessor" tokenizer_class = "CLIPTokenizer" + pretrained_init_configuration = { + "openai/clip-vit-base-patch32": {"do_lower_case": True}, + "openai/clip-vit-base-patch16": {"do_lower_case": True}, + "openai/clip-vit-large-patch14": {"do_lower_case": True}, + "laion/CLIP-ViT-H-14-laion2B-s32B-b79K": {"do_lower_case": True}, + "laion/CLIP-ViT-B-32-laion2B-s34B-b79K": {"do_lower_case": True}, + "openai/clip-rn50": {"do_lower_case": True}, + "openai/clip-rn101": {"do_lower_case": True}, + "openai/clip-rn50x4": {"do_lower_case": True}, + } + def __init__(self, image_processor=None, tokenizer=None, **kwargs): if "feature_extractor" in kwargs: warnings.warn( diff --git a/paddlenlp/transformers/ernie_vil/procesing.py b/paddlenlp/transformers/ernie_vil/procesing.py index 952cfe901477..e89ab381f4bb 100644 --- a/paddlenlp/transformers/ernie_vil/procesing.py +++ b/paddlenlp/transformers/ernie_vil/procesing.py @@ -40,6 +40,10 @@ class ErnieViLProcessor(ProcessorMixin): image_processor_class = "ErnieViLImageProcessor" tokenizer_class = "ErnieViLTokenizer" + pretrained_init_configuration = { + "PaddlePaddle/ernie_vil-2.0-base-zh": {"do_lower_case": True}, + } + def __init__(self, image_processor=None, tokenizer=None, **kwargs): if "feature_extractor" in kwargs: warnings.warn( From 29709007aa0ef7702089e0696255052b6064260b Mon Sep 17 00:00:00 2001 From: w5688414 Date: Wed, 1 Feb 2023 21:53:15 +0800 Subject: [PATCH 03/18] Add multimodal_retriever of pipelines --- paddlenlp/taskflow/image_text_retrieval.py | 4 +- pipelines/pipelines/document_stores/base.py | 86 +++++--- pipelines/pipelines/nodes/__init__.py | 6 +- .../pipelines/nodes/retriever/__init__.py | 1 + .../pipelines/nodes/retriever/embedder.py | 188 +++++++++++++++++ .../nodes/retriever/multimodal_retriever.py | 197 ++++++++++++++++++ pipelines/pipelines/schema.py | 62 +++--- 7 files changed, 485 insertions(+), 59 deletions(-) create mode 100644 pipelines/pipelines/nodes/retriever/embedder.py create mode 100644 pipelines/pipelines/nodes/retriever/multimodal_retriever.py diff --git a/paddlenlp/taskflow/image_text_retrieval.py b/paddlenlp/taskflow/image_text_retrieval.py index 8bfec444b63c..c8a4447020d7 100644 --- a/paddlenlp/taskflow/image_text_retrieval.py +++ b/paddlenlp/taskflow/image_text_retrieval.py @@ -28,11 +28,11 @@ class ImageTextRetrievalTask(Task): kwargs (dict, optional): Additional keyword arguments passed along to the specific task. """ - def __init__(self, task, model, **kwargs): + def __init__(self, task, model, batch_size=1, **kwargs): super().__init__(task=task, model=model, **kwargs) self._seed = None # we do not use batch - self._batch_size = 1 + self._batch_size = batch_size self._construct_tokenizer(model_name=model) self._construct_model(model) diff --git a/pipelines/pipelines/document_stores/base.py b/pipelines/pipelines/document_stores/base.py index 9dddb21dde1d..ce8537eb4b4b 100644 --- a/pipelines/pipelines/document_stores/base.py +++ b/pipelines/pipelines/document_stores/base.py @@ -13,25 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Generator, Optional, Dict, List, Set, Union - -import logging import collections -import numpy as np -from itertools import islice +import logging from abc import abstractmethod +from itertools import islice from pathlib import Path +from typing import Dict, Generator, List, Optional, Set, Union -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal # type: ignore +import numpy as np -from pipelines.schema import Document, Label -from pipelines.nodes.base import BaseComponent +from pipelines.document_stores.utils import ( + eval_data_from_json, + eval_data_from_jsonl, + squad_json_to_jsonl, +) from pipelines.errors import DuplicateDocumentError +from pipelines.nodes.base import BaseComponent from pipelines.nodes.preprocessor import PreProcessor -from pipelines.document_stores.utils import eval_data_from_json, eval_data_from_jsonl, squad_json_to_jsonl +from pipelines.schema import Document, FilterType, Label logger = logging.getLogger(__name__) @@ -290,6 +289,37 @@ def query_by_embedding( ) -> List[Document]: pass + def query_by_embedding_batch( + self, + query_embs: Union[List[np.ndarray], np.ndarray], + filters: Optional[Union[FilterType, List[Optional[FilterType]]]] = None, + top_k: int = 10, + index: Optional[str] = None, + return_embedding: Optional[bool] = None, + headers: Optional[Dict[str, str]] = None, + ) -> List[List[Document]]: + if isinstance(filters, list): + if len(filters) != len(query_embs): + raise Exception( + "Number of filters does not match number of query_embs. Please provide as many filters" + " as query_embs or a single filter that will be applied to each query_emb." + ) + else: + filters = [filters] * len(query_embs) + results = [] + for query_emb, filter in zip(query_embs, filters): + results.append( + self.query_by_embedding( + query_emb=query_emb, + filters=filter, + top_k=top_k, + index=index, + return_embedding=return_embedding, + headers=headers, + ) + ) + return results + @abstractmethod def get_label_count(self, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None) -> int: pass @@ -338,28 +368,28 @@ def add_eval_data( # TODO improve support for PreProcessor when adding eval data if preprocessor is not None: assert preprocessor.split_by != "sentence", ( - f"Split by sentence not supported.\n" - f"Please set 'split_by' to either 'word' or 'passage' in the supplied PreProcessor." + "Split by sentence not supported.\n" + "Please set 'split_by' to either 'word' or 'passage' in the supplied PreProcessor." ) - assert preprocessor.split_respect_sentence_boundary == False, ( - f"split_respect_sentence_boundary not supported yet.\n" - f"Please set 'split_respect_sentence_boundary' to False in the supplied PreProcessor." + assert preprocessor.split_respect_sentence_boundary is False, ( + "split_respect_sentence_boundary not supported yet.\n" + "Please set 'split_respect_sentence_boundary' to False in the supplied PreProcessor." ) assert preprocessor.split_overlap == 0, ( - f"Overlapping documents are currently not supported when adding eval data.\n" - f"Please set 'split_overlap=0' in the supplied PreProcessor." + "Overlapping documents are currently not supported when adding eval data.\n" + "Please set 'split_overlap=0' in the supplied PreProcessor." ) - assert preprocessor.clean_empty_lines == False, ( - f"clean_empty_lines currently not supported when adding eval data.\n" - f"Please set 'clean_empty_lines=False' in the supplied PreProcessor." + assert preprocessor.clean_empty_lines is False, ( + "clean_empty_lines currently not supported when adding eval data.\n" + "Please set 'clean_empty_lines=False' in the supplied PreProcessor." ) - assert preprocessor.clean_whitespace == False, ( - f"clean_whitespace is currently not supported when adding eval data.\n" - f"Please set 'clean_whitespace=False' in the supplied PreProcessor." + assert preprocessor.clean_whitespace is False, ( + "clean_whitespace is currently not supported when adding eval data.\n" + "Please set 'clean_whitespace=False' in the supplied PreProcessor." ) - assert preprocessor.clean_header_footer == False, ( - f"clean_header_footer is currently not supported when adding eval data.\n" - f"Please set 'clean_header_footer=False' in the supplied PreProcessor." + assert preprocessor.clean_header_footer is False, ( + "clean_header_footer is currently not supported when adding eval data.\n" + "Please set 'clean_header_footer=False' in the supplied PreProcessor." ) file_path = Path(filename) diff --git a/pipelines/pipelines/nodes/__init__.py b/pipelines/pipelines/nodes/__init__.py index 505aef8b7b3d..9834232943a1 100644 --- a/pipelines/pipelines/nodes/__init__.py +++ b/pipelines/pipelines/nodes/__init__.py @@ -35,7 +35,11 @@ from pipelines.nodes.question_generator import QuestionGenerator from pipelines.nodes.ranker import BaseRanker, ErnieRanker from pipelines.nodes.reader import BaseReader, ErnieReader -from pipelines.nodes.retriever import BaseRetriever, DensePassageRetriever +from pipelines.nodes.retriever import ( + BaseRetriever, + DensePassageRetriever, + MultiModalRetriever, +) from pipelines.nodes.sentiment_analysis import ( SentaProcessor, SentaVisualization, diff --git a/pipelines/pipelines/nodes/retriever/__init__.py b/pipelines/pipelines/nodes/retriever/__init__.py index 93dcc59c195f..a8be48ca86f3 100644 --- a/pipelines/pipelines/nodes/retriever/__init__.py +++ b/pipelines/pipelines/nodes/retriever/__init__.py @@ -14,4 +14,5 @@ # flake8: noqa from pipelines.nodes.retriever.base import BaseRetriever from pipelines.nodes.retriever.dense import DensePassageRetriever +from pipelines.nodes.retriever.multimodal_retriever import MultiModalRetriever from pipelines.nodes.retriever.sparse import BM25Retriever diff --git a/pipelines/pipelines/nodes/retriever/embedder.py b/pipelines/pipelines/nodes/retriever/embedder.py new file mode 100644 index 000000000000..8d14bbb1e038 --- /dev/null +++ b/pipelines/pipelines/nodes/retriever/embedder.py @@ -0,0 +1,188 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import paddle +from PIL import Image +from tqdm.auto import tqdm + +from paddlenlp import Taskflow +from pipelines.schema import Document + +logger = logging.getLogger(__name__) +FilterType = Dict[str, Union[Dict[str, Any], List[Any], str, int, float, bool]] + + +# TODO the keys should match with ContentTypes (currently 'audio' is missing) +DOCUMENT_CONVERTERS = { + # NOTE: Keep this '?' cleaning step, it needs to be double-checked for impact on the inference results. + "text": lambda doc: doc.content[:-1] if doc.content[-1] == "?" else doc.content, + "table": lambda doc: " ".join( + doc.content.columns.tolist() + [cell for row in doc.content.values.tolist() for cell in row] + ), + "image": lambda doc: Image.open(doc.content), +} + +CAN_EMBED_META = ["text", "table"] + + +class MultiModalEmbedder: + def __init__( + self, + embedding_models: Dict[str, Union[Path, str]], # replace str with ContentTypes starting from Python3.8 + feature_extractors_params: Optional[Dict[str, Dict[str, Any]]] = None, + batch_size: int = 16, + embed_meta_fields: List[str] = ["name"], + progress_bar: bool = True, + ): + """ + Init the Retriever and all its models from a local or remote model checkpoint. + The checkpoint format matches the Hugging Face transformers' model format. + :param embedding_models: A dictionary matching a local path or remote name of encoder checkpoint with + the content type it should handle ("text", "table", "image", etc...). + The format is the one that Hugging Face Hub models use. + Expected input format: `{'text': 'name_or_path_to_text_model', 'image': 'name_or_path_to_image_model', ... }` + Keep in mind that the models should output in the same embedding space for this retriever to work. + :param feature_extractors_params: A dictionary matching a content type ("text", "table", "image" and so on) with the + parameters of its own feature extractor if the model requires one. + Expected input format: `{'text': {'param_name': 'param_value', ...}, 'image': {'param_name': 'param_value', ...}, ...}` + :param batch_size: Number of questions or passages to encode at once. In case of multiple GPUs, this will be the total batch size. + :param embed_meta_fields: Concatenate the provided meta fields and text passage / image to a text pair that is + then used to create the embedding. + This is the approach used in the original paper and is likely to improve + performance if your titles contain meaningful information for retrieval + (topic, entities etc.). + :param progress_bar: Whether to show a tqdm progress bar or not. + Can be helpful to disable in production deployments to keep the logs clean. + """ + super().__init__() + + self.batch_size = batch_size + self.progress_bar = progress_bar + self.embed_meta_fields = embed_meta_fields + + feature_extractors_params = { + content_type: {"max_length": 256, **(feature_extractors_params or {}).get(content_type, {})} + for content_type in ["text", "table", "image", "audio"] # FIXME get_args(ContentTypes) from Python3.8 on + } + + self.models = {} # replace str with ContentTypes starting from Python3.8 + for content_type, embedding_model in embedding_models.items(): + self.models[content_type] = Taskflow("image_text_retrieval") + + # Check embedding sizes for models: they must all match + if len(self.models) > 1: + sizes = {model.embedding_dim for model in self.models.values()} + if None in sizes: + logger.warning( + "Haystack could not find the output embedding dimensions for '%s'. " + "Dimensions won't be checked before computing the embeddings.", + ", ".join( + { + str(model.model_name_or_path) + for model in self.models.values() + if model.embedding_dim is None + } + ), + ) + elif len(sizes) > 1: + embedding_sizes: Dict[int, List[str]] = {} + for model in self.models.values(): + embedding_sizes[model.embedding_dim] = embedding_sizes.get(model.embedding_dim, []) + [ + str(model.model_name_or_path) + ] + raise ValueError(f"Not all models have the same embedding size: {embedding_sizes}") + + def embed(self, documents: List[Document], batch_size: Optional[int] = None) -> np.ndarray: + """ + Create embeddings for a list of documents using the relevant encoder for their content type. + :param documents: Documents to embed. + :return: Embeddings, one per document, in the form of a np.array + """ + batch_size = batch_size if batch_size is not None else self.batch_size + + all_embeddings = [] + for batch_index in tqdm( + iterable=range(0, len(documents), batch_size), + unit=" Docs", + desc="Create embeddings", + position=1, + leave=False, + disable=not self.progress_bar, + ): + docs_batch = documents[batch_index : batch_index + batch_size] + data_by_type = self._docs_to_data(documents=docs_batch) + + # Get output for each model + outputs_by_type: Dict[str, paddle.Tensor] = {} # replace str with ContentTypes starting Python3.8 + for data_type, data in data_by_type.items(): + + model = self.models.get(data_type) + if not model: + raise Exception( + f"Some data of type {data_type} was passed, but no model capable of handling such data was " + f"initialized. Initialized models: {', '.join(self.models.keys())}" + ) + outputs_by_type[data_type] = model(data)["features"] + # Check the output sizes + embedding_sizes = [output.shape[-1] for output in outputs_by_type.values()] + + if not all(embedding_size == embedding_sizes[0] for embedding_size in embedding_sizes): + raise Exception( + "Some of the models are using a different embedding size. They should all match. " + f"Embedding sizes by model: " + f"{ {name: output.shape[-1] for name, output in outputs_by_type.items()} }" + ) + + # Combine the outputs in a single matrix + outputs = paddle.stack(list(outputs_by_type.values())) + embeddings = outputs.reshape([-1, embedding_sizes[0]]) + embeddings = embeddings.cpu() + all_embeddings.append(embeddings) + return np.concatenate(all_embeddings) + + def _docs_to_data( + self, documents: List[Document] + ) -> Dict[str, List[Any]]: # FIXME replace str to ContentTypes from Python3.8 + """ + Extract the data to embed from each document and return them classified by content type. + :param documents: The documents to prepare fur multimodal embedding. + :return: A dictionary containing one key for each content type, and a list of data extracted + from each document, ready to be passed to the feature extractor (for example the content + of a text document, a linearized table, a PIL image object, and so on) + """ + docs_data: Dict[str, List[Any]] = { # FIXME replace str to ContentTypes from Python3.8 + key: [] for key in ["text", "table", "image", "audio"] + } # FIXME get_args(ContentTypes) from Python3.8 on + for doc in documents: + try: + document_converter = DOCUMENT_CONVERTERS[doc.content_type] + except KeyError: + raise Exception( + f"Unknown content type '{doc.content_type}'. Known types: 'text', 'table', 'image'." # FIXME {', '.join(get_args(ContentTypes))}" from Python3.8 on + ) + + data = document_converter(doc) + + if doc.content_type in CAN_EMBED_META: + meta = [v for k, v in (doc.meta or {}).items() if k in self.embed_meta_fields] + data = f"{' '.join(meta)} {data}" if meta else data + + docs_data[doc.content_type].append(data) + + return {key: values for key, values in docs_data.items() if values} diff --git a/pipelines/pipelines/nodes/retriever/multimodal_retriever.py b/pipelines/pipelines/nodes/retriever/multimodal_retriever.py new file mode 100644 index 000000000000..f25cef2cff8a --- /dev/null +++ b/pipelines/pipelines/nodes/retriever/multimodal_retriever.py @@ -0,0 +1,197 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import numpy as np + +from pipelines.document_stores import BaseDocumentStore +from pipelines.nodes.retriever.base import BaseRetriever +from pipelines.nodes.retriever.embedder import MultiModalEmbedder +from pipelines.schema import ContentTypes, Document, FilterType + +logger = logging.getLogger(__name__) + + +class MultiModalRetriever(BaseRetriever): + def __init__( + self, + document_store: BaseDocumentStore, + query_embedding_model: Union[Path, str], + document_embedding_models: Dict[str, Union[Path, str]], # Replace str with ContentTypes starting Python3.8 + query_type: str = "text", # Replace str with ContentTypes starting Python3.8 + query_feature_extractor_params: Dict[str, Any] = {"max_length": 64}, + document_feature_extractors_params: Dict[str, Dict[str, Any]] = {"text": {"max_length": 256}}, + top_k: int = 10, + batch_size: int = 16, + embed_meta_fields: List[str] = ["name"], + similarity_function: str = "dot_product", + progress_bar: bool = True, + scale_score: bool = True, + ): + """ + Retriever that uses a multiple encoder to jointly retrieve among a database consisting of different + data types. + :param document_store: An instance of DocumentStore from which to retrieve documents. + :param query_embedding_model: Local path or remote name of question encoder checkpoint. The format equals the + one used by Hugging Face transformers' modelhub models. + :param document_embedding_models: Dictionary matching a local path or remote name of document encoder + checkpoint with the content type it should handle ("text", "table", "image", and so on). + The format equals the one used by Hugging Face transformers' modelhub models. + :param query_type: The content type of the query ("text", "image" and so on). + :param query_feature_extraction_params: The parameters to pass to the feature extractor of the query. + :param document_feature_extraction_params: The parameters to pass to the feature extractor of the documents. + :param top_k: How many documents to return per query. + :param batch_size: Number of questions or documents to encode at once. For multiple GPUs, this is + the total batch size. + :param embed_meta_fields: Concatenate the provided meta fields to a (text) pair that is then used to create + the embedding. This is likely to improve performance if your titles contain meaningful information + for retrieval (topic, entities, and so on). Note that only text and table documents support this feature. + :param similarity_function: Which function to apply for calculating the similarity of query and document + embeddings during training. Options: `dot_product` (default) or `cosine`. + :param progress_bar: Whether to show a tqdm progress bar or not. + Can be helpful to disable in production deployments to keep the logs clean. + :param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]). + If true (default) similarity scores (e.g. cosine or dot_product) which naturally have a different value + range are scaled to a range of [0,1], where 1 means extremely relevant. + Otherwise raw similarity scores (for example, cosine or dot_product) are used. + """ + super().__init__() + + self.similarity_function = similarity_function + self.progress_bar = progress_bar + self.top_k = top_k + self.scale_score = scale_score + + self.document_embedder = MultiModalEmbedder( + embedding_models=document_embedding_models, + feature_extractors_params=document_feature_extractors_params, + batch_size=batch_size, + embed_meta_fields=embed_meta_fields, + progress_bar=progress_bar, + ) + + # # Try to reuse the same embedder for queries if there is overlap + if document_embedding_models.get(query_type, None) == query_embedding_model: + self.query_embedder = self.document_embedder + else: + self.query_embedder = MultiModalEmbedder( + embedding_models={query_type: query_embedding_model}, + feature_extractors_params={query_type: query_feature_extractor_params}, + batch_size=batch_size, + embed_meta_fields=embed_meta_fields, + progress_bar=progress_bar, + ) + + self.document_store = document_store + + def retrieve( # type: ignore + self, + query: Any, + query_type: ContentTypes = "text", + filters: Optional[FilterType] = None, + top_k: Optional[int] = None, + index: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + scale_score: Optional[bool] = None, + document_store: Optional[BaseDocumentStore] = None, + ) -> List[Document]: + """ + Scan through documents in DocumentStore and return a small number of documents that are most relevant to the + supplied query. Returns a list of Documents. + :param query: Query value. It might be text, a path, a table, and so on. + :param query_type: Type of the query ("text", "table", "image" and so on). + :param filters: Optional filters to narrow down the search space to documents whose metadata fulfill certain + conditions. It can be a single filter applied to each query or a list of filters + (one filter per query). + :param top_k: How many documents to return per query. Must be > 0. + :param index: The name of the index in the DocumentStore from which to retrieve documents. + :param batch_size: Number of queries to embed at a time. Must be > 0. + :param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]). + If true, similarity scores (for example, cosine or dot_product) which naturally have a different + value range is scaled to a range of [0,1], where 1 means extremely relevant. + Otherwise raw similarity scores (for example, cosine or dot_product) are used. + """ + return self.retrieve_batch( + queries=[query], + queries_type=query_type, + filters=[filters], + top_k=top_k, + index=index, + headers=headers, + batch_size=1, + scale_score=scale_score, + document_store=document_store, + )[0] + + def retrieve_batch( # type: ignore + self, + queries: List[Any], + queries_type: ContentTypes = "text", + filters: Optional[Union[FilterType, List[Optional[FilterType]]]] = None, + top_k: Optional[int] = None, + index: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + batch_size: Optional[int] = None, + scale_score: Optional[bool] = None, + document_store: Optional[BaseDocumentStore] = None, + ) -> List[List[Document]]: + """ + Scan through documents in DocumentStore and return a small number of documents that are most relevant to the + supplied queries. Returns a list of lists of Documents (one list per query). + This method assumes all queries are of the same data type. Mixed-type query batches (for example one image and one text) + are currently not supported. Group the queries by type and call `retrieve()` on uniform batches only. + :param queries: List of query values. They might be text, paths, tables, and so on. + :param queries_type: Type of the query ("text", "table", "image" and so on) + :param filters: Optional filters to narrow down the search space to documents whose metadata fulfill certain + conditions. It can be a single filter that will be applied to each query or a list of filters + (one filter per query). + :param top_k: How many documents to return per query. Must be > 0. + :param index: The name of the index in the DocumentStore from which to retrieve documents. + :param batch_size: Number of queries to embed at a time. Must be > 0. + :param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]). + If True, similarity scores (for example, cosine or dot_product) which naturally have a different + value range are scaled to a range of [0,1], where 1 means extremely relevant. + Otherwise raw similarity scores (for example, cosine or dot_product) are used. + """ + top_k = top_k or self.top_k + document_store = document_store or self.document_store + if not document_store: + raise ValueError( + "This Retriever was not initialized with a Document Store. Provide one to the retrieve() or retrieve_batch() method." + ) + index = index or document_store.index + scale_score = scale_score or self.scale_score + + # Embed the queries - we need them into Document format to leverage MultiModalEmbedder.embed() + query_docs = [Document(content=query, content_type=queries_type) for query in queries] + query_embeddings = self.query_embedder.embed(documents=query_docs, batch_size=batch_size) + # Query documents by embedding (the actual retrieval step) + documents = document_store.query_by_embedding_batch( + query_embs=query_embeddings, + top_k=top_k, + filters=filters, + index=index, + headers=headers, + ) + return documents + + def embed_documents(self, docs: List[Document]) -> np.ndarray: + return self.document_embedder.embed(documents=docs) + + def embed_queries(self, queries: List[str]) -> np.ndarray: + query_documents = [Document(content=query, content_type="text") for query in queries] + return self.query_embedder.embed(documents=query_documents) diff --git a/pipelines/pipelines/schema.py b/pipelines/pipelines/schema.py index 4cf33be8312e..639c591a7ea3 100644 --- a/pipelines/pipelines/schema.py +++ b/pipelines/pipelines/schema.py @@ -14,39 +14,45 @@ # limitations under the License. from __future__ import annotations + import typing -from typing import Any, Optional, Dict, List, Union, Optional from dataclasses import asdict +from typing import Any, Dict, List, Optional, Union try: from typing import Literal except ImportError: from typing_extensions import Literal # type: ignore -# We are using Pydantic dataclasses instead of vanilla Python's -# See #1598 for the reasons behind this choice & performance considerations -from pydantic.dataclasses import dataclass if typing.TYPE_CHECKING: from dataclasses import dataclass # type: ignore +else: + # We are using Pydantic dataclasses instead of vanilla Python's + # See #1598 for the reasons behind this choice & performance considerations + from pydantic.dataclasses import dataclass -from pydantic.json import pydantic_encoder +import ast +import json +import logging +import time from pathlib import Path from uuid import uuid4 + import mmh3 import numpy as np -import logging -import time -import json import pandas as pd -import ast +from pydantic import BaseConfig +from pydantic.json import pydantic_encoder logger = logging.getLogger(__name__) -from pydantic import BaseConfig - BaseConfig.arbitrary_types_allowed = True +#: Types of content_types supported +ContentTypes = Literal["text", "table", "image", "audio"] +FilterType = Dict[str, Union[Dict[str, Any], List[Any], str, int, float, bool]] + @dataclass class Document: @@ -101,7 +107,7 @@ def __init__( """ if content is None: - raise ValueError(f"Can't create 'Document': Mandatory 'content' field is None") + raise ValueError("Can't create 'Document': Mandatory 'content' field is None") self.content = content self.content_type = content_type @@ -143,7 +149,7 @@ def _get_id(self, id_hash_keys: Optional[List[str]] = None): if final_hash_key == "": raise ValueError( - f"Cant't create 'Document': 'id_hash_keys' must contain at least one of ['content', 'meta']" + "Cant't create 'Document': 'id_hash_keys' must contain at least one of ['content', 'meta']" ) return "{:02x}".format(mmh3.hash128(final_hash_key, signed=False)) @@ -253,10 +259,10 @@ class Span: start: int end: int """ - Defining a sequence of characters (Text span) or cells (Table span) via start and end index. - For extractive QA: Character where answer starts/ends + Defining a sequence of characters (Text span) or cells (Table span) via start and end index. + For extractive QA: Character where answer starts/ends For TableQA: Cell where the answer starts/ends (counted from top left to bottom right of table) - + :param start: Position where the span starts :param end: Position where the spand ends """ @@ -277,24 +283,24 @@ class Answer: For example, it's used within some Nodes like the Reader, but also in the REST API. :param answer: The answer string. If there's no possible answer (aka "no_answer" or "is_impossible) this will be an empty string. - :param type: One of ("generative", "extractive", "other"): Whether this answer comes from an extractive model - (i.e. we can locate an exact answer string in one of the documents) or from a generative model - (i.e. no pointer to a specific document, no offsets ...). + :param type: One of ("generative", "extractive", "other"): Whether this answer comes from an extractive model + (i.e. we can locate an exact answer string in one of the documents) or from a generative model + (i.e. no pointer to a specific document, no offsets ...). :param score: The relevance score of the Answer determined by a model (e.g. Reader or Generator). In the range of [0,1], where 1 means extremely relevant. :param context: The related content that was used to create the answer (i.e. a text passage, part of a table, image ...) :param offsets_in_document: List of `Span` objects with start and end positions of the answer **in the document** (as stored in the document store). - For extractive QA: Character where answer starts => `Answer.offsets_in_document[0].start + For extractive QA: Character where answer starts => `Answer.offsets_in_document[0].start For TableQA: Cell where the answer starts (counted from top left to bottom right of table) => `Answer.offsets_in_document[0].start - (Note that in TableQA there can be multiple cell ranges that are relevant for the answer, thus there can be multiple `Spans` here) + (Note that in TableQA there can be multiple cell ranges that are relevant for the answer, thus there can be multiple `Spans` here) :param offsets_in_context: List of `Span` objects with start and end positions of the answer **in the context** (i.e. the surrounding text/table of a certain window size). - For extractive QA: Character where answer starts => `Answer.offsets_in_document[0].start + For extractive QA: Character where answer starts => `Answer.offsets_in_document[0].start For TableQA: Cell where the answer starts (counted from top left to bottom right of table) => `Answer.offsets_in_document[0].start - (Note that in TableQA there can be multiple cell ranges that are relevant for the answer, thus there can be multiple `Spans` here) + (Note that in TableQA there can be multiple cell ranges that are relevant for the answer, thus there can be multiple `Spans` here) :param document_id: ID of the document that the answer was located it (if any) - :param meta: Dict that can be used to associate any kind of custom meta data with the answer. + :param meta: Dict that can be used to associate any kind of custom meta data with the answer. In extractive QA, this will carry the meta data of the document where the answer was found. """ @@ -423,12 +429,12 @@ def __init__( # If an Answer is provided we need to make sure that it's consistent with the `no_answer` value # TODO: reassess if we want to enforce Span.start=0 and Span.end=0 for no_answer=True if self.answer is not None: - if no_answer == True: + if no_answer is True: if self.answer.answer != "" or self.answer.context: raise ValueError( f"Got no_answer == True while there seems to be an possible Answer: {self.answer}" ) - elif no_answer == False: + elif no_answer is False: if self.answer.answer == "": raise ValueError( f"Got no_answer == False while there seems to be no possible Answer: {self.answer}" @@ -524,13 +530,13 @@ def __init__(self, labels: List[Label], drop_negative_labels=False, drop_no_answ # drop duplicate labels and remove negative labels if needed. labels = list(set(labels)) if drop_negative_labels: - is_positive_label = lambda l: (l.is_correct_answer and l.is_correct_document) or ( + is_positive_label = lambda l: (l.is_correct_answer and l.is_correct_document) or ( # noqa: E731 l.answer is None and l.is_correct_document ) labels = [l for l in labels if is_positive_label(l)] if drop_no_answers: - labels = [l for l in labels if l.no_answer == False] + labels = [l for l in labels if l.no_answer is False] self.labels = labels From 6840d7142c66d492c360d06babdb60abbe91f7ec Mon Sep 17 00:00:00 2001 From: w5688414 Date: Fri, 3 Feb 2023 11:06:18 +0000 Subject: [PATCH 04/18] Add image text retrieval pipelines application --- paddlenlp/transformers/auto/processing.py | 1 + .../examples/image_text_retrieval/README.md | 191 ++++++++++++++++++ .../image_text_retrieval_example.py | 64 ++++++ .../image_text_retrieval/run_search_server.sh | 20 ++ .../image_text_retrieval/run_search_web.sh | 18 ++ .../pipelines/nodes/retriever/embedder.py | 2 +- .../pipeline/image_text_retrieval.yaml | 25 +++ pipelines/ui/utils.py | 31 +++ pipelines/ui/webapp_image_text_retrieval.py | 58 ++++++ pipelines/utils/offline_ann_mm.py | 112 ++++++++++ 10 files changed, 521 insertions(+), 1 deletion(-) create mode 100644 pipelines/examples/image_text_retrieval/README.md create mode 100644 pipelines/examples/image_text_retrieval/image_text_retrieval_example.py create mode 100644 pipelines/examples/image_text_retrieval/run_search_server.sh create mode 100644 pipelines/examples/image_text_retrieval/run_search_web.sh create mode 100644 pipelines/rest_api/pipeline/image_text_retrieval.yaml create mode 100644 pipelines/ui/webapp_image_text_retrieval.py create mode 100644 pipelines/utils/offline_ann_mm.py diff --git a/paddlenlp/transformers/auto/processing.py b/paddlenlp/transformers/auto/processing.py index 026ef51be767..34e896b8ebab 100644 --- a/paddlenlp/transformers/auto/processing.py +++ b/paddlenlp/transformers/auto/processing.py @@ -1,4 +1,5 @@ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pipelines/examples/image_text_retrieval/README.md b/pipelines/examples/image_text_retrieval/README.md new file mode 100644 index 000000000000..313cddfaf71b --- /dev/null +++ b/pipelines/examples/image_text_retrieval/README.md @@ -0,0 +1,191 @@ +# 端到端文图跨模态检索系统 + +## 1. 场景概述 + +文图跨模态检索系统目的是通过文字找到最符合描述的图片。传统的方案是用标签和图片的关键字进行匹配,而跨模态检索真正的实现了文本语义和图片语义内容的匹配,这种检索方式更符合人类的逻辑判断,是一种真正意义上的端到端人工智能。文图应用目前可以广泛应用于电商搜索,安防视频,图像检索,抖音等小视频,旅游app应用搜索。有助于提升效率和搜索体验。另外还有一些潜在的领域,比如司法的互联网调查取证,侵权检测,数据增强,文案匹配,各种互联网logo,肖像,风景,海报等图片网站的检索,医药等专业领域的文图搜索等。 + +## 2. 产品功能介绍 + +本项目提供了低成本搭建端到端文图跨模态检索系统的能力。用户只需要处理好自己的业务数据,就可以使用本项目预置的文图跨模态检索系统模型快速搭建一个针对自己业务数据的跨模态检索系统,并可以提供 Web 化产品服务。 + +
+ +
+ + +### 2.1 系统特色 + ++ 端到端 + + 提供包括数据建库、模型服务部署、WebUI 可视化一整套端到端文图跨模态检索系统能力 + + 依托百度领先的NLP技术,包括[ERNIE](https://github.com/PaddlePaddle/ERNIE)语义理解技术,[ERNIE-ViL 2.0](https://arxiv.org/abs/2209.15270)跨模态检索能力 + + 预置领先的深度学习模型 + +## 3. 快速开始: 快速搭建文图跨模态检索系统 + + +### 3.1 运行环境和安装说明 + +本实验采用了以下的运行环境进行,详细说明如下,用户也可以在自己 GPU 硬件环境进行: + +a. 软件环境: +- python >= 3.7.0 +- paddlenlp >= 2.5.0 +- paddlepaddle-gpu >=2.4.1 +- CUDA Version: 11.2 +- NVIDIA Driver Version: 440.64.00 +- Ubuntu 16.04.6 LTS (Docker) + +b. 硬件环境: + +- NVIDIA Tesla V100 16GB x4卡 +- Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz + +c. 依赖安装: +首先需要安装PaddlePaddle,PaddlePaddle的安装请参考文档[官方安装文档](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html),然后安装下面的依赖: +```bash +# pip 一键安装 +pip install --upgrade paddle-pipelines -i https://pypi.tuna.tsinghua.edu.cn/simple +# 或者源码进行安装最新版本 +cd ${HOME}/PaddleNLP/pipelines/ +pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple +python setup.py install +``` + +``` +# 下载pipelines源代码 +git clone https://github.com/PaddlePaddle/PaddleNLP.git +cd PaddleNLP/pipelines +``` +【注意】以下的所有的流程都只需要在`pipelines`根目录下进行,不需要跳转目录 + +### 3.2 数据说明 +文图跨模态检索数据库的数据来自于[Noah-Wukong数据集](https://wukong-dataset.github.io/wukong-dataset/index.html),并选取了测试集中3056张图片来搭建文图跨模态检索系统。 + +### 3.3 一键体验文图跨模态检索系统 + +#### 3.3.1 快速一键启动 + +我们预置了基于[Noah-Wukong数据集](https://wukong-dataset.github.io/wukong-dataset/index.html)搭建文图跨模态检索系统的代码示例,您可以通过如下命令快速体验文图跨模态检索系统的效果 +```bash +# 我们建议在 GPU 环境下运行本示例,运行速度较快 +# 设置 1 个空闲的 GPU 卡,此处假设 0 卡为空闲 GPU +export CUDA_VISIBLE_DEVICES=0 +python examples/image_text_retrieval/image_text_retrieval_example.py --device gpu \ + --search_engine faiss +# 如果只有 CPU 机器,可以通过 --device 参数指定 cpu 即可, 运行耗时较长 +unset CUDA_VISIBLE_DEVICES +python examples/image_text_retrieval/image_text_retrieval_example.py --device cpu \ + --search_engine faiss +``` + + +### 3.4 构建 Web 可视化文图跨模态检索系统 + +整个 Web 可视化文图跨模态检索系统主要包含 3 大组件: 1. 基于 ElasticSearch 的 ANN 服务 2. 基于 RestAPI 构建模型服务 3. 基于 Streamlit 构建 WebUI,接下来我们依次搭建这 3 个服务并最终形成可视化的文图跨模态检索系统。 + +#### 3.4.1 启动 ANN 服务 +1. 参考官方文档下载安装 [elasticsearch-8.3.2](https://www.elastic.co/cn/downloads/elasticsearch) 并解压。 +2. 启动 ES 服务 +首先修改`config/elasticsearch.yml`的配置: +``` +xpack.security.enabled: false +``` +然后启动: +```bash +./bin/elasticsearch +``` +3. 检查确保 ES 服务启动成功 +```bash +curl http://localhost:9200/_aliases?pretty=true +``` +备注:ES 服务默认开启端口为 9200 + +#### 3.4.2 文档数据写入 ANN 索引库 +``` +# 以DuReader-Robust 数据集为例建立 ANN 索引库 +python utils/offline_ann_mm.py --index_name wukong_test \ + --doc_dir data/wukong_test \ + --search_engine elastic \ + --delete_index +``` +可以使用下面的命令来查看数据: + +``` +# 打印几条数据 +curl http://localhost:9200/wukong_test/_search +``` + +参数含义说明 +* `index_name`: 索引的名称 +* `doc_dir`: txt文本数据的路径 +* `host`: ANN索引引擎的IP地址 +* `port`: ANN索引引擎的端口号 +* `search_engine`: 选择的近似索引引擎elastic,milvus,默认elastic +* `delete_index`: 是否删除现有的索引和数据,用于清空es的数据,默认为false + +删除索引也可以使用下面的命令: + +``` +curl -XDELETE http://localhost:9200/wukong_test +``` + +#### 3.4.3 启动 RestAPI 模型服务 +```bash +# 指定文图跨模态检索系统的Yaml配置文件 +export PIPELINE_YAML_PATH=rest_api/pipeline/image_text_retrieval.yaml +# 使用端口号 8891 启动模型服务 +python rest_api/application.py 8891 +``` +Linux 用户推荐采用 Shell 脚本来启动服务:: + +```bash +sh examples/image_text_retrieval/run_search_server.sh +``` +启动后可以使用curl命令验证是否成功运行: + +``` +curl -X POST -k http://localhost:8891/query -H 'Content-Type: application/json' -d '{"query": "云南普者黑现纯白色⒌蒂莲","params": {"Retriever": {"top_k": 5}}}' +``` + +更多API接口文档及其调用方式请参考链接[http://127.0.0.1:8891/docs](http://127.0.0.1:8891/docs) + +#### 3.4.4 启动 WebUI +```bash +# 配置模型服务地址 +export API_ENDPOINT=http://127.0.0.1:8891 +# 在指定端口 8502 启动 WebUI +python ui/webapp_image_text_retrieval.py --server.port 8502 +``` +Linux 用户推荐采用 Shell 脚本来启动服务:: + +```bash +sh examples/image_text_retrieval/run_search_web.sh +``` + +到这里您就可以打开浏览器访问 http://127.0.0.1:8502 地址体验文图跨模态检索系统服务了。 + +#### 3.4.5 数据更新 + +数据更新使用前面的 `utils/offline_ann_mm.py`进行数据更新,把图片放在特定目录,然后传入该目录即可: + +``` +python utils/offline_ann_mm.py --index_name wukong_test \ + --doc_dir data/wukong_test \ + --port 9200 \ + --search_engine elastic \ + --delete_index +``` + + +如果安装遇见问题可以查看[FAQ文档](../../FAQ.md) + +## Reference +[1]Y. Sun et al., “[ERNIE 3.0: Large-scale Knowledge Enhanced Pre-training for Language Understanding and Generation](https://arxiv.org/pdf/2107.02137.pdf),” arXiv:2107.02137 [cs], Jul. 2021, Accessed: Jan. 17, 2022. [Online]. Available: http://arxiv.org/abs/2107.02137 + +[2]Shan, Bin, et al. "[ERNIE-ViL 2.0: Multi-view Contrastive Learning for Image-Text Pre-training](https://arxiv.org/abs/2209.15270)." arXiv preprint arXiv:2209.15270 (2022). + +## Acknowledge + +我们借鉴了 Deepset.ai [Haystack](https://github.com/deepset-ai/haystack) 优秀的框架设计,在此对[Haystack](https://github.com/deepset-ai/haystack)作者及其开源社区表示感谢。 + +We learn form the excellent framework design of Deepset.ai [Haystack](https://github.com/deepset-ai/haystack), and we would like to express our thanks to the authors of Haystack and their open source community. diff --git a/pipelines/examples/image_text_retrieval/image_text_retrieval_example.py b/pipelines/examples/image_text_retrieval/image_text_retrieval_example.py new file mode 100644 index 000000000000..671d30251159 --- /dev/null +++ b/pipelines/examples/image_text_retrieval/image_text_retrieval_example.py @@ -0,0 +1,64 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +from pipelines.document_stores import FAISSDocumentStore +from pipelines.nodes import MultiModalRetriever +from pipelines.pipelines import Pipeline +from pipelines.schema import Document + +# yapf: disable +parser = argparse.ArgumentParser() +parser.add_argument('--device', choices=['cpu', 'gpu'], default="gpu", help="Select which device to run dense_qa system, defaults to gpu.") +parser.add_argument("--index_name", default='wukong_test', type=str, help="The ann index name of ANN.") +parser.add_argument("--embedding_dim", default=768, type=int, help="The embedding_dim of index") +parser.add_argument("--query_embedding_model", default="PaddlePaddle/ernie_vil-2.0-base-zh", type=str, help="The query_embedding_model path") +parser.add_argument("--document_embedding_model", default="PaddlePaddle/ernie_vil-2.0-base-zh", type=str, help="The document_embedding_model path") +args = parser.parse_args() +# yapf: enable + + +def image_text_retrieval_tutorial(): + faiss_document_store = "faiss_document_store.db" + if os.path.exists(args.index_name) and os.path.exists(faiss_document_store): + # connect to existed FAISS Index + document_store = FAISSDocumentStore.load(args.index_name) + retriever_mm = MultiModalRetriever( + document_store=document_store, + query_embedding_model=args.query_embedding_model, + query_type="text", + document_embedding_models={"image": args.document_embedding_model}, + ) + else: + doc_dir = "data/wukong_test" + document_store = FAISSDocumentStore(embedding_dim=args.embedding_dim, faiss_index_factory_str="Flat") + docs = [Document(content=f"./{doc_dir}/{filename}", content_type="image") for filename in os.listdir(doc_dir)] + retriever_mm = MultiModalRetriever( + document_store=document_store, + query_embedding_model=args.query_embedding_model, + query_type="text", + document_embedding_models={"image": args.document_embedding_model}, + ) + document_store.write_documents(docs) + document_store.update_embeddings(retriever_mm) + pipe = Pipeline() + pipe.add_node(component=retriever_mm, name="Retriever", inputs=["Query"]) + result = pipe.run(query="云南普者黑现纯白色⒌蒂莲", params={"Retriever": {"top_k": 5}}) + print(result) + + +if __name__ == "__main__": + image_text_retrieval_tutorial() diff --git a/pipelines/examples/image_text_retrieval/run_search_server.sh b/pipelines/examples/image_text_retrieval/run_search_server.sh new file mode 100644 index 000000000000..791d7ab53466 --- /dev/null +++ b/pipelines/examples/image_text_retrieval/run_search_server.sh @@ -0,0 +1,20 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +unset http_proxy && unset https_proxy +# 指定语义检索系统的Yaml配置文件 +export CUDA_VISIBLE_DEVICES=0 +export PIPELINE_YAML_PATH=rest_api/pipeline/image_text_retrieval.yaml +# 使用端口号 8891 启动模型服务 +python rest_api/application.py 8891 \ No newline at end of file diff --git a/pipelines/examples/image_text_retrieval/run_search_web.sh b/pipelines/examples/image_text_retrieval/run_search_web.sh new file mode 100644 index 000000000000..4f6027cd82d9 --- /dev/null +++ b/pipelines/examples/image_text_retrieval/run_search_web.sh @@ -0,0 +1,18 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# 配置模型服务地址 +export API_ENDPOINT=http://127.0.0.1:8891 +# 在指定端口 8502 启动 WebUI +python ui/webapp_image_text_retrieval.py --serving_port 8502 \ No newline at end of file diff --git a/pipelines/pipelines/nodes/retriever/embedder.py b/pipelines/pipelines/nodes/retriever/embedder.py index 8d14bbb1e038..3dffabfd2506 100644 --- a/pipelines/pipelines/nodes/retriever/embedder.py +++ b/pipelines/pipelines/nodes/retriever/embedder.py @@ -83,7 +83,7 @@ def __init__( self.models = {} # replace str with ContentTypes starting from Python3.8 for content_type, embedding_model in embedding_models.items(): - self.models[content_type] = Taskflow("image_text_retrieval") + self.models[content_type] = Taskflow("image_text_retrieval", model=embedding_model) # Check embedding sizes for models: they must all match if len(self.models) > 1: diff --git a/pipelines/rest_api/pipeline/image_text_retrieval.yaml b/pipelines/rest_api/pipeline/image_text_retrieval.yaml new file mode 100644 index 000000000000..29fb716d2129 --- /dev/null +++ b/pipelines/rest_api/pipeline/image_text_retrieval.yaml @@ -0,0 +1,25 @@ +version: '1.1.0' + +components: # define all the building-blocks for Pipeline + - name: DocumentStore + type: ElasticsearchDocumentStore # consider using Milvus2DocumentStore or WeaviateDocumentStore for scaling to large number of documents + params: + host: localhost + port: 9200 + index: wukong + embedding_dim: 768 + - name: Retriever + type: MultiModalRetriever + params: + document_store: DocumentStore # params can reference other components defined in the YAML + top_k: 10 + query_embedding_model: PaddlePaddle/ernie_vil-2.0-base-zh + document_embedding_models: + image: PaddlePaddle/ernie_vil-2.0-base-zh + +pipelines: + - name: query + type: Query + nodes: + - name: Retriever + inputs: [Query] \ No newline at end of file diff --git a/pipelines/ui/utils.py b/pipelines/ui/utils.py index 2a2b37237dfe..53bdfabd9f8a 100644 --- a/pipelines/ui/utils.py +++ b/pipelines/ui/utils.py @@ -230,6 +230,37 @@ def text_to_image_search( return results, response +def image_text_search(query, filters={}, top_k_retriever=5) -> Tuple[List[Dict[str, Any]], Dict[str, str]]: + """ + Send a query to the REST API and parse the answer. + Returns both a ready-to-use representation of the results and the raw JSON. + """ + + url = f"{API_ENDPOINT}/{DOC_REQUEST}" + params = {"filters": filters, "Retriever": {"top_k": top_k_retriever}} + req = {"query": query, "params": params} + response_raw = requests.post(url, json=req) + + if response_raw.status_code >= 400 and response_raw.status_code != 503: + raise Exception(f"{vars(response_raw)}") + + response = response_raw.json() + if "errors" in response: + raise Exception(", ".join(response["errors"])) + + # Format response + results = [] + answers = response["documents"] + for answer in answers: + results.append( + { + "context": answer["content"], + "relevance": round(answer["meta"]["es_ann_score"] * 100, 2), + } + ) + return results, response + + def text_to_qa_pair_search(query, is_filter=True) -> Tuple[List[Dict[str, Any]], Dict[str, str]]: """ Send a prompt text and corresponding parameters to the REST API diff --git a/pipelines/ui/webapp_image_text_retrieval.py b/pipelines/ui/webapp_image_text_retrieval.py new file mode 100644 index 000000000000..995bb4ff7149 --- /dev/null +++ b/pipelines/ui/webapp_image_text_retrieval.py @@ -0,0 +1,58 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import gradio as gr +from utils import image_text_search + +# yapf: disable +parser = argparse.ArgumentParser() +parser.add_argument("--serving_port", default=8502, type=int, help="Port for the serving.") +args = parser.parse_args() +# yapf: enable + + +def infer(query, top_k_retriever): + results, response = image_text_search(query, top_k_retriever=top_k_retriever) + images = [item["context"] for item in results] + return images + + +def main(): + block = gr.Blocks() + title = "

文图跨模态搜索应用

" + description = "本项目为ERNIE-ViL 2.0等CLIP中文版模型的DEMO,可用于图文检索和图像、文本的表征提取,应用于文图搜索、文图推荐、零样本分类、视频检索等应用场景。" + + with block: + gr.Markdown(title) + gr.Markdown(description) + with gr.Row(): + with gr.Column(scale=1): + with gr.Column(scale=2): + text = gr.Textbox(value="云南普者黑现纯白色⒌蒂莲", label="请填写文本", elem_id=0, interactive=True) + top_k = gr.components.Slider(minimum=0, maximum=50, step=1, value=8, label="返回图片数", elem_id=2) + btn = gr.Button( + "搜索", + ) + with gr.Column(scale=100): + out = gr.Gallery(label="检索结果为:").style(grid=4, height=200) + inputs = [text, top_k] + btn.click(fn=infer, inputs=inputs, outputs=out) + return block + + +if __name__ == "__main__": + block = main() + block.launch(server_name="0.0.0.0", server_port=args.serving_port, share=False) diff --git a/pipelines/utils/offline_ann_mm.py b/pipelines/utils/offline_ann_mm.py new file mode 100644 index 000000000000..b7d79e61a981 --- /dev/null +++ b/pipelines/utils/offline_ann_mm.py @@ -0,0 +1,112 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +from pipelines.document_stores import ElasticsearchDocumentStore, MilvusDocumentStore +from pipelines.nodes import MultiModalRetriever +from pipelines.schema import Document +from pipelines.utils import fetch_archive_from_http, launch_es + +data_dict = { + "data/wukong_test": "https://paddlenlp.bj.bcebos.com/applications/wukong_test_demo.zip", +} + +# yapf: disable +parser = argparse.ArgumentParser() +parser.add_argument("--index_name", default="wukong_test", type=str, help="The index name of the ANN search engine") +parser.add_argument("--doc_dir", default="data/wukong_test", type=str, help="The doc path of the corpus") +parser.add_argument("--search_engine", choices=["elastic", "milvus"], default="elastic", help="The type of ANN search engine.") +parser.add_argument("--host", type=str, default="127.0.0.1", help="host ip of ANN search engine") +parser.add_argument("--port", type=str, default="9200", help="port of ANN search engine") +parser.add_argument("--embedding_dim", default=768, type=int, help="The embedding_dim of index") +parser.add_argument("--query_embedding_model", default="PaddlePaddle/ernie_vil-2.0-base-zh", type=str, help="The query_embedding_model path") +parser.add_argument("--document_embedding_model", default="PaddlePaddle/ernie_vil-2.0-base-zh", type=str, help="The document_embedding_model path") +parser.add_argument("--delete_index", action="store_true", help="Whether to delete existing index while updating index") +args = parser.parse_args() +# yapf: enable + + +def offline_ann(index_name, doc_dir): + + if args.search_engine == "milvus": + document_store = MilvusDocumentStore( + embedding_dim=args.embedding_dim, + host=args.host, + index=args.index_name, + port=args.port, + index_param={"M": 16, "efConstruction": 50}, + index_type="HNSW", + ) + else: + launch_es() + document_store = ElasticsearchDocumentStore( + host=args.host, + port=args.port, + username="", + password="", + embedding_dim=args.embedding_dim, + index=index_name, + ) + docs = [ + Document(content=f"./{args.doc_dir}/{filename}", content_type="image") for filename in os.listdir(args.doc_dir) + ] + + print(docs[:3]) + + # 文档数据写入数据库 + document_store.write_documents(docs) + + # 语义索引模型 + retriever_mm = MultiModalRetriever( + document_store=document_store, + query_embedding_model=args.query_embedding_model, + query_type="text", + document_embedding_models={"image": args.document_embedding_model}, + ) + + # 建立索引库 + document_store.update_embeddings(retriever_mm) + + +def delete_data(index_name): + if args.search_engine == "milvus": + document_store = MilvusDocumentStore( + embedding_dim=args.embedding_dim, + host=args.host, + index=args.index_name, + port=args.port, + index_param={"M": 16, "efConstruction": 50}, + index_type="HNSW", + ) + else: + document_store = ElasticsearchDocumentStore( + host=args.host, + port=args.port, + username="", + password="", + embedding_dim=args.embedding_dim, + index=index_name, + ) + document_store.delete_index(index_name) + print("Delete an existing elasticsearch index {} Done.".format(index_name)) + + +if __name__ == "__main__": + if args.doc_dir in data_dict: + fetch_archive_from_http(url=data_dict[args.doc_dir], output_dir=args.doc_dir) + if args.delete_index: + delete_data(args.index_name) + offline_ann(args.index_name, args.doc_dir) From 3bcb2120d66c4aae0a8a69767069307c1902a8a5 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Mon, 6 Feb 2023 12:18:55 +0000 Subject: [PATCH 05/18] Change image text retrieval to feature extraction --- paddlenlp/taskflow/feature_extraction.py | 243 ++++++++++++++++++ paddlenlp/taskflow/image_text_retrieval.py | 137 ---------- paddlenlp/taskflow/taskflow.py | 30 +-- .../pipelines/nodes/retriever/embedder.py | 2 +- 4 files changed, 259 insertions(+), 153 deletions(-) create mode 100644 paddlenlp/taskflow/feature_extraction.py delete mode 100644 paddlenlp/taskflow/image_text_retrieval.py diff --git a/paddlenlp/taskflow/feature_extraction.py b/paddlenlp/taskflow/feature_extraction.py new file mode 100644 index 000000000000..b6a45ec9d1b2 --- /dev/null +++ b/paddlenlp/taskflow/feature_extraction.py @@ -0,0 +1,243 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import numpy as np +import paddle +from PIL import Image + +from ..transformers import AutoModel, AutoProcessor +from ..utils.log import logger +from .task import Task +from .utils import dygraph_mode_guard, static_mode_guard + + +class MultimodalFeatureExtractionTask(Task): + """ + The text_to_image generation model to generate the image. + Args: + task(string): The name of task. + model(string): The model name in the task. + kwargs (dict, optional): Additional keyword arguments passed along to the specific task. + """ + + def __init__(self, task, model, batch_size=1, _static_mode=True, **kwargs): + super().__init__(task=task, model=model, **kwargs) + self._seed = None + # we do not use batch + self._batch_size = batch_size + self._construct_tokenizer(model_name=model) + self._static_mode = _static_mode + self._config_map = {} + self.predictor_map = {} + self.input_names_map = {} + self.input_handles_map = {} + self.output_handle_map = {} + if self._static_mode: + self._get_inference_model() + else: + self._construct_model(model) + + def _construct_model(self, model): + """ + Construct the inference model for the predictor. + """ + self._model = AutoModel.from_pretrained(model) + self._model.eval() + + def _construct_tokenizer(self, model_name): + """ + Construct the tokenizer for the predictor. + """ + self._processor = AutoProcessor.from_pretrained(model_name) + + def _batchify(self, data, batch_size): + """ + Generate input batches. + """ + + def _parse_batch(batch_examples): + if isinstance(batch_examples[0], str): + batch_texts = batch_examples + batch_images = None + else: + batch_texts = None + batch_images = batch_examples + if self._static_mode: + tokenized_inputs = self._processor( + text=batch_texts, images=batch_images, return_tensors="np", padding="max_length", truncation=True + ) + else: + tokenized_inputs = self._processor( + text=batch_texts, images=batch_images, return_tensors="pd", padding="max_length", truncation=True + ) + return tokenized_inputs + + # Seperates data into some batches. + one_batch = [] + for example in data: + one_batch.append(example) + if len(one_batch) == batch_size: + yield _parse_batch(one_batch) + one_batch = [] + if one_batch: + yield _parse_batch(one_batch) + + def _check_input_text(self, inputs): + """ + Check whether the input text meet the requirement. + """ + inputs = inputs[0] + if isinstance(inputs, (str, Image.Image)): + if len(inputs) == 0: + raise ValueError("Invalid inputs, input text/image should not be empty, please check your input.") + inputs = [inputs] + elif isinstance(inputs, list): + # and len(inputs[0].strip()) > 0 + if not (isinstance(inputs[0], (str, Image.Image))): + raise TypeError( + "Invalid inputs, input text/image should be list of str/PIL.image, and first element of list should not be empty." + ) + else: + raise TypeError( + "Invalid inputs, input text should be str or list of str, but type of {} found!".format(type(inputs)) + ) + return inputs + + def _preprocess(self, inputs): + """ + Transform the raw text to the model inputs, two steps involved: + 1) Transform the raw text to token ids. + 2) Generate the other model inputs from the raw text and token ids. + """ + inputs = self._check_input_text(inputs) + batches = self._batchify(inputs, self._batch_size) + outputs = {"batches": batches, "text": inputs} + return outputs + + def _run_model(self, inputs): + """ + Run the task model from the outputs of the `_preprocess` function. + """ + all_feats = [] + if self._static_mode: + with static_mode_guard(): + for batch_inputs in inputs["batches"]: + if "input_ids" in batch_inputs: + self.input_handles_map["text"][0].copy_from_cpu(batch_inputs["input_ids"]) + self.predictor_map["text"].run() + text_features = self.output_handle_map["text"][0].copy_to_cpu() + all_feats.append(text_features) + elif "pixel_values" in batch_inputs: + self.input_handles_map["image"][0].copy_from_cpu(batch_inputs["pixel_values"]) + self.predictor_map["image"].run() + image_features = self.output_handle_map["image"][0].copy_to_cpu() + all_feats.append(image_features) + else: + for batch_inputs in inputs["batches"]: + if "input_ids" in batch_inputs: + text_features = self._model.get_text_features(input_ids=batch_inputs["input_ids"]) + all_feats.append(text_features) + if "pixel_values" in batch_inputs: + image_features = self._model.get_image_features(pixel_values=batch_inputs["pixel_values"]) + all_feats.append(image_features) + inputs.update({"features": all_feats}) + return inputs + + def _postprocess(self, inputs): + if self._static_mode: + inputs["features"] = paddle.to_tensor(np.concatenate(inputs["features"], axis=0)) + else: + inputs["features"] = paddle.concat(inputs["features"], axis=0) + return inputs + + def _construct_input_spec(self): + """ + Construct the input spec for the convert dygraph model to static model. + """ + + self._input_text_spec = [ + paddle.static.InputSpec(shape=[None, None], dtype="int64", name="input_ids"), + ] + + self._input_image_spec = [ + paddle.static.InputSpec(shape=[None, 3, 224, 224], dtype="float32", name="pixel_values"), + ] + + def _convert_dygraph_to_static(self): + """ + Convert the dygraph model to static model. + """ + assert ( + self._model is not None + ), "The dygraph model must be created before converting the dygraph model to static model." + assert ( + self._input_image_spec is not None or self._input_text_spec is not None + ), "The input spec must be created before converting the dygraph model to static model." + logger.info("Converting to the inference model cost a little time.") + + static_model = paddle.jit.to_static(self._model.get_text_features, input_spec=self._input_text_spec) + self.inference_model_path = self.inference_text_model_path + paddle.jit.save(static_model, self.inference_model_path) + logger.info("The inference model save in the path:{}".format(self.inference_model_path)) + + static_model = paddle.jit.to_static(self._model.get_image_features, input_spec=self._input_image_spec) + self.inference_model_path = self.inference_image_model_path + paddle.jit.save(static_model, self.inference_model_path) + logger.info("The inference model save in the path:{}".format(self.inference_model_path)) + + def _get_inference_model(self): + """ + Return the inference program, inputs and outputs in static mode. + """ + _base_path = os.path.join(self._home_path, "taskflow", self.task, self.model) + self.inference_image_model_path = os.path.join(_base_path, "static", "get_image_features") + self.inference_text_model_path = os.path.join(_base_path, "static", "get_text_features") + if ( + not os.path.exists(self.inference_image_model_path + ".pdiparams") + or self._param_updated + or not os.path.exists(self.inference_text_model_path + ".pdiparams") + ): + with dygraph_mode_guard(): + self._construct_model(self.model) + self._construct_input_spec() + self._convert_dygraph_to_static() + if self._predictor_type == "paddle-inference": + # Get text inference model + self.inference_model_path = self.inference_text_model_path + self._static_model_file = self.inference_model_path + ".pdmodel" + self._static_params_file = self.inference_model_path + ".pdiparams" + self._config = paddle.inference.Config(self._static_model_file, self._static_params_file) + self._prepare_static_mode() + + self.predictor_map["text"] = self.predictor + self.input_names_map["text"] = self.input_names + self.input_handles_map["text"] = self.input_handles + self.output_handle_map["text"] = self.output_handle + self._config_map["text"] = self._config + + # Get image inference model + self.inference_model_path = self.inference_image_model_path + self._static_model_file = self.inference_model_path + ".pdmodel" + self._static_params_file = self.inference_model_path + ".pdiparams" + self._config = paddle.inference.Config(self._static_model_file, self._static_params_file) + self._prepare_static_mode() + + self.predictor_map["image"] = self.predictor + self.input_names_map["image"] = self.input_names + self.input_handles_map["image"] = self.input_handles + self.output_handle_map["image"] = self.output_handle + self._config_map["image"] = self._config + else: + self._prepare_onnx_mode() diff --git a/paddlenlp/taskflow/image_text_retrieval.py b/paddlenlp/taskflow/image_text_retrieval.py deleted file mode 100644 index c8a4447020d7..000000000000 --- a/paddlenlp/taskflow/image_text_retrieval.py +++ /dev/null @@ -1,137 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import paddle -from PIL import Image - -from ..transformers import AutoModel, AutoProcessor -from .task import Task - - -class ImageTextRetrievalTask(Task): - """ - The text_to_image generation model to generate the image. - Args: - task(string): The name of task. - model(string): The model name in the task. - kwargs (dict, optional): Additional keyword arguments passed along to the specific task. - """ - - def __init__(self, task, model, batch_size=1, **kwargs): - super().__init__(task=task, model=model, **kwargs) - self._seed = None - # we do not use batch - self._batch_size = batch_size - self._construct_tokenizer(model_name=model) - self._construct_model(model) - - def _construct_model(self, model): - """ - Construct the inference model for the predictor. - """ - self._model = AutoModel.from_pretrained(model) - self._model.eval() - - def _construct_tokenizer(self, model_name): - """ - Construct the tokenizer for the predictor. - """ - self._processor = AutoProcessor.from_pretrained(model_name) - - def _batchify(self, data, batch_size): - """ - Generate input batches. - """ - - def _parse_batch(batch_examples): - if isinstance(batch_examples[0], str): - batch_texts = batch_examples - batch_images = None - else: - batch_texts = None - batch_images = batch_examples - - tokenized_inputs = self._processor( - text=batch_texts, images=batch_images, return_tensors="pd", padding="max_length", truncation=True - ) - return tokenized_inputs - - # Seperates data into some batches. - one_batch = [] - for example in data: - one_batch.append(example) - if len(one_batch) == batch_size: - yield _parse_batch(one_batch) - one_batch = [] - if one_batch: - yield _parse_batch(one_batch) - - def _check_input_text(self, inputs): - """ - Check whether the input text meet the requirement. - """ - inputs = inputs[0] - if isinstance(inputs, (str, Image.Image)): - if len(inputs) == 0: - raise ValueError("Invalid inputs, input text/image should not be empty, please check your input.") - inputs = [inputs] - elif isinstance(inputs, list): - # and len(inputs[0].strip()) > 0 - if not (isinstance(inputs[0], (str, Image.Image))): - raise TypeError( - "Invalid inputs, input text/image should be list of str/PIL.image, and first element of list should not be empty." - ) - else: - raise TypeError( - "Invalid inputs, input text should be str or list of str, but type of {} found!".format(type(inputs)) - ) - return inputs - - def _preprocess(self, inputs): - """ - Transform the raw text to the model inputs, two steps involved: - 1) Transform the raw text to token ids. - 2) Generate the other model inputs from the raw text and token ids. - """ - inputs = self._check_input_text(inputs) - batches = self._batchify(inputs, self._batch_size) - outputs = {"batches": batches, "text": inputs} - return outputs - - def _run_model(self, inputs): - """ - Run the task model from the outputs of the `_preprocess` function. - """ - all_feats = [] - for batch_inputs in inputs["batches"]: - if "input_ids" in batch_inputs: - text_features = self._model.get_text_features(input_ids=batch_inputs["input_ids"]) - all_feats.append(text_features) - if "pixel_values" in batch_inputs: - image_features = self._model.get_image_features(pixel_values=batch_inputs["pixel_values"]) - all_feats.append(image_features) - inputs.update({"features": all_feats}) - return inputs - - def _postprocess(self, inputs): - inputs["features"] = paddle.concat(inputs["features"], axis=0) - return inputs - - def _construct_input_spec(self): - """ - Construct the input spec for the convert dygraph model to static model. - """ - self._input_spec = [ - paddle.static.InputSpec(shape=[None, None], dtype="int64", name="input_ids"), - ] diff --git a/paddlenlp/taskflow/taskflow.py b/paddlenlp/taskflow/taskflow.py index 3ef4bacbb633..a5aa5fa43537 100644 --- a/paddlenlp/taskflow/taskflow.py +++ b/paddlenlp/taskflow/taskflow.py @@ -23,8 +23,8 @@ from .dependency_parsing import DDParserTask from .dialogue import DialogueTask from .document_intelligence import DocPromptTask +from .feature_extraction import MultimodalFeatureExtractionTask from .fill_mask import FillMaskTask -from .image_text_retrieval import ImageTextRetrievalTask from .information_extraction import GPTask, UIETask from .knowledge_mining import NPTagTask, WordTagTask from .lexical_analysis import LacTask @@ -487,69 +487,69 @@ }, "default": {"model": "utc-large"}, }, - "image_text_retrieval": { + "feature_extraction": { "models": { "PaddlePaddle/ernie_vil-2.0-base-zh": { - "task_class": ImageTextRetrievalTask, + "task_class": MultimodalFeatureExtractionTask, "task_flag": "image_text_retrieval-2.0-base-zh", }, "OFA-Sys/chinese-clip-vit-base-patch16": { - "task_class": ImageTextRetrievalTask, + "task_class": MultimodalFeatureExtractionTask, "task_flag": "image_text_retrieval-OFA-Sys/chinese-clip-vit-base-patch16", "task_priority_path": "OFA-Sys/chinese-clip-vit-base-patch16", }, "OFA-Sys/chinese-clip-vit-huge-patch14": { - "task_class": ImageTextRetrievalTask, + "task_class": MultimodalFeatureExtractionTask, "task_flag": "image_text_retrieval-OFA-Sys/chinese-clip-vit-huge-patch14", "task_priority_path": "OFA-Sys/chinese-clip-vit-huge-patch14", }, "OFA-Sys/chinese-clip-vit-large-patch14": { - "task_class": ImageTextRetrievalTask, + "task_class": MultimodalFeatureExtractionTask, "task_flag": "image_text_retrieval-OFA-Sys/chinese-clip-vit-large-patch14", "task_priority_path": "OFA-Sys/chinese-clip-vit-large-patch14", }, "OFA-Sys/chinese-clip-vit-large-patch14-336px": { - "task_class": ImageTextRetrievalTask, + "task_class": MultimodalFeatureExtractionTask, "task_flag": "image_text_retrieval-OFA-Sys/chinese-clip-vit-large-patch14-336px", "task_priority_path": "OFA-Sys/chinese-clip-vit-large-patch14-336px", }, "openai/clip-vit-base-patch32": { - "task_class": ImageTextRetrievalTask, + "task_class": MultimodalFeatureExtractionTask, "task_flag": "image_text_retrieval-openai/clip-vit-base-patch32", "task_priority_path": "openai/clip-vit-base-patch32", }, "openai/clip-vit-base-patch16": { - "task_class": ImageTextRetrievalTask, + "task_class": MultimodalFeatureExtractionTask, "task_flag": "image_text_retrieval-openai/clip-vit-base-patch16", "task_priority_path": "openai/clip-vit-base-patch16", }, "openai/clip-vit-large-patch14": { - "task_class": ImageTextRetrievalTask, + "task_class": MultimodalFeatureExtractionTask, "task_flag": "image_text_retrieval-openai/clip-vit-large-patch14", "task_priority_path": "openai/clip-vit-large-patch14", }, "laion/CLIP-ViT-H-14-laion2B-s32B-b79K": { - "task_class": ImageTextRetrievalTask, + "task_class": MultimodalFeatureExtractionTask, "task_flag": "image_text_retrieval-laion/CLIP-ViT-H-14-laion2B-s32B-b79K", "task_priority_path": "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", }, "laion/CLIP-ViT-B-32-laion2B-s34B-b79K": { - "task_class": ImageTextRetrievalTask, + "task_class": MultimodalFeatureExtractionTask, "task_flag": "image_text_retrieval-laion/CLIP-ViT-B-32-laion2B-s34B-b79K", "task_priority_path": "laion/CLIP-ViT-B-32-laion2B-s34B-b79K", }, "openai/clip-rn50": { - "task_class": ImageTextRetrievalTask, + "task_class": MultimodalFeatureExtractionTask, "task_flag": "image_text_retrieval-openai/clip-rn50", "task_priority_path": "openai/clip-rn50", }, "openai/clip-rn101": { - "task_class": ImageTextRetrievalTask, + "task_class": MultimodalFeatureExtractionTask, "task_flag": "image_text_retrieval-openai/clip-rn101", "task_priority_path": "openai/clip-rn101", }, "openai/clip-rn50x4": { - "task_class": ImageTextRetrievalTask, + "task_class": MultimodalFeatureExtractionTask, "task_flag": "image_text_retrieval-openai/clip-rn50x4", "task_priority_path": "openai/clip-rn50x4", }, diff --git a/pipelines/pipelines/nodes/retriever/embedder.py b/pipelines/pipelines/nodes/retriever/embedder.py index 3dffabfd2506..99a9c2c2f5b1 100644 --- a/pipelines/pipelines/nodes/retriever/embedder.py +++ b/pipelines/pipelines/nodes/retriever/embedder.py @@ -83,7 +83,7 @@ def __init__( self.models = {} # replace str with ContentTypes starting from Python3.8 for content_type, embedding_model in embedding_models.items(): - self.models[content_type] = Taskflow("image_text_retrieval", model=embedding_model) + self.models[content_type] = Taskflow("feature_extraction", model=embedding_model) # Check embedding sizes for models: they must all match if len(self.models) > 1: From dd4c83b76304be12d45c0f7dd7dc36740977ced2 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Thu, 9 Feb 2023 06:02:23 +0000 Subject: [PATCH 06/18] Add feature extraction docs --- docs/model_zoo/taskflow.md | 77 +++++++++++++++++++ paddlenlp/taskflow/feature_extraction.py | 51 ++++++++++-- .../image_text_retrieval/run_search_web.sh | 2 +- .../pipelines/nodes/retriever/embedder.py | 6 +- pipelines/utils/offline_ann_mm.py | 2 +- 5 files changed, 126 insertions(+), 12 deletions(-) diff --git a/docs/model_zoo/taskflow.md b/docs/model_zoo/taskflow.md index 42602b3ca854..9dce5bda49f3 100644 --- a/docs/model_zoo/taskflow.md +++ b/docs/model_zoo/taskflow.md @@ -47,6 +47,7 @@ PaddleNLP提供**开箱即用**的产业级NLP预置任务能力,无需训练 | [文档智能](#文档智能) | `Taskflow("document_intelligence")` | ✅ | ✅ | ✅ | ✅ | | 以多语言跨模态布局增强文档预训练模型ERNIE-Layout为核心底座 | | [问题生成](#问题生成) | `Taskflow("question_generation")` | ✅ | ✅ | ✅ | ✅ | | 问题生成大模型 | | [零样本文本分类](#零样本文本分类) | `Taskflow("zero_shot_text_classification")` | ✅ | ✅ | ✅ | | ✅ | 集成多场景的通用文本分类工具 | +| [通用特征提取](#通用特征提取) | `Taskflow("feature_extraction")` | ✅ | ✅ | ✅ | | | 集成文本,图片的特征抽取工具 | ## QuickStart @@ -1778,6 +1779,82 @@ from paddlenlp import Taskflow * `pred_threshold`:模型对标签预测的概率在0~1之间,返回结果去掉小于这个阈值的结果,默认为0.5。 * `precision`:选择模型精度,默认为`fp32`,可选有`fp16`和`fp32`。`fp16`推理速度更快。如果选择`fp16`,请先确保机器正确安装NVIDIA相关驱动和基础软件,**确保CUDA>=11.2,cuDNN>=8.1.1**,初次使用需按照提示安装相关依赖。其次,需要确保GPU设备的CUDA计算能力(CUDA Compute Capability)大于7.0,典型的设备包括V100、T4、A10、A100、GTX 20系列和30系列显卡等。更多关于CUDA Compute Capability和精度支持情况请参考NVIDIA文档:[GPU硬件与支持精度对照表](https://docs.nvidia.com/deeplearning/tensorrt/archives/tensorrt-840-ea/support-matrix/index.html#hardware-precision-matrix)。 + + +### 模型特征提取 + +
  基于百度自研中文图文跨模态预训练模型ERNIE-ViL 2.0
+ +#### 支持单条、批量预测 + +```python +>>> from paddlenlp import Taskflow +>>> from PIL import Image +>>> import paddle.nn.functional as F +# 单条输入 +>>> image_embeds = vision_language(Image.open("demo/000000039769.jpg")) +>>> image_embeds["features"] +Tensor(shape=[1, 768], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [[-0.59475428, -0.69795364, 0.22144008, 0.88066685, -0.58184201, +# 单条输入 +>>> text_embeds = vision_language("猫的照片") +>>> text_embeds['features'] +Tensor(shape=[1, 768], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [[ 0.04250504, -0.41429776, 0.26163983, 0.29910022, 0.39019185, + -0.41884750, -0.19893740, 0.44328332, 0.08186490, 0.10953025, + ...... + +# 多条输入 +>>> image_embeds = vision_language([Image.open("demo/000000039769.jpg")]) +>>> image_embeds["features"] +Tensor(shape=[1, 768], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [[-0.59475428, -0.69795364, 0.22144008, 0.88066685, -0.58184201, + ...... +# 多条输入 +>>> text_embeds = vision_language(["猫的照片","狗的照片"]) +>>> text_embeds["features"] +Tensor(shape=[2, 768], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [[ 0.04250504, -0.41429776, 0.26163983, ..., 0.26221892, + 0.34387422, 0.18779707], + [ 0.06672225, -0.41456309, 0.13787819, ..., 0.21791610, + 0.36693242, 0.34208685]]) +>>> image_features = image_embeds["features"] +>>> text_features = text_embeds["features"] +>>> image_features /= image_features.norm(axis=-1, keepdim=True) +>>> text_features /= text_features.norm(axis=-1, keepdim=True) +>>> logits_per_image = 100 * image_features @ text_features.t() +>>> probs = F.softmax(logits_per_image, axis=-1) +>>> probs +Tensor(shape=[1, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [[0.99833173, 0.00166824]]) +``` +#### 模型选择 + +- 多模型选择,满足精度、速度要求 + + | 模型 | 视觉| 文本 | 语言 | + | :---: | :--------: | :--------: | :--------: | + | `PaddlePaddle/ernie_vil-2.0-base-zh` (默认) | ViT | ERNIE | 中文 | + | `OFA-Sys/chinese-clip-vit-base-patch16` | ViT-B/16 |RoBERTa-wwm-Base| 中文 | + | `OFA-Sys/chinese-clip-vit-huge-patch14` | ViT-H/14 |RoBERTa-wwm-Large | 中文 | + | `OFA-Sys/chinese-clip-vit-large-patch14` | ViT-L/14 | RoBERTa-wwm-Base | 中文 | + | `OFA-Sys/chinese-clip-vit-large-patch14-336px` | ViT-L/14 | RoBERTa-wwm-Base | 中文 | + | `openai/clip-vit-base-patch32` | ViT-B/32 | transformer结构| 英文 | + | `openai/clip-vit-base-patch16` | ViT-B/16| transformer结构 | 英文 | + | `openai/clip-vit-large-patch14` | ViT-L/14 | transformer结构 | 英文 | + | `laion/CLIP-ViT-H-14-laion2B-s32B-b79K` | ViT-H/14 | transformer结构 | 英文 | + | `laion/CLIP-ViT-B-32-laion2B-s34B-b79K` | ViT-B/32 | transformer结构 | 英文 | + | `openai/clip-rn50` | RN50 | transformer结构 | 英文 | + | `openai/clip-rn101` | RN101 | transformer结构 | 英文 | + | `openai/clip-rn50x4` | RN50*4 | transformer结构 | 英文 | + +#### 可配置参数说明 +* `batch_size`:批处理大小,请结合机器情况进行调整,默认为1。 +* `_static_mode`:静态图模式,默认开启。 +* `model`:选择任务使用的模型,默认为`PaddlePaddle/ernie_vil-2.0-base-zh`。 + +
+ ## PART Ⅱ   定制化训练
适配任务列表
diff --git a/paddlenlp/taskflow/feature_extraction.py b/paddlenlp/taskflow/feature_extraction.py index b6a45ec9d1b2..8486b57e5842 100644 --- a/paddlenlp/taskflow/feature_extraction.py +++ b/paddlenlp/taskflow/feature_extraction.py @@ -22,10 +22,45 @@ from .task import Task from .utils import dygraph_mode_guard, static_mode_guard +usage = r""" + from paddlenlp import Taskflow + from PIL import Image + + # multi modal feature_extraction with ernie_vil-2.0-base-zh + senta = Taskflow("feature_extraction") + image_embeds = vision_language([Image.open("demo/000000039769.jpg")]) + print(image_embeds) + ''' + Tensor(shape=[1, 768], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [[-0.59475428, -0.69795364, 0.22144008, 0.88066685, -0.58184201, + -0.73454666, 0.95557910, -0.61410815, 0.23474170, 0.13301648, + 0.86196446, 0.12281934, 0.69097638, 1.47614217, 0.07238606, + ... + ''' + text_embeds = vision_language(["猫的照片","狗的照片"]) + text_features = text_embeds["features"] + print(text_features) + ''' + Tensor(shape=[2, 768], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [[ 0.04250504, -0.41429776, 0.26163983, ..., 0.26221892, + 0.34387422, 0.18779707], + ''' + image_features /= image_features.norm(axis=-1, keepdim=True) + text_features /= text_features.norm(axis=-1, keepdim=True) + logits_per_image = 100 * image_features @ text_features.t() + probs = F.softmax(logits_per_image, axis=-1) + print(probs) + ''' + Tensor(shape=[1, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [[0.99833173, 0.00166824]]) + ''' + """ + class MultimodalFeatureExtractionTask(Task): """ - The text_to_image generation model to generate the image. + Feature extraction task using no model head. This task extracts the hidden states from the base + model, which can be used as features in retrieval and clustering tasks. Args: task(string): The name of task. model(string): The model name in the task. @@ -99,9 +134,11 @@ def _check_input_text(self, inputs): Check whether the input text meet the requirement. """ inputs = inputs[0] - if isinstance(inputs, (str, Image.Image)): + if isinstance(inputs, str): if len(inputs) == 0: - raise ValueError("Invalid inputs, input text/image should not be empty, please check your input.") + raise ValueError("Invalid inputs, input text should not be empty, please check your input.") + inputs = [inputs] + elif isinstance(inputs, Image.Image): inputs = [inputs] elif isinstance(inputs, list): # and len(inputs[0].strip()) > 0 @@ -117,13 +154,13 @@ def _check_input_text(self, inputs): def _preprocess(self, inputs): """ - Transform the raw text to the model inputs, two steps involved: - 1) Transform the raw text to token ids. - 2) Generate the other model inputs from the raw text and token ids. + Transform the raw inputs to the model inputs, two steps involved: + 1) Transform the raw text/image to token ids/pixel_values. + 2) Generate the other model inputs from the raw text/image and token ids/pixel_values. """ inputs = self._check_input_text(inputs) batches = self._batchify(inputs, self._batch_size) - outputs = {"batches": batches, "text": inputs} + outputs = {"batches": batches, "inputs": inputs} return outputs def _run_model(self, inputs): diff --git a/pipelines/examples/image_text_retrieval/run_search_web.sh b/pipelines/examples/image_text_retrieval/run_search_web.sh index 4f6027cd82d9..fec15e00d197 100644 --- a/pipelines/examples/image_text_retrieval/run_search_web.sh +++ b/pipelines/examples/image_text_retrieval/run_search_web.sh @@ -1,4 +1,4 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pipelines/pipelines/nodes/retriever/embedder.py b/pipelines/pipelines/nodes/retriever/embedder.py index 99a9c2c2f5b1..1e5d40569b3e 100644 --- a/pipelines/pipelines/nodes/retriever/embedder.py +++ b/pipelines/pipelines/nodes/retriever/embedder.py @@ -78,7 +78,7 @@ def __init__( feature_extractors_params = { content_type: {"max_length": 256, **(feature_extractors_params or {}).get(content_type, {})} - for content_type in ["text", "table", "image", "audio"] # FIXME get_args(ContentTypes) from Python3.8 on + for content_type in ["text", "table", "image"] # FIXME get_args(ContentTypes) from Python3.8 on } self.models = {} # replace str with ContentTypes starting from Python3.8 @@ -90,7 +90,7 @@ def __init__( sizes = {model.embedding_dim for model in self.models.values()} if None in sizes: logger.warning( - "Haystack could not find the output embedding dimensions for '%s'. " + "Pipelines could not find the output embedding dimensions for '%s'. " "Dimensions won't be checked before computing the embeddings.", ", ".join( { @@ -167,7 +167,7 @@ def _docs_to_data( of a text document, a linearized table, a PIL image object, and so on) """ docs_data: Dict[str, List[Any]] = { # FIXME replace str to ContentTypes from Python3.8 - key: [] for key in ["text", "table", "image", "audio"] + key: [] for key in ["text", "table", "image"] } # FIXME get_args(ContentTypes) from Python3.8 on for doc in documents: try: diff --git a/pipelines/utils/offline_ann_mm.py b/pipelines/utils/offline_ann_mm.py index b7d79e61a981..01a41f92cce6 100644 --- a/pipelines/utils/offline_ann_mm.py +++ b/pipelines/utils/offline_ann_mm.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From f1c08b91a074e47891a5aa9c8a2a140dd4ab1bc1 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Thu, 9 Feb 2023 13:35:59 +0000 Subject: [PATCH 07/18] Add onnx support --- paddlenlp/taskflow/feature_extraction.py | 92 ++++++++++++++++--- .../pipelines/nodes/retriever/embedder.py | 16 +--- 2 files changed, 86 insertions(+), 22 deletions(-) diff --git a/paddlenlp/taskflow/feature_extraction.py b/paddlenlp/taskflow/feature_extraction.py index 8486b57e5842..83e5438b6991 100644 --- a/paddlenlp/taskflow/feature_extraction.py +++ b/paddlenlp/taskflow/feature_extraction.py @@ -71,6 +71,7 @@ def __init__(self, task, model, batch_size=1, _static_mode=True, **kwargs): super().__init__(task=task, model=model, **kwargs) self._seed = None # we do not use batch + self.mode = "text" self._batch_size = batch_size self._construct_tokenizer(model_name=model) self._static_mode = _static_mode @@ -79,6 +80,7 @@ def __init__(self, task, model, batch_size=1, _static_mode=True, **kwargs): self.input_names_map = {} self.input_handles_map = {} self.output_handle_map = {} + self._check_predictor_type() if self._static_mode: self._get_inference_model() else: @@ -171,16 +173,29 @@ def _run_model(self, inputs): if self._static_mode: with static_mode_guard(): for batch_inputs in inputs["batches"]: - if "input_ids" in batch_inputs: - self.input_handles_map["text"][0].copy_from_cpu(batch_inputs["input_ids"]) - self.predictor_map["text"].run() - text_features = self.output_handle_map["text"][0].copy_to_cpu() - all_feats.append(text_features) - elif "pixel_values" in batch_inputs: - self.input_handles_map["image"][0].copy_from_cpu(batch_inputs["pixel_values"]) - self.predictor_map["image"].run() - image_features = self.output_handle_map["image"][0].copy_to_cpu() - all_feats.append(image_features) + if self._predictor_type == "paddle-inference": + if "input_ids" in batch_inputs: + self.input_handles_map["text"][0].copy_from_cpu(batch_inputs["input_ids"]) + self.predictor_map["text"].run() + text_features = self.output_handle_map["text"][0].copy_to_cpu() + all_feats.append(text_features) + elif "pixel_values" in batch_inputs: + self.input_handles_map["image"][0].copy_from_cpu(batch_inputs["pixel_values"]) + self.predictor_map["image"].run() + image_features = self.output_handle_map["image"][0].copy_to_cpu() + all_feats.append(image_features) + else: + # onnx mode + if "input_ids" in batch_inputs: + input_dict = {} + input_dict["input_ids"] = batch_inputs["input_ids"] + text_features = self.predictor_map["text"].run(None, input_dict)[0].tolist() + all_feats.append(text_features) + elif "pixel_values" in batch_inputs: + input_dict = {} + input_dict["pixel_values"] = batch_inputs["pixel_values"] + image_features = self.predictor_map["image"].run(None, input_dict)[0].tolist() + all_feats.append(image_features) else: for batch_inputs in inputs["batches"]: if "input_ids" in batch_inputs: @@ -203,7 +218,6 @@ def _construct_input_spec(self): """ Construct the input spec for the convert dygraph model to static model. """ - self._input_text_spec = [ paddle.static.InputSpec(shape=[None, None], dtype="int64", name="input_ids"), ] @@ -234,6 +248,48 @@ def _convert_dygraph_to_static(self): paddle.jit.save(static_model, self.inference_model_path) logger.info("The inference model save in the path:{}".format(self.inference_model_path)) + def _prepare_onnx_mode(self): + try: + import onnx + import onnxruntime as ort + import paddle2onnx + from onnxconverter_common import float16 + except ImportError: + logger.warning( + "The inference precision is change to 'fp32', please install the dependencies that required for 'fp16' inference, pip install onnxruntime-gpu onnx onnxconverter-common" + ) + + onnx_dir = os.path.join(self._task_path, "onnx", self.mode) + # onnx_dir = os.path.join(self._task_path, "onnx") + if not os.path.exists(onnx_dir): + os.makedirs(onnx_dir) + float_onnx_file = os.path.join(onnx_dir, "model.onnx") + if not os.path.exists(float_onnx_file) or self._param_updated: + onnx_model = paddle2onnx.command.c_paddle_to_onnx( + model_file=self._static_model_file, + params_file=self._static_params_file, + opset_version=13, + enable_onnx_checker=True, + ) + with open(float_onnx_file, "wb") as f: + f.write(onnx_model) + fp16_model_file = os.path.join(onnx_dir, "fp16_model.onnx") + if not os.path.exists(fp16_model_file) or self._param_updated: + onnx_model = onnx.load_model(float_onnx_file) + trans_model = float16.convert_float_to_float16(onnx_model, keep_io_types=True) + onnx.save_model(trans_model, fp16_model_file) + providers = [("CUDAExecutionProvider", {"device_id": self.kwargs["device_id"]})] + sess_options = ort.SessionOptions() + sess_options.intra_op_num_threads = self._num_threads + sess_options.inter_op_num_threads = self._num_threads + self.predictor = ort.InferenceSession(fp16_model_file, sess_options=sess_options, providers=providers) + assert "CUDAExecutionProvider" in self.predictor.get_providers(), ( + "The environment for GPU inference is not set properly. " + "A possible cause is that you had installed both onnxruntime and onnxruntime-gpu. " + "Please run the following commands to reinstall: \n " + "1) pip uninstall -y onnxruntime onnxruntime-gpu \n 2) pip install onnxruntime-gpu" + ) + def _get_inference_model(self): """ Return the inference program, inputs and outputs in static mode. @@ -277,4 +333,18 @@ def _get_inference_model(self): self.output_handle_map["image"] = self.output_handle self._config_map["image"] = self._config else: + # Get text inference model + self.mode = "text" + self.inference_model_path = self.inference_text_model_path + self._static_model_file = self.inference_model_path + ".pdmodel" + self._static_params_file = self.inference_model_path + ".pdiparams" self._prepare_onnx_mode() + self.predictor_map["text"] = self.predictor + + # Get image inference model + self.mode = "image" + self.inference_model_path = self.inference_image_model_path + self._static_model_file = self.inference_model_path + ".pdmodel" + self._static_params_file = self.inference_model_path + ".pdiparams" + self._prepare_onnx_mode() + self.predictor_map["image"] = self.predictor diff --git a/pipelines/pipelines/nodes/retriever/embedder.py b/pipelines/pipelines/nodes/retriever/embedder.py index 1e5d40569b3e..c00e3f63b71c 100644 --- a/pipelines/pipelines/nodes/retriever/embedder.py +++ b/pipelines/pipelines/nodes/retriever/embedder.py @@ -28,17 +28,13 @@ FilterType = Dict[str, Union[Dict[str, Any], List[Any], str, int, float, bool]] -# TODO the keys should match with ContentTypes (currently 'audio' is missing) DOCUMENT_CONVERTERS = { # NOTE: Keep this '?' cleaning step, it needs to be double-checked for impact on the inference results. "text": lambda doc: doc.content[:-1] if doc.content[-1] == "?" else doc.content, - "table": lambda doc: " ".join( - doc.content.columns.tolist() + [cell for row in doc.content.values.tolist() for cell in row] - ), "image": lambda doc: Image.open(doc.content), } -CAN_EMBED_META = ["text", "table"] +CAN_EMBED_META = ["text"] class MultiModalEmbedder: @@ -52,13 +48,11 @@ def __init__( ): """ Init the Retriever and all its models from a local or remote model checkpoint. - The checkpoint format matches the Hugging Face transformers' model format. :param embedding_models: A dictionary matching a local path or remote name of encoder checkpoint with - the content type it should handle ("text", "table", "image", etc...). - The format is the one that Hugging Face Hub models use. + the content type it should handle ("text", "image", etc...). Expected input format: `{'text': 'name_or_path_to_text_model', 'image': 'name_or_path_to_image_model', ... }` Keep in mind that the models should output in the same embedding space for this retriever to work. - :param feature_extractors_params: A dictionary matching a content type ("text", "table", "image" and so on) with the + :param feature_extractors_params: A dictionary matching a content type ("text", "image" and so on) with the parameters of its own feature extractor if the model requires one. Expected input format: `{'text': {'param_name': 'param_value', ...}, 'image': {'param_name': 'param_value', ...}, ...}` :param batch_size: Number of questions or passages to encode at once. In case of multiple GPUs, this will be the total batch size. @@ -78,7 +72,7 @@ def __init__( feature_extractors_params = { content_type: {"max_length": 256, **(feature_extractors_params or {}).get(content_type, {})} - for content_type in ["text", "table", "image"] # FIXME get_args(ContentTypes) from Python3.8 on + for content_type in ["text", "image"] # FIXME get_args(ContentTypes) from Python3.8 on } self.models = {} # replace str with ContentTypes starting from Python3.8 @@ -167,7 +161,7 @@ def _docs_to_data( of a text document, a linearized table, a PIL image object, and so on) """ docs_data: Dict[str, List[Any]] = { # FIXME replace str to ContentTypes from Python3.8 - key: [] for key in ["text", "table", "image"] + key: [] for key in ["text", "image"] } # FIXME get_args(ContentTypes) from Python3.8 on for doc in documents: try: From 5840835883d6429802ec7e57aab2b0f6aa5f6e12 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Fri, 10 Feb 2023 10:25:11 +0000 Subject: [PATCH 08/18] Fix some bugs and remove unused comments --- paddlenlp/taskflow/feature_extraction.py | 5 ++--- .../image_text_retrieval_example.py | 13 ++++++++++++- pipelines/pipelines/nodes/retriever/embedder.py | 5 ++++- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/paddlenlp/taskflow/feature_extraction.py b/paddlenlp/taskflow/feature_extraction.py index 83e5438b6991..3b6c96b84968 100644 --- a/paddlenlp/taskflow/feature_extraction.py +++ b/paddlenlp/taskflow/feature_extraction.py @@ -260,7 +260,6 @@ def _prepare_onnx_mode(self): ) onnx_dir = os.path.join(self._task_path, "onnx", self.mode) - # onnx_dir = os.path.join(self._task_path, "onnx") if not os.path.exists(onnx_dir): os.makedirs(onnx_dir) float_onnx_file = os.path.join(onnx_dir, "model.onnx") @@ -333,7 +332,7 @@ def _get_inference_model(self): self.output_handle_map["image"] = self.output_handle self._config_map["image"] = self._config else: - # Get text inference model + # Get text onnx model self.mode = "text" self.inference_model_path = self.inference_text_model_path self._static_model_file = self.inference_model_path + ".pdmodel" @@ -341,7 +340,7 @@ def _get_inference_model(self): self._prepare_onnx_mode() self.predictor_map["text"] = self.predictor - # Get image inference model + # Get image onnx model self.mode = "image" self.inference_model_path = self.inference_image_model_path self._static_model_file = self.inference_model_path + ".pdmodel" diff --git a/pipelines/examples/image_text_retrieval/image_text_retrieval_example.py b/pipelines/examples/image_text_retrieval/image_text_retrieval_example.py index 671d30251159..b92117c1d165 100644 --- a/pipelines/examples/image_text_retrieval/image_text_retrieval_example.py +++ b/pipelines/examples/image_text_retrieval/image_text_retrieval_example.py @@ -19,6 +19,7 @@ from pipelines.nodes import MultiModalRetriever from pipelines.pipelines import Pipeline from pipelines.schema import Document +from pipelines.utils import fetch_archive_from_http # yapf: disable parser = argparse.ArgumentParser() @@ -34,7 +35,7 @@ def image_text_retrieval_tutorial(): faiss_document_store = "faiss_document_store.db" if os.path.exists(args.index_name) and os.path.exists(faiss_document_store): - # connect to existed FAISS Index + # Connect to existed FAISS Index document_store = FAISSDocumentStore.load(args.index_name) retriever_mm = MultiModalRetriever( document_store=document_store, @@ -44,6 +45,12 @@ def image_text_retrieval_tutorial(): ) else: doc_dir = "data/wukong_test" + wukong_data = "https://paddlenlp.bj.bcebos.com/applications/wukong_test_demo.zip" + fetch_archive_from_http(url=wukong_data, output_dir=doc_dir) + if os.path.exists(args.index_name): + os.remove(args.index_name) + if os.path.exists(faiss_document_store): + os.remove(faiss_document_store) document_store = FAISSDocumentStore(embedding_dim=args.embedding_dim, faiss_index_factory_str="Flat") docs = [Document(content=f"./{doc_dir}/{filename}", content_type="image") for filename in os.listdir(doc_dir)] retriever_mm = MultiModalRetriever( @@ -52,8 +59,12 @@ def image_text_retrieval_tutorial(): query_type="text", document_embedding_models={"image": args.document_embedding_model}, ) + # Update metadata document_store.write_documents(docs) + # Update Embedding document_store.update_embeddings(retriever_mm) + # Save index + document_store.save(args.index_name) pipe = Pipeline() pipe.add_node(component=retriever_mm, name="Retriever", inputs=["Query"]) result = pipe.run(query="云南普者黑现纯白色⒌蒂莲", params={"Retriever": {"top_k": 5}}) diff --git a/pipelines/pipelines/nodes/retriever/embedder.py b/pipelines/pipelines/nodes/retriever/embedder.py index c00e3f63b71c..4cba0f78b0a9 100644 --- a/pipelines/pipelines/nodes/retriever/embedder.py +++ b/pipelines/pipelines/nodes/retriever/embedder.py @@ -77,7 +77,10 @@ def __init__( self.models = {} # replace str with ContentTypes starting from Python3.8 for content_type, embedding_model in embedding_models.items(): - self.models[content_type] = Taskflow("feature_extraction", model=embedding_model) + if content_type in ["text", "image"]: + self.models[content_type] = Taskflow("feature_extraction", model=embedding_model) + else: + raise ValueError(f"{content_type} is not a supported content.") # Check embedding sizes for models: they must all match if len(self.models) > 1: From 932cd027f33687ef74f47b85d1b5ee5a70ea50c4 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Mon, 13 Feb 2023 05:26:22 +0000 Subject: [PATCH 09/18] fix some errors and adjust onnx ouput config --- docs/model_zoo/taskflow.md | 1 + paddlenlp/taskflow/feature_extraction.py | 43 +----------------------- paddlenlp/taskflow/task.py | 8 ++++- paddlenlp/taskflow/taskflow.py | 27 ++++++++------- 4 files changed, 23 insertions(+), 56 deletions(-) diff --git a/docs/model_zoo/taskflow.md b/docs/model_zoo/taskflow.md index a631f0b6bb2d..cb432844ee53 100644 --- a/docs/model_zoo/taskflow.md +++ b/docs/model_zoo/taskflow.md @@ -1791,6 +1791,7 @@ from paddlenlp import Taskflow >>> from paddlenlp import Taskflow >>> from PIL import Image >>> import paddle.nn.functional as F +>>> vision_language= Taskflow("feature_extraction") # 单条输入 >>> image_embeds = vision_language(Image.open("demo/000000039769.jpg")) >>> image_embeds["features"] diff --git a/paddlenlp/taskflow/feature_extraction.py b/paddlenlp/taskflow/feature_extraction.py index 3b6c96b84968..4946902e08d8 100644 --- a/paddlenlp/taskflow/feature_extraction.py +++ b/paddlenlp/taskflow/feature_extraction.py @@ -27,7 +27,7 @@ from PIL import Image # multi modal feature_extraction with ernie_vil-2.0-base-zh - senta = Taskflow("feature_extraction") + vision_language = Taskflow("feature_extraction") image_embeds = vision_language([Image.open("demo/000000039769.jpg")]) print(image_embeds) ''' @@ -248,47 +248,6 @@ def _convert_dygraph_to_static(self): paddle.jit.save(static_model, self.inference_model_path) logger.info("The inference model save in the path:{}".format(self.inference_model_path)) - def _prepare_onnx_mode(self): - try: - import onnx - import onnxruntime as ort - import paddle2onnx - from onnxconverter_common import float16 - except ImportError: - logger.warning( - "The inference precision is change to 'fp32', please install the dependencies that required for 'fp16' inference, pip install onnxruntime-gpu onnx onnxconverter-common" - ) - - onnx_dir = os.path.join(self._task_path, "onnx", self.mode) - if not os.path.exists(onnx_dir): - os.makedirs(onnx_dir) - float_onnx_file = os.path.join(onnx_dir, "model.onnx") - if not os.path.exists(float_onnx_file) or self._param_updated: - onnx_model = paddle2onnx.command.c_paddle_to_onnx( - model_file=self._static_model_file, - params_file=self._static_params_file, - opset_version=13, - enable_onnx_checker=True, - ) - with open(float_onnx_file, "wb") as f: - f.write(onnx_model) - fp16_model_file = os.path.join(onnx_dir, "fp16_model.onnx") - if not os.path.exists(fp16_model_file) or self._param_updated: - onnx_model = onnx.load_model(float_onnx_file) - trans_model = float16.convert_float_to_float16(onnx_model, keep_io_types=True) - onnx.save_model(trans_model, fp16_model_file) - providers = [("CUDAExecutionProvider", {"device_id": self.kwargs["device_id"]})] - sess_options = ort.SessionOptions() - sess_options.intra_op_num_threads = self._num_threads - sess_options.inter_op_num_threads = self._num_threads - self.predictor = ort.InferenceSession(fp16_model_file, sess_options=sess_options, providers=providers) - assert "CUDAExecutionProvider" in self.predictor.get_providers(), ( - "The environment for GPU inference is not set properly. " - "A possible cause is that you had installed both onnxruntime and onnxruntime-gpu. " - "Please run the following commands to reinstall: \n " - "1) pip uninstall -y onnxruntime onnxruntime-gpu \n 2) pip install onnxruntime-gpu" - ) - def _get_inference_model(self): """ Return the inference program, inputs and outputs in static mode. diff --git a/paddlenlp/taskflow/task.py b/paddlenlp/taskflow/task.py index a27410af525b..153b9ae00bbd 100644 --- a/paddlenlp/taskflow/task.py +++ b/paddlenlp/taskflow/task.py @@ -61,6 +61,8 @@ def __init__(self, model, task, priority_path=None, **kwargs): self._home_path = self.kwargs["home_path"] if "home_path" in self.kwargs else PPNLP_HOME self._task_flag = self.kwargs["task_flag"] if "task_flag" in self.kwargs else self.model self.from_hf_hub = kwargs.pop("from_hf_hub", False) + # Add mode flag for onnx output path redirection + self.mode = None if "task_path" in self.kwargs: self._task_path = self.kwargs["task_path"] @@ -221,8 +223,12 @@ def _prepare_onnx_mode(self): logger.warning( "The inference precision is change to 'fp32', please install the dependencies that required for 'fp16' inference, pip install onnxruntime-gpu onnx onnxconverter-common" ) + if self.mode is None: + onnx_dir = os.path.join(self._task_path, "onnx") + else: + # Compatible multimodal model for saving image and text path + onnx_dir = os.path.join(self._task_path, "onnx", self.mode) - onnx_dir = os.path.join(self._task_path, "onnx") if not os.path.exists(onnx_dir): os.mkdir(onnx_dir) float_onnx_file = os.path.join(onnx_dir, "model.onnx") diff --git a/paddlenlp/taskflow/taskflow.py b/paddlenlp/taskflow/taskflow.py index 57061cf727bc..da0e1da6f502 100644 --- a/paddlenlp/taskflow/taskflow.py +++ b/paddlenlp/taskflow/taskflow.py @@ -519,66 +519,67 @@ "models": { "PaddlePaddle/ernie_vil-2.0-base-zh": { "task_class": MultimodalFeatureExtractionTask, - "task_flag": "image_text_retrieval-2.0-base-zh", + "task_flag": "feature_extraction-PaddlePaddle/ernie_vil-2.0-base-zh", + "task_priority_path": "PaddlePaddle/ernie_vil-2.0-base-zh", }, "OFA-Sys/chinese-clip-vit-base-patch16": { "task_class": MultimodalFeatureExtractionTask, - "task_flag": "image_text_retrieval-OFA-Sys/chinese-clip-vit-base-patch16", + "task_flag": "feature_extraction-OFA-Sys/chinese-clip-vit-base-patch16", "task_priority_path": "OFA-Sys/chinese-clip-vit-base-patch16", }, "OFA-Sys/chinese-clip-vit-huge-patch14": { "task_class": MultimodalFeatureExtractionTask, - "task_flag": "image_text_retrieval-OFA-Sys/chinese-clip-vit-huge-patch14", + "task_flag": "feature_extraction-OFA-Sys/chinese-clip-vit-huge-patch14", "task_priority_path": "OFA-Sys/chinese-clip-vit-huge-patch14", }, "OFA-Sys/chinese-clip-vit-large-patch14": { "task_class": MultimodalFeatureExtractionTask, - "task_flag": "image_text_retrieval-OFA-Sys/chinese-clip-vit-large-patch14", + "task_flag": "feature_extraction-OFA-Sys/chinese-clip-vit-large-patch14", "task_priority_path": "OFA-Sys/chinese-clip-vit-large-patch14", }, "OFA-Sys/chinese-clip-vit-large-patch14-336px": { "task_class": MultimodalFeatureExtractionTask, - "task_flag": "image_text_retrieval-OFA-Sys/chinese-clip-vit-large-patch14-336px", + "task_flag": "feature_extraction-OFA-Sys/chinese-clip-vit-large-patch14-336px", "task_priority_path": "OFA-Sys/chinese-clip-vit-large-patch14-336px", }, "openai/clip-vit-base-patch32": { "task_class": MultimodalFeatureExtractionTask, - "task_flag": "image_text_retrieval-openai/clip-vit-base-patch32", + "task_flag": "feature_extraction-openai/clip-vit-base-patch32", "task_priority_path": "openai/clip-vit-base-patch32", }, "openai/clip-vit-base-patch16": { "task_class": MultimodalFeatureExtractionTask, - "task_flag": "image_text_retrieval-openai/clip-vit-base-patch16", + "task_flag": "feature_extraction-openai/clip-vit-base-patch16", "task_priority_path": "openai/clip-vit-base-patch16", }, "openai/clip-vit-large-patch14": { "task_class": MultimodalFeatureExtractionTask, - "task_flag": "image_text_retrieval-openai/clip-vit-large-patch14", + "task_flag": "feature_extraction-openai/clip-vit-large-patch14", "task_priority_path": "openai/clip-vit-large-patch14", }, "laion/CLIP-ViT-H-14-laion2B-s32B-b79K": { "task_class": MultimodalFeatureExtractionTask, - "task_flag": "image_text_retrieval-laion/CLIP-ViT-H-14-laion2B-s32B-b79K", + "task_flag": "feature_extraction-laion/CLIP-ViT-H-14-laion2B-s32B-b79K", "task_priority_path": "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", }, "laion/CLIP-ViT-B-32-laion2B-s34B-b79K": { "task_class": MultimodalFeatureExtractionTask, - "task_flag": "image_text_retrieval-laion/CLIP-ViT-B-32-laion2B-s34B-b79K", + "task_flag": "feature_extraction-laion/CLIP-ViT-B-32-laion2B-s34B-b79K", "task_priority_path": "laion/CLIP-ViT-B-32-laion2B-s34B-b79K", }, "openai/clip-rn50": { "task_class": MultimodalFeatureExtractionTask, - "task_flag": "image_text_retrieval-openai/clip-rn50", + "task_flag": "feature_extraction-openai/clip-rn50", "task_priority_path": "openai/clip-rn50", }, "openai/clip-rn101": { "task_class": MultimodalFeatureExtractionTask, - "task_flag": "image_text_retrieval-openai/clip-rn101", + "task_flag": "feature_extraction-openai/clip-rn101", "task_priority_path": "openai/clip-rn101", }, "openai/clip-rn50x4": { "task_class": MultimodalFeatureExtractionTask, - "task_flag": "image_text_retrieval-openai/clip-rn50x4", + "task_flag": "feature_extraction-openai/clip-rn50x4", "task_priority_path": "openai/clip-rn50x4", }, }, From 70cf0d8ca076865b330131bff195021662c7885d Mon Sep 17 00:00:00 2001 From: w5688414 Date: Mon, 13 Feb 2023 05:31:57 +0000 Subject: [PATCH 10/18] Update docs --- docs/model_zoo/taskflow.md | 2 +- pipelines/pipelines/schema.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/model_zoo/taskflow.md b/docs/model_zoo/taskflow.md index cb432844ee53..2e8330eebba3 100644 --- a/docs/model_zoo/taskflow.md +++ b/docs/model_zoo/taskflow.md @@ -47,7 +47,7 @@ PaddleNLP提供**开箱即用**的产业级NLP预置任务能力,无需训练 | [文档智能](#文档智能) | `Taskflow("document_intelligence")` | ✅ | ✅ | ✅ | ✅ | | 以多语言跨模态布局增强文档预训练模型ERNIE-Layout为核心底座 | | [问题生成](#问题生成) | `Taskflow("question_generation")` | ✅ | ✅ | ✅ | ✅ | | 问题生成大模型 | | [零样本文本分类](#零样本文本分类) | `Taskflow("zero_shot_text_classification")` | ✅ | ✅ | ✅ | | ✅ | 集成多场景的通用文本分类工具 | -| [通用特征提取](#通用特征提取) | `Taskflow("feature_extraction")` | ✅ | ✅ | ✅ | | | 集成文本,图片的特征抽取工具 | +| [模型特征提取](#模型特征提取) | `Taskflow("feature_extraction")` | ✅ | ✅ | ✅ | | | 集成文本,图片的特征抽取工具 | ## QuickStart diff --git a/pipelines/pipelines/schema.py b/pipelines/pipelines/schema.py index 639c591a7ea3..08c24e6400f4 100644 --- a/pipelines/pipelines/schema.py +++ b/pipelines/pipelines/schema.py @@ -50,7 +50,7 @@ BaseConfig.arbitrary_types_allowed = True #: Types of content_types supported -ContentTypes = Literal["text", "table", "image", "audio"] +ContentTypes = Literal["text", "image"] FilterType = Dict[str, Union[Dict[str, Any], List[Any], str, int, float, bool]] From 858a686ad3dbc8c158f2555d3d624aa5ca72d8e3 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Mon, 13 Feb 2023 07:35:32 +0000 Subject: [PATCH 11/18] Add taskflow loading finetune model --- paddlenlp/taskflow/feature_extraction.py | 72 ++++++++++++++++++++++-- 1 file changed, 68 insertions(+), 4 deletions(-) diff --git a/paddlenlp/taskflow/feature_extraction.py b/paddlenlp/taskflow/feature_extraction.py index 4946902e08d8..1c4758a48e12 100644 --- a/paddlenlp/taskflow/feature_extraction.py +++ b/paddlenlp/taskflow/feature_extraction.py @@ -67,13 +67,77 @@ class MultimodalFeatureExtractionTask(Task): kwargs (dict, optional): Additional keyword arguments passed along to the specific task. """ + resource_files_names = { + "model_state": "model_state.pdparams", + "config": "config.json", + "vocab_file": "vocab.txt", + "preprocessor_config": "preprocessor_config.json", + "special_tokens_map": "special_tokens_map.json", + "tokenizer_config": "tokenizer_config.json", + } + resource_files_urls = { + "PaddlePaddle/ernie_vil-2.0-base-zh": { + "model_state": [ + "https://paddlenlp.bj.bcebos.com/models/community/PaddlePaddle/ernie_vil-2.0-base-zh/model_state.pdparams", + "38d8c8e01f74ba881e87d9a3f669e5ae", + ], + "config": [ + "https://paddlenlp.bj.bcebos.com/models/community/PaddlePaddle/ernie_vil-2.0-base-zh/config.json", + "caf929b450d5638e8df2a95c936519e7", + ], + "vocab_file": [ + "https://paddlenlp.bj.bcebos.com/models/community/PaddlePaddle/ernie_vil-2.0-base-zh/vocab.txt", + "1c1c1f4fd93c5bed3b4eebec4de976a8", + ], + "preprocessor_config": [ + "https://paddlenlp.bj.bcebos.com/models/community/PaddlePaddle/ernie_vil-2.0-base-zh/preprocessor_config.json", + "9a2e8da9f41896fedb86756b79355ee2", + ], + "special_tokens_map": [ + "https://paddlenlp.bj.bcebos.com/models/community/PaddlePaddle/ernie_vil-2.0-base-zh/special_tokens_map.json", + "8b3fb1023167bb4ab9d70708eb05f6ec", + ], + "tokenizer_config": [ + "https://paddlenlp.bj.bcebos.com/models/community/PaddlePaddle/ernie_vil-2.0-base-zh/tokenizer_config.json", + "da5385c23c8f522d33fc3aac829e4375", + ], + }, + "OFA-Sys/chinese-clip-vit-base-patch16": { + "model_state": [ + "https://paddlenlp.bj.bcebos.com/models/community/OFA-Sys/chinese-clip-vit-base-patch16/model_state.pdparams", + "d594c94833b8cfeffc4f986712b3ef79", + ], + "config": [ + "https://paddlenlp.bj.bcebos.com/models/community/OFA-Sys/chinese-clip-vit-base-patch16/config.json", + "3611b5c34ad69dcf91e3c1d03b01a93a", + ], + "vocab_file": [ + "https://paddlenlp.bj.bcebos.com/models/community/OFA-Sys/chinese-clip-vit-base-patch16/vocab.txt", + "3b5b76c4aef48ecf8cb3abaafe960f09", + ], + "preprocessor_config": [ + "https://paddlenlp.bj.bcebos.com/models/community/OFA-Sys/chinese-clip-vit-base-patch16/preprocessor_config.json", + "ba1fb66c75b18b3c9580ea5120e01ced", + ], + "special_tokens_map": [ + "https://paddlenlp.bj.bcebos.com/models/community/OFA-Sys/chinese-clip-vit-base-patch16/special_tokens_map.json", + "8b3fb1023167bb4ab9d70708eb05f6ec", + ], + "tokenizer_config": [ + "https://paddlenlp.bj.bcebos.com/models/community/OFA-Sys/chinese-clip-vit-base-patch16/tokenizer_config.json", + "573ba0466e15cdb5bd423ff7010735ce", + ], + }, + } + def __init__(self, task, model, batch_size=1, _static_mode=True, **kwargs): super().__init__(task=task, model=model, **kwargs) self._seed = None # we do not use batch self.mode = "text" self._batch_size = batch_size - self._construct_tokenizer(model_name=model) + self._check_task_files() + self._construct_tokenizer() self._static_mode = _static_mode self._config_map = {} self.predictor_map = {} @@ -90,14 +154,14 @@ def _construct_model(self, model): """ Construct the inference model for the predictor. """ - self._model = AutoModel.from_pretrained(model) + self._model = AutoModel.from_pretrained(self._task_path) self._model.eval() - def _construct_tokenizer(self, model_name): + def _construct_tokenizer(self): """ Construct the tokenizer for the predictor. """ - self._processor = AutoProcessor.from_pretrained(model_name) + self._processor = AutoProcessor.from_pretrained(self.model) def _batchify(self, data, batch_size): """ From 2651daeba42cee9103be0f9346b0083e8ff77ed8 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Mon, 13 Feb 2023 08:10:49 +0000 Subject: [PATCH 12/18] Rename mode to export_type --- paddlenlp/taskflow/feature_extraction.py | 6 +++--- paddlenlp/taskflow/task.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/paddlenlp/taskflow/feature_extraction.py b/paddlenlp/taskflow/feature_extraction.py index 1c4758a48e12..5fceda2fcadc 100644 --- a/paddlenlp/taskflow/feature_extraction.py +++ b/paddlenlp/taskflow/feature_extraction.py @@ -134,7 +134,7 @@ def __init__(self, task, model, batch_size=1, _static_mode=True, **kwargs): super().__init__(task=task, model=model, **kwargs) self._seed = None # we do not use batch - self.mode = "text" + self.export_type = "text" self._batch_size = batch_size self._check_task_files() self._construct_tokenizer() @@ -356,7 +356,7 @@ def _get_inference_model(self): self._config_map["image"] = self._config else: # Get text onnx model - self.mode = "text" + self.export_type = "text" self.inference_model_path = self.inference_text_model_path self._static_model_file = self.inference_model_path + ".pdmodel" self._static_params_file = self.inference_model_path + ".pdiparams" @@ -364,7 +364,7 @@ def _get_inference_model(self): self.predictor_map["text"] = self.predictor # Get image onnx model - self.mode = "image" + self.export_type = "image" self.inference_model_path = self.inference_image_model_path self._static_model_file = self.inference_model_path + ".pdmodel" self._static_params_file = self.inference_model_path + ".pdiparams" diff --git a/paddlenlp/taskflow/task.py b/paddlenlp/taskflow/task.py index 153b9ae00bbd..0e8a33c4b1a3 100644 --- a/paddlenlp/taskflow/task.py +++ b/paddlenlp/taskflow/task.py @@ -62,7 +62,7 @@ def __init__(self, model, task, priority_path=None, **kwargs): self._task_flag = self.kwargs["task_flag"] if "task_flag" in self.kwargs else self.model self.from_hf_hub = kwargs.pop("from_hf_hub", False) # Add mode flag for onnx output path redirection - self.mode = None + self.export_type = None if "task_path" in self.kwargs: self._task_path = self.kwargs["task_path"] @@ -223,11 +223,11 @@ def _prepare_onnx_mode(self): logger.warning( "The inference precision is change to 'fp32', please install the dependencies that required for 'fp16' inference, pip install onnxruntime-gpu onnx onnxconverter-common" ) - if self.mode is None: + if self.export_type is None: onnx_dir = os.path.join(self._task_path, "onnx") else: # Compatible multimodal model for saving image and text path - onnx_dir = os.path.join(self._task_path, "onnx", self.mode) + onnx_dir = os.path.join(self._task_path, "onnx", self.export_type) if not os.path.exists(onnx_dir): os.mkdir(onnx_dir) From f8fcbae9f1b3aa496eaa5f5d1825dd2337d316c1 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Mon, 13 Feb 2023 09:58:15 +0000 Subject: [PATCH 13/18] Remove clip english models --- docs/model_zoo/taskflow.md | 10 +---- paddlenlp/taskflow/feature_extraction.py | 52 ++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 9 deletions(-) diff --git a/docs/model_zoo/taskflow.md b/docs/model_zoo/taskflow.md index 2e8330eebba3..161cc802b85f 100644 --- a/docs/model_zoo/taskflow.md +++ b/docs/model_zoo/taskflow.md @@ -1837,17 +1837,9 @@ Tensor(shape=[1, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True, | :---: | :--------: | :--------: | :--------: | | `PaddlePaddle/ernie_vil-2.0-base-zh` (默认) | ViT | ERNIE | 中文 | | `OFA-Sys/chinese-clip-vit-base-patch16` | ViT-B/16 |RoBERTa-wwm-Base| 中文 | - | `OFA-Sys/chinese-clip-vit-huge-patch14` | ViT-H/14 |RoBERTa-wwm-Large | 中文 | | `OFA-Sys/chinese-clip-vit-large-patch14` | ViT-L/14 | RoBERTa-wwm-Base | 中文 | | `OFA-Sys/chinese-clip-vit-large-patch14-336px` | ViT-L/14 | RoBERTa-wwm-Base | 中文 | - | `openai/clip-vit-base-patch32` | ViT-B/32 | transformer结构| 英文 | - | `openai/clip-vit-base-patch16` | ViT-B/16| transformer结构 | 英文 | - | `openai/clip-vit-large-patch14` | ViT-L/14 | transformer结构 | 英文 | - | `laion/CLIP-ViT-H-14-laion2B-s32B-b79K` | ViT-H/14 | transformer结构 | 英文 | - | `laion/CLIP-ViT-B-32-laion2B-s34B-b79K` | ViT-B/32 | transformer结构 | 英文 | - | `openai/clip-rn50` | RN50 | transformer结构 | 英文 | - | `openai/clip-rn101` | RN101 | transformer结构 | 英文 | - | `openai/clip-rn50x4` | RN50*4 | transformer结构 | 英文 | + #### 可配置参数说明 * `batch_size`:批处理大小,请结合机器情况进行调整,默认为1。 diff --git a/paddlenlp/taskflow/feature_extraction.py b/paddlenlp/taskflow/feature_extraction.py index 5fceda2fcadc..b7afe217414a 100644 --- a/paddlenlp/taskflow/feature_extraction.py +++ b/paddlenlp/taskflow/feature_extraction.py @@ -128,6 +128,58 @@ class MultimodalFeatureExtractionTask(Task): "573ba0466e15cdb5bd423ff7010735ce", ], }, + "OFA-Sys/chinese-clip-vit-large-patch14": { + "model_state": [ + "https://paddlenlp.bj.bcebos.com/models/community/OFA-Sys/chinese-clip-vit-large-patch14/model_state.pdparams", + "5c0dde02d68179a9cc566173e53966c0", + ], + "config": [ + "https://paddlenlp.bj.bcebos.com/models/community/OFA-Sys/chinese-clip-vit-large-patch14/config.json", + "a5e35843aa87ab1106e9f60f1e16b96d", + ], + "vocab_file": [ + "https://paddlenlp.bj.bcebos.com/models/community/OFA-Sys/chinese-clip-vit-large-patch14/vocab.txt", + "3b5b76c4aef48ecf8cb3abaafe960f09", + ], + "preprocessor_config": [ + "https://paddlenlp.bj.bcebos.com/models/community/OFA-Sys/chinese-clip-vit-large-patch14/preprocessor_config.json", + "ba1fb66c75b18b3c9580ea5120e01ced", + ], + "special_tokens_map": [ + "https://paddlenlp.bj.bcebos.com/models/community/OFA-Sys/chinese-clip-vit-large-patch14/special_tokens_map.json", + "8b3fb1023167bb4ab9d70708eb05f6ec", + ], + "tokenizer_config": [ + "https://paddlenlp.bj.bcebos.com/models/community/OFA-Sys/chinese-clip-vit-large-patch14/tokenizer_config.json", + "573ba0466e15cdb5bd423ff7010735ce", + ], + }, + "OFA-Sys/chinese-clip-vit-large-patch14-336px": { + "model_state": [ + "https://paddlenlp.bj.bcebos.com/models/community/OFA-Sys/chinese-clip-vit-large-patch14-336px/model_state.pdparams", + "ee3eb7f9667cfb06338bea5757c5e0d7", + ], + "config": [ + "https://paddlenlp.bj.bcebos.com/models/community/OFA-Sys/chinese-clip-vit-large-patch14-336px/config.json", + "cb2794d99bea8c8f45901d177e663e1e", + ], + "vocab_file": [ + "https://paddlenlp.bj.bcebos.com/models/community/OFA-Sys/chinese-clip-vit-large-patch14-336px/vocab.txt", + "3b5b76c4aef48ecf8cb3abaafe960f09", + ], + "preprocessor_config": [ + "https://paddlenlp.bj.bcebos.com/models/community/OFA-Sys/chinese-clip-vit-large-patch14-336px/preprocessor_config.json", + "c52a0b3abe9bdd1c3c5a3d56797f4a03", + ], + "special_tokens_map": [ + "https://paddlenlp.bj.bcebos.com/models/community/OFA-Sys/chinese-clip-vit-large-patch14-336px/special_tokens_map.json", + "8b3fb1023167bb4ab9d70708eb05f6ec", + ], + "tokenizer_config": [ + "https://paddlenlp.bj.bcebos.com/models/community/OFA-Sys/chinese-clip-vit-large-patch14-336px/tokenizer_config.json", + "573ba0466e15cdb5bd423ff7010735ce", + ], + }, } def __init__(self, task, model, batch_size=1, _static_mode=True, **kwargs): From d6c5638dbd8eb0ac0db1d6dade95c909b8cbdea7 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Mon, 13 Feb 2023 12:25:55 +0000 Subject: [PATCH 14/18] Add unit test for feature extraction taskflow --- paddlenlp/taskflow/feature_extraction.py | 14 +- tests/taskflow/test_feature_extraction.py | 156 ++++++++++++++++++++++ 2 files changed, 163 insertions(+), 7 deletions(-) create mode 100644 tests/taskflow/test_feature_extraction.py diff --git a/paddlenlp/taskflow/feature_extraction.py b/paddlenlp/taskflow/feature_extraction.py index b7afe217414a..0c9f854f2440 100644 --- a/paddlenlp/taskflow/feature_extraction.py +++ b/paddlenlp/taskflow/feature_extraction.py @@ -182,12 +182,13 @@ class MultimodalFeatureExtractionTask(Task): }, } - def __init__(self, task, model, batch_size=1, _static_mode=True, **kwargs): + def __init__(self, task, model, batch_size=1, _static_mode=True, return_tensors=True, **kwargs): super().__init__(task=task, model=model, **kwargs) self._seed = None # we do not use batch self.export_type = "text" self._batch_size = batch_size + self.return_tensors = return_tensors self._check_task_files() self._construct_tokenizer() self._static_mode = _static_mode @@ -316,18 +317,17 @@ def _run_model(self, inputs): for batch_inputs in inputs["batches"]: if "input_ids" in batch_inputs: text_features = self._model.get_text_features(input_ids=batch_inputs["input_ids"]) - all_feats.append(text_features) + all_feats.append(text_features.numpy()) if "pixel_values" in batch_inputs: image_features = self._model.get_image_features(pixel_values=batch_inputs["pixel_values"]) - all_feats.append(image_features) + all_feats.append(image_features.numpy()) inputs.update({"features": all_feats}) return inputs def _postprocess(self, inputs): - if self._static_mode: - inputs["features"] = paddle.to_tensor(np.concatenate(inputs["features"], axis=0)) - else: - inputs["features"] = paddle.concat(inputs["features"], axis=0) + inputs["features"] = np.concatenate(inputs["features"], axis=0) + if self.return_tensors: + inputs["features"] = paddle.to_tensor(inputs["features"]) return inputs def _construct_input_spec(self): diff --git a/tests/taskflow/test_feature_extraction.py b/tests/taskflow/test_feature_extraction.py new file mode 100644 index 000000000000..7dcb8b036d51 --- /dev/null +++ b/tests/taskflow/test_feature_extraction.py @@ -0,0 +1,156 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from tempfile import TemporaryDirectory + +import numpy as np +import paddle +from PIL import Image + +from paddlenlp.taskflow import Taskflow +from paddlenlp.taskflow.feature_extraction import MultimodalFeatureExtractionTask + + +class TestMultimodalFeatureExtractionTask(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.temp_dir = TemporaryDirectory() + cls.batch_size = 2 + cls.max_resolution = 40 + cls.min_resolution = 30 + cls.num_channels = 3 + + @classmethod + def tearDownClass(cls): + cls.temp_dir.cleanup() + + def test_small_model_pd(self): + feature_extractor = Taskflow(task="feature_extraction") + outputs = feature_extractor("This is a test") + self.assertEqual(outputs["features"].shape, [1, 768]) + + def test_return_tensors_pd(self): + feature_extractor = Taskflow(task="feature_extraction", return_tensors=True) + outputs = feature_extractor( + "This is a test", + ) + self.assertTrue(paddle.is_tensor(outputs["features"])) + + def prepare_inputs(self, equal_resolution=False, numpify=False, paddleify=False): + """This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True, + or a list of PaddlePaddle tensors if one specifies paddleify=True. + """ + + assert not (numpify and paddleify), "You cannot specify both numpy and PaddlePaddle tensors at the same time" + + if equal_resolution: + image_inputs = [] + for i in range(self.batch_size): + image_inputs.append( + np.random.randint( + 255, size=(self.num_channels, self.max_resolution, self.max_resolution), dtype=np.uint8 + ) + ) + else: + image_inputs = [] + for i in range(self.batch_size): + width, height = np.random.choice(np.arange(self.min_resolution, self.max_resolution), 2) + image_inputs.append(np.random.randint(255, size=(self.num_channels, width, height), dtype=np.uint8)) + + if not numpify and not paddleify: + # PIL expects the channel dimension as last dimension + image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs] + + if paddleify: + image_inputs = [paddle.to_tensor(x) for x in image_inputs] + + return image_inputs + + def test_feature_extraction_task(self): + input_text = (["这是一只猫", "这是一只狗"],) + # dygraph text test + dygraph_taskflow = MultimodalFeatureExtractionTask( + model="PaddlePaddle/ernie_vil-2.0-base-zh", + task="feature_extraction", + _static_mode=False, + return_tensors=False, + ) + dygraph_results = dygraph_taskflow(input_text) + shape = dygraph_results["features"].shape + self.assertEqual(shape[0], 2) + # static text test + static_taskflow = MultimodalFeatureExtractionTask( + model="PaddlePaddle/ernie_vil-2.0-base-zh", + task="feature_extraction", + _static_mode=True, + return_tensors=False, + device_id=0, + ) + static_results = static_taskflow(input_text) + self.assertEqual(static_results["features"].shape[0], 2) + + for dygraph_result, static_result in zip(dygraph_results["features"], static_results["features"]): + for dygraph_pred, static_pred in zip(dygraph_result.tolist(), static_result.tolist()): + self.assertAlmostEqual(dygraph_pred, static_pred, delta=1e-6) + + input_image = (self.prepare_inputs(equal_resolution=True, paddleify=False),) + # dygraph image test + dygraph_results = dygraph_taskflow(input_image) + self.assertEqual(dygraph_results["features"].shape[0], 2) + + # static image test + static_results = static_taskflow(input_image) + self.assertEqual(static_results["features"].shape[0], 2) + + for dygraph_result, static_result in zip(dygraph_results["features"], static_results["features"]): + for dygraph_pred, static_pred in zip(dygraph_result.tolist(), static_result.tolist()): + self.assertAlmostEqual(dygraph_pred, static_pred, delta=1e-6) + + def test_taskflow_task(self): + input_text = ["这是一只猫", "这是一只狗"] + # dygraph test + dygraph_taskflow = Taskflow( + task="feature_extraction", + _static_mode=False, + return_tensors=False, + ) + dygraph_results = dygraph_taskflow(input_text) + shape = dygraph_results["features"].shape + + self.assertEqual(shape[0], 2) + # static test + static_taskflow = Taskflow( + task="feature_extraction", + _static_mode=True, + return_tensors=False, + ) + static_results = static_taskflow(input_text) + self.assertEqual(static_results["features"].shape[0], 2) + for dygraph_result, static_result in zip(dygraph_results["features"], static_results["features"]): + for dygraph_pred, static_pred in zip(dygraph_result.tolist(), static_result.tolist()): + self.assertAlmostEqual(dygraph_pred, static_pred, delta=1e-6) + + input_image = self.prepare_inputs(equal_resolution=True, paddleify=False) + # dygraph image test + dygraph_results = dygraph_taskflow(input_image) + self.assertEqual(dygraph_results["features"].shape[0], 2) + + # static image test + static_results = static_taskflow(input_image) + self.assertEqual(static_results["features"].shape[0], 2) + + for dygraph_result, static_result in zip(dygraph_results["features"], static_results["features"]): + for dygraph_pred, static_pred in zip(dygraph_result.tolist(), static_result.tolist()): + self.assertAlmostEqual(dygraph_pred, static_pred, delta=1e-6) From 0cd4588c066fa2784b107fe83d36fb6931692e06 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Mon, 13 Feb 2023 13:26:02 +0000 Subject: [PATCH 15/18] set delta to 1e-5 --- tests/taskflow/test_feature_extraction.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/taskflow/test_feature_extraction.py b/tests/taskflow/test_feature_extraction.py index 7dcb8b036d51..6d4b297bb994 100644 --- a/tests/taskflow/test_feature_extraction.py +++ b/tests/taskflow/test_feature_extraction.py @@ -116,7 +116,7 @@ def test_feature_extraction_task(self): for dygraph_result, static_result in zip(dygraph_results["features"], static_results["features"]): for dygraph_pred, static_pred in zip(dygraph_result.tolist(), static_result.tolist()): - self.assertAlmostEqual(dygraph_pred, static_pred, delta=1e-6) + self.assertAlmostEqual(dygraph_pred, static_pred, delta=1e-5) def test_taskflow_task(self): input_text = ["这是一只猫", "这是一只狗"] @@ -153,4 +153,4 @@ def test_taskflow_task(self): for dygraph_result, static_result in zip(dygraph_results["features"], static_results["features"]): for dygraph_pred, static_pred in zip(dygraph_result.tolist(), static_result.tolist()): - self.assertAlmostEqual(dygraph_pred, static_pred, delta=1e-6) + self.assertAlmostEqual(dygraph_pred, static_pred, delta=1e-5) From 75d72dc2caac5c5eb0115069f02714d67c8fae52 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Mon, 13 Feb 2023 15:44:27 +0000 Subject: [PATCH 16/18] change delta to 1e-5 --- tests/taskflow/test_text_classification.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/taskflow/test_text_classification.py b/tests/taskflow/test_text_classification.py index 2fa193138998..9afa1c0c4627 100644 --- a/tests/taskflow/test_text_classification.py +++ b/tests/taskflow/test_text_classification.py @@ -140,7 +140,7 @@ def test_classification_task(self, batch_size, problem_type, model): for dygraph_result, static_result in zip(dygraph_results, static_results): for dygraph_pred, static_pred in zip(dygraph_result["predictions"], static_result["predictions"]): self.assertEqual(dygraph_pred["label"], static_pred["label"]) - self.assertAlmostEqual(dygraph_pred["score"], static_pred["score"], delta=1e-6) + self.assertAlmostEqual(dygraph_pred["score"], static_pred["score"], delta=1e-5) # if multi_label, all predictions should be greater than the threshold if model == "multi_label": self.assertGreater(dygraph_pred["score"], dygraph_taskflow.multilabel_threshold) @@ -197,7 +197,7 @@ def test_taskflow_task(self, batch_size, problem_type, model): for dygraph_result, static_result in zip(dygraph_results, static_results): for dygraph_pred, static_pred in zip(dygraph_result["predictions"], static_result["predictions"]): self.assertEqual(dygraph_pred["label"], static_pred["label"]) - self.assertAlmostEqual(dygraph_pred["score"], static_pred["score"], delta=1e-6) + self.assertAlmostEqual(dygraph_pred["score"], static_pred["score"], delta=1e-5) # if multi_label, all predictions should be greater than the threshold if model == "multi_label": self.assertGreater(dygraph_pred["score"], dygraph_taskflow.task_instance.multilabel_threshold) From d07f8c4125873692a4893abfe1d542d35e6cefce Mon Sep 17 00:00:00 2001 From: w5688414 Date: Mon, 13 Feb 2023 16:22:36 +0000 Subject: [PATCH 17/18] change delta to 1e-5 --- tests/taskflow/test_feature_extraction.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/taskflow/test_feature_extraction.py b/tests/taskflow/test_feature_extraction.py index 6d4b297bb994..bb057f481195 100644 --- a/tests/taskflow/test_feature_extraction.py +++ b/tests/taskflow/test_feature_extraction.py @@ -103,7 +103,7 @@ def test_feature_extraction_task(self): for dygraph_result, static_result in zip(dygraph_results["features"], static_results["features"]): for dygraph_pred, static_pred in zip(dygraph_result.tolist(), static_result.tolist()): - self.assertAlmostEqual(dygraph_pred, static_pred, delta=1e-6) + self.assertAlmostEqual(dygraph_pred, static_pred, delta=1e-5) input_image = (self.prepare_inputs(equal_resolution=True, paddleify=False),) # dygraph image test @@ -140,7 +140,7 @@ def test_taskflow_task(self): self.assertEqual(static_results["features"].shape[0], 2) for dygraph_result, static_result in zip(dygraph_results["features"], static_results["features"]): for dygraph_pred, static_pred in zip(dygraph_result.tolist(), static_result.tolist()): - self.assertAlmostEqual(dygraph_pred, static_pred, delta=1e-6) + self.assertAlmostEqual(dygraph_pred, static_pred, delta=1e-5) input_image = self.prepare_inputs(equal_resolution=True, paddleify=False) # dygraph image test From 1b2e89d1a99545617100a81430198c3b9b220caa Mon Sep 17 00:00:00 2001 From: w5688414 Date: Tue, 14 Feb 2023 07:20:53 +0000 Subject: [PATCH 18/18] Change to is_static_model --- paddlenlp/taskflow/feature_extraction.py | 10 +++++----- tests/taskflow/test_feature_extraction.py | 8 ++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/paddlenlp/taskflow/feature_extraction.py b/paddlenlp/taskflow/feature_extraction.py index 0c9f854f2440..beac02ee9495 100644 --- a/paddlenlp/taskflow/feature_extraction.py +++ b/paddlenlp/taskflow/feature_extraction.py @@ -182,7 +182,7 @@ class MultimodalFeatureExtractionTask(Task): }, } - def __init__(self, task, model, batch_size=1, _static_mode=True, return_tensors=True, **kwargs): + def __init__(self, task, model, batch_size=1, is_static_model=True, return_tensors=True, **kwargs): super().__init__(task=task, model=model, **kwargs) self._seed = None # we do not use batch @@ -191,14 +191,14 @@ def __init__(self, task, model, batch_size=1, _static_mode=True, return_tensors= self.return_tensors = return_tensors self._check_task_files() self._construct_tokenizer() - self._static_mode = _static_mode + self.is_static_model = is_static_model self._config_map = {} self.predictor_map = {} self.input_names_map = {} self.input_handles_map = {} self.output_handle_map = {} self._check_predictor_type() - if self._static_mode: + if self.is_static_model: self._get_inference_model() else: self._construct_model(model) @@ -228,7 +228,7 @@ def _parse_batch(batch_examples): else: batch_texts = None batch_images = batch_examples - if self._static_mode: + if self.is_static_model: tokenized_inputs = self._processor( text=batch_texts, images=batch_images, return_tensors="np", padding="max_length", truncation=True ) @@ -287,7 +287,7 @@ def _run_model(self, inputs): Run the task model from the outputs of the `_preprocess` function. """ all_feats = [] - if self._static_mode: + if self.is_static_model: with static_mode_guard(): for batch_inputs in inputs["batches"]: if self._predictor_type == "paddle-inference": diff --git a/tests/taskflow/test_feature_extraction.py b/tests/taskflow/test_feature_extraction.py index bb057f481195..bb15d7fb4ff2 100644 --- a/tests/taskflow/test_feature_extraction.py +++ b/tests/taskflow/test_feature_extraction.py @@ -84,7 +84,7 @@ def test_feature_extraction_task(self): dygraph_taskflow = MultimodalFeatureExtractionTask( model="PaddlePaddle/ernie_vil-2.0-base-zh", task="feature_extraction", - _static_mode=False, + is_static_model=False, return_tensors=False, ) dygraph_results = dygraph_taskflow(input_text) @@ -94,7 +94,7 @@ def test_feature_extraction_task(self): static_taskflow = MultimodalFeatureExtractionTask( model="PaddlePaddle/ernie_vil-2.0-base-zh", task="feature_extraction", - _static_mode=True, + is_static_model=True, return_tensors=False, device_id=0, ) @@ -123,7 +123,7 @@ def test_taskflow_task(self): # dygraph test dygraph_taskflow = Taskflow( task="feature_extraction", - _static_mode=False, + is_static_model=False, return_tensors=False, ) dygraph_results = dygraph_taskflow(input_text) @@ -133,7 +133,7 @@ def test_taskflow_task(self): # static test static_taskflow = Taskflow( task="feature_extraction", - _static_mode=True, + is_static_model=True, return_tensors=False, ) static_results = static_taskflow(input_text)