@@ -173,7 +173,7 @@ def assign_kv_heads(num_kv_heads: int, num_gpus: int):
173
173
return assignment_list
174
174
175
175
176
- def parallel_matmul (x : Tensor , y : Tensor , tensor_parallel_output = True ):
176
+ def parallel_matmul (x : Tensor , y : Tensor , transpose_y = False , tensor_parallel_output = True ):
177
177
is_fleet_init = True
178
178
tensor_parallel_degree = 1
179
179
try :
@@ -191,15 +191,15 @@ def parallel_matmul(x: Tensor, y: Tensor, tensor_parallel_output=True):
191
191
if is_fleet_init and tensor_parallel_degree > 1 and y_is_distributed :
192
192
# if not running under distributed.launch, it will raise AttributeError: 'Fleet' object has no attribute '_hcg'
193
193
input_parallel = paddle .distributed .collective ._c_identity (x , group = model_parallel_group )
194
- logits = paddle .matmul (input_parallel , y , transpose_y = False )
194
+ logits = paddle .matmul (input_parallel , y , transpose_y = transpose_y )
195
195
196
196
if tensor_parallel_output :
197
197
return logits
198
198
199
199
return paddle .distributed .collective ._c_concat (logits , group = model_parallel_group )
200
200
201
201
else :
202
- logits = paddle .matmul (x , y , transpose_y = False )
202
+ logits = paddle .matmul (x , y , transpose_y = transpose_y )
203
203
return logits
204
204
205
205
@@ -1267,7 +1267,8 @@ def _get_name_mappings(cls, config: LlamaConfig) -> list[StateDictNameMapping]:
1267
1267
for mapping in model_mappings :
1268
1268
mapping [0 ] = "model." + mapping [0 ]
1269
1269
mapping [1 ] = "llama." + mapping [1 ]
1270
- model_mappings .append (["lm_head.weight" , "lm_head.weight" , "transpose" ])
1270
+ if not config .tie_word_embeddings :
1271
+ model_mappings .append (["lm_head.weight" , "lm_head.weight" , "transpose" ])
1271
1272
1272
1273
mappings = [StateDictNameMapping (* mapping , index = index ) for index , mapping in enumerate (model_mappings )]
1273
1274
return mappings
@@ -1288,13 +1289,17 @@ def get_tensor_parallel_split_mappings(num_layers):
1288
1289
final_actions = {}
1289
1290
1290
1291
base_actions = {
1291
- "lm_head.weight" : partial (fn , is_column = True ),
1292
1292
# Row Linear
1293
1293
"embed_tokens.weight" : partial (fn , is_column = False ),
1294
1294
"layers.0.self_attn.o_proj.weight" : partial (fn , is_column = False ),
1295
1295
"layers.0.mlp.down_proj.weight" : partial (fn , is_column = False ),
1296
1296
}
1297
1297
1298
+ if config .tie_word_embeddings :
1299
+ base_actions ["lm_head.weight" ] = partial (fn , is_column = False )
1300
+ else :
1301
+ base_actions ["lm_head.weight" ] = partial (fn , is_column = True )
1302
+
1298
1303
if not config .vocab_size % config .tensor_parallel_degree == 0 :
1299
1304
base_actions .pop ("lm_head.weight" )
1300
1305
base_actions .pop ("embed_tokens.weight" )
@@ -1842,29 +1847,40 @@ def backward(ctx, grad):
1842
1847
1843
1848
1844
1849
class LlamaLMHead (nn .Layer ):
1845
- def __init__ (self , config : LlamaConfig ):
1850
+ def __init__ (self , config : LlamaConfig , embedding_weights = None , transpose_y = False ):
1846
1851
super (LlamaLMHead , self ).__init__ ()
1847
1852
self .config = config
1848
1853
if config .tensor_parallel_degree > 1 and config .vocab_size % config .tensor_parallel_degree == 0 :
1849
1854
vocab_size = config .vocab_size // config .tensor_parallel_degree
1850
1855
else :
1851
1856
vocab_size = config .vocab_size
1852
1857
1853
- if vocab_size != config .vocab_size :
1854
- with get_rng_state_tracker ().rng_state ():
1858
+ self .transpose_y = transpose_y
1859
+ if transpose_y :
1860
+ if embedding_weights is not None :
1861
+ self .weight = embedding_weights
1862
+ else :
1855
1863
self .weight = self .create_parameter (
1856
- shape = [config .hidden_size , vocab_size ],
1864
+ shape = [vocab_size , config .hidden_size ],
1857
1865
dtype = paddle .get_default_dtype (),
1858
1866
)
1859
1867
else :
1860
- self .weight = self .create_parameter (
1861
- shape = [config .hidden_size , vocab_size ],
1862
- dtype = paddle .get_default_dtype (),
1863
- )
1868
+ if vocab_size != config .vocab_size :
1869
+ with get_rng_state_tracker ().rng_state ():
1870
+ self .weight = self .create_parameter (
1871
+ shape = [config .hidden_size , vocab_size ],
1872
+ dtype = paddle .get_default_dtype (),
1873
+ )
1874
+ else :
1875
+ self .weight = self .create_parameter (
1876
+ shape = [config .hidden_size , vocab_size ],
1877
+ dtype = paddle .get_default_dtype (),
1878
+ )
1864
1879
# Must set distributed attr for Tensor Parallel !
1865
1880
self .weight .is_distributed = True if (vocab_size != config .vocab_size ) else False
1866
1881
if self .weight .is_distributed :
1867
- self .weight .split_axis = 1
1882
+ # for tie_word_embeddings
1883
+ self .weight .split_axis = 0 if self .transpose_y else 1
1868
1884
if get_env_device () == "xpu" :
1869
1885
try :
1870
1886
from paddle_xpu .layers .nn import ( # noqa: F401
@@ -1892,22 +1908,33 @@ def forward(self, hidden_states, tensor_parallel_output=None):
1892
1908
1893
1909
if get_env_device () == "xpu" and self .xpu_parallel_matmul is not None :
1894
1910
logits = self .xpu_parallel_matmul (
1895
- hidden_states , self .weight , tensor_parallel_output = tensor_parallel_output , training = self .training
1911
+ hidden_states ,
1912
+ self .weight ,
1913
+ transpose_y = self .transpose_y ,
1914
+ tensor_parallel_output = tensor_parallel_output ,
1915
+ training = self .training ,
1896
1916
)
1897
1917
else :
1898
- logits = parallel_matmul (hidden_states , self .weight , tensor_parallel_output = tensor_parallel_output )
1918
+ logits = parallel_matmul (
1919
+ hidden_states , self .weight , transpose_y = self .transpose_y , tensor_parallel_output = tensor_parallel_output
1920
+ )
1899
1921
return logits
1900
1922
1901
1923
1902
1924
class LlamaForCausalLM (LlamaPretrainedModel ):
1903
1925
enable_to_static_method = True
1926
+ _tied_weights_keys = ["lm_head.weight" ]
1904
1927
1905
1928
def __init__ (self , config ):
1906
1929
super ().__init__ (config )
1907
1930
self .config = config
1908
1931
1909
1932
self .llama = LlamaModel (config )
1910
- self .lm_head = LlamaLMHead (config )
1933
+ if config .tie_word_embeddings :
1934
+ self .lm_head = LlamaLMHead (config , embedding_weights = self .llama .embed_tokens .weight , transpose_y = True )
1935
+ self .tie_weights ()
1936
+ else :
1937
+ self .lm_head = LlamaLMHead (config )
1911
1938
self .criterion = LlamaPretrainingCriterion (config )
1912
1939
1913
1940
def get_input_embeddings (self ):
0 commit comments