Skip to content

Commit 22ae267

Browse files
joey12300ZeyuChen
andauthored
[FastTokenizer] Add clip fast tokenizer (#3746)
* Add clip fast tokenizer * Add clip fast tokenizer unittest * Add ThreadNum Set Get * Add set, get thread num pybind * Add pybind of set get thread num * Add sido foframework or clip * Add README * remve omp * Add EnccodeBatch for 2 vectors of strings * Fix AddedToken * Remove words_idx print * Add comments Co-authored-by: Zeyu Chen <chenzeyu01@baidu.com>
1 parent ec8ef93 commit 22ae267

29 files changed

+682
-301
lines changed

fast_tokenizer/CMakeLists.txt

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ project(tokenizers LANGUAGES CXX C VERSION 1.0)
44

55
option(WITH_TESTING "Compile PaddleNLP fast_tokenizer with unit testing" OFF)
66
option(WITH_PYTHON "Compile PaddleNLP fast_tokenizer with python interpreter" ON)
7-
add_definitions(-DFASTERTOKENIZER_LIB)
7+
add_definitions(-DFASTTOKENIZER_LIB)
88

99
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
1010
set (PUBLIC_DEPEND_LIBS "")
@@ -108,17 +108,6 @@ ELSE(WIN32)
108108
set (PUBLIC_DEPEND_LIBS ${CMAKE_DL_LIBS})
109109
ENDIF(WIN32)
110110

111-
# For OpenMP
112-
# openmp not support well for now on windows
113-
if (NOT APPLE AND NOT WIN32) # Linux
114-
find_package(OpenMP)
115-
if (OPENMP_FOUND)
116-
add_definitions(-DWITH_OMP)
117-
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
118-
set (CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}")
119-
endif()
120-
endif()
121-
122111
set(CMAKE_INSTALL_PREFIX ${PROJECT_SOURCE_DIR})
123112
set(TOKENIZERS_INSTALL_INCLUDE_DIR ${PROJECT_SOURCE_DIR})
124113

fast_tokenizer/FastTokenizer.cmake

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,18 @@ endif()
1616

1717
set(LIBRARY_NAME core_tokenizers)
1818

19-
set(FASTER_TOKENIZER_INCS "")
20-
list(APPEND FASTER_TOKENIZER_INCS ${CMAKE_CURRENT_LIST_DIR}/include)
21-
list(APPEND FASTER_TOKENIZER_INCS ${CMAKE_CURRENT_LIST_DIR}/third_party/include)
19+
set(FAST_TOKENIZER_INCS "")
20+
list(APPEND FAST_TOKENIZER_INCS ${CMAKE_CURRENT_LIST_DIR}/include)
21+
list(APPEND FAST_TOKENIZER_INCS ${CMAKE_CURRENT_LIST_DIR}/third_party/include)
2222

23-
set(FASTER_TOKENIZER_LIBS "")
23+
set(FAST_TOKENIZER_LIBS "")
2424
find_library(FTLIB ${LIBRARY_NAME} ${CMAKE_CURRENT_LIST_DIR}/lib NO_DEFAULT_PATH)
25-
list(APPEND FASTER_TOKENIZER_LIBS ${FTLIB})
25+
list(APPEND FAST_TOKENIZER_LIBS ${FTLIB})
2626

2727
if (WIN32)
2828
find_library(ICUDT icudt ${CMAKE_CURRENT_LIST_DIR}/third_party/lib NO_DEFAULT_PATH)
29-
list(APPEND FASTER_TOKENIZER_LIBS ${ICUDT})
29+
list(APPEND FAST_TOKENIZER_LIBS ${ICUDT})
3030

3131
find_library(ICUUC icuuc ${CMAKE_CURRENT_LIST_DIR}/third_party/lib NO_DEFAULT_PATH)
32-
list(APPEND FASTER_TOKENIZER_LIBS ${ICUUC})
32+
list(APPEND FAST_TOKENIZER_LIBS ${ICUUC})
3333
endif()

fast_tokenizer/README.md

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ FastTokenizer是一款简单易用、功能强大的跨平台高性能文本预
1717

1818
## 特性
1919

20-
- 高性能。由于底层采用C++实现,所以其性能远高于目前常规Python实现的Tokenizer。在文本分类任务上,FastTokenizer对比Python版本Tokenizer加速比最高可达20倍。
20+
- 高性能。由于底层采用C++实现,所以其性能远高于目前常规Python实现的Tokenizer。在文本分类任务上,FastTokenizer对比Python版本Tokenizer加速比最高可达20倍。支持多线程加速多文本批处理分词。默认使用单线程分词。
2121
- 跨平台。FastTokenizer可在不同的系统平台上使用,目前已支持Windows x64,Linux x64以及MacOS 10.14+平台上使用。
2222
- 多编程语言支持。FastTokenizer提供在C++、Python语言上开发的能力。
2323
- 灵活性强。用户可以通过指定不同的FastTokenizer组件定制满足需求的Tokenizer。
@@ -26,12 +26,12 @@ FastTokenizer是一款简单易用、功能强大的跨平台高性能文本预
2626

2727
下面将介绍Python版本FastTokenizer的使用方式,C++版本的使用方式可参考[FastTokenizer C++ Demo](./fast_tokenizer/demo/README.md)
2828

29-
### 前置依赖
29+
### 环境依赖
3030

3131
- Windows 64位系统
3232
- Linux x64系统
3333
- MacOS 10.14+系统(m1芯片的MacOS,需要使用x86_64版本的Anaconda作为python环境方可安装使用)
34-
- Python 3.6 ~ 3.9
34+
- Python 3.6 ~ 3.10
3535

3636
### 安装FastTokenizer
3737

@@ -53,7 +53,11 @@ wget https://bj.bcebos.com/paddlenlp/models/transformers/ernie/vocab.txt
5353
FastTokenizer库内置NLP任务常用的Tokenizer,如ErnieFastTokenizer。下面将展示FastTokenizer的简单用法。
5454

5555
```python
56+
import fast_tokenizer
5657
from fast_tokenizer import ErnieFastTokenizer, models
58+
59+
# 0.(可选)设置线程数
60+
fast_tokenizer.set_thread_num(1)
5761
# 1. 加载词表
5862
vocab = models.WordPiece.read_file("ernie_vocab.txt")
5963
# 2. 实例化ErnieFastTokenizer对象
@@ -96,10 +100,19 @@ Q:我在AutoTokenizer.from_pretrained接口上已经打开`use_fast=True`开
96100
A:在有三种情况下,打开`use_fast=True`开关可能无法提升性能:
97101
1. 没有安装fast_tokenizer。若在没有安装fast_tokenizer库的情况下打开`use_fast`开关,PaddleNLP会给出以下warning:"Can't find the fast_tokenizer package, please ensure install fast_tokenizer correctly. "。
98102

99-
2. 加载的Tokenizer类型暂不支持Fast版本。目前支持4种Tokenizer的Fast版本,分别是BERT、ERNIE、TinyBERT以及ERNIE-M Tokenizer。若加载不支持Fast版本的Tokenizer情况下打开`use_fast`开关,PaddleNLP会给出以下warning:"The tokenizer XXX doesn't have the fast version. Please check the map paddlenlp.transformers.auto.tokenizer.FASTER_TOKENIZER_MAPPING_NAMES to see which fast tokenizers are currently supported."
103+
2. 加载的Tokenizer类型暂不支持Fast版本。目前支持4种Tokenizer的Fast版本,分别是BERT、ERNIE、TinyBERT以及ERNIE-M Tokenizer。若加载不支持Fast版本的Tokenizer情况下打开`use_fast`开关,PaddleNLP会给出以下warning:"The tokenizer XXX doesn't have the fast version. Please check the map paddlenlp.transformers.auto.tokenizer.FAST_TOKENIZER_MAPPING_NAMES to see which fast tokenizers are currently supported."
100104

101105
3. 待切词文本长度过短(如文本平均长度小于5)。这种情况下切词开销可能不是整个文本预处理的性能瓶颈,导致在使用FastTokenizer后仍无法提升整体性能。
102106

107+
Q:如何使用多线程加速分词?
108+
109+
A:可以通过调用 `fast_tokenizer.set_thread_num(xxx)` 使用多线程进行分词。需要谨慎开启多线程加速分词,在以下场景下可以考虑开启多线程:
110+
1. CPU资源充足。若在推理阶段使用CPU进行推理,开启多线程分词可能会出现资源竞争情况,从而影响推理阶段的性能。
111+
112+
2. 文本的批大小较大。若批大小比较小,开启多线程可能不会得到任何加速效果,并且可能会因为线程调度导致延时增长。建议批大小大于4的时候再考虑开启多线程分词。
113+
114+
3. 文本长度较长。若文本长度较短,开启多线程可能不会得到任何加速效果,并且可能会因为线程调度导致延时增长。建议文本平均长度大于16的时候再考虑开启多线程分词。
115+
103116
## 相关文档
104117

105118
[FastTokenizer编译指南](docs/compile/README.md)

fast_tokenizer/fast_tokenizer/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ endif()
3333
else(WITH_PYTHON)
3434
# add_subdirectory(tokenizers)
3535
cc_library(core_tokenizers SHARED
36-
SRCS tokenizers/ernie_fast_tokenizer.cc
36+
SRCS tokenizers/ernie_fast_tokenizer.cc tokenizers/clip_fast_tokenizer.cc
3737
DEPS normalizers pretokenizers models decoders
3838
postprocessors core added_vocabulary tokenizer json)
3939

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
cc_library(added_vocabulary SRCS added_vocabulary.cc DEPS normalizers pretokenizers json)
2-
cc_library(tokenizer SRCS tokenizer.cc DEPS added_vocabulary json decoders trie models postprocessors)
3-
cc_library(core SRCS encoding.cc DEPS json)
2+
cc_library(base SRCS base.cc)
3+
cc_library(tokenizer SRCS tokenizer.cc DEPS added_vocabulary json decoders trie models postprocessors base)
4+
cc_library(core SRCS encoding.cc DEPS json base)
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "fast_tokenizer/core/base.h"
16+
#include <thread>
17+
18+
namespace paddlenlp {
19+
namespace fast_tokenizer {
20+
namespace core {
21+
22+
static int fast_tokenizer_thread_num = 1;
23+
24+
void SetThreadNum(int thread_num) { fast_tokenizer_thread_num = thread_num; }
25+
26+
int GetThreadNum() { return fast_tokenizer_thread_num; }
27+
28+
void RunMultiThread(std::function<void(size_t, size_t)> func,
29+
size_t batch_size) {
30+
int thread_num = GetThreadNum();
31+
std::vector<std::thread> vectorOfThread;
32+
size_t start_index = 0;
33+
size_t step_index = ceil(batch_size / float(thread_num));
34+
35+
for (size_t thread_index = 0; thread_index < thread_num; thread_index++) {
36+
vectorOfThread.emplace_back(std::thread(func, start_index, step_index));
37+
start_index = start_index + step_index;
38+
}
39+
for (size_t thread_index = 0; thread_index < thread_num; thread_index++) {
40+
vectorOfThread[thread_index].join();
41+
}
42+
}
43+
44+
} // namespace core
45+
} // namespace fast_tokenizer
46+
} // namespace paddlenlp

fast_tokenizer/fast_tokenizer/core/base.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,13 @@ struct FASTTOKENIZER_DECL BPEWord {
366366
std::vector<Symbol> symbols_;
367367
};
368368

369+
FASTTOKENIZER_DECL void SetThreadNum(int thread_num);
370+
371+
FASTTOKENIZER_DECL int GetThreadNum();
372+
373+
FASTTOKENIZER_DECL void RunMultiThread(std::function<void(size_t, size_t)> func,
374+
size_t batch_size);
375+
369376
} // namespace core
370377
} // namespace fast_tokenizer
371378
} // namespace paddlenlp

fast_tokenizer/fast_tokenizer/core/encoding.cc

Lines changed: 0 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,6 @@ limitations under the License. */
1919
#include <sstream>
2020
#include "glog/logging.h"
2121

22-
#ifdef WITH_OMP
23-
#include <omp.h>
24-
#endif
25-
2622
namespace paddlenlp {
2723
namespace fast_tokenizer {
2824
namespace core {
@@ -547,15 +543,6 @@ std::string Encoding::DebugString() const {
547543
oss << "{" << iter->first << " : (" << iter->second.first << ", "
548544
<< iter->second.second << ") }, ";
549545
}
550-
oss << "\n";
551-
552-
oss << "words_idx:";
553-
for (int i = 0; i < words_idx_.size(); ++i) {
554-
oss << words_idx_[i];
555-
if (i < words_idx_.size() - 1) {
556-
oss << ", ";
557-
}
558-
}
559546
return oss.str();
560547
}
561548

@@ -667,62 +654,15 @@ void PadEncodings(std::vector<Encoding>* encodings, const PadMethod& method) {
667654
pad_length += pad_length - pad_length % method.pad_to_multiple_of_;
668655
}
669656
auto batch_size = encodings->size();
670-
#ifdef WITH_OMP
671-
#pragma omp parallel for if (batch_size >= 4 && omp_get_max_threads() > 1)
672-
for (int i = 0; i < batch_size; ++i) {
673-
auto& encoding = (*encodings)[i];
674-
encoding.Pad(pad_length,
675-
method.pad_id_,
676-
method.pad_token_type_id_,
677-
method.pad_token_,
678-
method.direction_);
679-
}
680-
#else
681657
auto func = std::bind(&MultiThreadPadEncodings,
682658
encodings,
683659
std::ref(method),
684660
pad_length,
685661
std::placeholders::_1,
686662
std::placeholders::_2);
687663
RunMultiThread(func, batch_size);
688-
#endif
689664
}
690665

691-
int GetThreadNum(size_t batch_size) {
692-
char* env_var = std::getenv("OMP_NUM_THREADS");
693-
int thread_num = std::atoi(env_var);
694-
if (batch_size <= 0) {
695-
thread_num = 1;
696-
VLOG(3) << "batch_size <=0, we set OMP_NUM_THREADS = 1";
697-
} else {
698-
int best_num = ceil(batch_size / 4.0);
699-
if (thread_num > best_num) {
700-
thread_num = best_num;
701-
VLOG(3) << "OMP_NUM_THREADS > batch_size/4, we set OMP_NUM_THREADS = "
702-
"batch_size/4";
703-
} else if (thread_num == 0) {
704-
thread_num = best_num;
705-
VLOG(3) << "OMP_NUM_THREADS == 0, we set OMP_NUM_THREADS = batch_size/4";
706-
}
707-
}
708-
return thread_num;
709-
}
710-
711-
void RunMultiThread(std::function<void(size_t, size_t)> func,
712-
size_t batch_size) {
713-
int thread_num = GetThreadNum(batch_size);
714-
std::vector<std::thread> vectorOfThread;
715-
size_t start_index = 0;
716-
size_t step_index = ceil(batch_size / float(thread_num));
717-
718-
for (size_t thread_index = 0; thread_index < thread_num; thread_index++) {
719-
vectorOfThread.emplace_back(std::thread(func, start_index, step_index));
720-
start_index = start_index + step_index;
721-
}
722-
for (size_t thread_index = 0; thread_index < thread_num; thread_index++) {
723-
vectorOfThread[thread_index].join();
724-
}
725-
}
726666

727667
} // namespace core
728668
} // namespace fast_tokenizer

fast_tokenizer/fast_tokenizer/core/encoding.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,6 @@ bool FASTTOKENIZER_DECL TruncateEncodings(Encoding* encoding,
130130
void FASTTOKENIZER_DECL PadEncodings(std::vector<Encoding>* encoding,
131131
const PadMethod& method);
132132

133-
int FASTTOKENIZER_DECL GetThreadNum(size_t batch_size);
134-
135-
void FASTTOKENIZER_DECL RunMultiThread(std::function<void(size_t, size_t)> func,
136-
size_t batch_size);
137133
} // namespace core
138134
} // namespace fast_tokenizer
139135
} // namespace paddlenlp

0 commit comments

Comments
 (0)