Skip to content

Commit 594a050

Browse files
committed
update by environment
1 parent 370d2c9 commit 594a050

File tree

2 files changed

+7
-8
lines changed

2 files changed

+7
-8
lines changed

llm/run_pretrain.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -223,10 +223,6 @@ class ModelArguments:
223223
default=None,
224224
metadata={"help": "num_hidden_layers."},
225225
)
226-
use_casual_mask: Optional[bool] = field(
227-
default=True,
228-
metadata={"help": "whether to use casual mask"},
229-
)
230226

231227

232228
def create_pretrained_dataset(
@@ -480,7 +476,6 @@ def main():
480476
config.pp_recompute_interval = model_args.pp_recompute_interval
481477
config.recompute_use_reentrant = model_args.recompute_use_reentrant
482478
config.use_recompute = training_args.recompute
483-
config.use_casual_mask = model_args.use_casual_mask
484479

485480
config.tensor_parallel_degree = training_args.tensor_parallel_degree
486481
config.tensor_parallel_rank = training_args.tensor_parallel_rank

paddlenlp/transformers/llama/modeling.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,11 @@ def _get_interleave_power_of_2(n):
115115
)
116116

117117

118+
def get_use_casual_mask():
119+
"""Get the value of the 'USE_CASUAL_MASK' environment variable."""
120+
return os.getenv("USE_CASUAL_MASK", "False")
121+
122+
118123
def build_alibi_tensor(
119124
bool_attention_mask: Tensor, num_heads: int, dtype: paddle.dtype, tensor_parallel_degree=1
120125
) -> Tensor:
@@ -1532,9 +1537,8 @@ def forward(
15321537
if position_ids is None:
15331538
position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length))
15341539

1535-
use_casual_mask = (
1536-
True if hasattr(self.config, "use_casual_mask") and self.config.use_casual_mask is True else False
1537-
)
1540+
use_casual_mask = get_use_casual_mask()
1541+
15381542
if use_casual_mask:
15391543
attention_mask = None
15401544
else:

0 commit comments

Comments
 (0)