Skip to content

Autodiff cleanups #138627

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 21, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 67 additions & 22 deletions compiler/rustc_builtin_macros/src/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,30 +359,27 @@ mod llvm_enzyme {
ty
}

/// We only want this function to type-check, since we will replace the body
/// later on llvm level. Using `loop {}` does not cover all return types anymore,
/// so instead we build something that should pass. We also add a inline_asm
/// line, as one more barrier for rustc to prevent inlining of this function.
/// FIXME(ZuseZ4): We still have cases of incorrect inlining across modules, see
/// <https://github.com/EnzymeAD/rust/issues/173>, so this isn't sufficient.
/// It also triggers an Enzyme crash if we due to a bug ever try to differentiate
/// this function (which should never happen, since it is only a placeholder).
/// Finally, we also add back_box usages of all input arguments, to prevent rustc
/// from optimizing any arguments away.
fn gen_enzyme_body(
// Will generate a body of the type:
// ```
// {
// unsafe {
// asm!("NOP");
// }
// ::core::hint::black_box(primal(args));
// ::core::hint::black_box((args, ret));
// <This part remains to be done by following function>
// }
// ```
fn init_body_helper(
ecx: &ExtCtxt<'_>,
x: &AutoDiffAttrs,
n_active: u32,
sig: &ast::FnSig,
d_sig: &ast::FnSig,
span: Span,
primal: Ident,
new_names: &[String],
span: Span,
sig_span: Span,
new_decl_span: Span,
idents: Vec<Ident>,
idents: &[Ident],
errored: bool,
) -> P<ast::Block> {
) -> (P<ast::Block>, P<ast::Expr>, P<ast::Expr>, P<ast::Expr>) {
let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]);
let noop = ast::InlineAsm {
asm_macro: ast::AsmMacro::Asm,
Expand Down Expand Up @@ -431,6 +428,54 @@ mod llvm_enzyme {
}
body.stmts.push(ecx.stmt_semi(black_box_remaining_args));

(body, primal_call, black_box_primal_call, blackbox_call_expr)
}

/// We only want this function to type-check, since we will replace the body
/// later on llvm level. Using `loop {}` does not cover all return types anymore,
/// so instead we build something that should pass. We also add a inline_asm
/// line, as one more barrier for rustc to prevent inlining of this function.
/// FIXME(ZuseZ4): We still have cases of incorrect inlining across modules, see
/// <https://github.com/EnzymeAD/rust/issues/173>, so this isn't sufficient.
/// It also triggers an Enzyme crash if we due to a bug ever try to differentiate
/// this function (which should never happen, since it is only a placeholder).
/// Finally, we also add back_box usages of all input arguments, to prevent rustc
/// from optimizing any arguments away.
fn gen_enzyme_body(
ecx: &ExtCtxt<'_>,
x: &AutoDiffAttrs,
n_active: u32,
sig: &ast::FnSig,
d_sig: &ast::FnSig,
primal: Ident,
new_names: &[String],
span: Span,
sig_span: Span,
_new_decl_span: Span,
idents: Vec<Ident>,
errored: bool,
) -> P<ast::Block> {
let new_decl_span = d_sig.span;

// Just adding some default inline-asm and black_box usages to prevent early inlining
// and optimizations which alter the function signature.
//
// The bb_primal_call is the black_box call of the primal function. We keep it around,
// since it has the convenient property of returning the type of the primal function,
// Remember, we only care to match types here.
// No matter which return we pick, we always wrap it into a std::hint::black_box call,
// to prevent rustc from propagating it into the caller.
let (mut body, primal_call, bb_primal_call, bb_call_expr) = init_body_helper(
ecx,
span,
primal,
new_names,
sig_span,
new_decl_span,
&idents,
errored,
);

if !has_ret(&d_sig.decl.output) {
// there is no return type that we have to match, () works fine.
return body;
Expand All @@ -442,7 +487,7 @@ mod llvm_enzyme {

if primal_ret && n_active == 0 && x.mode.is_rev() {
// We only have the primal ret.
body.stmts.push(ecx.stmt_expr(black_box_primal_call));
body.stmts.push(ecx.stmt_expr(bb_primal_call));
return body;
}

Expand Down Expand Up @@ -534,11 +579,11 @@ mod llvm_enzyme {
return body;
}
[arg] => {
ret = ecx.expr_call(new_decl_span, blackbox_call_expr, thin_vec![arg.clone()]);
ret = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![arg.clone()]);
}
args => {
let ret_tuple: P<ast::Expr> = ecx.expr_tuple(span, args.into());
ret = ecx.expr_call(new_decl_span, blackbox_call_expr, thin_vec![ret_tuple]);
ret = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![ret_tuple]);
}
}
assert!(has_ret(&d_sig.decl.output));
Expand All @@ -551,7 +596,7 @@ mod llvm_enzyme {
ecx: &ExtCtxt<'_>,
span: Span,
primal: Ident,
idents: Vec<Ident>,
idents: &[Ident],
) -> P<ast::Expr> {
let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower;
if has_self {
Expand Down