Skip to content

Commit 0d05544

Browse files
authored
add scaling (#8256)
* add scaling * add scaling * add scaling * format
1 parent 662feb1 commit 0d05544

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

paddlenlp/peft/lora/lora_config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import json
16+
import math
1617
import os
1718
from dataclasses import asdict, dataclass, field
1819
from typing import List, Optional, Union
@@ -94,6 +95,15 @@ def __post_init__(self):
9495
)
9596
self.use_quick_lora = False
9697

98+
@property
99+
def scaling(self):
100+
if not self.rslora and not self.pissa:
101+
return self.lora_alpha / self.r
102+
elif self.pissa:
103+
return 1.0
104+
else:
105+
return self.lora_alpha / math.sqrt(self.r)
106+
97107
@property
98108
def __dict__(self):
99109
return asdict(self)
@@ -114,6 +124,7 @@ def save_pretrained(self, save_directory):
114124
os.makedirs(save_directory, exist_ok=True)
115125

116126
output_dict = self.__dict__
127+
output_dict["scaling"] = self.scaling
117128
output_path = os.path.join(save_directory, LORA_CONFIG_NAME)
118129

119130
# save it
@@ -136,6 +147,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
136147
raise ValueError(f"Can't find lora_config.json at '{pretrained_model_name_or_path}'")
137148

138149
loaded_attributes = cls.from_json_file(config_file)
150+
loaded_attributes.pop("scaling", None)
139151

140152
config = cls(**kwargs)
141153

0 commit comments

Comments
 (0)