diff --git a/csrc/xpu/README.md b/csrc/xpu/README.md new file mode 100644 index 000000000000..c1f64e29a08c --- /dev/null +++ b/csrc/xpu/README.md @@ -0,0 +1,32 @@ +# ernie-bot-custom-ops +ernie bot 昆仑自定义算子库。 + +## 快速开始 +# 构建 XDNN plugin 和 Paddle 自定义算子库 +``` +$ cd src +$ wget https://baidu-kunlun-product.su.bcebos.com/KL-SDK/klsdk-dev/20240429/xdnn-ubuntu_x86_64.tar.gz +$ wget https://baidu-kunlun-product.su.bcebos.com/KL-SDK/klsdk-dev/20240429/xre-ubuntu_x86_64.tar.gz +$ wget -q --no-check-certificate https://klx-sdk-release-public.su.bcebos.com/xtdk_llvm15/dev/2.7.98.2/xtdk-llvm15-ubuntu1604_x86_64.tar.gz +$ tar -xf xdnn-ubuntu_x86_64.tar.gz +$ tar -xf xre-ubuntu_x86_64.tar.gz +$ tar -xf xtdk-llvm15-ubuntu1604_x86_64.tar.gz +$ export PWD=$(pwd) +$ export XDNN_PATH=${PWD}/xdnn-ubuntu_x86_64/ +$ export XRE_PATH=${PWD}/xre-ubuntu_x86_64/ +$ export CLANG_PATH=${PWD}/xtdk-llvm15-ubuntu1604_x86_64/ +$ bash ./cmake_build.sh +``` + +## 测试 +# 运行 add2 单测 +``` +$ cd test/python +$ python test_get_padding_offset_v2.py +``` + +## 如何贡献 +``` +$ pip install pre-commit==2.17.0 +$ pre-commit install +``` diff --git a/csrc/xpu/src/cmake_build.sh b/csrc/xpu/src/cmake_build.sh new file mode 100755 index 000000000000..934cdf42d6a0 --- /dev/null +++ b/csrc/xpu/src/cmake_build.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -e + +# export XDNN_PATH=Paddle/build/third_party/xpu/src/extern_xpu/xdnn-ubuntu_x86_64/ # +# export XRE_PATH=Paddle/build/third_party/xpu/src/extern_xpu/xre-ubuntu_x86_64/ # +# export CLANG_PATH=xtdk-ubuntu_1604_x86_64 # +# export HOST_SYSROOT=/opt/compiler/gcc-8.2/bin/gcc # + +cd plugin +./cmake_build.sh +cd - + +python -m pip uninstall paddlenlp_ops -y +python setup.py install diff --git a/csrc/xpu/src/get_padding_offset_v2.cc b/csrc/xpu/src/get_padding_offset_v2.cc new file mode 100644 index 000000000000..b2219df0b856 --- /dev/null +++ b/csrc/xpu/src/get_padding_offset_v2.cc @@ -0,0 +1,94 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/extension.h" +#include "xpu/plugin.h" + +std::vector GetPaddingOffset(const paddle::Tensor& input_ids, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& token_num, + const paddle::Tensor& seq_len) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + + std::vector input_ids_shape = input_ids.shape(); + const int bsz = seq_len.shape()[0]; + const int seq_length = input_ids_shape[1]; + auto cum_offsets_out = cum_offsets.copy_to(cum_offsets.place(), false); + auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false); + + + const int token_num_data = cpu_token_num.data()[0]; + auto x_remove_padding = paddle::full( + {token_num_data}, 0, paddle::DataType::INT64, input_ids.place()); + auto padding_offset = paddle::full( + {token_num_data}, 0, paddle::DataType::INT32, input_ids.place()); + auto cu_seqlens_q = + paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place()); + auto cu_seqlens_k = + paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place()); + int r = baidu::xpu::api::plugin::get_padding_offset( + xpu_ctx->x_context(), + padding_offset.data(), + cum_offsets_out.data(), + cu_seqlens_q.data(), + cu_seqlens_k.data(), + x_remove_padding.data(), + input_ids.data(), + cum_offsets.data(), + seq_len.data(), + seq_length, + bsz); + PD_CHECK(r == 0, "baidu::xpu::api::plugin::get_padding_offset failed."); + return {x_remove_padding, + cum_offsets_out, + padding_offset, + cu_seqlens_q, + cu_seqlens_k}; +} + +std::vector> GetPaddingOffsetInferShape( + const std::vector& input_ids_shape, + const std::vector& cum_offsets_shape, + const std::vector& token_num_shape, + const std::vector& seq_len_shape) { + int64_t bsz = seq_len_shape[0]; + int64_t seq_len = input_ids_shape[1]; + return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}}; +} + +std::vector GetPaddingOffsetInferDtype( + const paddle::DataType& input_ids_dtype, + const paddle::DataType& cum_offsets_dtype, + const paddle::DataType& token_num_dtype, + const paddle::DataType& seq_len_dtype) { + return {input_ids_dtype, + seq_len_dtype, + seq_len_dtype, + seq_len_dtype, + seq_len_dtype}; +} + +PD_BUILD_OP(get_padding_offset_v2) + .Inputs({"input_ids", "token_num", "cum_offsets", "seq_len"}) + .Outputs({"x_remove_padding", + "cum_offsets_out", + "padding_offset", + "cu_seqlens_q", + "cu_seqlens_k"}) + .SetKernelFn(PD_KERNEL(GetPaddingOffset)) + .SetInferShapeFn(PD_INFER_SHAPE(GetPaddingOffsetInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(GetPaddingOffsetInferDtype)); diff --git a/csrc/xpu/src/get_token_penalty_multi_scores_v2.cc b/csrc/xpu/src/get_token_penalty_multi_scores_v2.cc new file mode 100644 index 000000000000..c7bb929121c8 --- /dev/null +++ b/csrc/xpu/src/get_token_penalty_multi_scores_v2.cc @@ -0,0 +1,108 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/extension.h" +#include "paddle/phi/core/enforce.h" +#include "xpu/plugin.h" + +void TokenPenaltyMultiScores(const paddle::Tensor& pre_ids, + const paddle::Tensor& logits, + const paddle::Tensor& penalty_scores, + const paddle::Tensor& frequency_scores, + const paddle::Tensor& presence_scores, + const paddle::Tensor& temperatures, + const paddle::Tensor& bad_tokens, + const paddle::Tensor& cur_len, + const paddle::Tensor& min_len, + const paddle::Tensor& eos_token_id) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + int64_t bs = logits.shape()[0]; + PADDLE_ENFORCE_LE( + bs, + 640, + phi::errors::InvalidArgument( + "Only support bsz <= 1024, but received bsz is %d", bs)); + int64_t length = logits.shape()[1]; + int64_t length_id = pre_ids.shape()[1]; + int64_t length_bad_words = bad_tokens.shape()[0]; + int64_t end_length = eos_token_id.shape()[0]; + switch (logits.type()) { + case paddle::DataType::FLOAT16: { + using XPUType = typename XPUTypeTrait::Type; + typedef paddle::float16 data_t; + int r = baidu::xpu::api::plugin::token_penalty_multi_scores( + xpu_ctx->x_context(), + pre_ids.data(), + reinterpret_cast( + const_cast(logits.data())), + reinterpret_cast(penalty_scores.data()), + reinterpret_cast(frequency_scores.data()), + reinterpret_cast(presence_scores.data()), + temperatures.data(), + cur_len.data(), + min_len.data(), + eos_token_id.data(), + bad_tokens.data(), + bs, + length, + length_id, + end_length, + length_bad_words); + PD_CHECK(r == 0, "xpu::plugin::token_penalty_multi_scores failed."); + } break; + case paddle::DataType::FLOAT32: { + int r = baidu::xpu::api::plugin::token_penalty_multi_scores( + xpu_ctx->x_context(), + pre_ids.data(), + const_cast(logits.data()), + penalty_scores.data(), + frequency_scores.data(), + presence_scores.data(), + temperatures.data(), + cur_len.data(), + min_len.data(), + eos_token_id.data(), + bad_tokens.data(), + bs, + length, + length_id, + end_length, + length_bad_words); + PD_CHECK(r == 0, "xpu::plugin::token_penalty_multi_scores failed."); + } break; + default: + PD_THROW( + "NOT supported data type. " + "Only float16 and float32 are supported. "); + break; + } +} + +PD_BUILD_OP(get_token_penalty_multi_scores_v2) + .Inputs({"pre_ids", + "logits", + "penalty_scores", + "frequency_scores", + "presence_scores", + "temperatures", + "bad_tokens", + "cur_len", + "min_len", + "eos_token_id"}) + .Outputs({"logits_out"}) + .SetInplaceMap({{"logits", "logits_out"}}) + .SetKernelFn(PD_KERNEL(TokenPenaltyMultiScores)); diff --git a/csrc/xpu/src/plugin/CMakeLists.txt b/csrc/xpu/src/plugin/CMakeLists.txt new file mode 100644 index 000000000000..32f563434851 --- /dev/null +++ b/csrc/xpu/src/plugin/CMakeLists.txt @@ -0,0 +1,378 @@ +cmake_minimum_required(VERSION 3.10) + +project(xpuplugin LANGUAGES CXX) + +if(NOT DEFINED BUILD_STANDALONE) + if(NOT DEFINED XPU_INC_DIR) + message( + FATAL_ERROR + "XPU_INC_DIR not set, or directory ${XPU_INC_DIR} not found, please compile with PaddlePaddle." + ) + endif() + if(NOT DEFINED XPU_LIB_DIR) + message( + FATAL_ERROR + "XPU_LIB_DIR not set, or directory ${XPU_LIB_DIR} not found, please compile with PaddlePaddle." + ) + endif() + set(XDNN_INC_DIR ${XPU_INC_DIR}) + set(XDNN_LIB_DIR ${XPU_LIB_DIR}) + set(XRE_INC_DIR ${XPU_INC_DIR}) + set(XRE_LIB_DIR ${XPU_LIB_DIR}) + set(XPU_DEPS xpulib) # Depends cmake/external/xpu.cmake +else() + if(NOT DEFINED XDNN_PATH) + set(XDNN_PATH $ENV{XDNN_PATH}) + endif() + if(NOT DEFINED XRE_PATH) + set(XRE_PATH $ENV{XRE_PATH}) + endif() + if(NOT IS_DIRECTORY ${XDNN_PATH}) + message( + FATAL_ERROR + "XDNN_PATH not set, or directory ${XDNN_PATH} not found, please export XDNN_PATH=." + ) + endif() + if(NOT IS_DIRECTORY ${XRE_PATH}) + message( + FATAL_ERROR + "XRE_PATH not set, or directory ${XRE_PATH} not found, please export XRE_PATH=." + ) + endif() + set(XDNN_INC_DIR ${XDNN_PATH}/include) + set(XDNN_LIB_DIR ${XDNN_PATH}/so) + set(XRE_INC_DIR ${XRE_PATH}/include) + set(XRE_LIB_DIR ${XRE_PATH}/so) +endif() + +if(NOT DEFINED CLANG_PATH) + set(CLANG_PATH $ENV{CLANG_PATH}) +endif() +if(NOT IS_DIRECTORY ${CLANG_PATH}) + message( + FATAL_ERROR + "Directory ${CLANG_PATH} not found, please export CLANG_PATH=." + ) +endif() + +message(STATUS "Build with CLANG_PATH=" ${CLANG_PATH}) +set(XPU_CLANG ${CLANG_PATH}/bin/clang++) +message(STATUS "Build with XPU_CLANG=" ${XPU_CLANG}) + +if(NOT DEFINED HOST_SYSROOT) + set(HOST_SYSROOT $ENV{HOST_SYSROOT}) +endif() +if(HOST_SYSROOT) + if(NOT IS_DIRECTORY ${HOST_SYSROOT}) + message( + FATAL_ERROR + "Directory ${HOST_SYSROOT} not found, please export HOST_SYSROOT=." + ) + endif() +endif() + +if(NOT DEFINED HOST_ARCH) + set(HOST_ARCH $ENV{HOST_ARCH}) +endif() +if(NOT HOST_ARCH) + set(HOST_ARCH x86_64-baidu-linux-gnu) +endif() + +if(NOT DEFINED TARGET_ARCH) + set(TARGET_ARCH $ENV{TARGET_ARCH}) +endif() +if(NOT TARGET_ARCH) + set(TARGET_ARCH x86_64-baidu-linux-gnu) +endif() + +if(NOT DEFINED TOOLCHAIN_ARGS) + set(TOOLCHAIN_ARGS $ENV{TOOLCHAIN_ARGS}) +endif() +if(HOST_ARCH MATCHES "x86_64") + if(TARGET_ARCH MATCHES "x86_64") + if(EXISTS ${HOST_SYSROOT}/bin/g++) + set(HOST_CXX ${HOST_SYSROOT}/bin/g++) + set(HOST_AR ${HOST_SYSROOT}/bin/ar) + if(NOT EXISTS ${HOST_AR}) + # try gcc-ar + set(HOST_AR ${HOST_SYSROOT}/bin/gcc-ar) + endif() + else() + set(HOST_CXX /usr/bin/g++) + set(HOST_AR /usr/bin/ar) + endif() + endif() + if(TARGET_ARCH MATCHES "aarch64") + set(TOOLCHAIN_ARGS "${TOOLCHAIN_ARGS} --gcc-toolchain=${HOST_SYSROOT}") + set(HOST_SYSROOT ${HOST_SYSROOT}/aarch64-linux-gnu/libc) + set(HOST_CXX ${CMAKE_CXX_COMPILER}) + set(HOST_AR ${CMAKE_AR}) + endif() +endif() +if(HOST_ARCH MATCHES "aarch64") + if(TARGET_ARCH MATCHES "aarch64") + if(EXISTS ${HOST_SYSROOT}/bin/g++) + set(HOST_CXX ${HOST_SYSROOT}/bin/g++) + set(HOST_AR ${HOST_SYSROOT}/bin/ar) + else() + set(HOST_CXX /usr/bin/g++) + set(HOST_AR /usr/bin/ar) + endif() + endif() +endif() + +set(OPT_LEVEL "-O2") +message(STATUS "Build with TARGET_ARCH=" ${TARGET_ARCH}) +message(STATUS "Build with TOOLCHAIN_ARGS=" ${TOOLCHAIN_ARGS}) +message(STATUS "Build with HOST_SYSROOT=" ${HOST_SYSROOT}) +message(STATUS "Build with HOST_CXX=" ${HOST_CXX}) +message(STATUS "Build with HOST_AR=" ${HOST_AR}) + +separate_arguments(TOOLCHAIN_ARGS) +# compile xpu kernel macro function +macro( + compile_kernel + kernel_path + kernel_name + xpu_n + rule + device_o_extra_flags + host_o_extra_flags + xpu_n_macro) + set(arg_rule ${rule}) + separate_arguments(arg_rule) + set(arg_device_o_extra_flags ${device_o_extra_flags}) + separate_arguments(arg_device_o_extra_flags) + set(arg_host_o_extra_flags ${host_o_extra_flags}) + separate_arguments(arg_host_o_extra_flags) + + add_custom_command( + OUTPUT ${kernel_name}.device.bin.o ${kernel_name}.o + COMMAND + ${XPU_CLANG} -std=c++11 ${OPT_LEVEL} ${arg_device_o_extra_flags} -c + ${kernel_path} -D ${xpu_n_macro} --target=${TARGET_ARCH} ${HOST_XPU_FLAGS} + --basename ${kernel_name} -fno-builtin --xpu-arch=${xpu_n} -fPIC + -Wno-int-to-void-pointer-cast -Wno-int-to-pointer-cast -Werror -mllvm + --xpu-inline-cost -mllvm --xpu-inline-hot-call -I${XDNN_INC_DIR} + -I${CMAKE_CURRENT_SOURCE_DIR}/include -I${CMAKE_CURRENT_SOURCE_DIR}/src + -I${CMAKE_CURRENT_SOURCE_DIR}/src/kernel + -I${CMAKE_CURRENT_SOURCE_DIR}/src/kernel/include ${arg_rule} + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${kernel_path} + COMMENT ${kernel_name}.device.bin.o ${kernel_name}.o + VERBATIM) + + list(APPEND xpuplugin_kernels_depends ${kernel_name}.device.bin.o + ${kernel_name}.o) +endmacro() + +macro( + __compile_kernel_with_rules + kernel_path + kernel_name + xpu_n + rules_path + device_o_extra_flags + host_o_extra_flags + xpu_n_macro) + file(STRINGS ${rules_path} rules) + + foreach(rule IN LISTS rules) + message(STATUS " Instantiate with '${rule}'") + execute_process( + COMMAND bash "-c" "echo -n ${rule} | md5sum | cut -c1-6" + OUTPUT_VARIABLE rule_md5 + OUTPUT_STRIP_TRAILING_WHITESPACE) + set(kernel_name_md5 ${kernel_name}_${rule_md5}) + compile_kernel( + ${kernel_path} + ${kernel_name_md5} + ${xpu_n} + ${rule} + ${device_o_extra_flags} + ${host_o_extra_flags} + ${xpu_n_macro}) + endforeach() +endmacro() + +macro( + compile_kernel_with_rules + kernel_path + kernel_name + xpu_n + rules_path + device_o_extra_flags + host_o_extra_flags + xpu_n_macro) + # reconfigure if file |rules_path| was modified + set_property( + DIRECTORY + APPEND + PROPERTY CMAKE_CONFIGURE_DEPENDS ${rules_path}) + __compile_kernel_with_rules( + ${kernel_path} + ${kernel_name} + ${xpu_n} + ${rules_path} + ${device_o_extra_flags} + ${host_o_extra_flags} + ${xpu_n_macro}) +endmacro() + +macro(search_and_compile_kernel xpu_n) + if(${xpu_n} STREQUAL "xpu1") + set(XPU_DEVICE_O_EXTRA_FLAGS " ") + set(XPU_HOST_O_EXTRA_FLAGS " ") + set(XPU_KERNEL_PATH "src/kernel/cpp/*.xpu") + set(xpu_n_macro "__XPU1__") + elseif(${xpu_n} STREQUAL "xpu2") + set(XPU_DEVICE_O_EXTRA_FLAGS "--xpu-arch=xpu2") + set(XPU_HOST_O_EXTRA_FLAGS "--xpu-arch=xpu2") + set(XPU_KERNEL_PATH "src/kernel/kunlun2cpp/*.xpu") + set(xpu_n_macro "__XPU2__") + elseif(${xpu_n} STREQUAL "xpu3") + set(XPU_DEVICE_O_EXTRA_FLAGS "--xpu-arch=xpu3") + set(XPU_HOST_O_EXTRA_FLAGS "--xpu-arch=xpu3") + set(XPU_KERNEL_PATH "src/kernel/kunlun3cpp/*.xpu") + set(xpu_n_macro "__XPU3__") + else() + message(FATAL_ERROR "Are you sure? ${xpu_n}") + endif() + file(GLOB_RECURSE xpu_kernels ${XPU_KERNEL_PATH}) + list(LENGTH xpu_kernels xpu_kernels_num) + message(STATUS "Found ${xpu_kernels_num} ${xpu_n} kernels") + + foreach(xpu_kernel IN LISTS xpu_kernels) + message(STATUS "Process ${xpu_kernel}") + get_filename_component(kernel_name ${xpu_kernel} NAME_WE) + get_filename_component(kernel_dir ${xpu_kernel} DIRECTORY) + set(kernel_rules ${kernel_dir}/${kernel_name}.rules) + set(kernel_name ${xpu_n}_${kernel_name}) + if(EXISTS ${kernel_rules}) + compile_kernel_with_rules( + ${xpu_kernel} + ${kernel_name} + ${xpu_n} + ${kernel_rules} + ${XPU_DEVICE_O_EXTRA_FLAGS} + ${XPU_HOST_O_EXTRA_FLAGS} + ${xpu_n_macro}) + else() + compile_kernel( + ${xpu_kernel} + ${kernel_name} + ${xpu_n} + " " + ${XPU_DEVICE_O_EXTRA_FLAGS} + ${XPU_HOST_O_EXTRA_FLAGS} + ${xpu_n_macro}) + endif() + endforeach() +endmacro() + +# compile xpu kernels +search_and_compile_kernel("xpu1") +search_and_compile_kernel("xpu2") +search_and_compile_kernel("xpu3") + +# compile xpu wrappers +file(GLOB_RECURSE xpu_wrappers src/wrapper/*.cpp) +list(LENGTH xpu_wrappers xpu_wrappers_num) +message(STATUS "Found ${xpu_wrappers_num} XPU wrappers") + +foreach(xpu_wrapper IN LISTS xpu_wrappers) + message(STATUS "Process ${xpu_wrapper}") + get_filename_component(wrapper_name ${xpu_wrapper} NAME_WE) + set(wrapper_target ${wrapper_name}_wrapper) + + add_custom_target( + ${wrapper_target} + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS wrapper_build/${wrapper_name}.wrapper.d + wrapper_build/${wrapper_name}.wrapper.o + COMMENT ${wrapper_target} + VERBATIM) + + add_custom_command( + OUTPUT wrapper_build/${wrapper_name}.wrapper.d + COMMAND ${CMAKE_COMMAND} -E make_directory wrapper_build + COMMAND + ${XPU_CLANG} -M -MQ wrapper_build/${wrapper_name}.wrapper.o -MF + wrapper_build/${wrapper_name}.wrapper.d -std=c++11 -x xpu -c + ${xpu_wrapper} -I${XDNN_INC_DIR} -I${XRE_INC_DIR} + -I${CMAKE_CURRENT_SOURCE_DIR}/include -I${CMAKE_CURRENT_SOURCE_DIR}/src + -I${CMAKE_CURRENT_SOURCE_DIR}/src/wrapper -D_GNU_SOURCE + -D__STDC_LIMIT_MACROS -DNDEBUG ${TOOLCHAIN_ARGS} --target=${TARGET_ARCH} + -fPIC -Werror -Wreorder -fvisibility=hidden --xpu-host-only + ${XPU_MF_FLAGS} + COMMAND + ${CMAKE_COMMAND} -E cmake_depends "Unix Makefiles" ${CMAKE_SOURCE_DIR} + ${CMAKE_SOURCE_DIR} ${CMAKE_BINARY_DIR} ${CMAKE_BINARY_DIR} + ${CMAKE_BINARY_DIR}/CMakeFiles/${wrapper_target}.dir/DependInfo.cmake + --color=$(COLOR) + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${xpu_wrapper} ${XPU_DEPS} + COMMENT wrapper_build/${wrapper_name}.wrapper.d + VERBATIM) + + add_custom_command( + OUTPUT wrapper_build/${wrapper_name}.wrapper.o + COMMAND ${CMAKE_COMMAND} -E make_directory wrapper_build + COMMAND + ${XPU_CLANG} -std=c++11 ${EXTRA_FLAGS} ${OPT_LEVEL} -x xpu -c + ${xpu_wrapper} -o wrapper_build/${wrapper_name}.wrapper.o + -I${XDNN_INC_DIR} -I${XRE_INC_DIR} -I${CMAKE_CURRENT_SOURCE_DIR}/include + -I${CMAKE_CURRENT_SOURCE_DIR}/src + -I${CMAKE_CURRENT_SOURCE_DIR}/src/wrapper -D_GNU_SOURCE + -D__STDC_LIMIT_MACROS -DNDEBUG ${TOOLCHAIN_ARGS} --target=${TARGET_ARCH} + -fPIC -Wunused-variable -Werror -Wreorder -fvisibility=hidden + --xpu-host-only ${HOST_XPU_FLAGS} + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS wrapper_build/${wrapper_name}.wrapper.d + COMMENT wrapper_build/${wrapper_name}.wrapper.o + VERBATIM) + list(APPEND xpuplugin_wrapper_depends wrapper_build/${wrapper_name}.wrapper.o) +endforeach() + +add_custom_command( + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/libxpuplugin.a + COMMAND ${HOST_AR} rcs ${CMAKE_CURRENT_BINARY_DIR}/libxpuplugin.a + ${xpuplugin_kernels_depends} ${xpuplugin_wrapper_depends} + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${xpuplugin_kernels_depends} ${xpuplugin_wrapper_depends} + COMMENT ${CMAKE_CURRENT_BINARY_DIR}/libxpuplugin.a + VERBATIM) + +add_custom_target( + xpuplugin_a + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${xpuplugin_kernels_depends} ${xpuplugin_wrapper_depends} + ${CMAKE_CURRENT_BINARY_DIR}/libxpuplugin.a + COMMENT xpuplugin_a + VERBATIM) + +add_custom_target( + xpuplugin_so ALL + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS xpuplugin_a ${CMAKE_CURRENT_BINARY_DIR}/libxpuplugin.so + COMMENT xpuplugin_so) + +add_custom_command( + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/libxpuplugin.so + COMMAND + ${HOST_CXX} -shared -o ${CMAKE_CURRENT_BINARY_DIR}/libxpuplugin.so -Xlinker + \"-\(\" -Wl,--whole-archive ${CMAKE_CURRENT_BINARY_DIR}/libxpuplugin.a + -Wl,--no-whole-archive -L${XDNN_LIB_DIR} -L${XRE_LIB_DIR} -lxpurt -lxpuapi + -Wl,--no-undefined -Wl,-soname,libxpuplugin.so -lstdc++ -ldl -lm -lpthread + -specs=${CMAKE_CURRENT_SOURCE_DIR}/src/linker.specs -Xlinker \"-\)\"\; + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/libxpuplugin.a + COMMENT ${CMAKE_CURRENT_BINARY_DIR}/libxpuplugin.so) + +if(NOT DEFINED BUILD_STANDALONE) + add_library(xpuplugin STATIC IMPORTED GLOBAL) + add_dependencies(xpuplugin xpuplugin_a) + set_target_properties( + xpuplugin PROPERTIES IMPORTED_LOCATION + ${CMAKE_CURRENT_BINARY_DIR}/libxpuplugin.a) +endif() diff --git a/csrc/xpu/src/plugin/README.md b/csrc/xpu/src/plugin/README.md new file mode 100644 index 000000000000..b3a8ef889697 --- /dev/null +++ b/csrc/xpu/src/plugin/README.md @@ -0,0 +1,16 @@ +# XPU PLUGIN +## Standalone build and test. +``` +$ cd plugin +Modify ./cmake_build.sh to set the path of XDNN, XRE and XTDK. +$ ./cmake_build.sh +``` +## Build with PaddlePaddle +### Copy to the source code of PaddlePaddle. +``` +$ cp -rf plugin /paddle/phi/xpu +``` +### Add -DWITH_XPU_PLUGIN=ON as extra cmake arguments. +``` +$ cmake .. -DWITH_XPU_PLUGIN=ON +``` diff --git a/csrc/xpu/src/plugin/cmake_build.sh b/csrc/xpu/src/plugin/cmake_build.sh new file mode 100755 index 000000000000..44f7344adadb --- /dev/null +++ b/csrc/xpu/src/plugin/cmake_build.sh @@ -0,0 +1,28 @@ +#!/bin/bash + +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -e + +# export XDNN_PATH=Paddle/build/third_party/xpu/src/extern_xpu/xdnn-ubuntu_x86_64/ # +# export XRE_PATH=Paddle/build/third_party/xpu/src/extern_xpu/xre-ubuntu_x86_64/ # +# export CLANG_PATH=xtdk-ubuntu_1604_x86_64 # +# export HOST_SYSROOT=/opt/compiler/gcc-8.2 # + +rm -rf build +mkdir build +cd build +cmake -DCMAKE_VERBOSE_MAKEFILE=ON -DBUILD_STANDALONE=ON .. +make -j 32 diff --git a/csrc/xpu/src/plugin/include/xpu/plugin.h b/csrc/xpu/src/plugin/include/xpu/plugin.h new file mode 100644 index 000000000000..a437ddc47bcf --- /dev/null +++ b/csrc/xpu/src/plugin/include/xpu/plugin.h @@ -0,0 +1,107 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/* + * copyright (C) 2022 KUNLUNXIN, Inc + */ + +#pragma once +#include "xpu/xdnn.h" + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +template +DLL_EXPORT int set_stop_value_multi_ends(Context* ctx, + bool* stop_flags, + T* topk_ids, + T* next_tokens, + const T* end_ids, + const int* seq_lens, + const int bs, + const int end_length, + const bool beam_search); + + +DLL_EXPORT int set_value_by_flags_and_idx(Context* ctx, + const bool* stop_flags, + int64_t* pre_ids_all, + const int64_t* input_ids, + const int* seq_lens_encoder, + const int* seq_lens_decoder, + const int64_t* step_idx, + int bs, + int length, + int length_input_ids); + +template +DLL_EXPORT int token_penalty_multi_scores(Context* ctx, + const int64_t* pre_ids, + T* logits, + const T* penalty_scores, + const T* frequency_scores, + const T* presence_scores, + const float* temperatures, + const int64_t* cur_len, + const int64_t* min_len, + const int64_t* eos_token_id, + const int64_t* bad_words, + const int64_t bs, + const int64_t length, + const int64_t length_id, + const int64_t end_length, + const int64_t length_bad_words); + +DLL_EXPORT int get_padding_offset(Context* ctx, + int* padding_offset, + int* cum_offsets_out, + int* cu_seqlens_q, + int* cu_seqlens_k, + int64_t* x_remove_padding, + const int64_t* input_ids, + const int* cum_offsets, + const int* seq_lens, + const int max_seq_len, + const int bs); + +DLL_EXPORT int update_inputs(Context* ctx, + bool* not_need_stop, + int* seq_lens_this_time, + int* seq_lens_encoder, + int* seq_lens_decoder, + int64_t* input_ids, + const int64_t* stop_nums, + const bool* stop_flags, + const bool* is_block_step, + const int64_t* next_tokens, + const int bsz, + const int max_bsz, + const int input_ids_stride); + +template +DLL_EXPORT int rebuild_padding(Context *ctx, + T *output_data, // [bs, dim_embed] + const T *input_data, // [token_num, dim_embed] + const int *cum_offsets, // [bs] + const int *seq_len_decoder, // [bs] + const int *seq_len_encoder, // [bs] + const int seq_len, + const int dim_embed, + const int elem_nums); + +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/csrc/xpu/src/plugin/src/kernel/kunlun2cpp/ban_bad_words.xpu b/csrc/xpu/src/plugin/src/kernel/kunlun2cpp/ban_bad_words.xpu new file mode 100644 index 000000000000..541a1f8351ec --- /dev/null +++ b/csrc/xpu/src/plugin/src/kernel/kunlun2cpp/ban_bad_words.xpu @@ -0,0 +1,58 @@ +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" + +namespace xpu2 { +namespace plugin { + +template +inline __device__ void update_bad_words_logit(_global_ptr_ T* logits) { + __local__ T min_value = -1e10; + mfence_lm(); + LM2GM((void*)&(min_value), logits, sizeof(T)); +} + +template <> +inline __device__ void update_bad_words_logit( + _global_ptr_ float16* logits) { + __local__ short min_value = 0xFBFF; + mfence_lm(); + LM2GM((void*)&(min_value), logits, sizeof(float16)); +} + +template +__global__ void ban_bad_words(T* logits, + const int64_t* bad_words_list, + const int64_t bs, + const int64_t length, + const int64_t bad_words_length) { + int tid = core_id() * cluster_num() + cluster_id(); + int nthreads = cluster_num() * core_num(); + int start = -1; + int end = -1; + partition( + tid, nthreads, static_cast(bs * bad_words_length), 1, &start, &end); + for (int i = start; i < end; i++) { + int batch_idx = i / bad_words_length; + int bad_words_idx = i - batch_idx * bad_words_length; + int64_t bad_words_token_id = -1; + mfence_lm(); + GM2LM(bad_words_list + bad_words_idx, + (void*)&(bad_words_token_id), + sizeof(int64_t)); + if (bad_words_token_id >= length || bad_words_token_id < 0) continue; + update_bad_words_logit(logits + batch_idx * length + bad_words_token_id); + } +} + +#define _XPU_DEF__BAN_BAD_WORDS_(DATA_TYPE) \ + template __global__ void ban_bad_words(DATA_TYPE* logits, \ + const int64_t* bad_words_list, \ + const int64_t bs, \ + const int64_t length, \ + const int64_t bad_words_length); +_XPU_DEF__BAN_BAD_WORDS_(float); +_XPU_DEF__BAN_BAD_WORDS_(float16); + +} // namespace plugin +} // namespace xpu2 diff --git a/csrc/xpu/src/plugin/src/kernel/kunlun2cpp/get_padding_offset.xpu b/csrc/xpu/src/plugin/src/kernel/kunlun2cpp/get_padding_offset.xpu new file mode 100644 index 000000000000..1357b53f9c99 --- /dev/null +++ b/csrc/xpu/src/plugin/src/kernel/kunlun2cpp/get_padding_offset.xpu @@ -0,0 +1,51 @@ +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" + +namespace xpu2 { +namespace plugin { + +__global__ void get_padding_offset(int *padding_offset, + int *cum_offsets_out, + int *cu_seqlens_q, + int *cu_seqlens_k, + const int *cum_offsets, + const int *seq_lens, + const int max_seq_len, + const int bs) { + int cid = core_id(); + int ncores = core_num(); + int clusterid = cluster_id(); + int nclusters = cluster_num(); + int tid = clusterid * ncores + cid; + + int buf_len = 32; + __simd__ int padding_offset_lm[buf_len]; + __simd__ int cum_offsets_lm[16]; + int seq_len_lm; + for (int i = clusterid; i < bs; i += nclusters) { + GM2LM_ASYNC(seq_lens + i, &seq_len_lm, sizeof(int)); + GM2LM(cum_offsets + i - 1, cum_offsets_lm, 2 * sizeof(int)); + if (i == 0) { + cum_offsets_lm[0] = 0; + } + for (int j = cid * buf_len; j < seq_len_lm; j += ncores * buf_len) { + int cur_len = min(seq_len_lm - j, buf_len); + for (int k = 0; k < cur_len; k++) { + padding_offset_lm[k] = cum_offsets_lm[0]; + } + LM2GM(padding_offset_lm, + padding_offset + i * max_seq_len - cum_offsets_lm[0] + j, + cur_len * sizeof(int)); + } + if (cid == 0) { + LM2GM_ASYNC(cum_offsets_lm, cum_offsets_out + i, sizeof(int)); + int cum_seq_len = (i + 1) * max_seq_len - cum_offsets_lm[1]; + LM2GM_ASYNC(&cum_seq_len, cu_seqlens_q + i + 1, sizeof(int)); + LM2GM(&cum_seq_len, cu_seqlens_k + i + 1, sizeof(int)); + } + } +} + +} // namespace plugin +} // namespace xpu2 diff --git a/csrc/xpu/src/plugin/src/kernel/kunlun2cpp/get_value_by_id.xpu b/csrc/xpu/src/plugin/src/kernel/kunlun2cpp/get_value_by_id.xpu new file mode 100644 index 000000000000..ce7da1d7d150 --- /dev/null +++ b/csrc/xpu/src/plugin/src/kernel/kunlun2cpp/get_value_by_id.xpu @@ -0,0 +1,53 @@ +#include "xpu/kernel/xtdk.h" +#include "xpu/kernel/xtdk_math.h" +#include "xpu/kernel/xtdk_simd.h" +#include "xpu/kernel/cluster.h" + +namespace xpu2 { +namespace plugin { + +// assert batch <= 512 +template +__global__ void get_value_by_id(const T* logits, + const TID* ids, + T* logits_out, + int batch, + int seq_len, + int hidden_dim) { + int tid = core_id() * cluster_num() + cluster_id(); + int nthreads = core_num() * cluster_num(); + constexpr int buf_size = 1024; + __simd__ T lm_buf[buf_size]; + __simd__ TID lm_ids[batch]; + int block_cnt = roundup_div(hidden_dim, buf_size); + GM2LM(ids, lm_ids, batch * sizeof(TID)); + + for (int i = tid; i < batch * block_cnt; i += nthreads) { + int curr_bs = i / block_cnt; + int curr_block = i % block_cnt; + TID curr_id = lm_ids[curr_bs]; + if (curr_id == -1) { + curr_id = 0; + } + int src_offset = curr_bs * seq_len * hidden_dim + curr_id * hidden_dim + + curr_block * buf_size; + int dst_offset = curr_bs * hidden_dim + curr_block * buf_size; + int read_len = min(buf_size, hidden_dim - curr_block * buf_size); + GM2LM(logits + src_offset, lm_buf, read_len * sizeof(T)); + LM2GM(lm_buf, logits_out + dst_offset, read_len * sizeof(T)); + } +} + +#define _XPU_DEF__GET_VALUE_BY_ID_(DTYPE, IDTYPE) \ + template __global__ void get_value_by_id(const DTYPE* logits, \ + const IDTYPE* ids, \ + DTYPE* logits_out, \ + int batch, \ + int seq_len, \ + int hidden_dim); + +_XPU_DEF__GET_VALUE_BY_ID_(float, int); +_XPU_DEF__GET_VALUE_BY_ID_(float16, int); + +} // namespace plugin +} // namespace xpu2 diff --git a/csrc/xpu/src/plugin/src/kernel/kunlun2cpp/min_length_logits_process.xpu b/csrc/xpu/src/plugin/src/kernel/kunlun2cpp/min_length_logits_process.xpu new file mode 100644 index 000000000000..818c4dd3580b --- /dev/null +++ b/csrc/xpu/src/plugin/src/kernel/kunlun2cpp/min_length_logits_process.xpu @@ -0,0 +1,73 @@ +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" + +namespace xpu2 { +namespace plugin { + +template +inline __device__ void update_logit(_global_ptr_ T* logits) { + __local__ T min_value = -1e10; + mfence_lm(); + LM2GM((void*)&(min_value), logits, sizeof(T)); +} + +template <> +inline __device__ void update_logit( + _global_ptr_ float16* logits) { + __local__ short min_value = 0xFBFF; + mfence_lm(); + LM2GM((void*)&(min_value), logits, sizeof(float16)); +} + +template +__global__ void min_length_logits_process(T* logits, + const int64_t* cur_len, + const int64_t* min_len, + const int64_t* eos_token_id, + const int64_t bs, + const int64_t length, + const int64_t length_id, + const int64_t end_length) { + int ncores = core_num(); + int cid = core_id(); + int tid = cluster_num() * cid + cluster_id(); + int nthreads = cluster_num() * ncores; + + int64_t cur_len_now; + int64_t min_len_now; + int64_t eos_token_id_now; + int64_t bi; + int64_t end_num; + __simd__ float float32logits_now[32]; + + for (int64_t i = tid; i < bs * end_length; i += nthreads) { + bi = i / end_length; + end_num = i % end_length; + mfence_lm(); + GM2LM_ASYNC(cur_len + bi, (void*)&(cur_len_now), sizeof(int64_t)); + GM2LM_ASYNC(min_len + bi, (void*)&(min_len_now), sizeof(int64_t)); + mfence(); + if (cur_len_now >= 0 && cur_len_now < min_len_now) { + GM2LM( + eos_token_id + end_num, (void*)&(eos_token_id_now), sizeof(int64_t)); + update_logit(logits + bi * length + eos_token_id_now); + } + } +} + +#define _XPU_DEF__UPDATE_LOGITS_REPEAT_TIMES_(DATA_TYPE) \ + template __global__ void min_length_logits_process( \ + DATA_TYPE * logits, \ + const int64_t* cur_len, \ + const int64_t* min_len, \ + const int64_t* eos_token_id, \ + const int64_t bs, \ + const int64_t length, \ + const int64_t length_id, \ + const int64_t end_length); +_XPU_DEF__UPDATE_LOGITS_REPEAT_TIMES_(float); +_XPU_DEF__UPDATE_LOGITS_REPEAT_TIMES_(float16); + +} // namespace plugin +} // namespace xpu2 diff --git a/csrc/xpu/src/plugin/src/kernel/kunlun2cpp/rebuild_padding.xpu b/csrc/xpu/src/plugin/src/kernel/kunlun2cpp/rebuild_padding.xpu new file mode 100644 index 000000000000..4c5883f923b4 --- /dev/null +++ b/csrc/xpu/src/plugin/src/kernel/kunlun2cpp/rebuild_padding.xpu @@ -0,0 +1,71 @@ +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" +#include "xpu/kernel/xtdk_io.h" + +namespace xpu2 { +namespace plugin { + +template +__global__ void rebuild_padding(T *output_data, // [bs, dim_embed] + const T *input_data, // [token_num, dim_embed] + const int *cum_offsets, // [bs] + const int *seq_len_decoder, // [bs] + const int *seq_len_encoder, // [bs] + const int seq_len, + const int dim_embed, + const int elem_nums) { + int ncores = core_num(); + int cid = core_id(); + int tid = cluster_num() * cid + cluster_id(); + int nthreads = cluster_num() * ncores; + int bs = elem_nums / dim_embed; + __local__ int cum_offsets_lm[bs]; + __local__ int seq_len_decoder_lm[bs]; + __local__ int seq_len_encoder_lm[bs]; + GM2LM_ASYNC(cum_offsets, cum_offsets_lm, bs * sizeof(int)); + GM2LM_ASYNC(seq_len_decoder, seq_len_decoder_lm, bs * sizeof(int)); + GM2LM_ASYNC(seq_len_encoder, seq_len_encoder_lm, bs * sizeof(int)); + mfence(); + int buf_size = 128; + __local__ T buf[buf_size]; + if (dim_embed< buf_size){ + buf_size = dim_embed; + } + int read_len = (dim_embed - 1) / buf_size + 1; + + for (int64_t i = tid; i < bs * read_len; i += nthreads) { + int bi = i / (read_len); + int bias_idx = i % (read_len); + int seq_id = 0; + // just encoder or stop, get last token; just decoder, get first token. + if (seq_len_decoder_lm[bi] == 0) { + if (seq_len_encoder_lm[bi] != 0) { + seq_id = seq_len_encoder_lm[bi] - 1; + } else { + continue; + } + } + int ori_token_idx = bi * seq_len - cum_offsets_lm[bi] + seq_id; + int src_offset = ori_token_idx * dim_embed + bias_idx * buf_size; + int copy_len = (bias_idx + 1) * buf_size <= dim_embed ? buf_size : dim_embed - buf_size; + GM2LM(input_data + src_offset, buf, copy_len * sizeof(T)); + LM2GM(buf, output_data + bi * dim_embed + bias_idx * buf_size, copy_len * sizeof(T)); + } +} + +#define _XPU_DEF_REBUILD_PADING_(DATA_TYPE) \ + template __global__ void rebuild_padding( \ + DATA_TYPE * output_data, \ + const DATA_TYPE *input_data, \ + const int *cum_offsets, \ + const int *seq_len_decoder, \ + const int *seq_len_encoder, \ + const int seq_len, \ + const int dim_embed, \ + const int elem_nums); +_XPU_DEF_REBUILD_PADING_(float); +_XPU_DEF_REBUILD_PADING_(float16); + +} // namespace plugin +} // namespace xpu2 diff --git a/csrc/xpu/src/plugin/src/kernel/kunlun2cpp/remove_padding.xpu b/csrc/xpu/src/plugin/src/kernel/kunlun2cpp/remove_padding.xpu new file mode 100644 index 000000000000..740e24b2c684 --- /dev/null +++ b/csrc/xpu/src/plugin/src/kernel/kunlun2cpp/remove_padding.xpu @@ -0,0 +1,40 @@ +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" + +namespace xpu2 { +namespace plugin { + +__global__ void remove_padding(int64_t *x_remove_padding, + const int64_t *input_data, + const int *seq_lens, + const int *cum_offsets, + const int sequence_length, + const int bs) { + int cid = core_id(); + int ncores = core_num(); + int clusterid = cluster_id(); + int nclusters = cluster_num(); + int tid = clusterid * ncores + cid; + + int buf_len = 32; + __simd__ int64_t input_lm[buf_len]; + int seq_len_lm; + int cum_offset_lm; + for (int i = clusterid; i < bs; i += nclusters) { + GM2LM_ASYNC(seq_lens + i, &seq_len_lm, sizeof(int)); + GM2LM(cum_offsets + i, &cum_offset_lm, sizeof(int)); + for (int j = cid * buf_len; j < seq_len_lm; j += ncores * buf_len) { + int cur_len = min(seq_len_lm - j, buf_len); + GM2LM(input_data + i * sequence_length + j, + input_lm, + sizeof(int64_t) * cur_len); + LM2GM(input_lm, + x_remove_padding + i * sequence_length - cum_offset_lm + j, + sizeof(int64_t) * cur_len); + } + } +} + +} // namespace plugin +} // namespace xpu2 diff --git a/csrc/xpu/src/plugin/src/kernel/kunlun2cpp/set_stop_value_multi_ends.xpu b/csrc/xpu/src/plugin/src/kernel/kunlun2cpp/set_stop_value_multi_ends.xpu new file mode 100644 index 000000000000..f7e15482a4b3 --- /dev/null +++ b/csrc/xpu/src/plugin/src/kernel/kunlun2cpp/set_stop_value_multi_ends.xpu @@ -0,0 +1,91 @@ +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/xtdk.h" +#include "xpu/kernel/xtdk_math.h" +#include "xpu/kernel/xtdk_simd.h" + +namespace xpu2 { +namespace plugin { + +template +static inline __device__ bool is_in_end(const T id, + const T* end_ids, + const int length) { + for (int i = 0; i < length; i++) { + if (id == end_ids[i]) { + return true; + } + } + return false; +} + +template +__global__ void set_stop_value_multi_ends(bool* stop_flags, + T* topk_ids, + T* next_tokens, + const T* end_ids, + const int* seq_lens, + const int bs, + const int end_length, + const bool beam_search) { + int ncores = core_num(); + int cid = core_id(); + int tid = cluster_num() * cid + cluster_id(); + int nthreads = cluster_num() * ncores; + + int startidx = -1; + int endidx = -1; + partition(tid, nthreads, bs, 1, &startidx, &endidx); + if (startidx >= endidx) return; + + constexpr int buf_len = 64; + __simd__ __local__ bool stop_flags_lm[buf_len]; + __simd__ __local__ T topk_ids_lm[buf_len]; + __simd__ __local__ T next_tokens_lm[buf_len]; + __simd__ __local__ T end_ids_lm[256]; + __simd__ __local__ int seq_lens_lm[buf_len]; + GM2LM_ASYNC(end_ids, end_ids_lm, sizeof(T) * end_length); + + for (int64_t i = startidx; i < endidx; i += buf_len) { + int readlen = min(static_cast(buf_len), endidx - i); + GM2LM_ASYNC(stop_flags + i, stop_flags_lm, sizeof(bool) * readlen); + GM2LM_ASYNC(topk_ids + i, topk_ids_lm, sizeof(T) * readlen); + GM2LM_ASYNC(next_tokens + i, next_tokens_lm, sizeof(T) * readlen); + GM2LM_ASYNC(seq_lens + i, seq_lens_lm, sizeof(int) * readlen); + mfence(); + for (int j = 0; j < readlen; j++) { + if (stop_flags_lm[j]) { + if (seq_lens_lm[j] == 0) { + topk_ids_lm[j] = -1; + } else { + topk_ids_lm[j] = end_ids_lm[0]; + next_tokens_lm[j] = end_ids_lm[0]; + } + } else { + next_tokens_lm[j] = topk_ids_lm[j]; + } + if (!beam_search && is_in_end(topk_ids_lm[j], end_ids_lm, end_length)) { + stop_flags_lm[j] = true; + } + } + mfence_lm(); + LM2GM_ASYNC(topk_ids_lm, topk_ids + i, sizeof(T) * readlen); + LM2GM_ASYNC(next_tokens_lm, next_tokens + i, sizeof(T) * readlen); + LM2GM_ASYNC(stop_flags_lm, stop_flags + i, sizeof(bool) * readlen); + mfence(); + } +} + +#define _XPU_DEF__SET_VALUE_BY_FLAGS_BOTH_(DATA_TYPE) \ + template __global__ void set_stop_value_multi_ends( \ + bool* stop_flags, \ + DATA_TYPE* topk_ids, \ + DATA_TYPE* next_tokens, \ + const DATA_TYPE* end_ids, \ + const int* seq_lens, \ + const int bs, \ + const int end_length, \ + const bool beam_search); +_XPU_DEF__SET_VALUE_BY_FLAGS_BOTH_(int64_t); + +} // namespace plugin +} // namespace xpu2 diff --git a/csrc/xpu/src/plugin/src/kernel/kunlun2cpp/set_value_by_flags_and_idx.xpu b/csrc/xpu/src/plugin/src/kernel/kunlun2cpp/set_value_by_flags_and_idx.xpu new file mode 100644 index 000000000000..24d6cf747cfc --- /dev/null +++ b/csrc/xpu/src/plugin/src/kernel/kunlun2cpp/set_value_by_flags_and_idx.xpu @@ -0,0 +1,50 @@ +#include "xpu/kernel/cluster.h" +namespace xpu2 { +namespace plugin { + +__global__ void set_value_by_flags_and_idx(const bool* stop_flags, + int64_t* pre_ids_all, + const int64_t* input_ids, + const int* seq_lens_encoder, + const int* seq_lens_decoder, + const int64_t* step_idx, + int bs, + int length, + int length_input_ids) { + int ncores = core_num(); + int cid = core_id(); + int tid = cluster_num() * cid + cluster_id(); + int nthreads = cluster_num() * ncores; + + bool stop_flags_now; + int64_t input_ids_now; + int seq_len_enc; + int seq_len_dec; + int64_t step_idx_now; + for (int i = tid; i < bs; i += nthreads) { + GM2LM_ASYNC(stop_flags + i, (void*)&(stop_flags_now), sizeof(bool)); + GM2LM_ASYNC(seq_lens_encoder + i, (void*)&(seq_len_enc), sizeof(int)); + GM2LM_ASYNC(seq_lens_decoder + i, (void*)&(seq_len_dec), sizeof(int)); + GM2LM(step_idx + i, (void*)&(step_idx_now), sizeof(int64_t)); + if (!stop_flags_now && step_idx_now >= 0 && + (seq_len_dec != 0 || seq_len_enc != 0)) { + if (seq_len_dec == 0) { + // encoder, get last token accord to seq_lens_encoder + GM2LM(input_ids + i * length_input_ids + seq_len_enc - 1, + (void*)&(input_ids_now), + sizeof(int64_t)); + } else { + // decoder, get first token + GM2LM(input_ids + i * length_input_ids, + (void*)&(input_ids_now), + sizeof(int64_t)); + } + LM2GM((void*)&(input_ids_now), + pre_ids_all + i * length + step_idx_now, + sizeof(int64_t)); + } + } +} + +} // namespace plugin +} // namespace xpu2 diff --git a/csrc/xpu/src/plugin/src/kernel/kunlun2cpp/update_inputs.xpu b/csrc/xpu/src/plugin/src/kernel/kunlun2cpp/update_inputs.xpu new file mode 100644 index 000000000000..3bfcf0121f00 --- /dev/null +++ b/csrc/xpu/src/plugin/src/kernel/kunlun2cpp/update_inputs.xpu @@ -0,0 +1,75 @@ +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" + +namespace xpu2 { +namespace plugin { + +__global__ void update_inputs(bool *not_need_stop, + int *seq_lens_this_time, + int *seq_lens_encoder, + int *seq_lens_decoder, + int64_t *input_ids, + const int64_t *stop_nums, + const bool *stop_flags, + const bool *is_block_step, + const int64_t *next_tokens, + const int bsz, + const int max_bsz, + const int input_ids_stride) { + int cid = core_id(); + int ncores = core_num(); + int clusterid = cluster_id(); + int nclusters = cluster_num(); + int tid = clusterid * ncores + cid; + if (clusterid != 0) return; + + // assert bsz <= 1024 + const int max_bs = 1024; + __shared__ bool stop_flags_sm[max_bs]; + __shared__ int stop_flags_int_sm[max_bs]; + if (cid == 0) { + GM2SM(stop_flags, stop_flags_sm, sizeof(bool) * bsz); + } + sync_all(); + for (int i = cid; i < bsz; i += ncores) { + int seq_len_encoder; + int seq_len_decoder; + bool is_block_step_lm; + GM2LM_ASYNC(seq_lens_encoder + i, &seq_len_encoder, sizeof(int)); + GM2LM_ASYNC(seq_lens_decoder + i, &seq_len_decoder, sizeof(int)); + GM2LM_ASYNC(is_block_step + i, &is_block_step_lm, sizeof(bool)); + mfence(); + + bool stop_flag_now = stop_flags_sm[i]; + stop_flags_int_sm[i] = is_block_step_lm ? 0 : stop_flag_now; + int seq_len_decoder_update = + stop_flag_now + ? 0 + : (seq_len_decoder == 0 ? seq_len_encoder : seq_len_decoder + 1); + LM2GM_ASYNC(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int)); + int seq_len_this_time_update = !stop_flag_now; + LM2GM_ASYNC(&seq_len_this_time_update, seq_lens_this_time + i, sizeof(int)); + int seq_len_encoder_update = 0; + LM2GM(&seq_len_encoder_update, seq_lens_encoder + i, sizeof(int)); + int64_t input_ids_update; + GM2LM(next_tokens + i, &input_ids_update, sizeof(int64_t)); + LM2GM(&input_ids_update, input_ids + i * input_ids_stride, sizeof(int64_t)); + } + sync_cluster(); + + int stop_sum = 0; + if (cid == 0) { + for (int i = 0; i < bsz; i++) { + stop_sum += stop_flags_int_sm[i]; + } + stop_sum += (max_bsz - bsz); + int64_t stop_num; + GM2LM(stop_nums, &stop_num, sizeof(int64_t)); + bool not_need_stop_update = stop_sum < static_cast(stop_num); + LM2GM(¬_need_stop_update, not_need_stop, sizeof(bool)); + } +} + +} // namespace plugin +} // namespace xpu2 diff --git a/csrc/xpu/src/plugin/src/kernel/kunlun2cpp/update_repeat_times.xpu b/csrc/xpu/src/plugin/src/kernel/kunlun2cpp/update_repeat_times.xpu new file mode 100644 index 000000000000..09f859a63b42 --- /dev/null +++ b/csrc/xpu/src/plugin/src/kernel/kunlun2cpp/update_repeat_times.xpu @@ -0,0 +1,75 @@ +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" + +namespace xpu2 { +namespace plugin { + +static __device__ void atomic_add(_shared_ptr_ int *ptr, int v) { + bool fail = true; + while (fail) { + int a; + __asm__ __volatile__("loada.w %0,%1" : "=&r"(a) : "r"(ptr)); + a += v; + __asm__ __volatile__("storea.w %0,%1,%2" : "=&r"(fail) : "r"(a), "r"(ptr)); + } +} + +__global__ void update_repeat_times(const int64_t *pre_ids, + const int64_t *cur_len, + int *repeat_times, + const int64_t bs, + const int64_t length, + const int64_t length_id) { + int cid = core_id(); + int ncores = core_num(); + int clusterid = cluster_id(); + int nclusters = cluster_num(); + int tid = clusterid * ncores + cid; + + const int max_sm_len = 256 * 1024 / sizeof(int); + __shared__ int repeated_times_sm[max_sm_len]; + int64_t pre_id_lm; + int n_length = (length + max_sm_len - 1) / max_sm_len; + + // assert bs <= 640 + const int max_bs = 640; + int64_t cur_len_lm[max_bs]; + GM2LM(cur_len, cur_len_lm, bs * sizeof(int64_t)); + + for (int nli = 0; nli < n_length; nli++) { + int step = nli * max_sm_len; + int cur_length = min(max_sm_len, length - step); + for (int64_t bi = clusterid; bi < bs; bi += nclusters) { + if (cur_len_lm[bi] < 0) { + continue; + } + if (cid == 0) { + GM2SM(repeat_times + bi * length + step, + repeated_times_sm, + sizeof(int) * cur_length); + } + sync_cluster(); + for (int i = cid; i < length_id; i += ncores) { + GM2LM(pre_ids + bi * length_id + i, &pre_id_lm, sizeof(int64_t)); + if (pre_id_lm < 0) { + break; + } + if (pre_id_lm >= step && pre_id_lm < step + cur_length) { + atomic_add(repeated_times_sm + pre_id_lm - step, 1); + mfence(); + } + } + sync_cluster(); + if (cid == 0) { + SM2GM(repeated_times_sm, + repeat_times + bi * length + step, + sizeof(int) * cur_length); + } + sync_cluster(); + } + } +} + +} // namespace plugin +} // namespace xpu2 diff --git a/csrc/xpu/src/plugin/src/kernel/kunlun2cpp/update_value_by_repeat_times.xpu b/csrc/xpu/src/plugin/src/kernel/kunlun2cpp/update_value_by_repeat_times.xpu new file mode 100644 index 000000000000..969efdcf93d1 --- /dev/null +++ b/csrc/xpu/src/plugin/src/kernel/kunlun2cpp/update_value_by_repeat_times.xpu @@ -0,0 +1,213 @@ +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_debug.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" + +namespace xpu2 { +namespace plugin { + +__device__ void do_cast(const int* xlm, float* ylm, int64_t len) { + for (int64_t i = 0; i < len; i += 8) { + ylm[i] = static_cast(xlm[i]); + ylm[i + 1] = static_cast(xlm[i + 1]); + ylm[i + 2] = static_cast(xlm[i + 2]); + ylm[i + 3] = static_cast(xlm[i + 3]); + ylm[i + 4] = static_cast(xlm[i + 4]); + ylm[i + 5] = static_cast(xlm[i + 5]); + ylm[i + 6] = static_cast(xlm[i + 6]); + ylm[i + 7] = static_cast(xlm[i + 7]); + } + mfence_lm(); +} + +template +__global__ void update_value_by_repeat_times(const int *repeat_times, + const T *penalty_scores, + const T *frequency_score, + const T *presence_score, + const float *temperatures, + T *logits, + const int64_t bs, + const int64_t length) { + int ncores = core_num(); + int cid = core_id(); + int thread_id = cid * cluster_num() + cluster_id(); + int nthreads = cluster_num() * ncores; + int start = -1; + int end = -1; + partition( + thread_id, nthreads, static_cast(bs * length), 1, &start, &end); + + int bs_start = start / length; + int bs_end = end / length; + const int param_len = 256; + // ncores = 64 for xpu2 + __shared__ __simd__ float alpha_buf[param_len * 64]; + __shared__ __simd__ float beta_buf[param_len * 64]; + __shared__ __simd__ float gamma_buf[param_len * 64]; + __shared__ __simd__ float temperatures_buf[param_len * 64]; + _shared_ptr_ float *alpha_sm = alpha_buf + cid * param_len; + _shared_ptr_ float *beta_sm = beta_buf + cid * param_len; + _shared_ptr_ float *gamma_sm = gamma_buf + cid * param_len; + _shared_ptr_ float *temperatures_sm = temperatures_buf + cid * param_len; + int read_param_len = bs_end - bs_start + 1; + GM2SM_ASYNC(penalty_scores + bs_start, alpha_sm, read_param_len * sizeof(T)); + GM2SM_ASYNC(frequency_score + bs_start, beta_sm, read_param_len * sizeof(T)); + GM2SM_ASYNC(presence_score + bs_start, gamma_sm, read_param_len * sizeof(T)); + GM2SM(temperatures + bs_start, temperatures_sm, read_param_len * sizeof(T)); + primitive_cast_sm( + (const _shared_ptr_ T *)(alpha_sm), alpha_sm, read_param_len); + primitive_cast_sm( + (const _shared_ptr_ T *)(beta_sm), beta_sm, read_param_len); + primitive_cast_sm( + (const _shared_ptr_ T *)(gamma_sm), gamma_sm, read_param_len); + + float logit_now; + float alpha; + float beta; + float gamma; + float temperature; + int time; + const int buffer_len = 768; + __simd__ float logits_lm[buffer_len]; + int times_lm[buffer_len]; + + for (int i = start; i < end; i += buffer_len) { + int read_len = min(end - i, buffer_len); + GM2LM_ASYNC(logits + i, logits_lm, read_len * sizeof(T)); + GM2LM(repeat_times + i, times_lm, read_len * sizeof(int)); + primitive_cast((const T *)(logits_lm), logits_lm, read_len); + for (int j = 0; j < read_len; j++) { + time = times_lm[j]; + logit_now = logits_lm[j]; + int param_idx = (i + j) / length - bs_start; + temperature = temperatures_sm[param_idx]; + if (time != 0) { + alpha = alpha_sm[param_idx]; + beta = beta_sm[param_idx]; + gamma = gamma_sm[param_idx]; + logit_now = logit_now < 0.0f ? logit_now * alpha : logit_now / alpha; + logit_now = logit_now - time * beta - gamma; + } + logits_lm[j] = logit_now / temperature; + } + primitive_cast(logits_lm, (T *)logits_lm, read_len); + LM2GM(logits_lm, logits + i, read_len * sizeof(T)); + } +} + +#define _XPU_DEF__UPDATE_VALUE_BY_REPEAT_TIMES_(DATA_TYPE) \ + template __global__ void update_value_by_repeat_times( \ + const int *repeat_times, \ + const DATA_TYPE *penalty_scores, \ + const DATA_TYPE *frequency_score, \ + const DATA_TYPE *presence_score, \ + const float *temperatures, \ + DATA_TYPE *logits, \ + const int64_t bs, \ + const int64_t length); +_XPU_DEF__UPDATE_VALUE_BY_REPEAT_TIMES_(float); +_XPU_DEF__UPDATE_VALUE_BY_REPEAT_TIMES_(float16); + +template +__global__ void update_value_by_repeat_times_simd( + const int *repeat_times, // [bs * length] + const T *penalty_scores, // [bs] + const T *frequency_score, // [bs] + const T *presence_score, // [bs] + const float *temperatures, // [bs] + T *logits, // [bs * length] + const int64_t bs, + const int64_t length) { + int ncores = core_num(); + int cid = core_id(); + int thread_id = cid * cluster_num() + cluster_id(); + int nthreads = cluster_num() * ncores; + int start = -1; + int end = -1; + partition(thread_id, nthreads, static_cast(bs * length), 16, &start, &end); + + const int param_len = 256; + // ncores = 64 for xpu3 + __shared__ __simd__ float alpha_buf[param_len * 64]; + __shared__ __simd__ float beta_buf[param_len * 64]; + __shared__ __simd__ float gamma_buf[param_len * 64]; + __shared__ __simd__ float temperatures_buf[param_len * 64]; + if (cid == 0) { + GM2SM_ASYNC(penalty_scores, alpha_buf, bs * sizeof(T)); + GM2SM_ASYNC(frequency_score, beta_buf, bs * sizeof(T)); + GM2SM_ASYNC(presence_score, gamma_buf, bs * sizeof(T)); + GM2SM(temperatures, temperatures_buf, bs * sizeof(float)); + primitive_cast_sm( + (const _shared_ptr_ T *)(alpha_buf), alpha_buf, bs); + primitive_cast_sm( + (const _shared_ptr_ T *)(beta_buf), beta_buf, bs); + primitive_cast_sm( + (const _shared_ptr_ T *)(gamma_buf), gamma_buf, bs); + } + mfence(); + sync_all(); + + float logit_now; + float alpha; + float beta; + float gamma; + float temperature; + int time; + const int buffer_len = 768; + __simd__ float logits_lm[buffer_len]; + __simd__ float times_lm[buffer_len]; + + float32x16_t logits_; + float32x16_t logits_tmp_0; + float32x16_t logits_tmp_1; + float32x16_t time_; + + for (int i = start; i < end; i += buffer_len) { + int read_len = min(end - i, buffer_len); + GM2LM_ASYNC(logits + i, logits_lm, read_len * sizeof(T)); + GM2LM(repeat_times + i, times_lm, read_len * sizeof(int)); + primitive_cast((const T *)(logits_lm), logits_lm, read_len); + // no impl for primitive_cast, so we do it by ourself + do_cast((const int *)(times_lm), times_lm, read_len); + int time_mask = 0; + int logit_mask = 0; + for (int j = 0; j < read_len; j += 16) { + time_ = vload_lm_float32x16(times_lm + j); + logits_ = vload_lm_float32x16(logits_lm + j); + int param_idx = (i + j) / length; + temperature = temperatures_buf[param_idx]; + alpha = alpha_buf[param_idx]; + beta = beta_buf[param_idx]; + gamma = gamma_buf[param_idx]; + time_mask = svneq_float32x16(0.f, time_); // time != 0 mask + logit_mask = svle_float32x16(0.f, logits_); // logit >= 0 mask + time_ = svmul_float32x16(beta, time_); // time * beta + time_ = svadd_float32x16(gamma, time_); // time * beta + gamma + logits_ = svmul_float32x16_mh(alpha, logits_, logits_, (time_mask & ~logit_mask)); // when time != 0 && logit < 0, do alpha * logit + logits_ = svmul_float32x16_mh(1.0f / alpha, logits_, logits_, (time_mask & logit_mask)); // when time != 0 && >=0, do logit / alpha + logits_ = vvsub_float32x16_mh(logits_, time_, logits_, time_mask); // when time != 0, do logit = logit - time * beta - gamma; + logits_ = svmul_float32x16(1.0f / temperature, logits_); // logit / temperature + vstore_lm_float32x16(logits_lm + j, logits_); + } + mfence_lm(); + primitive_cast(logits_lm, (T *)logits_lm, read_len); + LM2GM(logits_lm, logits + i, read_len * sizeof(T)); + } +} + +#define _XPU_DEF__UPDATE_VALUE_BY_REPEAT_TIMES_SIMD(DATA_TYPE) \ + template __global__ void update_value_by_repeat_times_simd( \ + const int *repeat_times, \ + const DATA_TYPE *penalty_scores, \ + const DATA_TYPE *frequency_score, \ + const DATA_TYPE *presence_score, \ + const float *temperatures, \ + DATA_TYPE *logits, \ + const int64_t bs, \ + const int64_t length); +_XPU_DEF__UPDATE_VALUE_BY_REPEAT_TIMES_SIMD(float); +_XPU_DEF__UPDATE_VALUE_BY_REPEAT_TIMES_SIMD(float16); + +} // namespace plugin +} // namespace xpu2 diff --git a/csrc/xpu/src/plugin/src/linker.specs b/csrc/xpu/src/plugin/src/linker.specs new file mode 100644 index 000000000000..55f6f7837074 --- /dev/null +++ b/csrc/xpu/src/plugin/src/linker.specs @@ -0,0 +1,6 @@ +# overwrite incorrect rpath arguments +# its original value is: +# -rpath $ORIGIN:$ORIGIN/lib:$ORIGIN/lib64:$ORIGIN/../lib:$ORIGIN/../lib64:/opt/compiler/gcc-4.8.2/lib:/opt/compiler/gcc-4.8.2/lib64 +# specify your own rpath if needed. +*linker: +collect2 -rpath $ORIGIN diff --git a/csrc/xpu/src/plugin/src/wrapper/get_padding_offset.cpp b/csrc/xpu/src/plugin/src/wrapper/get_padding_offset.cpp new file mode 100644 index 000000000000..344d341396da --- /dev/null +++ b/csrc/xpu/src/plugin/src/wrapper/get_padding_offset.cpp @@ -0,0 +1,189 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace xpu2 { +namespace plugin { + +__attribute__((global)) void get_padding_offset(int *padding_offset, + int *cum_offsets_out, + int *cu_seqlens_q, + int *cu_seqlens_k, + const int *cum_offsets, + const int *seq_lens, + const int max_seq_len, + const int bs); +__attribute__((global)) void remove_padding(int64_t *x_remove_padding, + const int64_t *input_data, + const int *seq_lens, + const int *cum_offsets, + const int sequence_length, + const int bs); + +} // namespace plugin +} // namespace xpu2 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +static int get_padding_offset_cpu(int *padding_offset, + int *cum_offsets_out, + int *cu_seqlens_q, + int *cu_seqlens_k, + const int *cum_offsets, + const int *seq_lens, + const int max_seq_len, + const int bs) { + for (int i = 0; i < bs; i++) { + int cum_offset = i == 0 ? 0 : cum_offsets[i - 1]; + for (int j = 0; j < seq_lens[i]; j++) { + padding_offset[i * max_seq_len - cum_offset + j] = cum_offset; + } + cum_offsets_out[i] = cum_offset; + int cum_seq_len = (i + 1) * max_seq_len - cum_offsets[i]; + cu_seqlens_q[i + 1] = cum_seq_len; + cu_seqlens_k[i + 1] = cum_seq_len; + } + return api::SUCCESS; +} + +static int remove_padding_cpu(int64_t *x_remove_padding, + const int64_t *input_data, + const int *seq_lens, + const int *cum_offsets, + const int sequence_length, + const int bs) { + for (int i = 0; i < bs; i++) { + for (int j = 0; j < seq_lens[i]; j++) { + const int tgt_seq_id = i * sequence_length - cum_offsets[i] + j; + const int src_seq_id = i * sequence_length + j; + x_remove_padding[tgt_seq_id] = input_data[src_seq_id]; + } + } + return api::SUCCESS; +} + +static int cpu_wrapper(Context *ctx, + int *padding_offset, + int *cum_offsets_out, + int *cu_seqlens_q, + int *cu_seqlens_k, + int64_t *x_remove_padding, + const int64_t *input_ids, + const int *cum_offsets, + const int *seq_lens, + const int max_seq_len, + const int bs) { + get_padding_offset_cpu(padding_offset, + cum_offsets_out, + cu_seqlens_q, + cu_seqlens_k, + cum_offsets, + seq_lens, + max_seq_len, + bs); + remove_padding_cpu( + x_remove_padding, input_ids, seq_lens, cum_offsets_out, max_seq_len, bs); + return api::SUCCESS; +} + +static int xpu2or3_wrapper(Context *ctx, + int *padding_offset, + int *cum_offsets_out, + int *cu_seqlens_q, + int *cu_seqlens_k, + int64_t *x_remove_padding, + const int64_t *input_ids, + const int *cum_offsets, + const int *seq_lens, + const int max_seq_len, + const int bs) { + using XPU_INT64 = typename XPUIndexType::type; + auto get_padding_offset = xpu2::plugin::get_padding_offset; + auto remove_padding = xpu2::plugin::remove_padding; + get_padding_offset<<ncluster(), 64, ctx->xpu_stream>>>(padding_offset, + cum_offsets_out, + cu_seqlens_q, + cu_seqlens_k, + cum_offsets, + seq_lens, + max_seq_len, + bs); + remove_padding<<ncluster(), 64, ctx->xpu_stream>>>( + reinterpret_cast(x_remove_padding), + reinterpret_cast(input_ids), + seq_lens, + cum_offsets_out, + max_seq_len, + bs); + return api::SUCCESS; +} + +int get_padding_offset(Context *ctx, + int *padding_offset, + int *cum_offsets_out, + int *cu_seqlens_q, + int *cu_seqlens_k, + int64_t *x_remove_padding, + const int64_t *input_ids, + const int *cum_offsets, + const int *seq_lens, + const int max_seq_len, + const int bs) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "get_padding_offset", int); + WRAPPER_DUMP_PARAM4( + ctx, padding_offset, cum_offsets_out, cu_seqlens_q, cu_seqlens_k); + WRAPPER_DUMP_PARAM4(ctx, x_remove_padding, input_ids, cum_offsets, seq_lens); + WRAPPER_DUMP_PARAM2(ctx, max_seq_len, bs); + WRAPPER_DUMP(ctx); + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + padding_offset, + cum_offsets_out, + cu_seqlens_q, + cu_seqlens_k, + x_remove_padding, + input_ids, + cum_offsets, + seq_lens, + max_seq_len, + bs); + } + if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) { + return xpu2or3_wrapper(ctx, + padding_offset, + cum_offsets_out, + cu_seqlens_q, + cu_seqlens_k, + x_remove_padding, + input_ids, + cum_offsets, + seq_lens, + max_seq_len, + bs); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/csrc/xpu/src/plugin/src/wrapper/nn_set_stop_value_multi_ends.cpp b/csrc/xpu/src/plugin/src/wrapper/nn_set_stop_value_multi_ends.cpp new file mode 100644 index 000000000000..b1a1aa3342fb --- /dev/null +++ b/csrc/xpu/src/plugin/src/wrapper/nn_set_stop_value_multi_ends.cpp @@ -0,0 +1,160 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace xpu2 { +namespace plugin { +template +__attribute__((global)) void set_stop_value_multi_ends(bool* stop_flags, + T* topk_ids, + T* next_tokens, + const T* end_ids, + const int* seq_lens, + const int bs, + const int end_length, + const bool beam_search); +} // namespace plugin +} // namespace xpu2 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +template +__inline__ bool is_in_end(const T id, const T* end_ids, int length) { + for (int i = 0; i < length; i++) { + if (id == end_ids[i]) { + return true; + } + } + return false; +} + +template +static int cpu_wrapper(Context* ctx, + bool* stop_flags, + T* topk_ids, + T* next_tokens, + const T* end_ids, + const int* seq_lens, + const int bs, + const int end_length, + const bool beam_search) { + for (int i = 0; i < bs; i++) { + if (stop_flags[i]) { + if (seq_lens[i] == 0) { + topk_ids[i] = -1; + } else { + topk_ids[i] = end_ids[0]; + next_tokens[i] = end_ids[0]; + } + } else { + next_tokens[i] = topk_ids[i]; + } + if (!beam_search && is_in_end(topk_ids[i], end_ids, end_length)) { + stop_flags[i] = true; + } + } + return api::SUCCESS; +} + +template +static int xpu2or3_wrapper(Context* ctx, + bool* stop_flags, + T* topk_ids, + T* next_tokens, + const T* end_ids, + const int* seq_lens, + const int bs, + const int end_length, + const bool beam_search) { + using XPU_TID = typename XPUIndexType::type; + auto set_stop_value_multi_ends = xpu2::plugin::set_stop_value_multi_ends; + set_stop_value_multi_ends<<ncluster(), 64, ctx->xpu_stream>>>( + stop_flags, + reinterpret_cast(topk_ids), + reinterpret_cast(next_tokens), + reinterpret_cast(end_ids), + seq_lens, + bs, + end_length, + beam_search); + return api::SUCCESS; +} + +template +int set_stop_value_multi_ends(Context* ctx, + bool* stop_flags, + T* topk_ids, + T* next_tokens, + const T* end_ids, + const int* seq_lens, + const int bs, + const int end_length, + const bool beam_search) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "set_stop_value_multi_ends", T); + WRAPPER_DUMP_PARAM5( + ctx, stop_flags, topk_ids, next_tokens, end_ids, seq_lens); + WRAPPER_DUMP_PARAM3(ctx, bs, end_length, beam_search); + WRAPPER_DUMP(ctx); + WRAPPER_CHECK_PTR(ctx, bool, bs, stop_flags); + WRAPPER_CHECK_PTR(ctx, T, bs, topk_ids); + WRAPPER_CHECK_PTR(ctx, T, end_length, end_ids); + WRAPPER_CHECK_PTR(ctx, T, bs, seq_lens); + WRAPPER_ASSERT_LE(ctx, end_length, 1024); // assume end_length <= 1024 + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + stop_flags, + topk_ids, + next_tokens, + end_ids, + seq_lens, + bs, + end_length, + beam_search); + } + if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) { + return xpu2or3_wrapper(ctx, + stop_flags, + topk_ids, + next_tokens, + end_ids, + seq_lens, + bs, + end_length, + beam_search); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +template int set_stop_value_multi_ends(Context* ctx, + bool* stop_flags, + int64_t* topk_ids, + int64_t* next_tokens, + const int64_t* end_ids, + const int* seq_lens, + const int bs, + const int end_length, + const bool beam_search); +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/csrc/xpu/src/plugin/src/wrapper/nn_set_value_by_flags_and_idx.cpp b/csrc/xpu/src/plugin/src/wrapper/nn_set_value_by_flags_and_idx.cpp new file mode 100644 index 000000000000..e0d402182efd --- /dev/null +++ b/csrc/xpu/src/plugin/src/wrapper/nn_set_value_by_flags_and_idx.cpp @@ -0,0 +1,199 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace xpu2 { +namespace plugin { + +__attribute__((global)) void set_value_by_flags_and_idx( + const bool* stop_flags, + int64_t* pre_ids_all, + const int64_t* input_ids, + const int* seq_lens_encoder, + const int* seq_lens_decoder, + const int64_t* step_idx, + int bs, + int length, + int length_input_ids); + +} // namespace plugin +} // namespace xpu2 + +namespace xpu3 { +namespace plugin { + +__attribute__((global)) void set_value_by_flags_and_idx( + const bool* stop_flags, + int64_t* pre_ids_all, + const int64_t* input_ids, + const int* seq_lens_encoder, + const int* seq_lens_decoder, + const int64_t* step_idx, + int bs, + int length, + int length_input_ids); + +} // namespace plugin +} // namespace xpu3 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +static int cpu_wrapper(Context* ctx, + const bool* stop_flags, + int64_t* pre_ids_all, + const int64_t* pre_ids, + const int64_t* step_idx, + const int bs, + const int length) { + for (int i = 0; i < bs; i++) { + int64_t* pre_ids_all_now = pre_ids_all + i * length; + if (!stop_flags[i] && step_idx[i] >= 0) { + pre_ids_all_now[step_idx[i]] = pre_ids[i]; + } + } + return api::SUCCESS; +} + +static int cpu_wrapper(Context* ctx, + const bool* stop_flags, + int64_t* pre_ids_all, + const int64_t* input_ids, + const int* seq_lens_encoder, + const int* seq_lens_decoder, + const int64_t* step_idx, + int bs, + int length, + int length_input_ids) { + for (int i = 0; i < bs; i++) { + if (!stop_flags[i]) { + int64_t* pre_ids_all_now = pre_ids_all + i * length; + const int64_t* input_ids_now = input_ids + i * length_input_ids; + const int seq_len_dec = seq_lens_decoder[i]; + const int seq_len_enc = seq_lens_encoder[i]; + if (seq_len_dec == 0 && seq_len_enc == 0) continue; + if (step_idx[i] >= 0) { + if (seq_len_dec == + 0) { // encoder, get last token accord to seq_lens_encoder + pre_ids_all_now[step_idx[i]] = input_ids_now[seq_len_enc - 1]; + } else { // decoder, get first token + pre_ids_all_now[step_idx[i]] = input_ids_now[0]; + } + } + } + } + return api::SUCCESS; +} + +static int xpu2or3_wrapper(Context* ctx, + const bool* stop_flags, + int64_t* pre_ids_all, + const int64_t* input_ids, + const int* seq_lens_encoder, + const int* seq_lens_decoder, + const int64_t* step_idx, + int bs, + int length, + int length_input_ids) { + using XPU_INT64 = typename XPUIndexType::type; + auto set_value_by_flags_and_idx_kernel = xpu2::plugin::set_value_by_flags_and_idx; + set_value_by_flags_and_idx_kernel<<ncluster(), 64, ctx->xpu_stream>>>( + stop_flags, + reinterpret_cast(pre_ids_all), + reinterpret_cast(input_ids), + seq_lens_encoder, + seq_lens_decoder, + reinterpret_cast(step_idx), + bs, + length, + length_input_ids); + return api::SUCCESS; +} + +int set_value_by_flags_and_idx(Context* ctx, + const bool* stop_flags, + int64_t* pre_ids_all, + const int64_t* input_ids, + const int* seq_lens_encoder, + const int* seq_lens_decoder, + const int64_t* step_idx, + int bs, + int length, + int length_input_ids) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "set_value_by_flags_and_idx", int64_t); + WRAPPER_DUMP_PARAM6(ctx, + stop_flags, + pre_ids_all, + input_ids, + seq_lens_encoder, + seq_lens_decoder, + step_idx); + WRAPPER_DUMP_PARAM3(ctx, bs, length, length_input_ids); + WRAPPER_DUMP(ctx); + int64_t stop_flags_len = -1; + int64_t pre_ids_all_len = -1; + int64_t input_ids_len = -1; + int64_t seq_lens_encoder_len = -1; + int64_t seq_lens_decoder_len = -1; + int64_t step_idx_len = -1; + WRAPPER_CHECK_SHAPE(ctx, &stop_flags_len, {bs}); + WRAPPER_CHECK_SHAPE(ctx, &pre_ids_all_len, {bs, length}); + WRAPPER_CHECK_SHAPE(ctx, &input_ids_len, {bs, length_input_ids}); + WRAPPER_CHECK_SHAPE(ctx, &seq_lens_encoder_len, {bs}); + WRAPPER_CHECK_SHAPE(ctx, &seq_lens_decoder_len, {bs}); + WRAPPER_CHECK_SHAPE(ctx, &step_idx_len, {bs}); + WRAPPER_CHECK_PTR(ctx, int64_t, stop_flags_len, stop_flags); + WRAPPER_CHECK_PTR(ctx, int64_t, pre_ids_all_len, pre_ids_all); + WRAPPER_CHECK_PTR(ctx, int64_t, input_ids_len, input_ids); + WRAPPER_CHECK_PTR(ctx, int, seq_lens_encoder_len, seq_lens_encoder); + WRAPPER_CHECK_PTR(ctx, int, seq_lens_decoder_len, seq_lens_decoder); + WRAPPER_CHECK_PTR(ctx, int64_t, step_idx_len, step_idx); + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + stop_flags, + pre_ids_all, + input_ids, + seq_lens_encoder, + seq_lens_decoder, + step_idx, + bs, + length, + length_input_ids); + } + if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) { + return xpu2or3_wrapper(ctx, + stop_flags, + pre_ids_all, + input_ids, + seq_lens_encoder, + seq_lens_decoder, + step_idx, + bs, + length, + length_input_ids); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/csrc/xpu/src/plugin/src/wrapper/nn_token_penalty_multi_scores.cpp b/csrc/xpu/src/plugin/src/wrapper/nn_token_penalty_multi_scores.cpp new file mode 100644 index 000000000000..f411e1b795e6 --- /dev/null +++ b/csrc/xpu/src/plugin/src/wrapper/nn_token_penalty_multi_scores.cpp @@ -0,0 +1,424 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace xpu2 { +namespace plugin { + +template +__attribute__((global)) void min_length_logits_process( + T* logits, + const int64_t* cur_len, + const int64_t* min_len, + const int64_t* eos_token_id, + const int64_t bs, + const int64_t length, + const int64_t length_id, + const int64_t end_length); +__attribute__((global)) void update_repeat_times(const int64_t* pre_ids, + const int64_t* cur_len, + int* repeat_times, + const int64_t bs, + const int64_t length, + const int64_t length_id); +template +__attribute__((global)) void update_value_by_repeat_times( + const int* repeat_times, + const T* penalty_scores, + const T* frequency_score, + const T* presence_score, + const float* temperatures, + T* logits, + const int64_t bs, + const int64_t length); +template +__attribute__((global)) void update_value_by_repeat_times_simd( + const int* repeat_times, + const T* penalty_scores, + const T* frequency_score, + const T* presence_score, + const float* temperatures, + T* logits, + const int64_t bs, + const int64_t length); +template +__attribute__((global)) void ban_bad_words(T* logits, + const int64_t* bad_words_list, + const int64_t bs, + const int64_t length, + const int64_t bad_words_length); + +} // namespace plugin +} // namespace xpu2 + +namespace xpu3 { +namespace plugin { + +template +__attribute__((global)) void min_length_logits_process( + T* logits, + const int64_t* cur_len, + const int64_t* min_len, + const int64_t* eos_token_id, + const int64_t bs, + const int64_t length, + const int64_t length_id, + const int64_t end_length); +__attribute__((global)) void update_repeat_times(const int64_t* pre_ids, + const int64_t* cur_len, + int* repeat_times, + const int64_t bs, + const int64_t length, + const int64_t length_id); +template +__attribute__((global)) void update_value_by_repeat_times( + const int* repeat_times, + const T* penalty_scores, + const T* frequency_score, + const T* presence_score, + const float* temperatures, + T* logits, + const int64_t bs, + const int64_t length); +template +__attribute__((global)) void update_value_by_repeat_times_simd( + const int* repeat_times, + const T* penalty_scores, + const T* frequency_score, + const T* presence_score, + const float* temperatures, + T* logits, + const int64_t bs, + const int64_t length); +template +__attribute__((global)) void ban_bad_words(T* logits, + const int64_t* bad_words_list, + const int64_t bs, + const int64_t length, + const int64_t bad_words_length); + +} // namespace plugin +} // namespace xpu3 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +void update_repeat_times_cpu(const int64_t* pre_ids, + const int64_t* cur_len, + int* repeat_times, + const int64_t bs, + const int64_t length, + const int64_t length_id) { + for (int64_t i = 0; i < bs; i++) { + if (cur_len[i] >= 0) { + for (int64_t j = 0; j < length_id; j++) { + int64_t id = pre_ids[i * length_id + j]; + if (id < 0 || id >= length) continue; + repeat_times[i * length + id] += 1; + } + } + } +} + +void ban_bad_words_cpu(float* logits, + const int64_t* bad_words_list, + const int64_t bs, + const int64_t length, + const int64_t bad_words_length) { + for (int64_t i = 0; i < bs; i++) { + float* logits_now = logits + i * length; + for (int64_t j = 0; j < bad_words_length; j++) { + int64_t bad_words_token_id = bad_words_list[j]; + if (bad_words_token_id >= length || bad_words_token_id < 0) continue; + logits_now[bad_words_token_id] = -1e10; + } + } +} + +template +static int cpu_wrapper(Context* ctx, + const int64_t* pre_ids, + T* logits, + const T* penalty_scores, + const T* frequency_scores, + const T* presence_scores, + const float* temperatures, + const int64_t* cur_len, + const int64_t* min_len, + const int64_t* eos_token_id, + const int64_t* bad_words, + const int64_t bs, + const int64_t length, + const int64_t length_id, + const int64_t end_length, + const int64_t length_bad_words) { + std::vector logitsfp32(bs * length); + std::vector penalty_scoresfp32(bs); + std::vector frequency_scoresfp32(bs); + std::vector presence_scoresfp32(bs); + std::vector repeat_times_buffer(bs * length, 0); + int ret = api::cast(ctx, logits, logitsfp32.data(), bs * length); + WRAPPER_ASSERT_SUCCESS(ctx, ret); + ret = api::cast(ctx, penalty_scores, penalty_scoresfp32.data(), bs); + WRAPPER_ASSERT_SUCCESS(ctx, ret); + ret = api::cast( + ctx, frequency_scores, frequency_scoresfp32.data(), bs); + WRAPPER_ASSERT_SUCCESS(ctx, ret); + ret = + api::cast(ctx, presence_scores, presence_scoresfp32.data(), bs); + WRAPPER_ASSERT_SUCCESS(ctx, ret); + for (int64_t i = 0; i < bs; i++) { + if (cur_len[i] >= 0 && cur_len[i] < min_len[i]) { + for (int64_t j = 0; j < end_length; j++) { + logitsfp32[i * length + eos_token_id[j]] = -1e4; + } + } + } + int* repeat_times = &(repeat_times_buffer[0]); + update_repeat_times_cpu( + pre_ids, cur_len, repeat_times, bs, length, length_id); + for (int64_t i = 0; i < bs; i++) { + float alpha = penalty_scoresfp32[i]; + float beta = frequency_scoresfp32[i]; + float gamma = presence_scoresfp32[i]; + float temperature = temperatures[i]; + for (int64_t j = 0; j < length; j++) { + int times = repeat_times[i * length + j]; + float logit_now = logitsfp32[i * length + j]; + if (times != 0) { + logit_now = logit_now < 0 ? logit_now * alpha : logit_now / alpha; + logit_now = logit_now - times * beta - gamma; + } + logitsfp32[i * length + j] = logit_now / temperature; + } + } + if (bad_words && length_bad_words > 0) { + ban_bad_words_cpu( + logitsfp32.data(), bad_words, bs, length, length_bad_words); + } + ret = api::cast(ctx, logitsfp32.data(), logits, bs * length); + return ret; +} + +template +static int xpu2or3_wrapper(Context* ctx, + const int64_t* pre_ids, + T* logits, + const T* penalty_scores, + const T* frequency_scores, + const T* presence_scores, + const float* temperatures, + const int64_t* cur_len, + const int64_t* min_len, + const int64_t* eos_token_id, + const int64_t* bad_words, + const int64_t bs, + const int64_t length, + const int64_t length_id, + const int64_t end_length, + const int64_t length_bad_words) { + api::ctx_guard RAII_GUARD(ctx); + using XPU_INT64 = typename XPUIndexType::type; + auto min_length_logits_process_kernel = xpu2::plugin::min_length_logits_process; + auto update_repeat_times_kernel = xpu2::plugin::update_repeat_times; + auto update_value_by_repeat_times_kernel = xpu2::plugin::update_value_by_repeat_times; + if(length % 16 == 0) { + update_value_by_repeat_times_kernel = xpu2::plugin::update_value_by_repeat_times_simd; + } + auto ban_bad_words_kernel = xpu2::plugin::ban_bad_words; + + int* repeat_times = RAII_GUARD.alloc_l3_or_gm(bs * length); + WRAPPER_ASSERT_WORKSPACE(ctx, repeat_times); + int ret = api::constant(ctx, repeat_times, bs * length, 0); + WRAPPER_ASSERT_SUCCESS(ctx, ret); + + update_repeat_times_kernel<<ncluster(), 64, ctx->xpu_stream>>>( + reinterpret_cast(pre_ids), + reinterpret_cast(cur_len), + repeat_times, + bs, + length, + length_id); + min_length_logits_process_kernel<<ncluster(), 64, ctx->xpu_stream>>>( + logits, + reinterpret_cast(cur_len), + reinterpret_cast(min_len), + reinterpret_cast(eos_token_id), + bs, + length, + length_id, + end_length); + update_value_by_repeat_times_kernel<<ncluster(), 64, ctx->xpu_stream>>>( + repeat_times, + penalty_scores, + frequency_scores, + presence_scores, + temperatures, + logits, + bs, + length); + + if (bad_words && length_bad_words > 0) { + ban_bad_words_kernel<<ncluster(), 64, ctx->xpu_stream>>>( + logits, + reinterpret_cast(bad_words), + bs, + length, + length_bad_words); + } + return api::SUCCESS; +} + +template +int token_penalty_multi_scores(Context* ctx, + const int64_t* pre_ids, + T* logits, + const T* penalty_scores, + const T* frequency_scores, + const T* presence_scores, + const float* temperatures, + const int64_t* cur_len, + const int64_t* min_len, + const int64_t* eos_token_id, + const int64_t* bad_words, + const int64_t bs, + const int64_t length, + const int64_t length_id, + const int64_t end_length, + const int64_t length_bad_words) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "token_penalty_multi_scores", T); + WRAPPER_DUMP_PARAM6(ctx, + pre_ids, + logits, + penalty_scores, + frequency_scores, + presence_scores, + temperatures); + WRAPPER_DUMP_PARAM3(ctx, cur_len, min_len, eos_token_id); + WRAPPER_DUMP_PARAM4(ctx, bs, length, length_id, end_length); + WRAPPER_DUMP(ctx); + // TODO(mayang02) shape check + int64_t pre_ids_len = -1; + int64_t logits_len = -1; + int64_t penalty_scores_len = -1; + int64_t frequency_scores_len = -1; + int64_t presence_scores_len = -1; + int64_t temperatures_len = -1; + int64_t cur_len_len = -1; + int64_t min_len_len = -1; + int64_t eos_token_id_len = -1; + int64_t bad_words_len = -1; + WRAPPER_CHECK_SHAPE(ctx, &pre_ids_len, {bs, length_id}); + WRAPPER_CHECK_SHAPE(ctx, &logits_len, {bs, length}); + WRAPPER_CHECK_SHAPE(ctx, &penalty_scores_len, {bs}); + WRAPPER_CHECK_SHAPE(ctx, &frequency_scores_len, {bs}); + WRAPPER_CHECK_SHAPE(ctx, &presence_scores_len, {bs}); + WRAPPER_CHECK_SHAPE(ctx, &temperatures_len, {bs}); + WRAPPER_CHECK_SHAPE(ctx, &cur_len_len, {bs}); + WRAPPER_CHECK_SHAPE(ctx, &min_len_len, {bs}); + WRAPPER_CHECK_SHAPE(ctx, &eos_token_id_len, {end_length}); + WRAPPER_CHECK_SHAPE(ctx, &bad_words_len, {length_bad_words}); + WRAPPER_CHECK_PTR(ctx, int64_t, pre_ids_len, pre_ids); + WRAPPER_CHECK_PTR(ctx, T, logits_len, logits); + WRAPPER_CHECK_PTR(ctx, T, penalty_scores_len, penalty_scores); + WRAPPER_CHECK_PTR(ctx, T, frequency_scores_len, frequency_scores); + WRAPPER_CHECK_PTR(ctx, T, presence_scores_len, presence_scores); + WRAPPER_CHECK_PTR(ctx, float, temperatures_len, temperatures); + WRAPPER_CHECK_PTR(ctx, int64_t, cur_len_len, cur_len); + WRAPPER_CHECK_PTR(ctx, int64_t, min_len_len, min_len); + WRAPPER_CHECK_PTR(ctx, int64_t, eos_token_id_len, eos_token_id); + WRAPPER_CHECK_PTR(ctx, int64_t, bad_words_len, bad_words); + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + pre_ids, + logits, + penalty_scores, + frequency_scores, + presence_scores, + temperatures, + cur_len, + min_len, + eos_token_id, + bad_words, + bs, + length, + length_id, + end_length, + length_bad_words); + } + if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) { + return xpu2or3_wrapper(ctx, + pre_ids, + logits, + penalty_scores, + frequency_scores, + presence_scores, + temperatures, + cur_len, + min_len, + eos_token_id, + bad_words, + bs, + length, + length_id, + end_length, + length_bad_words); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +template int token_penalty_multi_scores(Context* ctx, + const int64_t* pre_ids, + float* logits, + const float* penalty_scores, + const float* frequency_scores, + const float* presence_scores, + const float* temperatures, + const int64_t* cur_len, + const int64_t* min_len, + const int64_t* eos_token_id, + const int64_t* bad_words, + const int64_t bs, + const int64_t length, + const int64_t length_id, + const int64_t end_length, + const int64_t length_bad_words); +template int token_penalty_multi_scores( + Context* ctx, + const int64_t* pre_ids, + float16* logits, + const float16* penalty_scores, + const float16* frequency_scores, + const float16* presence_scores, + const float* temperatures, + const int64_t* cur_len, + const int64_t* min_len, + const int64_t* eos_token_id, + const int64_t* bad_words, + const int64_t bs, + const int64_t length, + const int64_t length_id, + const int64_t end_length, + const int64_t length_bad_words); +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/csrc/xpu/src/plugin/src/wrapper/rebuild_padding.cpp b/csrc/xpu/src/plugin/src/wrapper/rebuild_padding.cpp new file mode 100644 index 000000000000..c5f4543306a0 --- /dev/null +++ b/csrc/xpu/src/plugin/src/wrapper/rebuild_padding.cpp @@ -0,0 +1,153 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace xpu2 { +namespace plugin { + +template +__attribute__((global)) void rebuild_padding(T *output_data, // [bs, dim_embed] + const T *input_data, // [token_num, dim_embed] + const int *cum_offsets, // [bs] + const int *seq_len_decoder, // [bs] + const int *seq_len_encoder, // [bs] + const int seq_len, + const int dim_embed, + const int elem_nums); + +} // namespace plugin +} // namespace xpu2 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { +template +static int cpu_wrapper(Context *ctx, + T *output_data, // [bs, dim_embed] + const T *input_data, // [token_num, dim_embed] + const int *cum_offsets, // [bs] + const int *seq_len_decoder, // [bs] + const int *seq_len_encoder, // [bs] + const int seq_len, + const int dim_embed, + const int elem_nums) { + for (int i=0;i < elem_nums;i++){ + const int bi = i / dim_embed; + const int bias_idx = i % dim_embed; + int seq_id = 0; + // just encoder or stop, get last token; just decoder, get first token. + if (seq_len_decoder[bi] == 0) { + if (seq_len_encoder[bi] != 0) { + seq_id = seq_len_encoder[bi] - 1; + } else { + continue; + } + } + const int ori_token_idx = bi * seq_len - cum_offsets[bi] + seq_id; + const int src_offset = ori_token_idx * dim_embed + bias_idx; + output_data[i] = input_data[src_offset]; + } + + return api::SUCCESS; +} +template +static int xpu2or3_wrapper(Context *ctx, + T *output_data, // [bs, dim_embed] + const T *input_data, // [token_num, dim_embed] + const int *cum_offsets, // [bs] + const int *seq_len_decoder, // [bs] + const int *seq_len_encoder, // [bs] + const int seq_len, + const int dim_embed, + const int elem_nums) { + xpu2::plugin::rebuild_padding<<ncluster(), 64, ctx->xpu_stream>>>(output_data, + input_data, + cum_offsets, + seq_len_decoder, + seq_len_encoder, + seq_len, + dim_embed, + elem_nums); + return api::SUCCESS; +} + +template +int rebuild_padding(Context *ctx, + T *output_data, // [bs, dim_embed] + const T *input_data, // [token_num, dim_embed] + const int *cum_offsets, // [bs] + const int *seq_len_decoder, // [bs] + const int *seq_len_encoder, // [bs] + const int seq_len, + const int dim_embed, + const int elem_nums) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "rebuild_padding", T); + WRAPPER_DUMP_PARAM5( + ctx, output_data, input_data, cum_offsets, seq_len_decoder, seq_len_encoder); + WRAPPER_DUMP_PARAM3(ctx, seq_len, dim_embed, elem_nums); + WRAPPER_DUMP(ctx); + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + output_data, + input_data, + cum_offsets, + seq_len_decoder, + seq_len_encoder, + seq_len, + dim_embed, + elem_nums); + } + if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) { + return xpu2or3_wrapper(ctx, + output_data, + input_data, + cum_offsets, + seq_len_decoder, + seq_len_encoder, + seq_len, + dim_embed, + elem_nums); + } + WRAPPER_UNIMPLEMENTED(ctx); +} +template int rebuild_padding(Context *ctx, + float *output_data, // [bs, dim_embed] + const float *input_data, // [token_num, dim_embed] + const int *cum_offsets, // [bs] + const int *seq_len_decoder, // [bs] + const int *seq_len_encoder, // [bs] + const int seq_len, + const int dim_embed, + const int elem_nums); +template int rebuild_padding(Context *ctx, + float16 *output_data, // [bs, dim_embed] + const float16 *input_data, // [token_num, dim_embed] + const int *cum_offsets, // [bs] + const int *seq_len_decoder, // [bs] + const int *seq_len_encoder, // [bs] + const int seq_len, + const int dim_embed, + const int elem_nums); + + +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu \ No newline at end of file diff --git a/csrc/xpu/src/plugin/src/wrapper/update_inputs.cpp b/csrc/xpu/src/plugin/src/wrapper/update_inputs.cpp new file mode 100644 index 000000000000..43aedc33f953 --- /dev/null +++ b/csrc/xpu/src/plugin/src/wrapper/update_inputs.cpp @@ -0,0 +1,192 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace xpu2 { +namespace plugin { + +__attribute__((global)) void update_inputs(bool *not_need_stop, + int *seq_lens_this_time, + int *seq_lens_encoder, + int *seq_lens_decoder, + int64_t *input_ids, + const int64_t *stop_nums, + const bool *stop_flags, + const bool *is_block_step, + const int64_t *next_tokens, + const int bsz, + const int max_bsz, + const int input_ids_stride); + +} // namespace plugin +} // namespace xpu2 + +namespace xpu3 { +namespace plugin { + +__attribute__((global)) void update_inputs(bool *not_need_stop, + int *seq_lens_this_time, + int *seq_lens_encoder, + int *seq_lens_decoder, + int64_t *input_ids, + const int64_t *stop_nums, + const bool *stop_flags, + const bool *is_block_step, + const int64_t *next_tokens, + const int bsz, + const int max_bsz, + const int input_ids_stride); + +} // namespace plugin +} // namespace xpu3 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +static int cpu_wrapper(Context *ctx, + bool *not_need_stop, + int *seq_lens_this_time, + int *seq_lens_encoder, + int *seq_lens_decoder, + int64_t *input_ids, + const int64_t *stop_nums, + const bool *stop_flags, + const bool *is_block_step, + const int64_t *next_tokens, + const int bsz, + const int max_bsz, + const int input_ids_stride) { + std::vector stop_flag_now_int(max_bsz, 1); + for (int i = 0; i < bsz; i++) { + bool stop_flags_now = stop_flags[i]; + stop_flag_now_int[i] = is_block_step[i] ? 0 : stop_flags_now; + const int seq_len_encoder = seq_lens_encoder[i]; + const int seq_len_decoder = seq_lens_decoder[i]; + + seq_lens_decoder[i] = + stop_flags[i] + ? 0 + : (seq_len_decoder == 0 ? seq_len_encoder : seq_len_decoder + 1); + + seq_lens_this_time[i] = stop_flags[i] ? 0 : 1; + seq_lens_encoder[i] = 0; + int64_t *input_ids_now = input_ids + i * input_ids_stride; + input_ids_now[0] = next_tokens[i]; + } + int64_t stop_sum = 0; + for (size_t i = 0; i < stop_flag_now_int.size(); i++) { + stop_sum += stop_flag_now_int[i]; + } + not_need_stop[0] = stop_sum < stop_nums[0]; + return api::SUCCESS; +} + +static int xpu2or3_wrapper(Context *ctx, + bool *not_need_stop, + int *seq_lens_this_time, + int *seq_lens_encoder, + int *seq_lens_decoder, + int64_t *input_ids, + const int64_t *stop_nums, + const bool *stop_flags, + const bool *is_block_step, + const int64_t *next_tokens, + const int bsz, + const int max_bsz, + const int input_ids_stride) { + using XPU_INT64 = typename XPUIndexType::type; + auto update_inputs = xpu2::plugin::update_inputs; + update_inputs<<ncluster(), 64, ctx->xpu_stream>>>( + not_need_stop, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + reinterpret_cast(input_ids), + reinterpret_cast(stop_nums), + stop_flags, + is_block_step, + reinterpret_cast(next_tokens), + bsz, + max_bsz, + input_ids_stride); + return api::SUCCESS; +} + +int update_inputs(Context *ctx, + bool *not_need_stop, + int *seq_lens_this_time, + int *seq_lens_encoder, + int *seq_lens_decoder, + int64_t *input_ids, + const int64_t *stop_nums, + const bool *stop_flags, + const bool *is_block_step, + const int64_t *next_tokens, + const int bsz, + const int max_bsz, + const int input_ids_stride) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "update_inputs", int); + WRAPPER_DUMP_PARAM5(ctx, + not_need_stop, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + input_ids); + WRAPPER_DUMP_PARAM4(ctx, stop_nums, stop_flags, is_block_step, next_tokens); + WRAPPER_DUMP_PARAM3(ctx, bsz, max_bsz, input_ids_stride); + WRAPPER_DUMP(ctx); + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + not_need_stop, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + input_ids, + stop_nums, + stop_flags, + is_block_step, + next_tokens, + bsz, + max_bsz, + input_ids_stride); + } + if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) { + return xpu2or3_wrapper(ctx, + not_need_stop, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + input_ids, + stop_nums, + stop_flags, + is_block_step, + next_tokens, + bsz, + max_bsz, + input_ids_stride); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/csrc/xpu/src/rebuild_padding_v2.cc b/csrc/xpu/src/rebuild_padding_v2.cc new file mode 100644 index 000000000000..ed261960a84e --- /dev/null +++ b/csrc/xpu/src/rebuild_padding_v2.cc @@ -0,0 +1,95 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/extension.h" +#include "xpu/plugin.h" + +std::vector RebuildPaddingV2(const paddle::Tensor& tmp_out, // [token_num, dim_embed] + const paddle::Tensor& cum_offsets, // [bsz, 1] + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_encoder, + int max_input_length) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + std::vector tmp_out_shape = tmp_out.shape(); + const int token_num = tmp_out_shape[0]; + const int dim_embed = tmp_out_shape[1]; + const int bsz = cum_offsets.shape()[0]; + auto out = paddle::full({bsz, dim_embed}, 0, tmp_out.dtype(), tmp_out.place()); + int elem_nums = out.numel(); + switch (tmp_out.type()) { + case paddle::DataType::FLOAT16: { + using XPUType = typename XPUTypeTrait::Type; + typedef paddle::float16 data_t; + int r = baidu::xpu::api::plugin::rebuild_padding( + xpu_ctx->x_context(), + reinterpret_cast(out.data()), + reinterpret_cast(tmp_out.data()), + cum_offsets.data(), + seq_lens_decoder.data(), + seq_lens_encoder.data(), + max_input_length, + dim_embed, + elem_nums + ); + PD_CHECK(r == 0, "xpu::plugin::rebuild_padding failed."); + } break; + case paddle::DataType::FLOAT32: { + int r = baidu::xpu::api::plugin::rebuild_padding( + xpu_ctx->x_context(), + out.data(), + tmp_out.data(), + cum_offsets.data(), + seq_lens_decoder.data(), + seq_lens_encoder.data(), + max_input_length, + dim_embed, + elem_nums + ); + PD_CHECK(r == 0, "xpu::plugin::rebuild_padding failed."); + } break; + default: + PD_THROW( + "NOT supported data type. " + "Only float16 and float32 are supported. "); + break; + } + return {out}; +} + +std::vector> RebuildPaddingV2InferShape(const std::vector& tmp_out_shape, + const std::vector& cum_offsets_shape, + const std::vector& seq_lens_decoder_shape, + const std::vector& seq_lens_encoder_shape) { + int64_t bsz = cum_offsets_shape[0]; + int64_t dim_embed = tmp_out_shape[1]; + return {{bsz, dim_embed}}; +} + +std::vector RebuildPaddingV2InferDtype(const paddle::DataType& tmp_out_dtype, + const paddle::DataType& cum_offsets_dtype, + const paddle::DataType& seq_lens_decoder_dtype, + const paddle::DataType& seq_lens_encoder_dtype) { + return {tmp_out_dtype}; +} + +PD_BUILD_OP(rebuild_padding_v2) + .Inputs({"tmp_out", "cum_offsets", "seq_lens_decoder", "seq_lens_encoder"}) + .Outputs({"out"}) + .Attrs({"max_input_length: int"}) + .SetKernelFn(PD_KERNEL(RebuildPaddingV2)) + .SetInferShapeFn(PD_INFER_SHAPE(RebuildPaddingV2InferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(RebuildPaddingV2InferDtype)); \ No newline at end of file diff --git a/csrc/xpu/src/set_stop_value_multi_ends_v2.cc b/csrc/xpu/src/set_stop_value_multi_ends_v2.cc new file mode 100644 index 000000000000..e587b105d21f --- /dev/null +++ b/csrc/xpu/src/set_stop_value_multi_ends_v2.cc @@ -0,0 +1,61 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "paddle/extension.h" +#include "xpu/plugin.h" + +void GetStopFlagsMulti(const paddle::Tensor &topk_ids, + const paddle::Tensor &stop_flags, + const paddle::Tensor &seq_lens, + const paddle::Tensor &end_ids, + const paddle::Tensor &next_tokens) { + PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64); + PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL); + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + std::vector shape = topk_ids.shape(); + int64_t bs_now = shape[0]; + int64_t end_length = end_ids.shape()[0]; + bool beam_search = false; + int r = baidu::xpu::api::plugin::set_stop_value_multi_ends( + xpu_ctx->x_context(), + const_cast(stop_flags.data()), + const_cast(topk_ids.data()), + const_cast(next_tokens.data()), + end_ids.data(), + seq_lens.data(), + bs_now, + end_length, + beam_search); + PD_CHECK(r == 0, "xpu::plugin::set_stop_value_multi_ends failed."); +} + +PD_BUILD_OP(set_stop_value_multi_ends_v2) + .Inputs({"topk_ids", "stop_flags", "seq_lens", "end_ids", "next_tokens"}) + .Outputs({"topk_ids_out", "stop_flags_out", "next_tokens_out"}) + .SetInplaceMap({{"topk_ids", "topk_ids_out"}, + {"stop_flags", "stop_flags_out"}, + {"next_tokens", "next_tokens_out"}}) + .SetKernelFn(PD_KERNEL(GetStopFlagsMulti)); diff --git a/csrc/xpu/src/set_value_by_flags_and_idx_v2.cc b/csrc/xpu/src/set_value_by_flags_and_idx_v2.cc new file mode 100644 index 000000000000..327d736a483b --- /dev/null +++ b/csrc/xpu/src/set_value_by_flags_and_idx_v2.cc @@ -0,0 +1,57 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/extension.h" +#include "xpu/plugin.h" + +void SetValueByFlagsAndIdx(const paddle::Tensor& pre_ids_all, + const paddle::Tensor& input_ids, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& step_idx, + const paddle::Tensor& stop_flags) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + std::vector pre_ids_all_shape = pre_ids_all.shape(); + int bs = seq_lens_this_time.shape()[0]; + int length = pre_ids_all.shape()[1]; + int length_input_ids = input_ids.shape()[1]; + int r = baidu::xpu::api::plugin::set_value_by_flags_and_idx( + xpu_ctx->x_context(), + stop_flags.data(), + const_cast(pre_ids_all.data()), + input_ids.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + step_idx.data(), + bs, + length, + length_input_ids); + PD_CHECK(r == 0, "xpu::plugin::set_value_by_flags_and_idx failed."); +} + +PD_BUILD_OP(set_value_by_flags_and_idx_v2) + .Inputs({"pre_ids_all", + "input_ids", + "seq_lens_this_time", + "seq_lens_encoder", + "seq_lens_decoder", + "step_idx", + "stop_flags"}) + .Outputs({"pre_ids_all_out"}) + .SetInplaceMap({{"pre_ids_all", "pre_ids_all_out"}}) + .SetKernelFn(PD_KERNEL(SetValueByFlagsAndIdx)); diff --git a/csrc/xpu/src/setup.py b/csrc/xpu/src/setup.py new file mode 100644 index 000000000000..9b9b255650fb --- /dev/null +++ b/csrc/xpu/src/setup.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Copyright (c) 2024 Baidu.com, Inc. All Rights Reserved. + +Build and setup XPU custom ops for ERNIE Bot. +""" + +from paddle.utils.cpp_extension import CppExtension, setup + +setup( + name="custom_setup_ops", + ext_modules=[ + CppExtension( + sources=[ + "./set_stop_value_multi_ends_v2.cc", + "./set_value_by_flags_and_idx_v2.cc", + "./get_token_penalty_multi_scores_v2.cc", + "./get_padding_offset_v2.cc", + "./update_inputs.cc", + "./rebuild_padding_v2.cc", + "../../generation/save_with_output.cc", + "../../generation/save_with_output_msg.cc", + "../../generation/get_output.cc", + ], + include_dirs=["./plugin/include"], + extra_objects=["./plugin/build/libxpuplugin.a"], + extra_compile_args={ + "cxx": ["-D_GLIBCXX_USE_CXX11_ABI=1", "-DPADDLE_WITH_XPU"] + }, + ) + ], +) diff --git a/csrc/xpu/src/update_inputs.cc b/csrc/xpu/src/update_inputs.cc new file mode 100644 index 000000000000..b162bf7ab075 --- /dev/null +++ b/csrc/xpu/src/update_inputs.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/extension.h" +#include "paddle/phi/core/enforce.h" +#include "xpu/plugin.h" + +void UpdateInputes(const paddle::Tensor& stop_flags, + const paddle::Tensor& not_need_stop, // xpu + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& input_ids, + const paddle::Tensor& stop_nums, + const paddle::Tensor& next_tokens, + const paddle::Tensor& is_block_step) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + + const int max_bsz = stop_flags.shape()[0]; + PADDLE_ENFORCE_LE( + max_bsz, + 1024, + phi::errors::InvalidArgument( + "Only support max_bs <= 1024, but received max_bs is %d", max_bsz)); + const int now_bsz = seq_lens_this_time.shape()[0]; + const int input_ids_stride = input_ids.shape()[1]; + int r = baidu::xpu::api::plugin::update_inputs( + xpu_ctx->x_context(), + const_cast(not_need_stop.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(input_ids.data()), + stop_nums.data(), + stop_flags.data(), + is_block_step.data(), + next_tokens.data(), + now_bsz, + max_bsz, + input_ids_stride); + PD_CHECK(r == 0, "baidu::xpu::api::plugin::update_inputs failed."); +} + +PD_BUILD_OP(update_inputs) + .Inputs({"stop_flags", + "not_need_stop", + "seq_lens_this_time", + "seq_lens_encoder", + "seq_lens_decoder", + "input_ids", + "stop_nums", + "next_tokens", + "is_block_step"}) + .Outputs({"not_need_stop_out", + "seq_lens_this_time_out", + "seq_lens_encoder_out", + "seq_lens_decoder_out", + "input_ids_out"}) + .SetInplaceMap({{"not_need_stop", "not_need_stop_out"}, + {"seq_lens_this_time", "seq_lens_this_time_out"}, + {"seq_lens_encoder", "seq_lens_encoder_out"}, + {"seq_lens_decoder", "seq_lens_decoder_out"}, + {"input_ids", "input_ids_out"}}) + .SetKernelFn(PD_KERNEL(UpdateInputes)); diff --git a/csrc/xpu/test/python/test_get_padding_offset_v2.py b/csrc/xpu/test/python/test_get_padding_offset_v2.py new file mode 100644 index 000000000000..d80915c80819 --- /dev/null +++ b/csrc/xpu/test/python/test_get_padding_offset_v2.py @@ -0,0 +1,69 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +import unittest +from paddlenlp_ops import get_padding_offset_v2 + +np.random.seed(2023) + +class GetPaddingOffsetV2Test(unittest.TestCase): + def test_get_padding_offset_v2(self): + max_len = 10 + seq_lens = np.array([4, 3, 6], "int32").reshape(-1, 1) + cum_offset = np.cumsum((max_len - seq_lens).flatten(), -1, "int32") + token_num = np.sum(seq_lens) + bs = seq_lens.shape[0] + input_ids = np.zeros([bs, max_len], "int64") + for i in range(bs): + ids_len = seq_lens[i, 0] + input_ids[i, 0:ids_len] = np.random.randint(1, 10, seq_lens[i, 0], "int64") + + x_remove_padding, cum_offsets_out, padding_offset, cu_seqlens_q, cu_seqlens_k = get_padding_offset_v2( + paddle.to_tensor(input_ids), + paddle.to_tensor(cum_offset), + paddle.to_tensor(token_num), + paddle.to_tensor(seq_lens), + ) + + print("input_ids:\n", input_ids) + print("cum_offset:\n", cum_offset) + print("token_num:\n", token_num) + print("seq_lens:\n", seq_lens) + print("x_remove_padding:\n", x_remove_padding) + print("cum_offsets_out:\n", cum_offsets_out) + print("padding_offset:\n", padding_offset) + print("cu_seqlens_q:\n", cu_seqlens_q) + print("cu_seqlens_k:\n", cu_seqlens_k) + + ref_x_remove_padding = np.array( + [8, 7, 8, 2, 4, 5, 5, 7, 6, 1, 7, 2, 6], "int64") + ref_cum_offsets_out = np.array([0, 6, 13], "int32") + ref_padding_offset = np.array( + [0, 0, 0, 0, 6, 6, 6, 13, 13, 13, 13, 13, 13], "int32") + ref_cu_seqlens_q = np.array([0 , 4 , 7 , 13], "int32") + ref_cu_seqlens_k = np.array([0 , 4 , 7 , 13], "int32") + + assert sum(ref_x_remove_padding + - x_remove_padding) == 0, 'Check x_remove_padding failed.' + assert sum(ref_cum_offsets_out + - cum_offsets_out) == 0, 'Check cum_offsets_out failed.' + assert sum(ref_padding_offset + - padding_offset) == 0, 'Check padding_offset failed.' + assert sum(ref_cu_seqlens_q - cu_seqlens_q) == 0, 'Check cu_seqlens_q failed.' + assert sum(ref_cu_seqlens_k - cu_seqlens_k) == 0, 'Check cu_seqlens_k failed.' + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/csrc/xpu/test/python/test_get_token_penalty_multi_scores_v2.py b/csrc/xpu/test/python/test_get_token_penalty_multi_scores_v2.py new file mode 100644 index 000000000000..43ff11f80eaf --- /dev/null +++ b/csrc/xpu/test/python/test_get_token_penalty_multi_scores_v2.py @@ -0,0 +1,254 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +from paddlenlp_ops import get_token_penalty_multi_scores_v2 +import unittest + +paddle.seed(2023) +class GetTokenPenaltyMultiScoresV2Test(unittest.TestCase): + def test_get_token_penalty_multi_scores_v2(self): + pre_ids = paddle.to_tensor([[1, 9, 3, 4, 5, 6, 7, -1, -1, -1], [1, 9, 7, 6, 5, 4, -1, -1, -1, -1]], "int64") + logits = paddle.to_tensor([[0.1, 0.9, 0.3, 0.4, 0.5, 0.6, 0.7, 0.1, 0.1, 0.1], [0.1, 0.9, 0.7, 0.6, 0.5, 0.4, 0.1, 0.1, 0.1, 0.1]], "float32") + penalty_scores = paddle.to_tensor([1.0, 1.0], "float32") + frequency_scores = paddle.to_tensor([0.1, 0.1], "float32") + presence_scores = paddle.to_tensor([0.0, 0.0], "float32") + temperatures = paddle.to_tensor([0.5, 0.25], "float32") + bad_tokens = paddle.to_tensor([0, 1], "int64") + cur_len = paddle.to_tensor([7, 6], "int64") + min_len = paddle.to_tensor([1, 8], "int64") + eos_token_id = paddle.to_tensor([2, 9], "int64") + print("logits\n", logits) + get_token_penalty_multi_scores_v2( + pre_ids, + logits, + penalty_scores, + frequency_scores, + presence_scores, + temperatures, + bad_tokens, + cur_len, + min_len, + eos_token_id, + ) + print("pre_ids\n", pre_ids) + print("logits\n", logits) + print("penalty_scores\n", penalty_scores) + print("frequency_scores\n", frequency_scores) + print("presence_scores\n", presence_scores) + print("temperatures\n", temperatures) + print("bad_tokens\n", bad_tokens) + print("cur_len\n", cur_len) + print("min_len\n", min_len) + print("eos_token_id\n", eos_token_id) + + ref_logits = np.array( + [ + [ + -10000000000, + -10000000000, + 0.6, + 0.6, + 0.8, + 1, + 1.2, + 0, + 0.2, + 0, + ], + [ + -10000000000, + -10000000000, + -40000000000, + 2.4, + 1.6, + 1.2, + 0, + 0, + 0.4, + -40000000000, + ], + ], + "float32", + ) + diff_logits = np.sum(np.abs(ref_logits - logits.numpy())) + print("diff_logits\n", diff_logits) + assert diff_logits < 1e-6, 'Check failed.' + + pre_ids = paddle.to_tensor([[2, 3, 3, 5, 8, 9, 3, 9, 1, 8, 9, 2, 3, 8, 8, 9, 9, 1, 4, 2, 6, 2, 6, 8, + 7, 2, 2, 3, 8, 1, 5, 7, 9, 2, 2, 9, 1, 4, 9, 8, 5, 8, 5, 7, 3, 6, 4, 4, + 9, 9, 8, 5, 5, 2, 2, 9, 4, 8, 1, 9, 6, 9, 2, 2, 7, 2, 2, 9, 4, 6, 4, 6, + 1, 4, 1, 9, 1, 8, 8, 5, 7, 9, 4, 2, 5, 1, 1, 4, 1, 5, 5, 4, 4, 2, 1, 8, + 7, 1, 2, 9, 6, 7, 9, 6, 7, 7, 4, 9, 9, 7, 5, 1, 8, 9, 8, 8, 5, 4, 6, 4, + 7, 5, 5, 7, 6, 9, 3, 9], + [7, 8, 1, 3, 1, 7, 6, 3, 5, 3, 8, 3, 1, 9, 7, 1, 1, 9, 5, 4, 9, 6, 1, 9, + 3, 8, 3, 9, 9, 6, 4, 2, 8, 5, 3, 1, 6, 9, 1, 3, 9, 8, 1, 7, 5, 1, 5, 1, + 8, 7, 4, 5, 9, 8, 7, 4, 7, 3, 6, 4, 6, 6, 5, 5, 2, 9, 9, 5, 8, 8, 4, 8, + 2, 8, 1, 3, 9, 1, 8, 5, 8, 3, 8, 8, 2, 7, 3, 7, 5, 7, 2, 6, 3, 5, 1, 4, + 6, 1, 9, 8, 2, 2, 3, 6, 7, 6, 2, 6, 5, 1, 5, 6, 2, 1, 6, 4, 7, 7, 3, 8, + 5, 1, 9, 1, 2, 8, 6, 8]]) + logits = paddle.to_tensor([[0.16274983, 0.61470598, 0.94366980, 0.82005417, 0.50752640, 0.38316748, + 0.92648441, 0.24050158, 0.05461595, 0.42218581, 0.36270225, 0.15464807, + 0.13614719, 0.67509544, 0.40315166, 0.10671722, 0.24832056, 0.76091218, + 0.11598995, 0.10962527, 0.04688513, 0.81536716, 0.72259802, 0.60476679, + 0.16701800, 0.84160781, 0.79649884, 0.78021604, 0.75329530, 0.98587888, + 0.13421868, 0.16027625, 0.15269397, 0.06228730, 0.73856270, 0.34721911, + 0.73683006, 0.78178608, 0.32068327, 0.79906309, 0.44214272, 0.63330448, + 0.08016958, 0.63367140, 0.19788943, 0.55346787, 0.11142531, 0.90518415, + 0.21236691, 0.81587470, 0.83752930, 0.70979482, 0.35684183, 0.28715104, + 0.87162822, 0.17679396, 0.98725849, 0.76129991, 0.04090235, 0.37181064, + 0.63317049, 0.24689502, 0.21126501, 0.57617670, 0.74346697, 0.40613672, + 0.56907010, 0.68556929, 0.29032683, 0.17866278, 0.35165095, 0.97015840, + 0.70785582, 0.54259878, 0.14712237, 0.90483177, 0.02094105, 0.36411613, + 0.02495066, 0.88874054, 0.88895452, 0.86216462, 0.58062190, 0.95583254, + 0.20553111, 0.29870346, 0.69652933, 0.36861244, 0.85316223, 0.50240189, + 0.17566244, 0.61080140, 0.88203174, 0.98675215, 0.24344546, 0.17213407, + 0.78160852, 0.25165486, 0.48188508, 0.82812423, 0.10199814, 0.90475923, + 0.66907483, 0.71910626, 0.40660757, 0.59460294, 0.70212913, 0.90841550, + 0.00329034, 0.11290466, 0.89654654, 0.69114941, 0.29473618, 0.62027222, + 0.37333879, 0.98911142, 0.46510187, 0.65914583, 0.73022646, 0.12790845, + 0.12817244, 0.43015456, 0.75011456, 0.43562204, 0.48086026, 0.75587070, + 0.98481447, 0.77367836], + [0.12336024, 0.74152875, 0.09191196, 0.99301219, 0.44764417, 0.01848883, + 0.78326035, 0.99228370, 0.81447607, 0.02627683, 0.51033205, 0.98703283, + 0.15247856, 0.77640921, 0.60799915, 0.87518770, 0.76818430, 0.86542630, + 0.31795895, 0.04829503, 0.85567141, 0.30271924, 0.67515039, 0.59728831, + 0.78710967, 0.75111693, 0.56837374, 0.49085775, 0.91510201, 0.59545547, + 0.99482232, 0.59036905, 0.58267909, 0.28770933, 0.53237396, 0.95318258, + 0.93987304, 0.61142951, 0.26737869, 0.52285451, 0.03479086, 0.61631846, + 0.66777998, 0.15736090, 0.00447258, 0.37035006, 0.15281211, 0.95372260, + 0.25963321, 0.61036694, 0.15020694, 0.19171195, 0.55252832, 0.00391038, + 0.31052542, 0.96495175, 0.42586124, 0.05630261, 0.99728668, 0.01856293, + 0.83201504, 0.10701843, 0.56434178, 0.38009524, 0.51095045, 0.13202040, + 0.07133843, 0.75313550, 0.17111187, 0.80716974, 0.00172165, 0.83906764, + 0.73240769, 0.85843354, 0.11042888, 0.07912333, 0.33689004, 0.22334915, + 0.59059596, 0.52789515, 0.29831955, 0.39515004, 0.55602801, 0.83818001, + 0.05865780, 0.25654668, 0.76624149, 0.35190639, 0.04158346, 0.59157544, + 0.30779791, 0.94609004, 0.10759670, 0.65575141, 0.37828529, 0.29571742, + 0.76361233, 0.72476572, 0.18568406, 0.85430276, 0.02057583, 0.76195669, + 0.65507215, 0.69129735, 0.25084621, 0.75223947, 0.06064088, 0.20287007, + 0.35887691, 0.75043523, 0.47575447, 0.40021798, 0.44464844, 0.67975360, + 0.40443239, 0.71052992, 0.21782248, 0.50568426, 0.89037591, 0.06661721, + 0.28788096, 0.70773387, 0.42428264, 0.80419677, 0.42710736, 0.87317258, + 0.88229448, 0.79217333]]) + # pre_ids = paddle.to_tensor(np.float32(np.random.random([2, 1024]))) + # logits = paddle.to_tensor(np.float32(np.random.random([2, 1024]))) + penalty_scores = paddle.to_tensor([1.0, 1.0], "float32") + frequency_scores = paddle.to_tensor([0.1, 0.1], "float32") + presence_scores = paddle.to_tensor([0.0, 0.0], "float32") + temperatures = paddle.to_tensor([0.5, 0.25], "float32") + bad_tokens = paddle.to_tensor([0, 1], "int64") + cur_len = paddle.to_tensor([7, 6], "int64") + min_len = paddle.to_tensor([1, 8], "int64") + eos_token_id = paddle.to_tensor([2, 9], "int64") + print("logits\n", logits) + get_token_penalty_multi_scores_v2( + pre_ids, + logits, + penalty_scores, + frequency_scores, + presence_scores, + temperatures, + bad_tokens, + cur_len, + min_len, + eos_token_id, + ) + print("pre_ids\n", pre_ids) + print("logits\n", logits) + print("penalty_scores\n", penalty_scores) + print("frequency_scores\n", frequency_scores) + print("presence_scores\n", presence_scores) + print("temperatures\n", temperatures) + print("bad_tokens\n", bad_tokens) + print("cur_len\n", cur_len) + print("min_len\n", min_len) + print("eos_token_id\n", eos_token_id) + + ref_logits = np.array( + [[-10000000000., -10000000000., 1.88733959 , 1.64010835 , + 1.01505280 , 0.76633495 , 1.85296881 , 0.48100317 , + 0.10923190 , 0.84437162 , 0.72540450 , 0.30929613 , + 0.27229437 , 1.35019088 , 0.80630332 , 0.21343444 , + 0.49664113 , 1.52182436 , 0.23197991 , 0.21925054 , + 0.09377026 , 1.63073432 , 1.44519603 , 1.20953357 , + 0.33403599 , 1.68321562 , 1.59299767 , 1.56043208 , + 1.50659060 , 1.97175777 , 0.26843736 , 0.32055250 , + 0.30538794 , 0.12457460 , 1.47712541 , 0.69443822 , + 1.47366011 , 1.56357217 , 0.64136654 , 1.59812617 , + 0.88428545 , 1.26660895 , 0.16033916 , 1.26734281 , + 0.39577886 , 1.10693574 , 0.22285062 , 1.81036830 , + 0.42473382 , 1.63174939 , 1.67505860 , 1.41958964 , + 0.71368366 , 0.57430208 , 1.74325645 , 0.35358793 , + 1.97451699 , 1.52259982 , 0.08180470 , 0.74362129 , + 1.26634097 , 0.49379003 , 0.42253003 , 1.15235341 , + 1.48693395 , 0.81227344 , 1.13814020 , 1.37113857 , + 0.58065367 , 0.35732555 , 0.70330191 , 1.94031680 , + 1.41571164 , 1.08519757 , 0.29424474 , 1.80966353 , + 0.04188210 , 0.72823226 , 0.04990132 , 1.77748108 , + 1.77790904 , 1.72432923 , 1.16124380 , 1.91166508 , + 0.41106221 , 0.59740692 , 1.39305866 , 0.73722488 , + 1.70632446 , 1.00480378 , 0.35132489 , 1.22160280 , + 1.76406348 , 1.97350430 , 0.48689091 , 0.34426814 , + 1.56321704 , 0.50330973 , 0.96377015 , 1.65624845 , + 0.20399629 , 1.80951846 , 1.33814967 , 1.43821251 , + 0.81321514 , 1.18920588 , 1.40425825 , 1.81683099 , + 0.00658068 , 0.22580932 , 1.79309309 , 1.38229883 , + 0.58947235 , 1.24054444 , 0.74667758 , 1.97822285 , + 0.93020374 , 1.31829166 , 1.46045291 , 0.25581691 , + 0.25634488 , 0.86030912 , 1.50022912 , 0.87124407 , + 0.96172053 , 1.51174140 , 1.96962893 , 1.54735672 ], + [-10000000000., -10000000000., -40000000000. , 3.97204876 , + 1.79057670 , 0.07395532 , 3.13304138 , 3.96913481 , + 3.25790429 , -40000000000. , 2.04132819 , 3.94813132 , + 0.60991424 , 3.10563684 , 2.43199658 , 3.50075078 , + 3.07273722 , 3.46170521 , 1.27183580 , 0.19318011 , + 3.42268562 , 1.21087694 , 2.70060158 , 2.38915324 , + 3.14843869 , 3.00446773 , 2.27349496 , 1.96343100 , + 3.66040802 , 2.38182187 , 3.97928929 , 2.36147618 , + 2.33071637 , 1.15083730 , 2.12949586 , 3.81273031 , + 3.75949216 , 2.44571805 , 1.06951475 , 2.09141803 , + 0.13916343 , 2.46527386 , 2.67111993 , 0.62944359 , + 0.01789032 , 1.48140025 , 0.61124843 , 3.81489038 , + 1.03853285 , 2.44146776 , 0.60082775 , 0.76684779 , + 2.21011329 , 0.01564152 , 1.24210167 , 3.85980701 , + 1.70344496 , 0.22521044 , 3.98914671 , 0.07425172 , + 3.32806015 , 0.42807373 , 2.25736713 , 1.52038097 , + 2.04380178 , 0.52808160 , 0.28535372 , 3.01254201 , + 0.68444747 , 3.22867894 , 0.00688660 , 3.35627055 , + 2.92963076 , 3.43373418 , 0.44171551 , 0.31649333 , + 1.34756017 , 0.89339662 , 2.36238384 , 2.11158061 , + 1.19327819 , 1.58060014 , 2.22411203 , 3.35272002 , + 0.23463120 , 1.02618670 , 3.06496596 , 1.40762556 , + 0.16633384 , 2.36630177 , 1.23119164 , 3.78436017 , + 0.43038681 , 2.62300563 , 1.51314116 , 1.18286967 , + 3.05444932 , 2.89906287 , 0.74273622 , 3.41721106 , + 0.08230332 , 3.04782677 , 2.62028861 , 2.76518941 , + 1.00338483 , 3.00895786 , 0.24256352 , 0.81148028 , + 1.43550766 , 3.00174093 , 1.90301788 , 1.60087192 , + 1.77859378 , 2.71901441 , 1.61772954 , 2.84211969 , + 0.87128991 , 2.02273703 , 3.56150365 , 0.26646885 , + 1.15152383 , 2.83093548 , 1.69713056 , 3.21678710 , + 1.70842946 , 3.49269032 , 3.52917790 , 3.16869330 ]] + , + "float32", + ) + print(logits.numpy()) + diff_logits = np.sum(np.abs(ref_logits - logits.numpy())) + print("diff_logits\n", diff_logits) + # assert diff_logits < 1e-6, 'Check failed.' + +if __name__ == '__main__': + unittest.main() diff --git a/csrc/xpu/test/python/test_rebuild_padding_v2.py b/csrc/xpu/test/python/test_rebuild_padding_v2.py new file mode 100644 index 000000000000..685c799883f5 --- /dev/null +++ b/csrc/xpu/test/python/test_rebuild_padding_v2.py @@ -0,0 +1,73 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +import unittest +from paddlenlp_ops import rebuild_padding_v2 + +np.random.seed(2024) + +class GetRebuildPaddingV2Test(unittest.TestCase): + def test_rebuild_padding_v2(self): + max_len = 10 + seq_lens = np.array([4, 3, 6], "int32").reshape(-1, 1) + seq_lens_decoder = np.zeros_like(seq_lens).astype("int32") + + cum_offsets = np.insert(np.cumsum((max_len - seq_lens).flatten(), -1, "int32"),0,0)[:-1] + token_num = np.sum(seq_lens) + bs = seq_lens.shape[0] + dim_emb = 129 + tmp_out = np.random.random((token_num, dim_emb)).astype("float16") + # print("tmp_out:\n", paddle.to_tensor(tmp_out)) + # print("cum_offsets:\n", paddle.to_tensor(cum_offsets)) + # print("seq_lens_decoder:\n", paddle.to_tensor(seq_lens_decoder)) + # print("seq_lens:\n", paddle.to_tensor(seq_lens)) + + + out = rebuild_padding_v2( + paddle.to_tensor(tmp_out), + paddle.to_tensor(cum_offsets), + paddle.to_tensor(seq_lens_decoder), + paddle.to_tensor(seq_lens), + max_len + ) + + def rebuild_padding_cpu(tmp_out, cum_offsets, seq_lens_decoder, seq_len_encoder, max_len): + bs = seq_lens.shape[0] + dim_emb = tmp_out.shape[1] + output_data = np.zeros((bs, dim_emb)).flatten() + seq_len = max_len + tmp_out = tmp_out.flatten() + for i in range(bs*dim_emb): + bi = i // dim_emb + bias_idx = i % dim_emb + seq_id = 0 + # just encoder or stop, get last token; just decoder, get first token. + if (seq_lens_decoder[bi] == 0): + if seq_len_encoder[bi] != 0: + seq_id = seq_len_encoder[bi] - 1 + else: + continue + ori_token_idx = bi * seq_len - cum_offsets[bi] + seq_id + src_offset = ori_token_idx * dim_emb + bias_idx + output_data[i] = tmp_out[src_offset] + return output_data.reshape(bs, dim_emb) + + out_ = rebuild_padding_cpu(tmp_out, cum_offsets, seq_lens_decoder, seq_lens, max_len) + + np.testing.assert_allclose(out.numpy(), out_, atol=1e-05, rtol=1e-05) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/csrc/xpu/test/python/test_set_stop_value_multi_ends_v2.py b/csrc/xpu/test/python/test_set_stop_value_multi_ends_v2.py new file mode 100644 index 000000000000..02f34c8e2d08 --- /dev/null +++ b/csrc/xpu/test/python/test_set_stop_value_multi_ends_v2.py @@ -0,0 +1,130 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +import unittest +from paddlenlp_ops import set_stop_value_multi_ends_v2 +np.random.seed(1) + +class GetStopValueMultiEndsV2Test(unittest.TestCase): + def test_set_stop_value_multi_ends_v2(self): + bs = 64 + + # test beam_search=False + topk_ids = paddle.arange(0, bs, dtype="int64") + next_tokens = paddle.full([bs], 0, dtype="int64") + stop_flags = paddle.to_tensor(np.random.randint(0, 2, [bs]), "bool") + seq_lens = paddle.to_tensor(np.random.randint(0, 5, [bs]), "int32") + end_ids = paddle.to_tensor([0, 1, 2, 3, 4, 5], "int64") + print("topk_ids\n", topk_ids) + print("next_tokens\n", next_tokens) + print("stop_flags\n", stop_flags) + set_stop_value_multi_ends_v2(topk_ids, stop_flags, + seq_lens, end_ids, next_tokens) + print("topk_ids\n", topk_ids) + print("next_tokens\n", next_tokens) + print("stop_flags\n", stop_flags) + print("seq_lens\n", seq_lens) + print("end_ids\n", end_ids) + + ref_topk_ids = np.array( + [0, 0, 2, 3, -1, 0, 0, 0, 0, 9, 10, 0, 12, 0, -1, 15, 16, 0, + 18, 19, 20, 0, 22, 23, 0, 25, 26, 27, -1, 29, 30, 31, 0, 0, 0, -1, + -1, 37, 38, 39, -1, -1, 0, 0, 0, 0, 46, -1, 0, 49, 50, 0, 52, 53, + 0, -1, 0, 57, -1, 59, 60, 0, 0, 63], + "int64", + ) + ref_next_tokens = np.array( + [0, 0, 2, 3, 0, 0, 0, 0, 0, 9, 10, 0, 12, 0, 0, 15, 16, 0, + 18, 19, 20, 0, 22, 23, 0, 25, 26, 27, 0, 29, 30, 31, 0, 0, 0, 0, + 0, 37, 38, 39, 0, 0, 0, 0, 0, 0, 46, 0, 0, 49, 50, 0, 52, 53, + 0, 0, 0, 57, 0, 59, 60, 0, 0, 63], + "int64", + ) + ref_stop_flags = np.array( + [True, True, True, True, True, True, True, True, True, False, + False, True, False, True, True, False, False, True, False, False, + False, True, False, False, True, False, False, False, True, False, + False, False, True, True, True, True, True, False, False, False, + True, True, True, True, True, True, False, True, True, False, + False, True, False, False, True, True, True, False, True, False, + False, True, True, False], + "bool", + ) + diff_topk_ids = np.sum(np.abs(ref_topk_ids - topk_ids.numpy())) + print("diff_topk_ids\n", diff_topk_ids) + assert diff_topk_ids == 0, 'Check failed.' + diff_next_tokens = np.sum(np.abs(ref_next_tokens - next_tokens.numpy())) + print("diff_next_tokens\n", diff_next_tokens) + assert diff_next_tokens == 0, 'Check failed.' + diff_stop_flags = np.sum(np.abs(ref_stop_flags.astype( + np.int32) - stop_flags.numpy().astype(np.int32))) + print("diff_stop_flags\n", diff_stop_flags) + assert diff_stop_flags == 0, 'Check failed.' + + # test beam_search=True + # topk_ids = paddle.arange(0, bs, dtype="int64") + # next_tokens = paddle.full([bs], 0, dtype="int64") + # stop_flags = paddle.to_tensor(np.random.randint(0, 2, [bs]), "bool") + # seq_lens = paddle.to_tensor(np.random.randint(0, 5, [bs]), "int32") + # end_ids = paddle.to_tensor([0, 1, 2, 3, 4, 5], "int64") + # print("topk_ids\n", topk_ids) + # print("next_tokens\n", next_tokens) + # print("stop_flags\n", stop_flags) + # set_stop_value_multi_ends_v2(topk_ids, stop_flags, + # seq_lens, end_ids, next_tokens, True) + # print("topk_ids\n", topk_ids) + # print("next_tokens\n", next_tokens) + # print("stop_flags\n", stop_flags) + # print("seq_lens\n", seq_lens) + # print("end_ids\n", end_ids) + + # ref_topk_ids = np.array( + # [0, 1, 2, 3, 4, 0, 6, 7, -1, 9, 10, 0, -1, 13, 14, 15, 0, 17, + # 18, 19, 20, 0, 22, 23, 24, 25, -1, -1, 28, 29, 0, 0, -1, 33, 34, 35, + # 36, 37, 0, -1, 0, 41, -1, 0, 44, 45, 46, 0, 0, 49, 0, 0, 0, 53, + # 0, 0, 0, 0, 58, -1, 60, 61, -1, 63], + # "int64", + # ) + # ref_next_tokens = np.array( + # [0, 1, 2, 3, 4, 0, 6, 7, 0, 9, 10, 0, 0, 13, 14, 15, 0, 17, + # 18, 19, 20, 0, 22, 23, 24, 25, 0, 0, 28, 29, 0, 0, 0, 33, 34, 35, + # 36, 37, 0, 0, 0, 41, 0, 0, 44, 45, 46, 0, 0, 49, 0, 0, 0, 53, + # 0, 0, 0, 0, 58, 0, 60, 61, 0, 63], + # "int64", + # ) + # ref_stop_flags = np.array( + # [False, False, False, False, False, True, False, False, True, False, + # False, True, True, False, False, False, True, False, False, False, + # False, True, False, False, False, False, True, True, False, False, + # True, True, True, False, False, False, False, False, True, True, + # True, False, True, True, False, False, False, True, True, False, + # True, True, True, False, True, True, True, True, False, True, + # False, False, True, False], + # "bool", + # ) + # diff_topk_ids = np.sum(np.abs(ref_topk_ids - topk_ids.numpy())) + # print("diff_topk_ids\n", diff_topk_ids) + # assert diff_topk_ids == 0, 'Check failed.' + # diff_next_tokens = np.sum(np.abs(ref_next_tokens - next_tokens.numpy())) + # print("diff_next_tokens\n", diff_next_tokens) + # assert diff_next_tokens == 0, 'Check failed.' + # diff_stop_flags = np.sum(np.abs(ref_stop_flags.astype( + # np.int32) - stop_flags.numpy().astype(np.int32))) + # print("diff_stop_flags\n", diff_stop_flags) + # assert diff_stop_flags == 0, 'Check failed.'.' + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/csrc/xpu/test/python/test_set_value_by_flags_and_idx_v2.py b/csrc/xpu/test/python/test_set_value_by_flags_and_idx_v2.py new file mode 100644 index 000000000000..9f718020ade3 --- /dev/null +++ b/csrc/xpu/test/python/test_set_value_by_flags_and_idx_v2.py @@ -0,0 +1,75 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +import unittest +from paddlenlp_ops import set_value_by_flags_and_idx_v2 + +paddle.seed(2023) + +class GetStopValueMultiEndsV2Test(unittest.TestCase): + def test_set_stop_value_multi_ends_v2(self): + pre_ids_all = paddle.to_tensor([[1, 9, 3, 4, 5, 6, 7, -1, -1, -1], [1, 9, 7, 6, 5, 4, -1, -1, -1, -1]], "int64") + input_ids = paddle.to_tensor([[1, 9, 3, 4, 5, 6, 7, -1, -1, -1, -1, -1, -1], [1, 9, 7, 6, 5, 4, -1, -1, -1, -1, -1, -1, -1]], "int64") + seq_lens_this_time = paddle.to_tensor([1, 1], "int32") + seq_lens_encoder = paddle.to_tensor([1, 1], "int32") + seq_lens_decoder = paddle.to_tensor([1, 1], "int32") + step_idx = paddle.to_tensor([1, 1], "int64") + stop_flags = paddle.to_tensor([0, 1], "bool") + print("pre_ids_all\n", pre_ids_all) + set_value_by_flags_and_idx_v2(pre_ids_all, input_ids, seq_lens_this_time, seq_lens_encoder, seq_lens_decoder, step_idx, stop_flags) + print("pre_ids_all\n", pre_ids_all) + print("input_ids\n", input_ids) + print("seq_lens_this_time\n", seq_lens_this_time) + print("seq_lens_encoder\n", seq_lens_encoder) + print("seq_lens_decoder\n", seq_lens_decoder) + print("step_idx\n", step_idx) + print("stop_flags\n", stop_flags) + + ref_pre_ids_all = np.array( + [ + [ + 1, + 1, + 3, + 4, + 5, + 6, + 7, + -1, + -1, + -1, + ], + [ + 1, + 9, + 7, + 6, + 5, + 4, + -1, + -1, + -1, + -1, + ], + ], + "int64", + ) + diff_pre_ids_all = np.sum(np.abs(ref_pre_ids_all - pre_ids_all.numpy())) + print("diff_pre_ids_all\n", diff_pre_ids_all) + assert diff_pre_ids_all == 0, 'Check failed.' + +if __name__ == '__main__': + unittest.main() diff --git a/csrc/xpu/test/python/test_update_inputs.py b/csrc/xpu/test/python/test_update_inputs.py new file mode 100644 index 000000000000..94de52e92514 --- /dev/null +++ b/csrc/xpu/test/python/test_update_inputs.py @@ -0,0 +1,110 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +import unittest +from paddlenlp_ops import update_inputs + +np.random.seed(2023) +class GetUpdateInputsTest(unittest.TestCase): + def test_update_inputs(self): + bs = 48 + max_bs = 64 + max_input_length = 6144 + + stop_flags = np.random.randint(0, 2, max_bs).astype("bool") + not_need_stop = np.array([1], "bool") + seq_lens_this_time = np.zeros([bs], "int32") + seq_lens_encoder = np.zeros([max_bs], "int32") + seq_lens_decoder = np.zeros([max_bs], "int32") + for i in range(bs): + if i % 2 == 0: + seq_lens_encoder[i] = i + seq_lens_this_time[i] = i + else: + seq_lens_decoder[i] = i + seq_lens_this_time[i] = 1 + input_ids_np = np.random.randint(1, 10, [max_bs, max_input_length], "int64") + stop_nums = np.array([max_bs], "int64") + next_tokens = np.random.randint(1, 10, [max_bs], "int64") + is_block_step = np.random.randint(0, 2, [max_bs]).astype("bool") + + stop_flags = paddle.to_tensor(stop_flags) + not_need_stop = paddle.to_tensor(not_need_stop, place=paddle.CPUPlace()) + seq_lens_this_time = paddle.to_tensor(seq_lens_this_time) + seq_lens_encoder = paddle.to_tensor(seq_lens_encoder) + seq_lens_decoder = paddle.to_tensor(seq_lens_decoder) + input_ids = paddle.to_tensor(input_ids_np) + stop_nums = paddle.to_tensor(stop_nums) + next_tokens = paddle.to_tensor(next_tokens) + is_block_step = paddle.to_tensor(is_block_step) + + print("stop_flags:\n", stop_flags) + print("not_need_stop:\n", not_need_stop) + print("seq_lens_this_time:\n", seq_lens_this_time) + print("seq_lens_encoder:\n", seq_lens_encoder) + print("seq_lens_decoder:\n", seq_lens_decoder) + print("input_ids:\n", input_ids) + print("stop_nums:\n", stop_nums) + print("next_tokens:\n", next_tokens) + print("is_block_step:\n", is_block_step) + + update_inputs( + stop_flags, + not_need_stop, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + input_ids, + stop_nums, + next_tokens, + is_block_step + ) + + print("-" * 50) + print("stop_flags:\n", stop_flags) + print("not_need_stop:\n", not_need_stop) + print("seq_lens_this_time:\n", seq_lens_this_time) + print("seq_lens_encoder:\n", seq_lens_encoder) + print("seq_lens_decoder:\n", seq_lens_decoder) + print("input_ids:\n", input_ids) + print("stop_nums:\n", stop_nums) + print("next_tokens:\n", next_tokens) + + ref_not_need_stop_out = np.array([True]) + ref_seq_lens_this_time_out = np.array([0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, + 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1], "int32") + ref_seq_lens_encoder_out = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "int32") + ref_seq_lens_decoder_out = np.array([0, 0, 2, 0, 0, 6, 0, 8, 8, 10, 0, 12, 12, 0, 0, 0, 0, 0, 0, 0, 20, 22, 0, 24, + 24, 0, 26, 28, 0, 0, 0, 32, 32, 0, 34, 0, 0, 38, 0, 40, 0, 0, 42, 0, 0, 46, 46, 48, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "int32") + input_ids_np[:, 0] = np.array([6, 5, 9, 8, 6, 2, 8, 1, 3, 1, 3, 6, 9, 8, 1, 9, 1, 8, 8, 6, 7, 6, 5, 3, + 5, 9, 3, 6, 3, 9, 8, 8, 8, 8, 4, 8, 7, 4, 2, 3, 5, 8, 4, 2, 5, 6, 8, 9, + 6, 7, 4, 2, 4, 6, 2, 3, 4, 9, 7, 2, 1, 8, 7, 8], "int64") + + assert not_need_stop.numpy() == ref_not_need_stop_out, 'Check not_need_stop failed.' + assert np.all(seq_lens_this_time.numpy() + == ref_seq_lens_this_time_out), 'Check seq_lens_this_time failed.' + assert np.all(seq_lens_encoder.numpy() + == ref_seq_lens_encoder_out), 'Check seq_lens_encoder failed.' + assert np.all(seq_lens_decoder.numpy() + == ref_seq_lens_decoder_out), 'Check seq_lens_decoder failed.' + assert np.all(input_ids.numpy() + == input_ids_np), 'Check input_ids failed.' + +if __name__ == '__main__': + unittest.main()