Skip to content

Commit ba9c345

Browse files
authored
【Inference】fix step kernel (#9122)
* fix step.cu when setting FLAGS_allocator_strategy=auto_growth
1 parent af23e2d commit ba9c345

File tree

1 file changed

+21
-22
lines changed

1 file changed

+21
-22
lines changed

csrc/gpu/step.cu

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ __global__ void free_and_dispatch_block(bool *stop_flags,
102102
}
103103
__syncthreads();
104104
}
105-
106105
// 为需要block的位置分配block,每个位置分配一个block
107106
if (tid < need_block_len[0]) {
108107
const int need_block_id = need_block_list[tid];
@@ -116,33 +115,33 @@ __global__ void free_and_dispatch_block(bool *stop_flags,
116115
need_block_list[tid] = -1;
117116
}
118117
__syncthreads();
119-
120118
// 计算可以复原的query id
121119
if (tid == 0) {
122120
int ori_free_list_len = free_list_len[0];
123121
int ori_step_len = step_len[0];
124-
printf("ori_step_len %d\n", ori_step_len);
125-
int ori_step_block_id = step_block_list[ori_step_len - 1];
126-
int tmp_used_len = used_list_len[ori_step_block_id];
127-
// 比之前调度时多分配一个block,防止马上恢复刚调度的query(比如回收的seq_id在need_block_list中)
128-
int used_len = tmp_used_len < max_decoder_block_num ? tmp_used_len + 1 : tmp_used_len;
129-
while (ori_step_len > 0 && ori_free_list_len >= used_len) {
122+
if (ori_step_len > 0) {
123+
int ori_step_block_id = step_block_list[ori_step_len - 1];
124+
int tmp_used_len = used_list_len[ori_step_block_id];
125+
// 比之前调度时多分配一个block,防止马上恢复刚调度的query(比如回收的seq_id在need_block_list中)
126+
int used_len = tmp_used_len < max_decoder_block_num ? tmp_used_len + 1 : tmp_used_len;
127+
while (ori_step_len > 0 && ori_free_list_len >= used_len) {
130128
#ifdef DEBUG_STEP
131-
printf("recover seq_id: %d, free_list_len: %d, used_list_len: %d\n",
132-
ori_step_block_id, ori_free_list_len, used_len);
129+
printf("recover seq_id: %d, free_list_len: %d, used_list_len: %d\n",
130+
ori_step_block_id, ori_free_list_len, used_len);
133131
#endif
134-
recover_block_list[recover_len[0]] = ori_step_block_id;
135-
is_block_step[ori_step_block_id] = false;
136-
used_list_len[ori_step_block_id] = used_len;
137-
ori_free_list_len -= used_len;
138-
step_block_list[ori_step_len - 1] = -1;
139-
step_len[0] -= 1;
140-
recover_len[0] += 1;
141-
ori_step_len = step_len[0];
142-
if (ori_step_len > 0) {
143-
ori_step_block_id = step_block_list[ori_step_len - 1];
144-
tmp_used_len = used_list_len[ori_step_block_id];
145-
used_len = tmp_used_len < max_decoder_block_num ? tmp_used_len + 1 : tmp_used_len;
132+
recover_block_list[recover_len[0]] = ori_step_block_id;
133+
is_block_step[ori_step_block_id] = false;
134+
used_list_len[ori_step_block_id] = used_len;
135+
ori_free_list_len -= used_len;
136+
step_block_list[ori_step_len - 1] = -1;
137+
step_len[0] -= 1;
138+
recover_len[0] += 1;
139+
ori_step_len = step_len[0];
140+
if (ori_step_len > 0) {
141+
ori_step_block_id = step_block_list[ori_step_len - 1];
142+
tmp_used_len = used_list_len[ori_step_block_id];
143+
used_len = tmp_used_len < max_decoder_block_num ? tmp_used_len + 1 : tmp_used_len;
144+
}
146145
}
147146
}
148147
need_block_len[0] = 0;

0 commit comments

Comments
 (0)