From 2f790af726e69d66952233957129f22c70b98001 Mon Sep 17 00:00:00 2001 From: "Celina G. Val" Date: Mon, 11 Dec 2023 14:53:39 -0800 Subject: [PATCH] Fix BinOp ty assertion and `fn_sig` for closures Also added a few more util methods to TyKind to check for specific types. --- compiler/rustc_smir/src/rustc_smir/context.rs | 12 ++ compiler/stable_mir/src/compiler_interface.rs | 6 + compiler/stable_mir/src/mir/body.rs | 22 +-- compiler/stable_mir/src/ty.rs | 139 +++++++++++++++++- 4 files changed, 167 insertions(+), 12 deletions(-) diff --git a/compiler/rustc_smir/src/rustc_smir/context.rs b/compiler/rustc_smir/src/rustc_smir/context.rs index 22e9f66ba9655..4ec5e2a538708 100644 --- a/compiler/rustc_smir/src/rustc_smir/context.rs +++ b/compiler/rustc_smir/src/rustc_smir/context.rs @@ -213,6 +213,11 @@ impl<'tcx> Context for TablesWrapper<'tcx> { def.internal(&mut *tables).is_box() } + fn adt_is_simd(&self, def: AdtDef) -> bool { + let mut tables = self.0.borrow_mut(); + def.internal(&mut *tables).repr().simd() + } + fn fn_sig(&self, def: FnDef, args: &GenericArgs) -> PolyFnSig { let mut tables = self.0.borrow_mut(); let def_id = def.0.internal(&mut *tables); @@ -220,6 +225,13 @@ impl<'tcx> Context for TablesWrapper<'tcx> { sig.stable(&mut *tables) } + fn closure_sig(&self, args: &GenericArgs) -> PolyFnSig { + let mut tables = self.0.borrow_mut(); + let args_ref = args.internal(&mut *tables); + let sig = args_ref.as_closure().sig(); + sig.stable(&mut *tables) + } + fn adt_variants_len(&self, def: AdtDef) -> usize { let mut tables = self.0.borrow_mut(); def.internal(&mut *tables).variants().len() diff --git a/compiler/stable_mir/src/compiler_interface.rs b/compiler/stable_mir/src/compiler_interface.rs index 17c5212fb9cd4..2fac59e71fd5b 100644 --- a/compiler/stable_mir/src/compiler_interface.rs +++ b/compiler/stable_mir/src/compiler_interface.rs @@ -69,9 +69,15 @@ pub trait Context { /// Returns if the ADT is a box. fn adt_is_box(&self, def: AdtDef) -> bool; + /// Returns whether this ADT is simd. + fn adt_is_simd(&self, def: AdtDef) -> bool; + /// Retrieve the function signature for the given generic arguments. fn fn_sig(&self, def: FnDef, args: &GenericArgs) -> PolyFnSig; + /// Retrieve the closure signature for the given generic arguments. + fn closure_sig(&self, args: &GenericArgs) -> PolyFnSig; + /// The number of variants in this ADT. fn adt_variants_len(&self, def: AdtDef) -> usize; diff --git a/compiler/stable_mir/src/mir/body.rs b/compiler/stable_mir/src/mir/body.rs index 663275d9a0f8c..883baeb9f71fc 100644 --- a/compiler/stable_mir/src/mir/body.rs +++ b/compiler/stable_mir/src/mir/body.rs @@ -228,7 +228,7 @@ pub struct InlineAsmOperand { pub raw_rpr: String, } -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Copy, Clone, Debug, Eq, PartialEq)] pub enum UnwindAction { Continue, Unreachable, @@ -248,7 +248,7 @@ pub enum AssertMessage { MisalignedPointerDereference { required: Operand, found: Operand }, } -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Copy, Clone, Debug, Eq, PartialEq)] pub enum BinOp { Add, AddUnchecked, @@ -278,8 +278,10 @@ impl BinOp { /// Return the type of this operation for the given input Ty. /// This function does not perform type checking, and it currently doesn't handle SIMD. pub fn ty(&self, lhs_ty: Ty, rhs_ty: Ty) -> Ty { - assert!(lhs_ty.kind().is_primitive()); - assert!(rhs_ty.kind().is_primitive()); + let lhs_kind = lhs_ty.kind(); + let rhs_kind = rhs_ty.kind(); + assert!(lhs_kind.is_primitive() || lhs_kind.is_any_ptr()); + assert!(rhs_kind.is_primitive() || rhs_kind.is_any_ptr()); match self { BinOp::Add | BinOp::AddUnchecked @@ -306,7 +308,7 @@ impl BinOp { } } -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Copy, Clone, Debug, Eq, PartialEq)] pub enum UnOp { Not, Neg, @@ -319,7 +321,7 @@ pub enum CoroutineKind { Gen(CoroutineSource), } -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Copy, Clone, Debug, Eq, PartialEq)] pub enum CoroutineSource { Block, Closure, @@ -343,7 +345,7 @@ pub enum FakeReadCause { } /// Describes what kind of retag is to be performed -#[derive(Clone, Debug, Eq, PartialEq, Hash)] +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] pub enum RetagKind { FnEntry, TwoPhase, @@ -351,7 +353,7 @@ pub enum RetagKind { Default, } -#[derive(Clone, Debug, Eq, PartialEq, Hash)] +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] pub enum Variance { Covariant, Invariant, @@ -862,7 +864,7 @@ pub enum Safety { Normal, } -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Copy, Clone, Debug, Eq, PartialEq)] pub enum PointerCoercion { /// Go from a fn-item type to a fn-pointer type. ReifyFnPointer, @@ -889,7 +891,7 @@ pub enum PointerCoercion { Unsize, } -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Copy, Clone, Debug, Eq, PartialEq)] pub enum CastKind { PointerExposeAddress, PointerFromExposedAddress, diff --git a/compiler/stable_mir/src/ty.rs b/compiler/stable_mir/src/ty.rs index bea7702bd34bf..836fabd8319f2 100644 --- a/compiler/stable_mir/src/ty.rs +++ b/compiler/stable_mir/src/ty.rs @@ -214,38 +214,62 @@ impl TyKind { if let TyKind::RigidTy(inner) = self { Some(inner) } else { None } } + #[inline] pub fn is_unit(&self) -> bool { matches!(self, TyKind::RigidTy(RigidTy::Tuple(data)) if data.is_empty()) } + #[inline] pub fn is_bool(&self) -> bool { matches!(self, TyKind::RigidTy(RigidTy::Bool)) } + #[inline] + pub fn is_char(&self) -> bool { + matches!(self, TyKind::RigidTy(RigidTy::Char)) + } + + #[inline] pub fn is_trait(&self) -> bool { matches!(self, TyKind::RigidTy(RigidTy::Dynamic(_, _, DynKind::Dyn))) } + #[inline] pub fn is_enum(&self) -> bool { matches!(self, TyKind::RigidTy(RigidTy::Adt(def, _)) if def.kind() == AdtKind::Enum) } + #[inline] pub fn is_struct(&self) -> bool { matches!(self, TyKind::RigidTy(RigidTy::Adt(def, _)) if def.kind() == AdtKind::Struct) } + #[inline] pub fn is_union(&self) -> bool { matches!(self, TyKind::RigidTy(RigidTy::Adt(def, _)) if def.kind() == AdtKind::Union) } + #[inline] + pub fn is_adt(&self) -> bool { + matches!(self, TyKind::RigidTy(RigidTy::Adt(..))) + } + + #[inline] + pub fn is_ref(&self) -> bool { + matches!(self, TyKind::RigidTy(RigidTy::Ref(..))) + } + + #[inline] pub fn is_fn(&self) -> bool { matches!(self, TyKind::RigidTy(RigidTy::FnDef(..))) } + #[inline] pub fn is_fn_ptr(&self) -> bool { matches!(self, TyKind::RigidTy(RigidTy::FnPtr(..))) } + #[inline] pub fn is_primitive(&self) -> bool { matches!( self, @@ -259,6 +283,102 @@ impl TyKind { ) } + /// A scalar type is one that denotes an atomic datum, with no sub-components. + /// (A RawPtr is scalar because it represents a non-managed pointer, so its + /// contents are abstract to rustc.) + #[inline] + pub fn is_scalar(&self) -> bool { + matches!( + self, + TyKind::RigidTy(RigidTy::Bool) + | TyKind::RigidTy(RigidTy::Char) + | TyKind::RigidTy(RigidTy::Int(_)) + | TyKind::RigidTy(RigidTy::Float(_)) + | TyKind::RigidTy(RigidTy::Uint(_)) + | TyKind::RigidTy(RigidTy::FnDef(..)) + | TyKind::RigidTy(RigidTy::FnPtr(_)) + | TyKind::RigidTy(RigidTy::RawPtr(..)) + ) + } + + #[inline] + pub fn is_float(&self) -> bool { + matches!(self, TyKind::RigidTy(RigidTy::Float(_))) + } + + #[inline] + pub fn is_integral(&self) -> bool { + matches!(self, TyKind::RigidTy(RigidTy::Int(_) | RigidTy::Uint(_))) + } + + #[inline] + pub fn is_numeric(&self) -> bool { + self.is_integral() || self.is_float() + } + + #[inline] + pub fn is_signed(&self) -> bool { + matches!(self, TyKind::RigidTy(RigidTy::Int(_))) + } + + #[inline] + pub fn is_str(&self) -> bool { + *self == TyKind::RigidTy(RigidTy::Str) + } + + #[inline] + pub fn is_slice(&self) -> bool { + matches!(self, TyKind::RigidTy(RigidTy::Slice(_))) + } + + #[inline] + pub fn is_array(&self) -> bool { + matches!(self, TyKind::RigidTy(RigidTy::Array(..))) + } + + #[inline] + pub fn is_mutable_ptr(&self) -> bool { + matches!( + self, + TyKind::RigidTy(RigidTy::RawPtr(_, Mutability::Mut)) + | TyKind::RigidTy(RigidTy::Ref(_, _, Mutability::Mut)) + ) + } + + #[inline] + pub fn is_raw_ptr(&self) -> bool { + matches!(self, TyKind::RigidTy(RigidTy::RawPtr(..))) + } + + /// Tests if this is any kind of primitive pointer type (reference, raw pointer, fn pointer). + #[inline] + pub fn is_any_ptr(&self) -> bool { + self.is_ref() || self.is_raw_ptr() || self.is_fn_ptr() + } + + #[inline] + pub fn is_coroutine(&self) -> bool { + matches!(self, TyKind::RigidTy(RigidTy::Coroutine(..))) + } + + #[inline] + pub fn is_closure(&self) -> bool { + matches!(self, TyKind::RigidTy(RigidTy::Closure(..))) + } + + #[inline] + pub fn is_box(&self) -> bool { + match self { + TyKind::RigidTy(RigidTy::Adt(def, _)) => def.is_box(), + _ => false, + } + } + + #[inline] + pub fn is_simd(&self) -> bool { + matches!(self, TyKind::RigidTy(RigidTy::Adt(def, _)) if def.is_simd()) + } + pub fn trait_principal(&self) -> Option> { if let TyKind::RigidTy(RigidTy::Dynamic(predicates, _, _)) = self { if let Some(Binder { value: ExistentialPredicate::Trait(trait_ref), bound_vars }) = @@ -300,12 +420,12 @@ impl TyKind { } } - /// Get the function signature for function like types (Fn, FnPtr, Closure, Coroutine) - /// FIXME(closure) + /// Get the function signature for function like types (Fn, FnPtr, and Closure) pub fn fn_sig(&self) -> Option { match self { TyKind::RigidTy(RigidTy::FnDef(def, args)) => Some(with(|cx| cx.fn_sig(*def, args))), TyKind::RigidTy(RigidTy::FnPtr(sig)) => Some(sig.clone()), + TyKind::RigidTy(RigidTy::Closure(_def, args)) => Some(with(|cx| cx.closure_sig(args))), _ => None, } } @@ -481,6 +601,10 @@ impl AdtDef { with(|cx| cx.adt_is_box(*self)) } + pub fn is_simd(&self) -> bool { + with(|cx| cx.adt_is_simd(*self)) + } + /// The number of variants in this ADT. pub fn num_variants(&self) -> usize { with(|cx| cx.adt_variants_len(*self)) @@ -738,6 +862,7 @@ pub enum Abi { RiscvInterruptS, } +/// A Binder represents a possibly generic type and its bound vars. #[derive(Clone, Debug, Eq, PartialEq)] pub struct Binder { pub value: T, @@ -745,6 +870,16 @@ pub struct Binder { } impl Binder { + /// Create a new binder with the given bound vars. + pub fn new(value: T, bound_vars: Vec) -> Self { + Binder { value, bound_vars } + } + + /// Create a new binder with no bounded variable. + pub fn dummy(value: T) -> Self { + Binder { value, bound_vars: vec![] } + } + pub fn skip_binder(self) -> T { self.value }