Skip to content

Commit

Permalink
Add quote support
Browse files Browse the repository at this point in the history
  • Loading branch information
charliermarsh committed Dec 10, 2023
1 parent febc69a commit 57915b4
Show file tree
Hide file tree
Showing 9 changed files with 581 additions and 94 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
def f():
from pandas import DataFrame

def baz() -> DataFrame:
...


def f():
from pandas import DataFrame

def baz() -> DataFrame[int]:
...


def f():
from pandas import DataFrame

def baz() -> DataFrame["int"]:
...


def f():
import pandas as pd

def baz() -> pd.DataFrame:
...


def f():
import pandas as pd

def baz() -> pd.DataFrame.Extra:
...
96 changes: 65 additions & 31 deletions crates/ruff_linter/src/checkers/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -514,8 +514,10 @@ where
.chain(&parameters.kwonlyargs)
{
if let Some(expr) = &parameter_with_default.parameter.annotation {
if runtime_annotation || singledispatch {
self.visit_runtime_annotation(expr);
if singledispatch {
self.visit_runtime_required_annotation(expr);
} else if runtime_annotation {
self.visit_runtime_evaluated_annotation(expr);
} else {
self.visit_annotation(expr);
};
Expand All @@ -528,7 +530,7 @@ where
if let Some(arg) = &parameters.vararg {
if let Some(expr) = &arg.annotation {
if runtime_annotation {
self.visit_runtime_annotation(expr);
self.visit_runtime_evaluated_annotation(expr);
} else {
self.visit_annotation(expr);
};
Expand All @@ -537,15 +539,15 @@ where
if let Some(arg) = &parameters.kwarg {
if let Some(expr) = &arg.annotation {
if runtime_annotation {
self.visit_runtime_annotation(expr);
self.visit_runtime_evaluated_annotation(expr);
} else {
self.visit_annotation(expr);
};
}
}
for expr in returns {
if runtime_annotation {
self.visit_runtime_annotation(expr);
self.visit_runtime_evaluated_annotation(expr);
} else {
self.visit_annotation(expr);
};
Expand Down Expand Up @@ -676,40 +678,64 @@ where
value,
..
}) => {
// If we're in a class or module scope, then the annotation needs to be
// available at runtime.
// See: https://docs.python.org/3/reference/simple_stmts.html#annotated-assignment-statements
let runtime_annotation = if self.semantic.future_annotations() {
self.semantic
enum AnnotationKind {
RuntimeRequired,
RuntimeEvaluated,
TypingOnly,
}

fn annotation_kind(
semantic: &SemanticModel,
settings: &LinterSettings,
) -> AnnotationKind {
// If the annotation is in a class, and that class is marked as
// runtime-evaluated, treat the annotation as runtime-required.
// TODO(charlie): We could also include function calls here.
if semantic
.current_scope()
.kind
.as_class()
.is_some_and(|class_def| {
flake8_type_checking::helpers::runtime_evaluated_class(
flake8_type_checking::helpers::runtime_required_class(
class_def,
&self
.settings
.flake8_type_checking
.runtime_evaluated_base_classes,
&self
.settings
.flake8_type_checking
.runtime_evaluated_decorators,
&self.semantic,
&settings.flake8_type_checking.runtime_evaluated_base_classes,
&settings.flake8_type_checking.runtime_evaluated_decorators,
semantic,
)
})
} else {
matches!(
self.semantic.current_scope().kind,
{
return AnnotationKind::RuntimeRequired;
}

// If `__future__` annotations are enabled, then annotations are never evaluated
// at runtime, so we can treat them as typing-only.
if semantic.future_annotations() {
return AnnotationKind::TypingOnly;
}

// Otherwise, if we're in a class or module scope, then the annotation needs to
// be available at runtime.
// See: https://docs.python.org/3/reference/simple_stmts.html#annotated-assignment-statements
if matches!(
semantic.current_scope().kind,
ScopeKind::Class(_) | ScopeKind::Module
)
};
) {
return AnnotationKind::RuntimeEvaluated;
}

if runtime_annotation {
self.visit_runtime_annotation(annotation);
} else {
self.visit_annotation(annotation);
AnnotationKind::TypingOnly
}

match annotation_kind(&self.semantic, self.settings) {
AnnotationKind::RuntimeRequired => {
self.visit_runtime_required_annotation(annotation);
}
AnnotationKind::RuntimeEvaluated => {
self.visit_runtime_evaluated_annotation(annotation);
}
AnnotationKind::TypingOnly => self.visit_annotation(annotation),
}

if let Some(expr) = value {
if self.semantic.match_typing_expr(annotation, "TypeAlias") {
self.visit_type_definition(expr);
Expand Down Expand Up @@ -1526,10 +1552,18 @@ impl<'a> Checker<'a> {
self.semantic.flags = snapshot;
}

/// Visit an [`Expr`], and treat it as a runtime-evaluated type annotation.
fn visit_runtime_evaluated_annotation(&mut self, expr: &'a Expr) {
let snapshot = self.semantic.flags;
self.semantic.flags |= SemanticModelFlags::RUNTIME_EVALUATED_ANNOTATION;
self.visit_type_definition(expr);
self.semantic.flags = snapshot;
}

/// Visit an [`Expr`], and treat it as a runtime-required type annotation.
fn visit_runtime_annotation(&mut self, expr: &'a Expr) {
fn visit_runtime_required_annotation(&mut self, expr: &'a Expr) {
let snapshot = self.semantic.flags;
self.semantic.flags |= SemanticModelFlags::RUNTIME_ANNOTATION;
self.semantic.flags |= SemanticModelFlags::RUNTIME_REQUIRED_ANNOTATION;
self.visit_type_definition(expr);
self.semantic.flags = snapshot;
}
Expand Down
25 changes: 17 additions & 8 deletions crates/ruff_linter/src/rules/flake8_type_checking/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,37 @@ pub(crate) fn is_valid_runtime_import(binding: &Binding, semantic: &SemanticMode
binding.context.is_runtime()
&& binding
.references()
.any(|reference_id| semantic.reference(reference_id).context().is_runtime())
.map(|reference_id| semantic.reference(reference_id))
.any(|reference| {
// This is like: typing context _or_ a runtime-required type annotation (since
// we're willing to quote it).
!(reference.in_type_checking_block()
|| reference.in_typing_only_annotation()
|| reference.in_runtime_evaluated_annotation()
|| reference.in_complex_string_type_definition()
|| reference.in_simple_string_type_definition())
})
} else {
false
}
}

pub(crate) fn runtime_evaluated_class(
pub(crate) fn runtime_required_class(
class_def: &ast::StmtClassDef,
base_classes: &[String],
decorators: &[String],
semantic: &SemanticModel,
) -> bool {
if runtime_evaluated_base_class(class_def, base_classes, semantic) {
if runtime_required_base_class(class_def, base_classes, semantic) {
return true;
}
if runtime_evaluated_decorators(class_def, decorators, semantic) {
if runtime_required_decorators(class_def, decorators, semantic) {
return true;
}
false
}

fn runtime_evaluated_base_class(
fn runtime_required_base_class(
class_def: &ast::StmtClassDef,
base_classes: &[String],
semantic: &SemanticModel,
Expand All @@ -45,7 +54,7 @@ fn runtime_evaluated_base_class(
seen: &mut FxHashSet<BindingId>,
) -> bool {
class_def.bases().iter().any(|expr| {
// If the base class is itself runtime-evaluated, then this is too.
// If the base class is itself runtime-required, then this is too.
// Ex) `class Foo(BaseModel): ...`
if semantic
.resolve_call_path(map_subscript(expr))
Expand All @@ -58,7 +67,7 @@ fn runtime_evaluated_base_class(
return true;
}

// If the base class extends a runtime-evaluated class, then this does too.
// If the base class extends a runtime-required class, then this does too.
// Ex) `class Bar(BaseModel): ...; class Foo(Bar): ...`
if let Some(id) = semantic.lookup_attribute(map_subscript(expr)) {
if seen.insert(id) {
Expand Down Expand Up @@ -86,7 +95,7 @@ fn runtime_evaluated_base_class(
inner(class_def, base_classes, semantic, &mut FxHashSet::default())
}

fn runtime_evaluated_decorators(
fn runtime_required_decorators(
class_def: &ast::StmtClassDef,
decorators: &[String],
semantic: &SemanticModel,
Expand Down
1 change: 1 addition & 0 deletions crates/ruff_linter/src/rules/flake8_type_checking/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ mod tests {
#[test_case(Rule::TypingOnlyStandardLibraryImport, Path::new("TCH003.py"))]
#[test_case(Rule::TypingOnlyStandardLibraryImport, Path::new("snapshot.py"))]
#[test_case(Rule::TypingOnlyThirdPartyImport, Path::new("TCH002.py"))]
#[test_case(Rule::TypingOnlyThirdPartyImport, Path::new("quote.py"))]
#[test_case(Rule::TypingOnlyThirdPartyImport, Path::new("singledispatch.py"))]
#[test_case(Rule::TypingOnlyThirdPartyImport, Path::new("strict.py"))]
#[test_case(Rule::TypingOnlyThirdPartyImport, Path::new("typing_modules_1.py"))]
Expand Down
Loading

0 comments on commit 57915b4

Please sign in to comment.