-
Notifications
You must be signed in to change notification settings - Fork 3k
add inputs_embeds to Bart/MBart/Unified_Transformer/Unimo/CodeGen #3769
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ovided to save memory during training
…rovided to save memory during training
…nto inputs_embed
@@ -515,11 +540,13 @@ def set_input_embeddings(self, value): | |||
|
|||
def forward( | |||
self, | |||
input_ids, | |||
input_ids=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
也修改下文档,加一个 optional
,后面的地方也一样。
if attention_mask is None: | ||
assert input_ids is not None, "input_ids should be " "specified when generating attention_mask" | ||
if attention_mask is None and input_ids is not None: | ||
# assert input_ids is not None, "input_ids should be " \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里能否也和 codegen 的修改一样,在 input_ids 为 None 的时候贴一个 warning,然后删除这个注释
elif input_ids is not None: | ||
inputs_sample = input_ids | ||
elif input_embeddings is not None: | ||
inputs_sample = input_embeddings[:, :, -1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的 inputs_sample
在其表示意义上有点奇怪,后面的 paddle.expand_as()
能否替换成 paddle.expand()
,然后这里改成获取 input_shape
,也避免后续反复调用 paddle.shape()
).astype("int64") | ||
else: | ||
logger.warning( | ||
"position_ids or pad_token_ids should be provided when input_embeds is specified, otherwise an unexpected result may be returned" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
otherwise 这句,说明下是怎样的 position id 吧,an unexpected result
有些笼统了
paddle.cast(input_ids == self.pad_token_id, dtype=paddle.get_default_dtype()).unsqueeze([1, 2]) | ||
* -1e4 | ||
) | ||
logger.warning("provided inputs_embeds without attention_mask") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
首字母大写,其次说明下将会使用默认值 None 作为 attention mask,表示不进行 mask 操作。后面的也一样。
elif input_ids is not None: | ||
inputs_sample = input_ids | ||
elif input_embeddings is not None: | ||
inputs_sample = input_embeddings[:, :, -1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里和上面一样,需要修改下。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- input_ids 的文档加上
optional
assert input_ids is not None
注释删除,并加上loggerpaddle.expand_as
部分替换成paddle.expand
- 补充了
logger.warning
的内容,开头大写。
Codecov Report
@@ Coverage Diff @@
## develop #3769 +/- ##
===========================================
+ Coverage 32.95% 33.06% +0.10%
===========================================
Files 400 400
Lines 56031 56131 +100
===========================================
+ Hits 18466 18560 +94
- Misses 37565 37571 +6
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
PR types
New features
PR changes
Models
Description
inputs_embeds
to Bart/MBart/Unified_Transformer/Unimo/CodeGenuse_cache
toFalse
iflabels
is provided for Bart and MBart Model (in training)