@@ -364,12 +364,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
364
364
365
365
class HFLanguageRepresentationNetwork (nn .Module ):
366
366
def __init__ (self ,
367
- model_path : str = 'google-bert/bert-base-uncased' ,
368
- embedding_size : int = 768 ,
369
- group_size : int = 8 ,
370
- norm_type : str = "simnorm" ,
371
- # norm_type: str = "layernorm", # TODO: Why does nan appear in the first step of training?
372
- tokenizer = None ):
367
+ model_path : str = 'google-bert/bert-base-uncased' ,
368
+ embedding_size : int = 768 ,
369
+ group_size : int = 8 ,
370
+ final_norm_option_in_encoder : str = "layernorm" ,
371
+ tokenizer = None ):
373
372
"""
374
373
Overview:
375
374
This class defines a language representation network that utilizes a pretrained Hugging Face model.
@@ -379,7 +378,7 @@ def __init__(self,
379
378
- model_path (str): The path to the pretrained Hugging Face model. Default is 'google-bert/bert-base-uncased'.
380
379
- embedding_size (int): The dimension of the output embeddings. Default is 768.
381
380
- group_size (int): The group size for SimNorm when using normalization.
382
- - norm_type (str): The type of normalization to use ("simnorm" or "layernorm"). Default is "layernorm".
381
+ - final_norm_option_in_encoder (str): The type of normalization to use ("simnorm" or "layernorm"). Default is "layernorm".
383
382
- tokenizer (Optional): An instance of a tokenizer. If None, the tokenizer will be loaded from the pretrained model.
384
383
"""
385
384
super ().__init__ ()
@@ -389,12 +388,13 @@ def __init__(self,
389
388
390
389
# In distributed training, only the rank 0 process downloads the model, and other processes load from cache to speed up startup.
391
390
if get_rank () == 0 :
392
- self .model = AutoModel .from_pretrained (model_path )
391
+ self .pretrained_model = AutoModel .from_pretrained (model_path )
392
+
393
393
if get_world_size () > 1 :
394
394
# Wait for rank 0 to finish loading the model.
395
395
torch .distributed .barrier ()
396
396
if get_rank () != 0 :
397
- self .model = AutoModel .from_pretrained (model_path )
397
+ self .pretrained_model = AutoModel .from_pretrained (model_path )
398
398
399
399
if tokenizer is None :
400
400
# Only rank 0 downloads the tokenizer, and then other processes load it from cache.
@@ -409,15 +409,15 @@ def __init__(self,
409
409
410
410
# Set the embedding dimension. A linear projection is added (the dimension remains unchanged here but can be extended for other mappings).
411
411
self .embedding_size = embedding_size
412
- self .embed_proj_head = nn .Linear (self .model .config .hidden_size , self .embedding_size )
412
+ self .embed_proj_head = nn .Linear (self .pretrained_model .config .hidden_size , self .embedding_size )
413
413
414
- # Select the normalization method based on the norm_type parameter.
415
- if norm_type .lower () == "simnorm" :
414
+ # # Select the normalization method based on the final_norm_option_in_encoder parameter.
415
+ if final_norm_option_in_encoder .lower () == "simnorm" :
416
416
self .norm = SimNorm (simnorm_dim = group_size )
417
- elif norm_type .lower () == "layernorm" :
417
+ elif final_norm_option_in_encoder .lower () == "layernorm" :
418
418
self .norm = nn .LayerNorm (embedding_size )
419
419
else :
420
- raise NotImplementedError (f"Normalization type '{ norm_type } ' is not implemented. "
420
+ raise NotImplementedError (f"Normalization type '{ final_norm_option_in_encoder } ' is not implemented. "
421
421
f"Choose 'simnorm' or 'layernorm'." )
422
422
423
423
def forward (self , x : torch .Tensor , no_grad : bool = True ) -> torch .Tensor :
@@ -433,26 +433,27 @@ def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor:
433
433
Returns:
434
434
- torch.Tensor: The processed language embedding with shape [batch_size, embedding_size].
435
435
"""
436
+
436
437
# Construct the attention mask to exclude padding tokens.
437
438
attention_mask = x != self .tokenizer .pad_token_id
438
439
439
440
# Use no_grad context if specified to disable gradient computation.
440
441
if no_grad :
441
442
with torch .no_grad ():
442
443
x = x .long () # Ensure the input tensor is of type long.
443
- outputs = self .model (x , attention_mask = attention_mask )
444
+ outputs = self .pretrained_model (x , attention_mask = attention_mask )
444
445
# Get the hidden state from the last layer and select the output corresponding to the [CLS] token.
445
446
cls_embedding = outputs .last_hidden_state [:, 0 , :]
446
447
else :
447
448
x = x .long ()
448
- outputs = self .model (x , attention_mask = attention_mask )
449
+ outputs = self .pretrained_model (x , attention_mask = attention_mask )
449
450
cls_embedding = outputs .last_hidden_state [:, 0 , :]
450
451
451
452
# Apply linear projection to obtain the desired output dimension.
452
453
cls_embedding = self .embed_proj_head (cls_embedding )
453
454
# Normalize the embeddings using the selected normalization layer (SimNorm or LayerNorm) to ensure training stability.
454
455
cls_embedding = self .norm (cls_embedding )
455
-
456
+
456
457
return cls_embedding
457
458
458
459
@@ -468,6 +469,7 @@ def __init__(
468
469
norm_type : str = 'BN' ,
469
470
embedding_dim : int = 256 ,
470
471
group_size : int = 8 ,
472
+ final_norm_option_in_encoder : str = 'LayerNorm' , # TODO
471
473
) -> None :
472
474
"""
473
475
Overview:
@@ -486,6 +488,8 @@ def __init__(
486
488
- norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'.
487
489
- embedding_dim (:obj:`int`): The dimension of the latent state.
488
490
- group_size (:obj:`int`): The dimension for simplicial normalization.
491
+ - final_norm_option_in_encoder (:obj:`str`): The normalization option for the final layer, defaults to 'SimNorm'. \
492
+ Options are 'SimNorm' and 'LayerNorm'.
489
493
"""
490
494
super ().__init__ ()
491
495
assert norm_type in ['BN' , 'LN' ], "norm_type must in ['BN', 'LN']"
@@ -530,7 +534,14 @@ def __init__(
530
534
elif self .observation_shape [1 ] in [84 , 96 ]:
531
535
self .last_linear = nn .Linear (64 * 6 * 6 , self .embedding_dim , bias = False )
532
536
533
- self .sim_norm = SimNorm (simnorm_dim = group_size )
537
+ self .final_norm_option_in_encoder = final_norm_option_in_encoder
538
+ if self .final_norm_option_in_encoder == 'LayerNorm' :
539
+ self .final_norm = nn .LayerNorm (self .embedding_dim , eps = 1e-5 )
540
+ elif self .final_norm_option_in_encoder == 'SimNorm' :
541
+ self .final_norm = SimNorm (simnorm_dim = group_size )
542
+ else :
543
+ raise ValueError (f"Unsupported final_norm_option_in_encoder: { self .final_norm_option_in_encoder } " )
544
+
534
545
535
546
def forward (self , x : torch .Tensor ) -> torch .Tensor :
536
547
"""
@@ -557,7 +568,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
557
568
x = x .view (- 1 , self .embedding_dim )
558
569
559
570
# NOTE: very important for training stability.
560
- x = self .sim_norm (x )
571
+ x = self .final_norm (x )
561
572
562
573
return x
563
574
@@ -670,6 +681,7 @@ def __init__(
670
681
activation : nn .Module = nn .GELU (approximate = 'tanh' ),
671
682
norm_type : Optional [str ] = 'BN' ,
672
683
group_size : int = 8 ,
684
+ final_norm_option_in_encoder : str = 'LayerNorm' , # TODO
673
685
) -> torch .Tensor :
674
686
"""
675
687
Overview:
@@ -700,7 +712,15 @@ def __init__(
700
712
# last_linear_layer_init_zero=True is beneficial for convergence speed.
701
713
last_linear_layer_init_zero = True ,
702
714
)
703
- self .sim_norm = SimNorm (simnorm_dim = group_size )
715
+
716
+ # # Select the normalization method based on the final_norm_option_in_encoder parameter.
717
+ if final_norm_option_in_encoder .lower () == "simnorm" :
718
+ self .norm = SimNorm (simnorm_dim = group_size )
719
+ elif final_norm_option_in_encoder .lower () == "layernorm" :
720
+ self .norm = nn .LayerNorm (hidden_channels )
721
+ else :
722
+ raise NotImplementedError (f"Normalization type '{ final_norm_option_in_encoder } ' is not implemented. "
723
+ f"Choose 'simnorm' or 'layernorm'." )
704
724
705
725
def forward (self , x : torch .Tensor ) -> torch .Tensor :
706
726
"""
@@ -709,8 +729,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
709
729
- output (:obj:`torch.Tensor`): :math:`(B, hidden_channels)`, where B is batch size.
710
730
"""
711
731
x = self .fc_representation (x )
712
- # TODO
713
- x = self . sim_norm ( x )
732
+ x = self . norm ( x )
733
+
714
734
return x
715
735
716
736
0 commit comments