Skip to content

Commit 2af80b1

Browse files
pkhk-1lyuwenyu
andauthored
add pp-inscaptagger (PaddlePaddle#727)
Co-authored-by: lyuwenyu <wenyu.lyu@gmail.com>
1 parent 249d537 commit 2af80b1

File tree

7 files changed

+508
-1
lines changed

7 files changed

+508
-1
lines changed

paddlemix/auto/processing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def _get_processor_class(cls, pretrained_model_name_or_path, text_model_name_or_
124124

125125
for names, processor_class in cls._processor_mapping.items():
126126

127-
if names.lower() in pretrained_model_name_or_path.lower().replace("-", "_").replace("vicuna", "llava"):
127+
if names.lower() in pretrained_model_name_or_path.lower().replace("-", "_").replace("vicuna", "llava").replace("inscaptagger", "llava"):
128128

129129
attributes = processor_class["processor"].attributes
130130
attributes_dict = {}
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
from pathlib import Path
17+
from functools import partial
18+
from paddlemix.datacopilot.core import MMDataset
19+
from paddlemix.datacopilot.misc import enumerate_chunk
20+
from paddlemix.datacopilot.nn import PPInsCapTagger
21+
import paddle
22+
import json
23+
24+
25+
class QAschema(argparse.Action):
26+
def __call__(self, parser, namespace, values, option_string=None):
27+
assert len(values)%2 == 0, "QA content must be a list of pairs"
28+
values = list(zip(values[0::2], values[1::2]))
29+
setattr(namespace, self.dest, values)
30+
31+
32+
if __name__ == '__main__':
33+
34+
base = argparse.ArgumentParser(add_help=False)
35+
base.add_argument('-m', '--model-name-or-path', type=str, default='paddlemix/PP-InsCapTagger')
36+
base.add_argument('-t', '--dtype', type=str, default='float16')
37+
base.add_argument('-k', '--k-start', type=int, default=0)
38+
base.add_argument('-o', '--output-dir', default='SFT_tag_output_test')
39+
base.add_argument('--seed', type=int, default=0)
40+
41+
42+
parser = argparse.ArgumentParser()
43+
subs = parser.add_subparsers(help='mod of data: json_data/single_data', dest='mod')
44+
json_parser = subs.add_parser('json_data', parents=[base])
45+
json_parser.add_argument('-d', '--dataset-path', type=str, required=True)
46+
47+
single_parser = subs.add_parser('single_data', parents=[base])
48+
single_parser.add_argument('-image', '--image-path', type=str, required=True)
49+
single_parser.add_argument('-qa', '--qa-content', nargs='+', type=str, required=True, action=QAschema)
50+
51+
args = parser.parse_args()
52+
paddle.seed(seed=args.seed)
53+
54+
if args.mod == 'json_data':
55+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
56+
m = PPInsCapTagger(args.model_name_or_path)
57+
dataset = MMDataset.from_auto(args.dataset_path)
58+
print('loading dataset...')
59+
print('data size==', len(dataset))
60+
for i, subdata in enumerate_chunk(dataset, chunk_size=1000, start=args.k_start):
61+
print(f'convert {i}th(1000) data')
62+
subdata: MMDataset
63+
subdata = subdata.map(m.inference, max_workers=1)
64+
subdata.export_json(f'{args.output_dir}/tagger_{i:05}.json')
65+
print(f'{i*1000}th(1000) data save to {args.output_dir}/tagger_{i:05}.json')
66+
67+
if args.mod == 'single_data':
68+
item = {}
69+
item["image"] = args.image_path
70+
item['conversations'] = args.qa_content
71+
m = PPInsCapTagger(args.model_name_or_path)
72+
tag_item = m(item)
73+
print(tag_item)
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
2+
# PP-InsCapTagger
3+
4+
## 方案简介
5+
6+
PP-InsCapTagger(Instance Capability Tagger) 是 DataCopilot 基于 PaddleMix 实现的数据集行为标签模型,用于为多模态数据实例能力打标,通过实例能力分布对数据集进行优化,可以提高模型训练效率,为数据集分析和评价提供了一种高效的方案。
7+
结合模型推理打标结果对LLaVA SFT数据集进行优化,可以**提高LLaVA模型SFT阶段50%的训练效率**
8+
9+
数据实例能力标签:在多模态任务中,每条数据都可以抽象出一种或多种能力,在训练时,模型会从这些数据中学习并增强自身对应的能力,如下图。为了评价和优化数据集,我们可以通过模型为每条多模态数据在模型训练中贡献的实例能力进行打标,并根据打标结果中数据实例能力分布进行数据集的优化,进而提升模型的训练效率。
10+
11+
<p align="center">
12+
<img src="https://github.com/user-attachments/assets/e2a8931f-ce24-47c5-9970-b42031bb28c5" align="middle" width = "800" />
13+
</p>
14+
15+
PP-InsCapTagger 基于 PaddleMix 进行训练,使用 `llava-v1.6-7b` 模型作为 `base` 模型。数据集使用多模态数据 [LLaVA-Instruct-150K](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K) 的部分图片和多轮对话内容,并通过 GPT-4o 为每一条数据的实例能力进行打标,并将打标结果作为该条数据的 `tags` 属性进行保存,然后使用 DataCopilot 实现数据集的高效预处理,结合原始多轮对话内容和 `tags` 结果重构数据集的 `question``answer`
16+
17+
PP-InsCapTagger 部分训练和推理的细节可以参考AI Studio 项目:[基于PaddleMIX的数据集行为标签分类器训推实例](https://aistudio.baidu.com/projectdetail/7917712)
18+
19+
## 模型使用示例
20+
21+
本项目提供 PP-InsCapTagger 使用脚本 `inference.py`, 通过`single_data``json_data`两种推理模式,可以分别实现以图像-文本对输入的单条样本推理 和 以`json`文件输入的批量数据推理。
22+
23+
### 单样本推理:
24+
25+
输入图片:<center><img src="https://github.com/user-attachments/assets/1c2fec64-3c94-4782-bc85-ccb083c1f4b2" width = "250"/></center>
26+
27+
输入多轮对话:
28+
29+
```
30+
Q: What animal is in the image? A: The image features a dog.
31+
Q: What color are the dog's eyes? A: The dog has blue eyes.
32+
Q: Where is the dog situated in the image? A: The dog is situated inside a vehicle, on a front passenger seat.
33+
```
34+
35+
```bash
36+
# PaddleMIX根目录下执行
37+
python paddlemix/datacopilot/example/pp_inscaptagger/inference.py \
38+
single_data \
39+
-m paddlemix/PP-InsCapTagger \
40+
-image https://paddlenlp.bj.bcebos.com/models/community/paddlemix/PP-InsCapTagger/demo.jpg \
41+
-qa "What animal is in the image?" "The image features a dog." \
42+
"What color are the dog's eyes?" "The dog has blue eyes." \
43+
"Where is the dog situated in the image?" "The dog is situated inside a vehicle, on a front passenger seat."
44+
```
45+
46+
其中,`-m`表示模型所用权重路径,当值为`paddlemix/PP-InsCapTagger`时,会自动下载`PP-InsCapTagger`模型到本地;`-image`表示输入的图像地址(本地地址\http链接);`-qa`表示输入的多轮对话内容,以空格分隔。
47+
48+
### 批量数据推理:
49+
50+
```bash
51+
# PaddleMIX根目录下执行
52+
python paddlemix/datacopilot/example/pp_inscaptagger/inference.py \
53+
json_data \
54+
-m paddlemix/PP-InsCapTagger \
55+
-d path/to/your/data.json \
56+
-k 0 \
57+
-o path/to/your/output-dir
58+
```
59+
其中,`path/to/your/data.json` 为输入的批量数据文件路径,格式如下:
60+
61+
```json
62+
[
63+
{
64+
"image": "http://ecx.images-amazon.com/images/I/51ntbts0gmL.jpg",
65+
"conversations": [
66+
[
67+
"<image>\nWhat is the genre of this book?",
68+
"Literature & Fiction"
69+
],
70+
[
71+
"What is the color of this book?",
72+
"Red and black"
73+
]
74+
75+
]
76+
},
77+
{
78+
"image": "http://ecx.images-amazon.com/images/I/51cc3XrLevL.jpg",
79+
"conversations": [
80+
[
81+
"<image>\nWhat is the title of this book?",
82+
"Beyond Bigger Leaner Stronger: The Advanced Guide to Building Muscle, Staying Lean, and Getting Strong (The Build Muscle, Get Lean, and Stay Healthy Series)"
83+
]
84+
]
85+
}
86+
]
87+
```
88+
`-k`表示脚本批量处理的起始位置为第k个chunk的数据,默认为0,当处理中断时可以更改处理起始位置;`path/to/your/output-dir`表示处理结果json文件保存的位置,所有chunk的处理结果分别保存在对应的json文件中,命名为`tagger_{i:05}.json`
89+
90+
## 标签使用案例
91+
92+
LLaVA v1.5模型SFT阶段训练时,使用的指令微调数据集为[LLaVA-Instruct-150K](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150)中llava_v1_5_mix665k数据集,该数据集为多个数据集混合而成,相比于预训练数据集,该数据集规模更大,同时在实例能力分布上也存在较大的差异。为了优化该数据集的实例能力分布,进而提高模型训练效率,我们使用PP-InsCapTagger对数据集进行打标,并统计标签分布。
93+
94+
使用PP-InsCapTagger对llava_v1_5_mix665k数据集进行打标,可以得到7913个标签,对数量最多的前100个标签分布进行可视化,可以看出标签分布存在较大的差异,如下图所示:
95+
96+
<details>
97+
<summary>See</summary>
98+
<center><img src="https://github.com/user-attachments/assets/48e30848-fe18-4e1a-a9a5-6c6f18ad9029" width = "300"/></center>
99+
</details>
100+
101+
102+
为了对llava_v1_5_mix665k数据集进行优化,我们使用PP-InsCapTagger打标的标签结果对数据集进行筛选,**首先确定出能够覆盖80%数据的单条数据的标签数量N,然后在数据集标签集合中选出标签数量占比前0.7%的标签作为一个筛选集合R,对于llava_v1_5_mix665k数据集中的每条数据,如果该条数据标签数量小于N,且该条数据的所有标签均在集合R中,则删除该条数据,否则保留该条数据**。通过该筛选策略,最终保留数据集规模为原始数据集的50%左右。
103+
104+
我们分别使用llava_v1_5_mix665k数据集和筛选后的数据集进行llava-1.5-7b SFT阶段训练,对比结果如下表所示:
105+
106+
| Version | ScienceQA | TextVQA | VQAv2 | GQA | mmmu | mme |
107+
|:----------------------:|:-----------:|:---------:|:-------:|:-------:|:-------:|:----------------:|
108+
| llava-1.5-7b <br> (paper) | 66.8 | 58.2 | 78.5 | 62.0 | - | - |
109+
| llava-1.5-7b <br> (rerun) | 69.01 | 57.6 | 79.0 | 62.95 | 36.89 | 1521 <br> 323 |
110+
| llava-1.5-7b <br> (tag 50%/our) | 70.24 | 57.12 | 78.32 | 62.14 | 37.11 | 1476 <br> 338 |
111+
112+
通过PP-InsCapTagger的打标和优化,50%数据集与原始数据集的训练效果基本持平,大大提高了模型训练效率。
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
from collections import Counter
17+
import numpy as np
18+
import ast
19+
import matplotlib.pyplot as plt
20+
import glob
21+
from paddlemix.datacopilot.core import MMDataset
22+
from tqdm import tqdm
23+
24+
def merge_json_files(folder_path):
25+
newdataset = MMDataset()
26+
pathes = sorted(glob.glob(f'{folder_path}/*.json'))
27+
file_count = len(pathes)
28+
for path in sorted(glob.glob(f'{folder_path}/*.json')):
29+
newdataset += MMDataset.from_json(path)
30+
output_file = f'merged_{file_count}.json'
31+
newdataset.export_json(output_file)
32+
return output_file
33+
def all_tag_count(data_json):
34+
data = json.load(open(data_json, encoding='utf-8'))
35+
tag_counts = {}
36+
n=0
37+
for item in data:
38+
try:
39+
tags = ast.literal_eval(item["tag"])['tags']
40+
for tag in list(set(tags)):
41+
# 如果tag中包含逗号,则分割tag
42+
if',' in tag:
43+
# 使用split()方法按照逗号分割字符串
44+
split_strings = tag.split(',')
45+
# 去除每个字符串两端的空格
46+
tags = [s.strip() for s in split_strings]
47+
for tag in tags:
48+
if tag in tag_counts:
49+
tag_counts[tag] += 1
50+
else:
51+
tag_counts[tag] = 1
52+
53+
if tag in tag_counts:
54+
tag_counts[tag] += 1
55+
else:
56+
tag_counts[tag] = 1
57+
except:
58+
n+=1
59+
print('无效tag的数据数量:',n)
60+
print('数据集总量:',len(data))
61+
print('tag数量:',len(tag_counts))
62+
sorted_tag_counts = sorted(tag_counts.items(), key=lambda item: item[1], reverse=True)
63+
output_file = data_json.replace('.json', '_tag_count.json')
64+
with open(output_file, 'w', encoding='utf-8') as f:
65+
json.dump(sorted_tag_counts, f, ensure_ascii=False, indent=4)
66+
return sorted_tag_counts
67+
68+
def one_data_tag_count(data_json):
69+
data = json.load(open(data_json, encoding='utf-8'))
70+
# 统计每条数据中tag的数量
71+
tag_counts = []
72+
for item in data:
73+
try:
74+
tags = ast.literal_eval(item["tag"])['tags']
75+
tag_counts.append(len(tags))
76+
except:
77+
print(item["tag"])
78+
79+
# 统计每个tag数量级别的数据条数
80+
tag_count_freq = Counter(tag_counts)
81+
# 按tag数量排序并计算累积数据覆盖数量
82+
sorted_tag_counts = sorted(tag_count_freq.items(), key=lambda x: x[0], reverse=True)
83+
# 将统计结果保存为字典
84+
tag_count_freq_dict = dict(sorted_tag_counts)
85+
# 将统计结果保存到JSON文件
86+
output_file = data_json.replace('.json', '_tag_count_statistics.json')
87+
with open(output_file, 'w') as f:
88+
json.dump(tag_count_freq_dict, f, indent=4)
89+
90+
cumulative_coverage = np.cumsum([count for _, count in sorted_tag_counts])
91+
# 找到覆盖90%和80%数据的最少tag数量
92+
total_data = len(tag_counts)
93+
cover_90_percent = next(tag for tag, cum_cov in zip([tag for tag, _ in sorted_tag_counts], cumulative_coverage) if cum_cov >= 0.8 * total_data)
94+
cover_80_percent = next(tag for tag, cum_cov in zip([tag for tag, _ in sorted_tag_counts], cumulative_coverage) if cum_cov >= 0.6 * total_data)
95+
print(f"可以覆盖90%数据的单条数据tag数量: {cover_90_percent}")
96+
print(f"可以覆盖80%数据的单条数据tag数量: {cover_80_percent}")
97+
98+
99+
def tag_count_freq_plot(tag_count_file, topn):
100+
data = json.load(open(tag_count_file, encoding='utf-8'))
101+
# 设置字体,使用您安装的中文字体名称
102+
plt.rcParams['font.sans-serif'] = ['WenQuanYi Zen Hei'] # 例如:SimHei、Microsoft YaHei、WenQuanYi Zen Hei等
103+
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
104+
# 确保以UTF-8编码读取JSON文件
105+
data = json.load(open(tag_file, encoding='utf-8'))
106+
data = data[:topn] # 绘制部分数据
107+
categories, values = zip(*data)
108+
plt.figure(figsize=(10, 30))
109+
plt.barh(categories, values, color='skyblue')
110+
plt.xlabel('数量')
111+
plt.title('类别分布')
112+
plt.yticks(fontsize=8) # 调整字体大小以适应显示
113+
# 保存图形
114+
im_path = tag_file.replace('.json', '_plot.png')
115+
plt.savefig(im_path, bbox_inches='tight')
116+
117+
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
import re
17+
import os
18+
import requests
19+
from io import BytesIO
20+
from PIL import Image
21+
from datacopilot.core import MMDataset
22+
import datacopilot.hub as hub
23+
from functools import partial
24+
import ast
25+
26+
tag_num = 3
27+
def process(item, all_tags):
28+
try:
29+
tags = ast.literal_eval(item["tag"])['tags']
30+
tags = set(tags)
31+
32+
tag_counts=len(tags)
33+
34+
if tag_counts < tag_num and tags - all_tags == set():
35+
return None
36+
else:
37+
return item
38+
except:
39+
return item
40+
41+
if __name__ == '__main__':
42+
tag_most_ratio = 0.007
43+
all_tags = set()
44+
path = 'path/to/your/tag_file.json'
45+
tag_path = 'path/to/your/tag_file_tag_count.json'
46+
tag_count_list = MMDataset.from_json(tag_path)
47+
tag_num = len(tag_count_list)
48+
print(f'{path}数据集tags的种类总数为:',tag_num)
49+
50+
tag_used_num = int(tag_num*tag_most_ratio)
51+
print(f'数量占比前{tag_most_ratio}的tags的种类总数为:',tag_used_num)
52+
for t,n in tag_count_list[:tag_used_num]:
53+
all_tags.add(t)
54+
print(f'使用的前{tag_most_ratio}%的tags:',all_tags)
55+
56+
dataset = MMDataset.from_json(path)
57+
data_len = len(dataset)
58+
print('原始数据集长度:',data_len)
59+
func = partial(
60+
process,
61+
all_tags=all_tags
62+
)
63+
dataset = dataset.map(func)
64+
newdataset = dataset.nonempty()
65+
out_data_len = len(newdataset)
66+
print('筛选后数据集数量:',out_data_len)
67+
print('筛选后数据集占原数据集比例: ', out_data_len/data_len)
68+
69+
newdataset.export_json(path.replace('.json', f'_filter_{out_data_len}_tag.json'))

paddlemix/datacopilot/nn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@
1414

1515

1616
from ._lid import FastTextLIDModel
17+
from .inscaptagger import PPInsCapTagger
1718

0 commit comments

Comments
 (0)