Skip to content

remove phi3 in internvl2 and refine format #715

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ __pycache__/
/lib/
/lib64/
/output/
/work_dirs/
/inference_model/
/output_inference/
/parts/
Expand Down
120 changes: 88 additions & 32 deletions paddlemix/examples/internvl2/chat_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
# limitations under the License.

import argparse

import paddle
import paddle.vision.transforms as T
from paddlenlp.transformers import Llama3Tokenizer, LlamaTokenizer, Qwen2Tokenizer
from PIL import Image

from paddlemix.datasets.internvl_dataset import dynamic_preprocess
from paddlemix.models.internvl2.internlm2 import InternLM2Tokenizer
from paddlenlp.transformers import AutoTokenizer, Qwen2Tokenizer, LlamaTokenizer, Llama3Tokenizer
from paddlemix.models.internvl2.internvl_chat import InternVLChatModel
from paddlemix.datasets.internvl_dataset import dynamic_preprocess

paddle.set_grad_enabled(False)

Expand All @@ -30,17 +31,19 @@

def build_transform(input_size):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
transform = T.Compose([
# T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation='bicubic'),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
transform = T.Compose(
[
# T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation="bicubic"),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD),
]
)
return transform


def load_image(image_file, input_size=448, max_num=12):
image = Image.open(image_file).convert('RGB')
image = Image.open(image_file).convert("RGB")
transform = build_transform(input_size=input_size)
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
pixel_values = [transform(image) for image in images]
Expand All @@ -49,34 +52,85 @@ def load_image(image_file, input_size=448, max_num=12):


def load_tokenizer(model_size, model_path):
if model_size in ['1B']:
if model_size in ["1B"]:
tokenizer = Qwen2Tokenizer.from_pretrained(model_path)
# TODO:
tokenizer.added_tokens_encoder = {'<|endoftext|>': 151643, '<|im_start|>': 151644, '<|im_end|>': 151645, '<img>': 151646, '</img>': 151647, '<IMG_CONTEXT>': 151648, '<quad>': 151649, '</quad>': 151650, '<ref>': 151651, '</ref>': 151652, '<box>': 151653, '</box>': 151654}
tokenizer.added_tokens_encoder = {
"<|endoftext|>": 151643,
"<|im_start|>": 151644,
"<|im_end|>": 151645,
"<img>": 151646,
"</img>": 151647,
"<IMG_CONTEXT>": 151648,
"<quad>": 151649,
"</quad>": 151650,
"<ref>": 151651,
"</ref>": 151652,
"<box>": 151653,
"</box>": 151654,
}
tokenizer.added_tokens_decoder = {v: k for k, v in tokenizer.added_tokens_encoder.items()}

elif model_size in ['2B', '8B', '26B']:
elif model_size in ["2B", "8B", "26B"]:
tokenizer = InternLM2Tokenizer.from_pretrained(model_path)
# TODO:
tokenizer.added_tokens_encoder = {'<unk>': 0, '<s>': 1, '</s>': 2, '<|plugin|>': 92538, '<|interpreter|>': 92539, '<|action_end|>': 92540, '<|action_start|>': 92541, '<|im_end|>': 92542, '<|im_start|>': 92543, '<img>': 92544, '</img>': 92545, '<IMG_CONTEXT>': 92546, '<quad>': 92547, '</quad>': 92548, '<ref>': 92549, '</ref>': 92550, '<box>': 92551, '</box>': 92552}
tokenizer.added_tokens_decoder = {v: k for k, v in tokenizer.added_tokens_encoder.items()}

elif model_size in ['4B']:
tokenizer = LlamaTokenizer.from_pretrained(model_path)
# TODO:
tokenizer.added_tokens_encoder = {'<unk>': 0, '<s>': 1, '</s>': 2, '<|endoftext|>': 32000, '<|assistant|>': 32001, '<|placeholder1|>': 32002, '<|placeholder2|>': 32003, '<|placeholder3|>': 32004, '<|placeholder4|>': 32005, '<|system|>': 32006, '<|end|>': 32007, '<|placeholder5|>': 32008, '<|placeholder6|>': 32009, '<|user|>': 32010, '<img>': 32011, '</img>': 32012, '<IMG_CONTEXT>': 32013, '<quad>': 32014, '</quad>': 32015, '<ref>': 32016, '</ref>': 32017, '<box>': 32018, '</box>': 32019}
tokenizer.added_tokens_encoder = {
"<unk>": 0,
"<s>": 1,
"</s>": 2,
"<|plugin|>": 92538,
"<|interpreter|>": 92539,
"<|action_end|>": 92540,
"<|action_start|>": 92541,
"<|im_end|>": 92542,
"<|im_start|>": 92543,
"<img>": 92544,
"</img>": 92545,
"<IMG_CONTEXT>": 92546,
"<quad>": 92547,
"</quad>": 92548,
"<ref>": 92549,
"</ref>": 92550,
"<box>": 92551,
"</box>": 92552,
}
tokenizer.added_tokens_decoder = {v: k for k, v in tokenizer.added_tokens_encoder.items()}

elif model_size in ['40B']:
elif model_size in ["40B"]:
tokenizer = LlamaTokenizer.from_pretrained(model_path)
# TODO:
tokenizer.added_tokens_encoder = {'<unk>': 0, '<|startoftext|>': 1, '<|endoftext|>': 2, '<|im_start|>': 6, '<|im_end|>': 7, '<img>': 68, '</img>': 70, '<IMG_CONTEXT>': 64000, '<quad>': 64001, '</quad>': 64002, '<ref>': 64003, '</ref>': 64004, '<box>': 64005, '</box>': 64006}
tokenizer.added_tokens_encoder = {
"<unk>": 0,
"<|startoftext|>": 1,
"<|endoftext|>": 2,
"<|im_start|>": 6,
"<|im_end|>": 7,
"<img>": 68,
"</img>": 70,
"<IMG_CONTEXT>": 64000,
"<quad>": 64001,
"</quad>": 64002,
"<ref>": 64003,
"</ref>": 64004,
"<box>": 64005,
"</box>": 64006,
}
tokenizer.added_tokens_decoder = {v: k for k, v in tokenizer.added_tokens_encoder.items()}

elif model_size in ['76B']:
elif model_size in ["76B"]:
tokenizer = Llama3Tokenizer.from_pretrained(model_path)
# TODO:
tokenizer.added_tokens_encoder = {'<img>': 128256, '</img>': 128257, '<IMG_CONTEXT>': 128258, '<quad>': 128259, '</quad>': 128260, '<ref>': 128261, '</ref>': 128262, '<box>': 128263, '</box>': 128264}
tokenizer.added_tokens_encoder = {
"<img>": 128256,
"</img>": 128257,
"<IMG_CONTEXT>": 128258,
"<quad>": 128259,
"</quad>": 128260,
"<ref>": 128261,
"</ref>": 128262,
"<box>": 128263,
"</box>": 128264,
}
tokenizer.added_tokens_decoder = {v: k for k, v in tokenizer.added_tokens_encoder.items()}

else:
Expand All @@ -86,28 +140,30 @@ def load_tokenizer(model_size, model_path):


def main(args):
if args.image_path is not None and args.image_path != 'None':
if args.image_path is not None and args.image_path != "None":
pixel_values = load_image(args.image_path, max_num=12).to(paddle.bfloat16)
args.text = '<image>\n' + args.text
args.text = "<image>\n" + args.text

else:
pixel_values = None

# init model and tokenizer
MODEL_PATH = args.model_name_or_path
model_size = MODEL_PATH.split('-')[-1]
print(f'model size: {model_size}')
model_size = MODEL_PATH.split("-")[-1]
print(f"model size: {model_size}")
tokenizer = load_tokenizer(model_size, MODEL_PATH)
print('tokenizer:\n', tokenizer)
print('len(tokenizer): ', len(tokenizer))
print("tokenizer:\n", tokenizer)
print("len(tokenizer): ", len(tokenizer))

model = InternVLChatModel.from_pretrained(MODEL_PATH).eval()

generation_config = dict(max_new_tokens=1024, do_sample=False)

with paddle.no_grad():
response, history = model.chat(tokenizer, pixel_values, args.text, generation_config, history=None, return_history=True)
print(f'User: {args.text}\nAssistant: {response}')
response, history = model.chat(
tokenizer, pixel_values, args.text, generation_config, history=None, return_history=True
)
print(f"User: {args.text}\nAssistant: {response}")


if __name__ == "__main__":
Expand All @@ -119,6 +175,6 @@ def main(args):
help="pretrained ckpt and tokenizer",
)
parser.add_argument("--image_path", type=str, default=None)
parser.add_argument("--text", type=str, default='Please describe the image shortly.', required=True)
parser.add_argument("--text", type=str, default="Please describe the image shortly.", required=True)
args = parser.parse_args()
main(args)
Loading