From 29db6b8cb0861e1c183944f2b2574d89c5b13db0 Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Wed, 18 Sep 2024 16:21:23 -0700 Subject: [PATCH] [red-knot] simplify subtypes from unions --- crates/red_knot_python_semantic/src/types.rs | 55 +++++++++++++++++-- .../src/types/builder.rs | 38 ++++++++++++- .../src/types/infer.rs | 3 +- 3 files changed, 86 insertions(+), 10 deletions(-) diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 7a95fa94200c6..124e14c12f192 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -388,16 +388,18 @@ impl<'db> Type<'db> { } } - /// Return true if this type is [assignable to] type `target`. + /// Return true if this type is a [subtype of] type `target`. /// - /// [assignable to]: https://typing.readthedocs.io/en/latest/spec/concepts.html#the-assignable-to-or-consistent-subtyping-relation - pub(crate) fn is_assignable_to(self, db: &'db dyn Db, target: Type<'db>) -> bool { + /// [subtype of]: https://typing.readthedocs.io/en/latest/spec/concepts.html#subtype-supertype-and-type-equivalence + pub(crate) fn is_subtype_of(self, db: &'db dyn Db, target: Type<'db>) -> bool { if self.is_equivalent_to(db, target) { return true; } match (self, target) { - (Type::Unknown | Type::Any | Type::Never, _) => true, - (_, Type::Unknown | Type::Any) => true, + (Type::Unknown | Type::Any, _) => false, + (_, Type::Unknown | Type::Any) => false, + (Type::Never, _) => true, + (_, Type::Never) => false, (Type::IntLiteral(_), Type::Instance(class)) if class.is_stdlib_symbol(db, "builtins", "int") => { @@ -417,12 +419,28 @@ impl<'db> Type<'db> { (ty, Type::Union(union)) => union .elements(db) .iter() - .any(|&elem_ty| ty.is_assignable_to(db, elem_ty)), + .any(|&elem_ty| ty.is_subtype_of(db, elem_ty)), // TODO _ => false, } } + /// Return true if this type is [assignable to] type `target`. + /// + /// [assignable to]: https://typing.readthedocs.io/en/latest/spec/concepts.html#the-assignable-to-or-consistent-subtyping-relation + pub(crate) fn is_assignable_to(self, db: &'db dyn Db, target: Type<'db>) -> bool { + match (self, target) { + (Type::Unknown | Type::Any, _) => true, + (_, Type::Unknown | Type::Any) => true, + (ty, Type::Union(union)) => union + .elements(db) + .iter() + .any(|&elem_ty| ty.is_assignable_to(db, elem_ty)), + // TODO other types containing gradual forms (e.g. generics containing Any/Unknown) + _ => self.is_subtype_of(db, target), + } + } + /// Return true if this type is equivalent to type `other`. pub(crate) fn is_equivalent_to(self, _db: &'db dyn Db, other: Type<'db>) -> bool { // TODO equivalent but not identical structural types, differently-ordered unions and @@ -1123,6 +1141,31 @@ mod tests { assert!(!from.into_type(&db).is_assignable_to(&db, to.into_type(&db))); } + #[test_case(Ty::Never, Ty::IntLiteral(1))] + #[test_case(Ty::IntLiteral(1), Ty::BuiltinInstance("int"))] + #[test_case(Ty::StringLiteral("foo"), Ty::BuiltinInstance("str"))] + #[test_case(Ty::StringLiteral("foo"), Ty::LiteralString)] + #[test_case(Ty::LiteralString, Ty::BuiltinInstance("str"))] + #[test_case(Ty::BytesLiteral("foo"), Ty::BuiltinInstance("bytes"))] + #[test_case(Ty::IntLiteral(1), Ty::Union(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str")]))] + fn is_subtype_of(from: Ty, to: Ty) { + let db = setup_db(); + assert!(from.into_type(&db).is_subtype_of(&db, to.into_type(&db))); + } + + #[test_case(Ty::Unknown, Ty::IntLiteral(1))] + #[test_case(Ty::Any, Ty::IntLiteral(1))] + #[test_case(Ty::IntLiteral(1), Ty::Unknown)] + #[test_case(Ty::IntLiteral(1), Ty::Any)] + #[test_case(Ty::IntLiteral(1), Ty::Union(vec![Ty::Unknown, Ty::BuiltinInstance("str")]))] + #[test_case(Ty::IntLiteral(1), Ty::BuiltinInstance("str"))] + #[test_case(Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str"))] + #[test_case(Ty::BuiltinInstance("int"), Ty::IntLiteral(1))] + fn is_not_subtype_of(from: Ty, to: Ty) { + let db = setup_db(); + assert!(!from.into_type(&db).is_subtype_of(&db, to.into_type(&db))); + } + #[test_case( Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)]), Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)]) diff --git a/crates/red_knot_python_semantic/src/types/builder.rs b/crates/red_knot_python_semantic/src/types/builder.rs index 0db9fee05a7fc..b4b57aff8fe3f 100644 --- a/crates/red_knot_python_semantic/src/types/builder.rs +++ b/crates/red_knot_python_semantic/src/types/builder.rs @@ -46,11 +46,27 @@ impl<'db> UnionBuilder<'db> { pub(crate) fn add(mut self, ty: Type<'db>) -> Self { match ty { Type::Union(union) => { - self.elements.extend(union.elements(self.db)); + for element in union.elements(self.db) { + self = self.add(*element); + } } Type::Never => {} _ => { - self.elements.insert(ty); + let mut add = true; + let mut remove = vec![]; + for element in &self.elements { + if ty.is_subtype_of(self.db, *element) { + add = false; + } else if element.is_subtype_of(self.db, ty) { + remove.push(*element); + } + } + for element in remove { + self.elements.remove(&element); + } + if add { + self.elements.insert(ty); + } } } @@ -368,6 +384,24 @@ mod tests { assert_eq!(union.elements_vec(&db), &[t0, t1, t2]); } + #[test] + fn build_union_simplify_subtype() { + let db = setup_db(); + let t0 = builtins_symbol_ty(&db, "str").to_instance(&db); + let t1 = Type::LiteralString; + let t2 = Type::Unknown; + let u0 = UnionType::from_elements(&db, [t0, t1]); + let u1 = UnionType::from_elements(&db, [t1, t0]); + let u2 = UnionType::from_elements(&db, [t0, t1, t2]); + + assert_eq!(u0, t0); + assert_eq!(u1, t0); + assert_eq!(u2.expect_union().elements_vec(&db), &[t0, t2]); + } + + #[test] + fn build_union_no_simplify_any() {} + impl<'db> IntersectionType<'db> { fn pos_vec(self, db: &'db TestDb) -> Vec> { self.positive(db).into_iter().copied().collect() diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index aaf3a702e7593..bf8d4b90ce020 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -5802,8 +5802,7 @@ mod tests { .unwrap(); db.write_file("/src/c.pyi", "x: int").unwrap(); - // TODO this should simplify to just 'int' - assert_public_ty(&db, "/src/a.py", "x", "int | Literal[1]"); + assert_public_ty(&db, "/src/a.py", "x", "int"); } // Incremental inference tests