Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#[contracts::requires(...)] + #[contracts::ensures(...)] #128045

Merged
merged 10 commits into from
Feb 5, 2025
9 changes: 8 additions & 1 deletion compiler/rustc_ast/src/ast.rs
Original file line number Diff line number Diff line change
@@ -3348,11 +3348,18 @@ pub struct Impl {
pub items: ThinVec<P<AssocItem>>,
}

#[derive(Clone, Encodable, Decodable, Debug, Default)]
pub struct FnContract {
pub requires: Option<P<Expr>>,
pub ensures: Option<P<Expr>>,
}

#[derive(Clone, Encodable, Decodable, Debug)]
pub struct Fn {
pub defaultness: Defaultness,
pub generics: Generics,
pub sig: FnSig,
pub contract: Option<P<FnContract>>,
pub body: Option<P<Block>>,
}

@@ -3650,7 +3657,7 @@ mod size_asserts {
static_assert_size!(Block, 32);
static_assert_size!(Expr, 72);
static_assert_size!(ExprKind, 40);
static_assert_size!(Fn, 160);
static_assert_size!(Fn, 168);
static_assert_size!(ForeignItem, 88);
static_assert_size!(ForeignItemKind, 16);
static_assert_size!(GenericArg, 24);
19 changes: 18 additions & 1 deletion compiler/rustc_ast/src/mut_visit.rs
Original file line number Diff line number Diff line change
@@ -143,6 +143,10 @@ pub trait MutVisitor: Sized {
walk_flat_map_assoc_item(self, i, ctxt)
}

fn visit_contract(&mut self, c: &mut P<FnContract>) {
walk_contract(self, c);
}

fn visit_fn_decl(&mut self, d: &mut P<FnDecl>) {
walk_fn_decl(self, d);
}
@@ -958,13 +962,16 @@ fn walk_fn<T: MutVisitor>(vis: &mut T, kind: FnKind<'_>) {
_ctxt,
_ident,
_vis,
Fn { defaultness, generics, body, sig: FnSig { header, decl, span } },
Fn { defaultness, generics, contract, body, sig: FnSig { header, decl, span } },
) => {
// Identifier and visibility are visited as a part of the item.
visit_defaultness(vis, defaultness);
vis.visit_fn_header(header);
vis.visit_generics(generics);
vis.visit_fn_decl(decl);
if let Some(contract) = contract {
vis.visit_contract(contract);
}
if let Some(body) = body {
vis.visit_block(body);
}
@@ -979,6 +986,16 @@ fn walk_fn<T: MutVisitor>(vis: &mut T, kind: FnKind<'_>) {
}
}

fn walk_contract<T: MutVisitor>(vis: &mut T, contract: &mut P<FnContract>) {
let FnContract { requires, ensures } = contract.deref_mut();
if let Some(pred) = requires {
vis.visit_expr(pred);
}
if let Some(pred) = ensures {
vis.visit_expr(pred);
}
}

fn walk_fn_decl<T: MutVisitor>(vis: &mut T, decl: &mut P<FnDecl>) {
let FnDecl { inputs, output } = decl.deref_mut();
inputs.flat_map_in_place(|param| vis.flat_map_param(param));
17 changes: 16 additions & 1 deletion compiler/rustc_ast/src/visit.rs
Original file line number Diff line number Diff line change
@@ -188,6 +188,9 @@ pub trait Visitor<'ast>: Sized {
fn visit_closure_binder(&mut self, b: &'ast ClosureBinder) -> Self::Result {
walk_closure_binder(self, b)
}
fn visit_contract(&mut self, c: &'ast FnContract) -> Self::Result {
walk_contract(self, c)
}
fn visit_where_predicate(&mut self, p: &'ast WherePredicate) -> Self::Result {
walk_where_predicate(self, p)
}
@@ -800,6 +803,17 @@ pub fn walk_closure_binder<'a, V: Visitor<'a>>(
V::Result::output()
}

pub fn walk_contract<'a, V: Visitor<'a>>(visitor: &mut V, c: &'a FnContract) -> V::Result {
let FnContract { requires, ensures } = c;
if let Some(pred) = requires {
visitor.visit_expr(pred);
}
if let Some(pred) = ensures {
visitor.visit_expr(pred);
}
V::Result::output()
}

pub fn walk_where_predicate<'a, V: Visitor<'a>>(
visitor: &mut V,
predicate: &'a WherePredicate,
@@ -862,12 +876,13 @@ pub fn walk_fn<'a, V: Visitor<'a>>(visitor: &mut V, kind: FnKind<'a>) -> V::Resu
_ctxt,
_ident,
_vis,
Fn { defaultness: _, sig: FnSig { header, decl, span: _ }, generics, body },
Fn { defaultness: _, sig: FnSig { header, decl, span: _ }, generics, contract, body },
) => {
// Identifier and visibility are visited as a part of the item.
try_visit!(visitor.visit_fn_header(header));
try_visit!(visitor.visit_generics(generics));
try_visit!(visitor.visit_fn_decl(decl));
visit_opt!(visitor, visit_contract, contract);
visit_opt!(visitor, visit_block, body);
}
FnKind::Closure(binder, coroutine_kind, decl, body) => {
35 changes: 31 additions & 4 deletions compiler/rustc_ast_lowering/src/expr.rs
Original file line number Diff line number Diff line change
@@ -314,8 +314,8 @@ impl<'hir> LoweringContext<'_, 'hir> {
hir::ExprKind::Continue(self.lower_jump_destination(e.id, *opt_label))
}
ExprKind::Ret(e) => {
let e = e.as_ref().map(|x| self.lower_expr(x));
hir::ExprKind::Ret(e)
let expr = e.as_ref().map(|x| self.lower_expr(x));
self.checked_return(expr)
}
ExprKind::Yeet(sub_expr) => self.lower_expr_yeet(e.span, sub_expr.as_deref()),
ExprKind::Become(sub_expr) => {
@@ -382,6 +382,32 @@ impl<'hir> LoweringContext<'_, 'hir> {
})
}

/// Create an `ExprKind::Ret` that is preceded by a call to check contract ensures clause.
fn checked_return(&mut self, opt_expr: Option<&'hir hir::Expr<'hir>>) -> hir::ExprKind<'hir> {
let checked_ret = if let Some(Some((span, fresh_ident))) =
self.contract.as_ref().map(|c| c.ensures.as_ref().map(|e| (e.expr.span, e.fresh_ident)))
{
let expr = opt_expr.unwrap_or_else(|| self.expr_unit(span));
Some(self.inject_ensures_check(expr, span, fresh_ident.0, fresh_ident.2))
} else {
opt_expr
};
hir::ExprKind::Ret(checked_ret)
}

/// Wraps an expression with a call to the ensures check before it gets returned.
pub(crate) fn inject_ensures_check(
&mut self,
expr: &'hir hir::Expr<'hir>,
span: Span,
check_ident: Ident,
check_hir_id: HirId,
) -> &'hir hir::Expr<'hir> {
let checker_fn = self.expr_ident(span, check_ident, check_hir_id);
let span = self.mark_span_with_reason(DesugaringKind::Contract, span, None);
self.expr_call(span, checker_fn, std::slice::from_ref(expr))
}

pub(crate) fn lower_const_block(&mut self, c: &AnonConst) -> hir::ConstBlock {
self.with_new_scopes(c.value.span, |this| {
let def_id = this.local_def_id(c.id);
@@ -1970,7 +1996,8 @@ impl<'hir> LoweringContext<'_, 'hir> {
),
))
} else {
self.arena.alloc(self.expr(try_span, hir::ExprKind::Ret(Some(from_residual_expr))))
let ret_expr = self.checked_return(Some(from_residual_expr));
self.arena.alloc(self.expr(try_span, ret_expr))
};
self.lower_attrs(ret_expr.hir_id, &attrs);

@@ -2019,7 +2046,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
let target_id = Ok(catch_id);
hir::ExprKind::Break(hir::Destination { label: None, target_id }, Some(from_yeet_expr))
} else {
hir::ExprKind::Ret(Some(from_yeet_expr))
self.checked_return(Some(from_yeet_expr))
}
}

93 changes: 89 additions & 4 deletions compiler/rustc_ast_lowering/src/item.rs
Original file line number Diff line number Diff line change
@@ -207,9 +207,40 @@ impl<'hir> LoweringContext<'_, 'hir> {
sig: FnSig { decl, header, span: fn_sig_span },
generics,
body,
contract,
..
}) => {
self.with_new_scopes(*fn_sig_span, |this| {
assert!(this.contract.is_none());
if let Some(contract) = contract {
let requires = contract.requires.clone();
let ensures = contract.ensures.clone();
let ensures = ensures.map(|ens| {
// FIXME: this needs to be a fresh (or illegal) identifier to prevent
// accidental capture of a parameter or global variable.
let checker_ident: Ident =
Ident::from_str_and_span("__ensures_checker", ens.span);
let (checker_pat, checker_hir_id) = this.pat_ident_binding_mode_mut(
ens.span,
checker_ident,
hir::BindingMode::NONE,
);

crate::FnContractLoweringEnsures {
expr: ens,
fresh_ident: (checker_ident, checker_pat, checker_hir_id),
}
});

// Note: `with_new_scopes` will reinstall the outer
// item's contract (if any) after its callback finishes.
this.contract.replace(crate::FnContractLoweringInfo {
span,
requires,
ensures,
});
}

// Note: we don't need to change the return type from `T` to
// `impl Future<Output = T>` here because lower_body
// only cares about the input argument patterns in the function
@@ -1054,10 +1085,64 @@ impl<'hir> LoweringContext<'_, 'hir> {
body: impl FnOnce(&mut Self) -> hir::Expr<'hir>,
) -> hir::BodyId {
self.lower_body(|this| {
(
this.arena.alloc_from_iter(decl.inputs.iter().map(|x| this.lower_param(x))),
body(this),
)
let params =
this.arena.alloc_from_iter(decl.inputs.iter().map(|x| this.lower_param(x)));
let result = body(this);

let opt_contract = this.contract.take();

// { body }
// ==>
// { contract_requires(PRECOND); { body } }
let Some(contract) = opt_contract else { return (params, result) };
let result_ref = this.arena.alloc(result);
let lit_unit = |this: &mut LoweringContext<'_, 'hir>| {
this.expr(contract.span, hir::ExprKind::Tup(&[]))
};

let precond: hir::Stmt<'hir> = if let Some(req) = contract.requires {
let lowered_req = this.lower_expr_mut(&req);
let precond = this.expr_call_lang_item_fn_mut(
req.span,
hir::LangItem::ContractCheckRequires,
&*arena_vec![this; lowered_req],
);
this.stmt_expr(req.span, precond)
} else {
let u = lit_unit(this);
this.stmt_expr(contract.span, u)
};

let (postcond_checker, result) = if let Some(ens) = contract.ensures {
let crate::FnContractLoweringEnsures { expr: ens, fresh_ident } = ens;
let lowered_ens: hir::Expr<'hir> = this.lower_expr_mut(&ens);
let postcond_checker = this.expr_call_lang_item_fn(
ens.span,
hir::LangItem::ContractBuildCheckEnsures,
&*arena_vec![this; lowered_ens],
);
let checker_binding_pat = fresh_ident.1;
(
this.stmt_let_pat(
None,
ens.span,
Some(postcond_checker),
this.arena.alloc(checker_binding_pat),
hir::LocalSource::Contract,
),
this.inject_ensures_check(result_ref, ens.span, fresh_ident.0, fresh_ident.2),
)
} else {
let u = lit_unit(this);
(this.stmt_expr(contract.span, u), &*result_ref)
};

let block = this.block_all(
contract.span,
arena_vec![this; precond, postcond_checker],
Some(result),
);
(params, this.expr_block(block))
})
}

20 changes: 20 additions & 0 deletions compiler/rustc_ast_lowering/src/lib.rs
Original file line number Diff line number Diff line change
@@ -86,6 +86,19 @@ mod path;

rustc_fluent_macro::fluent_messages! { "../messages.ftl" }

#[derive(Debug, Clone)]
struct FnContractLoweringInfo<'hir> {
pub span: Span,
pub requires: Option<ast::ptr::P<ast::Expr>>,
pub ensures: Option<FnContractLoweringEnsures<'hir>>,
}

#[derive(Debug, Clone)]
struct FnContractLoweringEnsures<'hir> {
expr: ast::ptr::P<ast::Expr>,
fresh_ident: (Ident, hir::Pat<'hir>, HirId),
}

struct LoweringContext<'a, 'hir> {
tcx: TyCtxt<'hir>,
resolver: &'a mut ResolverAstLowering,
@@ -100,6 +113,8 @@ struct LoweringContext<'a, 'hir> {
/// Collect items that were created by lowering the current owner.
children: Vec<(LocalDefId, hir::MaybeOwner<'hir>)>,

contract: Option<FnContractLoweringInfo<'hir>>,

coroutine_kind: Option<hir::CoroutineKind>,

/// When inside an `async` context, this is the `HirId` of the
@@ -148,6 +163,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
bodies: Vec::new(),
attrs: SortedMap::default(),
children: Vec::default(),
contract: None,
current_hir_id_owner: hir::CRATE_OWNER_ID,
item_local_id_counter: hir::ItemLocalId::ZERO,
ident_and_label_to_local_id: Default::default(),
@@ -834,12 +850,16 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
let was_in_loop_condition = self.is_in_loop_condition;
self.is_in_loop_condition = false;

let old_contract = self.contract.take();

let catch_scope = self.catch_scope.take();
let loop_scope = self.loop_scope.take();
let ret = f(self);
self.catch_scope = catch_scope;
self.loop_scope = loop_scope;

self.contract = old_contract;

self.is_in_loop_condition = was_in_loop_condition;

self.current_item = current_item;
2 changes: 1 addition & 1 deletion compiler/rustc_ast_passes/src/ast_validation.rs
Original file line number Diff line number Diff line change
@@ -917,7 +917,7 @@ impl<'a> Visitor<'a> for AstValidator<'a> {
walk_list!(self, visit_attribute, &item.attrs);
return; // Avoid visiting again.
}
ItemKind::Fn(func @ box Fn { defaultness, generics: _, sig, body }) => {
ItemKind::Fn(func @ box Fn { defaultness, generics: _, sig, contract: _, body }) => {
self.check_defaultness(item.span, *defaultness);

let is_intrinsic =
2 changes: 2 additions & 0 deletions compiler/rustc_ast_passes/src/feature_gate.rs
Original file line number Diff line number Diff line change
@@ -548,6 +548,8 @@ pub fn check_crate(krate: &ast::Crate, sess: &Session, features: &Features) {
gate_all!(pin_ergonomics, "pinned reference syntax is experimental");
gate_all!(unsafe_fields, "`unsafe` fields are experimental");
gate_all!(unsafe_binders, "unsafe binder types are experimental");
gate_all!(contracts, "contracts are incomplete");
gate_all!(contracts_internals, "contract internal machinery is for internal use only");

if !visitor.features.never_patterns() {
if let Some(spans) = spans.get(&sym::never_patterns) {
Loading
Loading