diff --git a/deploy/llava/export_model.py b/deploy/llava/export_model.py index a64abdcbb..57c6ddc42 100644 --- a/deploy/llava/export_model.py +++ b/deploy/llava/export_model.py @@ -30,13 +30,13 @@ def export_encode_text(model, config, compute_dtype): def export_encode_image(model, compute_dtype): - + paddle.save(model.llama.image_newline,args.save_path + "/encode_image/clip/image_newline.pdparams") # convert to static graph with specific input description model = paddle.jit.to_static( model.encode_images, input_spec=[ - paddle.static.InputSpec(shape=[None, 3, 336, 336], dtype=compute_dtype), # images - ], + paddle.static.InputSpec(shape=[None,3, 336, 336], dtype=compute_dtype), # images + ] ) # save to static model @@ -76,6 +76,7 @@ def export_encode_image(model, compute_dtype): vision_tower = model.get_vision_tower() vision_tower.load_model() model.eval() + export_encode_image(model, compute_dtype) elif args.encode_text: diff --git a/deploy/llava/run_static_predict.py b/deploy/llava/run_static_predict.py index 73a83538f..0255c9729 100644 --- a/deploy/llava/run_static_predict.py +++ b/deploy/llava/run_static_predict.py @@ -26,6 +26,8 @@ IMAGE_TOKEN_INDEX, ) from paddlemix.models.llava.conversation import conv_templates +from paddlemix.models.llava.mm_utils import load_image,get_anyres_image_grid_shape +from paddlemix.models.llava.base_model import unpad_image from paddlemix.utils.log import logger @@ -39,15 +41,20 @@ def __init__(self, args): self.args = args self.config = AutoConfigMIX.from_pretrained(args.model_name_or_path) + self.clip_config = AutoConfigMIX.from_pretrained(self.config.mm_vision_tower) + self.tokenizer = AutoTokenizerMIX.from_pretrained(args.model_name_or_path) - self.processor, _ = AutoProcessorMIX.from_pretrained(args.model_name_or_path, eval="eval") + self.processor, _ = AutoProcessorMIX.from_pretrained(args.model_name_or_path, image_aspect_ratio=self.config.image_aspect_ratio,eval="eval") self.first_predictor = self.create_predictor(args.first_model_path) print(f"first_model_path: {args.first_model_path}, {self.first_predictor}") + self.second_predictor = self.create_predictor(args.second_model_path) print(f"second_model_path: {args.second_model_path}, {self.second_predictor}") + self.image_newline = paddle.load(os.path.join(args.first_model_path, "image_newline.pdparams")) + def create_predictor(self, model_path): from paddlenlp.utils.import_utils import import_module @@ -77,9 +84,79 @@ def create_predictor(self, model_path): return predictor @paddle.no_grad() - def encode_images(self, pixel_values): - language_model_inputs = self.first_predictor.run(pixel_values) - return language_model_inputs + def encode_images(self, images, image_sizes): + if type(images) is list or images.ndim == 5: + if type(images) is list: + images = [(x.unsqueeze(axis=0) if x.ndim == 3 else x) for x in images] + concat_images = paddle.concat(x=[image for image in images], axis=0) + + image_features = self.first_predictor.run(concat_images)[0] + + split_sizes = [image.shape[0] for image in images] + image_features = paddle.split(image_features, split_sizes, axis=0) + mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat") + image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square") + if mm_patch_merge_type == "flat": + image_features = [x.flatten(start_axis=0, stop_axis=1) for x in image_features] + elif mm_patch_merge_type.startswith("spatial"): + new_image_features = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + height = width = self.clip_config.image_resolution // self.clip_config.vision_patch_size + assert height * width == base_image_feature.shape[0] + if image_aspect_ratio == "anyres": + num_patch_width, num_patch_height = get_anyres_image_grid_shape( + image_sizes[image_idx], + self.config.image_grid_pinpoints, + self.clip_config.image_resolution, + ) + + image_feature = paddle.reshape( + image_feature, (num_patch_height, num_patch_width, height, width, -1) + ) + else: + raise NotImplementedError + if "unpad" in mm_patch_merge_type: + image_feature = image_feature.transpose(perm=[4, 0, 2, 1, 3]) + image_feature = image_feature.flatten(start_axis=1, stop_axis=2).flatten( + start_axis=2, stop_axis=3 + ) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + image_feature = paddle.concat( + x=( + image_feature, + self.image_newline[:, (None), (None)].expand( + shape=[*image_feature.shape[:-1], 1] + ).astype(image_feature.dtype), + ), + axis=-1, + ) + x = image_feature.flatten(start_axis=1, stop_axis=2) + perm_12 = list(range(x.ndim)) + perm_12[0] = 1 + perm_12[1] = 0 + image_feature = x.transpose(perm=perm_12) + else: + image_feature = image_feature.transpose(perm=[0, 2, 1, 3, 4]) + image_feature = image_feature.flatten(start_axis=0, stop_axis=3) + image_feature = paddle.concat(x=(base_image_feature, image_feature), axis=0) + else: + image_feature = image_feature[0] + if "unpad" in mm_patch_merge_type: + image_feature = paddle.concat( + x=(image_feature, self.image_newline[None].to(image_feature.place)), axis=0 + ) + new_image_features.append(image_feature) + image_features = new_image_features + image_features = paddle.stack(x=image_features, axis=0) + else: + raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}") + else: + image_features = self.first_predictor.run(images)[0] + + return image_features @paddle.no_grad() def generate_with_image_features(self, image_features, input_ids): @@ -225,9 +302,9 @@ def pre_processing(self, inp, first_message): conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() record = {"image": self.args.image_file, "conversations": prompt} - + image_size = load_image(args.image_file).size data_dict = self.processor(record=record, image_aspect_ratio=self.config.image_aspect_ratio) - + data_dict['image_size'] = [image_size] return data_dict def post_processing(self, generate_ids): @@ -245,8 +322,8 @@ def run_benchmark(self): inp = "user: Generate the caption in English with grounding" data_dict = self.pre_processing(inp, first_message) image = paddle.cast(data_dict["images"], self.compute_dtype) - - image_features = self.encode_images(image)[0] + + image_features = self.encode_images(image,data_dict['image_size']) generate_ids, _ = self.generate_with_image_features( image_features, @@ -277,9 +354,9 @@ def predict(self): print(f"{roles[1]}: ", end="") data_dict = self.pre_processing(inp, first_message) image = paddle.cast(data_dict["images"], self.compute_dtype) - - image_features = self.encode_images(image)[0] - + + image_features = self.encode_images(image,data_dict['image_size']) + generate_ids, _ = self.generate_with_image_features( image_features, data_dict["input_ids"], diff --git a/paddlemix/models/llava/clip_model.py b/paddlemix/models/llava/clip_model.py index d5d9918e0..344493128 100644 --- a/paddlemix/models/llava/clip_model.py +++ b/paddlemix/models/llava/clip_model.py @@ -947,6 +947,7 @@ def __init__(self, config: CLIPVisionConfig): super().__init__() self.config = config embed_dim = config.hidden_size + self.embed_dim = embed_dim self.input_resolution = config.image_size self.class_embedding = self.create_parameter( (embed_dim,), @@ -1016,17 +1017,18 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict target_dtype = self.conv1.weight.dtype pixel_values = self.conv1(pixel_values.cast(target_dtype)) - + # for to_static - pixel_values_shape = paddle.to_tensor(pixel_values.shape, dtype="int32") + pixel_values_shape = paddle.shape(pixel_values) pixel_values = pixel_values.reshape( (pixel_values_shape[0], pixel_values_shape[1], pixel_values_shape[2] * pixel_values_shape[3]) ) pixel_values = pixel_values.transpose((0, 2, 1)) embedding_output = paddle.concat( - [self.class_embedding.unsqueeze([0, 1]).expand([pixel_values.shape[0], -1, -1]), pixel_values], axis=1 + [self.class_embedding.unsqueeze([0, 1]).expand([pixel_values_shape[0], -1, -1]), pixel_values], axis=1 ) + hidden_states = embedding_output + self.positional_embedding.weight hidden_states = self.ln_pre(hidden_states)