Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 5024ae1

Browse files
committedMar 10, 2025
fix usage of autodiff macro with inner functions
1 parent 3ea711f commit 5024ae1

File tree

1 file changed

+82
-20
lines changed

1 file changed

+82
-20
lines changed
 

‎compiler/rustc_builtin_macros/src/autodiff.rs

+82-20
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ mod llvm_enzyme {
135135
}
136136
let dcx = ecx.sess.dcx();
137137
// first get the annotable item:
138-
let (sig, is_impl): (FnSig, bool) = match &item {
138+
let sig: FnSig = match &item {
139139
Annotatable::Item(iitem) => {
140140
let sig = match &iitem.kind {
141141
ItemKind::Fn(box ast::Fn { sig, .. }) => sig,
@@ -144,7 +144,7 @@ mod llvm_enzyme {
144144
return vec![item];
145145
}
146146
};
147-
(sig.clone(), false)
147+
sig.clone()
148148
}
149149
Annotatable::AssocItem(assoc_item, _) => {
150150
let sig = match &assoc_item.kind {
@@ -154,7 +154,24 @@ mod llvm_enzyme {
154154
return vec![item];
155155
}
156156
};
157-
(sig.clone(), true)
157+
sig.clone()
158+
}
159+
Annotatable::Stmt(stmt) => {
160+
let sig = match &stmt.kind {
161+
ast::StmtKind::Item(iitem) => match &iitem.kind {
162+
ast::ItemKind::Fn(box ast::Fn { sig, .. }) => sig,
163+
_ => {
164+
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
165+
return vec![item];
166+
}
167+
},
168+
_ => {
169+
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
170+
return vec![item];
171+
}
172+
};
173+
174+
sig.clone()
158175
}
159176
_ => {
160177
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
@@ -178,6 +195,10 @@ mod llvm_enzyme {
178195
Annotatable::AssocItem(assoc_item, _) => {
179196
(assoc_item.vis.clone(), assoc_item.ident.clone())
180197
}
198+
Annotatable::Stmt(stmt) => match &stmt.kind {
199+
ast::StmtKind::Item(iitem) => (iitem.vis.clone(), iitem.ident.clone()),
200+
_ => unreachable!("stmt kind checked previously"),
201+
},
181202
_ => {
182203
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
183204
return vec![item];
@@ -302,6 +323,21 @@ mod llvm_enzyme {
302323
}
303324
Annotatable::AssocItem(assoc_item.clone(), i)
304325
}
326+
Annotatable::Stmt(ref mut stmt) => {
327+
match stmt.kind {
328+
ast::StmtKind::Item(ref mut item) => {
329+
if !item.attrs.iter().any(|a| a.id == attr.id) {
330+
item.attrs.push(attr.clone());
331+
}
332+
if !item.attrs.iter().any(|a| a.id == inline_never.id) {
333+
item.attrs.push(inline_never.clone());
334+
}
335+
}
336+
_ => unreachable!("stmt kind checked previously"),
337+
};
338+
339+
Annotatable::Stmt(stmt.clone())
340+
}
305341
_ => {
306342
unreachable!("annotatable kind checked previously")
307343
}
@@ -319,23 +355,49 @@ mod llvm_enzyme {
319355
span,
320356
};
321357

322-
let d_annotatable = if is_impl {
323-
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
324-
let d_fn = P(ast::AssocItem {
325-
attrs: thin_vec![d_attr, inline_never],
326-
id: ast::DUMMY_NODE_ID,
327-
span,
328-
vis,
329-
ident: d_ident,
330-
kind: assoc_item,
331-
tokens: None,
332-
});
333-
Annotatable::AssocItem(d_fn, Impl)
334-
} else {
335-
let mut d_fn =
336-
ecx.item(span, d_ident, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
337-
d_fn.vis = vis;
338-
Annotatable::Item(d_fn)
358+
let d_annotatable = match &item {
359+
Annotatable::AssocItem(_, _) => {
360+
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
361+
let d_fn = P(ast::AssocItem {
362+
attrs: thin_vec![d_attr.clone(), inline_never],
363+
id: ast::DUMMY_NODE_ID,
364+
span,
365+
vis,
366+
ident: d_ident,
367+
kind: assoc_item,
368+
tokens: None,
369+
});
370+
Annotatable::AssocItem(d_fn, Impl)
371+
}
372+
Annotatable::Item(_) => {
373+
let mut d_fn = ecx.item(
374+
span,
375+
d_ident,
376+
thin_vec![d_attr.clone(), inline_never],
377+
ItemKind::Fn(asdf),
378+
);
379+
d_fn.vis = vis;
380+
381+
Annotatable::Item(d_fn)
382+
}
383+
Annotatable::Stmt(_) => {
384+
let mut d_fn = ecx.item(
385+
span,
386+
d_ident,
387+
thin_vec![d_attr.clone(), inline_never],
388+
ItemKind::Fn(asdf),
389+
);
390+
d_fn.vis = vis;
391+
392+
Annotatable::Stmt(P(ast::Stmt {
393+
id: ast::DUMMY_NODE_ID,
394+
kind: ast::StmtKind::Item(d_fn),
395+
span,
396+
}))
397+
}
398+
_ => {
399+
unreachable!()
400+
}
339401
};
340402

341403
return vec![orig_annotatable, d_annotatable];

0 commit comments

Comments
 (0)
Failed to load comments.