@@ -160,18 +160,18 @@ def __init__(self, config):
160
160
has_bias = True ,
161
161
gather_output = False ,
162
162
)
163
- self .o_proj = RowParallelLinear (
163
+ self .c_proj = RowParallelLinear (
164
164
config .hidden_size ,
165
165
self .projection_size ,
166
166
has_bias = False ,
167
167
input_is_parallel = True ,
168
168
)
169
169
else :
170
170
self .c_attn = nn .Linear (config .hidden_size , 3 * self .projection_size , bias_attr = True )
171
- self .o_proj = nn .Linear (
171
+ self .c_proj = nn .Linear (
172
172
config .hidden_size ,
173
173
self .projection_size ,
174
- bias_attr = False ,
174
+ bias_attr = not config . no_bias ,
175
175
)
176
176
177
177
if config .rotary_pct == 1.0 :
@@ -377,7 +377,7 @@ def forward(
377
377
378
378
# if sequence_parallel is true, out shape are [q_len / n, bs, num_head * head_dim]
379
379
# else their shape are [bs, q_len, num_head * head_dim], n is mp parallelism.
380
- attn_output = self .o_proj (attn_output )
380
+ attn_output = self .c_proj (attn_output )
381
381
outputs = (attn_output , present )
382
382
if output_attentions :
383
383
outputs += (attn_weight ,)
@@ -576,8 +576,8 @@ def _get_name_mappings(cls, config: QWenConfig) -> List[StateDictNameMapping]:
576
576
f"h.{ layer_index } .attn.c_attn.bias" ,
577
577
],
578
578
[
579
- f"h.{ layer_index } .attn.o_proj .weight" ,
580
- f"h.{ layer_index } .attn.o_proj .weight" ,
579
+ f"h.{ layer_index } .attn.c_proj .weight" ,
580
+ f"h.{ layer_index } .attn.c_proj .weight" ,
581
581
"transpose" ,
582
582
],
583
583
[
0 commit comments