Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

-Znext-solver caching #128828

Merged
merged 8 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions compiler/rustc_middle/src/arena.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,6 @@ macro_rules! arena_types {
[] dtorck_constraint: rustc_middle::traits::query::DropckConstraint<'tcx>,
[] candidate_step: rustc_middle::traits::query::CandidateStep<'tcx>,
[] autoderef_bad_ty: rustc_middle::traits::query::MethodAutoderefBadTy<'tcx>,
[] canonical_goal_evaluation:
rustc_type_ir::solve::inspect::CanonicalGoalEvaluationStep<
rustc_middle::ty::TyCtxt<'tcx>
>,
[] query_region_constraints: rustc_middle::infer::canonical::QueryRegionConstraints<'tcx>,
[] type_op_subtype:
rustc_middle::infer::canonical::Canonical<'tcx,
Expand Down
9 changes: 0 additions & 9 deletions compiler/rustc_middle/src/ty/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,6 @@ impl<'tcx> Interner for TyCtxt<'tcx> {
self.mk_predefined_opaques_in_body(data)
}
type DefiningOpaqueTypes = &'tcx ty::List<LocalDefId>;
type CanonicalGoalEvaluationStepRef =
&'tcx solve::inspect::CanonicalGoalEvaluationStep<TyCtxt<'tcx>>;
type CanonicalVars = CanonicalVarInfos<'tcx>;
fn mk_canonical_var_infos(self, infos: &[ty::CanonicalVarInfo<Self>]) -> Self::CanonicalVars {
self.mk_canonical_var_infos(infos)
Expand Down Expand Up @@ -277,13 +275,6 @@ impl<'tcx> Interner for TyCtxt<'tcx> {
self.debug_assert_args_compatible(def_id, args);
}

fn intern_canonical_goal_evaluation_step(
self,
step: solve::inspect::CanonicalGoalEvaluationStep<TyCtxt<'tcx>>,
) -> &'tcx solve::inspect::CanonicalGoalEvaluationStep<TyCtxt<'tcx>> {
self.arena.alloc(step)
}

fn mk_type_list_from_iter<I, T>(self, args: I) -> T::Output
where
I: Iterator<Item = T>,
Expand Down
106 changes: 14 additions & 92 deletions compiler/rustc_next_trait_solver/src/solve/inspect/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
//! see the comment on [ProofTreeBuilder].

use std::marker::PhantomData;
use std::mem;

use derive_where::derive_where;
use rustc_type_ir::inherent::*;
use rustc_type_ir::{self as ty, search_graph, Interner};
use rustc_type_ir::{self as ty, Interner};

use crate::delegate::SolverDelegate;
use crate::solve::eval_ctxt::canonical;
Expand Down Expand Up @@ -94,31 +93,10 @@ impl<I: Interner> WipGoalEvaluation<I> {
}
}

#[derive_where(PartialEq, Eq; I: Interner)]
pub(in crate::solve) enum WipCanonicalGoalEvaluationKind<I: Interner> {
Overflow,
CycleInStack,
ProvisionalCacheHit,
Interned { final_revision: I::CanonicalGoalEvaluationStepRef },
}

impl<I: Interner> std::fmt::Debug for WipCanonicalGoalEvaluationKind<I> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Overflow => write!(f, "Overflow"),
Self::CycleInStack => write!(f, "CycleInStack"),
Self::ProvisionalCacheHit => write!(f, "ProvisionalCacheHit"),
Self::Interned { final_revision: _ } => {
f.debug_struct("Interned").finish_non_exhaustive()
}
}
}
}

#[derive_where(PartialEq, Eq, Debug; I: Interner)]
struct WipCanonicalGoalEvaluation<I: Interner> {
goal: CanonicalInput<I>,
kind: Option<WipCanonicalGoalEvaluationKind<I>>,
encountered_overflow: bool,
/// Only used for uncached goals. After we finished evaluating
/// the goal, this is interned and moved into `kind`.
final_revision: Option<WipCanonicalGoalEvaluationStep<I>>,
Expand All @@ -127,25 +105,17 @@ struct WipCanonicalGoalEvaluation<I: Interner> {

impl<I: Interner> WipCanonicalGoalEvaluation<I> {
fn finalize(self) -> inspect::CanonicalGoalEvaluation<I> {
// We've already interned the final revision in
// `fn finalize_canonical_goal_evaluation`.
assert!(self.final_revision.is_none());
let kind = match self.kind.unwrap() {
WipCanonicalGoalEvaluationKind::Overflow => {
inspect::CanonicalGoalEvaluation {
goal: self.goal,
kind: if self.encountered_overflow {
assert!(self.final_revision.is_none());
inspect::CanonicalGoalEvaluationKind::Overflow
}
WipCanonicalGoalEvaluationKind::CycleInStack => {
inspect::CanonicalGoalEvaluationKind::CycleInStack
}
WipCanonicalGoalEvaluationKind::ProvisionalCacheHit => {
inspect::CanonicalGoalEvaluationKind::ProvisionalCacheHit
}
WipCanonicalGoalEvaluationKind::Interned { final_revision } => {
} else {
let final_revision = self.final_revision.unwrap().finalize();
inspect::CanonicalGoalEvaluationKind::Evaluation { final_revision }
}
};

inspect::CanonicalGoalEvaluation { goal: self.goal, kind, result: self.result.unwrap() }
},
result: self.result.unwrap(),
}
}
}

Expand Down Expand Up @@ -308,7 +278,7 @@ impl<D: SolverDelegate<Interner = I>, I: Interner> ProofTreeBuilder<D> {
) -> ProofTreeBuilder<D> {
self.nested(|| WipCanonicalGoalEvaluation {
goal,
kind: None,
encountered_overflow: false,
final_revision: None,
result: None,
})
Expand All @@ -329,11 +299,11 @@ impl<D: SolverDelegate<Interner = I>, I: Interner> ProofTreeBuilder<D> {
}
}

pub fn canonical_goal_evaluation_kind(&mut self, kind: WipCanonicalGoalEvaluationKind<I>) {
pub fn canonical_goal_evaluation_overflow(&mut self) {
if let Some(this) = self.as_mut() {
match this {
DebugSolver::CanonicalGoalEvaluation(canonical_goal_evaluation) => {
assert_eq!(canonical_goal_evaluation.kind.replace(kind), None);
canonical_goal_evaluation.encountered_overflow = true;
}
_ => unreachable!(),
};
Expand Down Expand Up @@ -547,51 +517,3 @@ impl<D: SolverDelegate<Interner = I>, I: Interner> ProofTreeBuilder<D> {
}
}
}

impl<D, I> search_graph::ProofTreeBuilder<I> for ProofTreeBuilder<D>
where
D: SolverDelegate<Interner = I>,
I: Interner,
{
fn try_apply_proof_tree(
&mut self,
proof_tree: Option<I::CanonicalGoalEvaluationStepRef>,
) -> bool {
if !self.is_noop() {
if let Some(final_revision) = proof_tree {
let kind = WipCanonicalGoalEvaluationKind::Interned { final_revision };
self.canonical_goal_evaluation_kind(kind);
true
} else {
false
}
} else {
true
}
}

fn on_provisional_cache_hit(&mut self) {
self.canonical_goal_evaluation_kind(WipCanonicalGoalEvaluationKind::ProvisionalCacheHit);
}

fn on_cycle_in_stack(&mut self) {
self.canonical_goal_evaluation_kind(WipCanonicalGoalEvaluationKind::CycleInStack);
}

fn finalize_canonical_goal_evaluation(
&mut self,
tcx: I,
) -> Option<I::CanonicalGoalEvaluationStepRef> {
self.as_mut().map(|this| match this {
DebugSolver::CanonicalGoalEvaluation(evaluation) => {
let final_revision = mem::take(&mut evaluation.final_revision).unwrap();
let final_revision =
tcx.intern_canonical_goal_evaluation_step(final_revision.finalize());
let kind = WipCanonicalGoalEvaluationKind::Interned { final_revision };
assert_eq!(evaluation.kind.replace(kind), None);
final_revision
}
_ => unreachable!(),
})
}
}
66 changes: 44 additions & 22 deletions compiler/rustc_next_trait_solver/src/solve/search_graph.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use std::convert::Infallible;
use std::marker::PhantomData;

use rustc_type_ir::inherent::*;
use rustc_type_ir::search_graph::{self, CycleKind, UsageKind};
use rustc_type_ir::search_graph::{self, PathKind};
use rustc_type_ir::solve::{CanonicalInput, Certainty, QueryResult};
use rustc_type_ir::Interner;

use super::inspect::{self, ProofTreeBuilder};
use super::FIXPOINT_STEP_LIMIT;
use super::inspect::ProofTreeBuilder;
use super::{has_no_inference_or_external_constraints, FIXPOINT_STEP_LIMIT};
use crate::delegate::SolverDelegate;

/// This type is never constructed. We only use it to implement `search_graph::Delegate`
Expand All @@ -22,43 +23,48 @@ where
{
type Cx = D::Interner;

const ENABLE_PROVISIONAL_CACHE: bool = true;
type ValidationScope = Infallible;
fn enter_validation_scope(
_cx: Self::Cx,
_input: CanonicalInput<I>,
) -> Option<Self::ValidationScope> {
None
}

const FIXPOINT_STEP_LIMIT: usize = FIXPOINT_STEP_LIMIT;

type ProofTreeBuilder = ProofTreeBuilder<D>;
fn inspect_is_noop(inspect: &mut Self::ProofTreeBuilder) -> bool {
inspect.is_noop()
}

const DIVIDE_AVAILABLE_DEPTH_ON_OVERFLOW: usize = 4;
fn recursion_limit(cx: I) -> usize {
cx.recursion_limit()
}

fn initial_provisional_result(
cx: I,
kind: CycleKind,
kind: PathKind,
input: CanonicalInput<I>,
) -> QueryResult<I> {
match kind {
CycleKind::Coinductive => response_no_constraints(cx, input, Certainty::Yes),
CycleKind::Inductive => response_no_constraints(cx, input, Certainty::overflow(false)),
PathKind::Coinductive => response_no_constraints(cx, input, Certainty::Yes),
PathKind::Inductive => response_no_constraints(cx, input, Certainty::overflow(false)),
}
}

fn reached_fixpoint(
cx: I,
kind: UsageKind,
fn is_initial_provisional_result(
cx: Self::Cx,
kind: PathKind,
input: CanonicalInput<I>,
provisional_result: Option<QueryResult<I>>,
result: QueryResult<I>,
) -> bool {
if let Some(r) = provisional_result {
r == result
} else {
match kind {
UsageKind::Single(CycleKind::Coinductive) => {
response_no_constraints(cx, input, Certainty::Yes) == result
}
UsageKind::Single(CycleKind::Inductive) => {
response_no_constraints(cx, input, Certainty::overflow(false)) == result
}
UsageKind::Mixed => false,
match kind {
PathKind::Coinductive => response_no_constraints(cx, input, Certainty::Yes) == result,
PathKind::Inductive => {
response_no_constraints(cx, input, Certainty::overflow(false)) == result
}
}
}
Expand All @@ -68,14 +74,30 @@ where
inspect: &mut ProofTreeBuilder<D>,
input: CanonicalInput<I>,
) -> QueryResult<I> {
inspect.canonical_goal_evaluation_kind(inspect::WipCanonicalGoalEvaluationKind::Overflow);
inspect.canonical_goal_evaluation_overflow();
response_no_constraints(cx, input, Certainty::overflow(true))
}

fn on_fixpoint_overflow(cx: I, input: CanonicalInput<I>) -> QueryResult<I> {
response_no_constraints(cx, input, Certainty::overflow(false))
}

fn is_ambiguous_result(result: QueryResult<I>) -> bool {
result.is_ok_and(|response| {
has_no_inference_or_external_constraints(response)
&& matches!(response.value.certainty, Certainty::Maybe(_))
})
}

fn propagate_ambiguity(
cx: I,
for_input: CanonicalInput<I>,
from_result: QueryResult<I>,
) -> QueryResult<I> {
let certainty = from_result.unwrap().value.certainty;
response_no_constraints(cx, for_input, certainty)
}

fn step_is_coinductive(cx: I, input: CanonicalInput<I>) -> bool {
input.value.goal.predicate.is_coinductive(cx)
}
Expand Down
10 changes: 3 additions & 7 deletions compiler/rustc_trait_selection/src/solve/inspect/analyse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,13 +332,9 @@ impl<'a, 'tcx> InspectGoal<'a, 'tcx> {

pub fn candidates(&'a self) -> Vec<InspectCandidate<'a, 'tcx>> {
let mut candidates = vec![];
let last_eval_step = match self.evaluation_kind {
inspect::CanonicalGoalEvaluationKind::Overflow
| inspect::CanonicalGoalEvaluationKind::CycleInStack
| inspect::CanonicalGoalEvaluationKind::ProvisionalCacheHit => {
warn!("unexpected root evaluation: {:?}", self.evaluation_kind);
return vec![];
}
let last_eval_step = match &self.evaluation_kind {
// An annoying edge case in case the recursion limit is 0.
inspect::CanonicalGoalEvaluationKind::Overflow => return vec![],
inspect::CanonicalGoalEvaluationKind::Evaluation { final_revision } => final_revision,
};

Expand Down
24 changes: 8 additions & 16 deletions compiler/rustc_type_ir/src/binder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use derive_where::derive_where;
use rustc_macros::{HashStable_NoContext, TyDecodable, TyEncodable};
#[cfg(feature = "nightly")]
use rustc_serialize::Decodable;
use tracing::debug;
use tracing::instrument;

use crate::data_structures::SsoHashSet;
use crate::fold::{FallibleTypeFolder, TypeFoldable, TypeFolder, TypeSuperFoldable};
Expand Down Expand Up @@ -831,28 +831,20 @@ impl<'a, I: Interner> ArgFolder<'a, I> {
/// As indicated in the diagram, here the same type `&'a i32` is instantiated once, but in the
/// first case we do not increase the De Bruijn index and in the second case we do. The reason
/// is that only in the second case have we passed through a fn binder.
#[instrument(level = "trace", skip(self), fields(binders_passed = self.binders_passed), ret)]
fn shift_vars_through_binders<T: TypeFoldable<I>>(&self, val: T) -> T {
debug!(
"shift_vars(val={:?}, binders_passed={:?}, has_escaping_bound_vars={:?})",
val,
self.binders_passed,
val.has_escaping_bound_vars()
);

if self.binders_passed == 0 || !val.has_escaping_bound_vars() {
return val;
val
} else {
ty::fold::shift_vars(self.cx, val, self.binders_passed)
}

let result = ty::fold::shift_vars(TypeFolder::cx(self), val, self.binders_passed);
debug!("shift_vars: shifted result = {:?}", result);

result
}

fn shift_region_through_binders(&self, region: I::Region) -> I::Region {
if self.binders_passed == 0 || !region.has_escaping_bound_vars() {
return region;
region
} else {
ty::fold::shift_region(self.cx, region, self.binders_passed)
}
ty::fold::shift_region(self.cx, region, self.binders_passed)
}
}
11 changes: 5 additions & 6 deletions compiler/rustc_type_ir/src/fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
use std::mem;

use rustc_index::{Idx, IndexVec};
use tracing::debug;
use tracing::instrument;

use crate::data_structures::Lrc;
use crate::inherent::*;
Expand Down Expand Up @@ -417,15 +417,14 @@ pub fn shift_region<I: Interner>(cx: I, region: I::Region, amount: u32) -> I::Re
}
}

#[instrument(level = "trace", skip(cx), ret)]
pub fn shift_vars<I: Interner, T>(cx: I, value: T, amount: u32) -> T
where
T: TypeFoldable<I>,
{
debug!("shift_vars(value={:?}, amount={})", value, amount);

if amount == 0 || !value.has_escaping_bound_vars() {
return value;
value
} else {
value.fold_with(&mut Shifter::new(cx, amount))
}

value.fold_with(&mut Shifter::new(cx, amount))
}
Loading
Loading