@@ -135,7 +135,7 @@ mod llvm_enzyme {
135
135
}
136
136
let dcx = ecx. sess . dcx ( ) ;
137
137
// first get the annotable item:
138
- let ( sig, is_impl ) : ( FnSig , bool ) = match & item {
138
+ let sig: FnSig = match & item {
139
139
Annotatable :: Item ( iitem) => {
140
140
let sig = match & iitem. kind {
141
141
ItemKind :: Fn ( box ast:: Fn { sig, .. } ) => sig,
@@ -144,7 +144,7 @@ mod llvm_enzyme {
144
144
return vec ! [ item] ;
145
145
}
146
146
} ;
147
- ( sig. clone ( ) , false )
147
+ sig. clone ( )
148
148
}
149
149
Annotatable :: AssocItem ( assoc_item, _) => {
150
150
let sig = match & assoc_item. kind {
@@ -154,7 +154,24 @@ mod llvm_enzyme {
154
154
return vec ! [ item] ;
155
155
}
156
156
} ;
157
- ( sig. clone ( ) , true )
157
+ sig. clone ( )
158
+ }
159
+ Annotatable :: Stmt ( stmt) => {
160
+ let sig = match & stmt. kind {
161
+ ast:: StmtKind :: Item ( iitem) => match & iitem. kind {
162
+ ast:: ItemKind :: Fn ( box ast:: Fn { sig, .. } ) => sig,
163
+ _ => {
164
+ dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
165
+ return vec ! [ item] ;
166
+ }
167
+ } ,
168
+ _ => {
169
+ dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
170
+ return vec ! [ item] ;
171
+ }
172
+ } ;
173
+
174
+ sig. clone ( )
158
175
}
159
176
_ => {
160
177
dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
@@ -178,6 +195,10 @@ mod llvm_enzyme {
178
195
Annotatable :: AssocItem ( assoc_item, _) => {
179
196
( assoc_item. vis . clone ( ) , assoc_item. ident . clone ( ) )
180
197
}
198
+ Annotatable :: Stmt ( stmt) => match & stmt. kind {
199
+ ast:: StmtKind :: Item ( iitem) => ( iitem. vis . clone ( ) , iitem. ident . clone ( ) ) ,
200
+ _ => unreachable ! ( "stmt kind checked previously" ) ,
201
+ } ,
181
202
_ => {
182
203
dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
183
204
return vec ! [ item] ;
@@ -302,6 +323,21 @@ mod llvm_enzyme {
302
323
}
303
324
Annotatable :: AssocItem ( assoc_item. clone ( ) , i)
304
325
}
326
+ Annotatable :: Stmt ( ref mut stmt) => {
327
+ match stmt. kind {
328
+ ast:: StmtKind :: Item ( ref mut item) => {
329
+ if !item. attrs . iter ( ) . any ( |a| a. id == attr. id ) {
330
+ item. attrs . push ( attr. clone ( ) ) ;
331
+ }
332
+ if !item. attrs . iter ( ) . any ( |a| a. id == inline_never. id ) {
333
+ item. attrs . push ( inline_never. clone ( ) ) ;
334
+ }
335
+ }
336
+ _ => unreachable ! ( "stmt kind checked previously" ) ,
337
+ } ;
338
+
339
+ Annotatable :: Stmt ( stmt. clone ( ) )
340
+ }
305
341
_ => {
306
342
unreachable ! ( "annotatable kind checked previously" )
307
343
}
@@ -319,23 +355,49 @@ mod llvm_enzyme {
319
355
span,
320
356
} ;
321
357
322
- let d_annotatable = if is_impl {
323
- let assoc_item: AssocItemKind = ast:: AssocItemKind :: Fn ( asdf) ;
324
- let d_fn = P ( ast:: AssocItem {
325
- attrs : thin_vec ! [ d_attr, inline_never] ,
326
- id : ast:: DUMMY_NODE_ID ,
327
- span,
328
- vis,
329
- ident : d_ident,
330
- kind : assoc_item,
331
- tokens : None ,
332
- } ) ;
333
- Annotatable :: AssocItem ( d_fn, Impl )
334
- } else {
335
- let mut d_fn =
336
- ecx. item ( span, d_ident, thin_vec ! [ d_attr, inline_never] , ItemKind :: Fn ( asdf) ) ;
337
- d_fn. vis = vis;
338
- Annotatable :: Item ( d_fn)
358
+ let d_annotatable = match & item {
359
+ Annotatable :: AssocItem ( _, _) => {
360
+ let assoc_item: AssocItemKind = ast:: AssocItemKind :: Fn ( asdf) ;
361
+ let d_fn = P ( ast:: AssocItem {
362
+ attrs : thin_vec ! [ d_attr. clone( ) , inline_never] ,
363
+ id : ast:: DUMMY_NODE_ID ,
364
+ span,
365
+ vis,
366
+ ident : d_ident,
367
+ kind : assoc_item,
368
+ tokens : None ,
369
+ } ) ;
370
+ Annotatable :: AssocItem ( d_fn, Impl )
371
+ }
372
+ Annotatable :: Item ( _) => {
373
+ let mut d_fn = ecx. item (
374
+ span,
375
+ d_ident,
376
+ thin_vec ! [ d_attr. clone( ) , inline_never] ,
377
+ ItemKind :: Fn ( asdf) ,
378
+ ) ;
379
+ d_fn. vis = vis;
380
+
381
+ Annotatable :: Item ( d_fn)
382
+ }
383
+ Annotatable :: Stmt ( _) => {
384
+ let mut d_fn = ecx. item (
385
+ span,
386
+ d_ident,
387
+ thin_vec ! [ d_attr. clone( ) , inline_never] ,
388
+ ItemKind :: Fn ( asdf) ,
389
+ ) ;
390
+ d_fn. vis = vis;
391
+
392
+ Annotatable :: Stmt ( P ( ast:: Stmt {
393
+ id : ast:: DUMMY_NODE_ID ,
394
+ kind : ast:: StmtKind :: Item ( d_fn) ,
395
+ span,
396
+ } ) )
397
+ }
398
+ _ => {
399
+ unreachable ! ( )
400
+ }
339
401
} ;
340
402
341
403
return vec ! [ orig_annotatable, d_annotatable] ;
0 commit comments