Skip to content

Commit

Permalink
Add quote support
Browse files Browse the repository at this point in the history
  • Loading branch information
charliermarsh committed Jul 23, 2023
1 parent 1776cbd commit 270814f
Show file tree
Hide file tree
Showing 7 changed files with 409 additions and 139 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,10 @@ def f():
from module import Member

x: Member = 1


def f():
from pandas import DataFrame

def baz() -> DataFrame:
...
106 changes: 62 additions & 44 deletions crates/ruff/src/checkers/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ use ruff_python_ast::visitor::{walk_except_handler, walk_pattern, Visitor};
use ruff_python_ast::{cast, helpers, str, visitor};
use ruff_python_semantic::analyze::{branch_detection, typing, visibility};
use ruff_python_semantic::{
Binding, BindingFlags, BindingId, BindingKind, ContextualizedDefinition, Exceptions,
ExecutionContext, Export, FromImport, Globals, Import, Module, ModuleKind, ScopeId, ScopeKind,
SemanticModel, SemanticModelFlags, StarImport, SubmoduleImport,
Binding, BindingFlags, BindingId, BindingKind, ContextualizedDefinition, Exceptions, Export,
FromImport, Globals, Import, Module, ModuleKind, ScopeId, ScopeKind, SemanticModel,
SemanticModelFlags, StarImport, SubmoduleImport,
};
use ruff_python_stdlib::builtins::{BUILTINS, MAGIC_GLOBALS};
use ruff_python_stdlib::path::is_python_stub_file;
Expand Down Expand Up @@ -1835,11 +1835,7 @@ where
for name in names {
if let Some((scope_id, binding_id)) = self.semantic.nonlocal(name) {
// Mark the binding as "used".
self.semantic.add_local_reference(
binding_id,
name.range(),
ExecutionContext::Runtime,
);
self.semantic.add_local_reference(binding_id, name.range());

// Mark the binding in the enclosing scope as "rebound" in the current
// scope.
Expand Down Expand Up @@ -1895,7 +1891,7 @@ where
{
if let Some(expr) = &arg_with_default.def.annotation {
if runtime_annotation {
self.visit_runtime_annotation(expr);
self.visit_runtime_evaluated_annotation(expr);
} else {
self.visit_annotation(expr);
};
Expand All @@ -1907,7 +1903,7 @@ where
if let Some(arg) = &args.vararg {
if let Some(expr) = &arg.annotation {
if runtime_annotation {
self.visit_runtime_annotation(expr);
self.visit_runtime_evaluated_annotation(expr);
} else {
self.visit_annotation(expr);
};
Expand All @@ -1916,15 +1912,15 @@ where
if let Some(arg) = &args.kwarg {
if let Some(expr) = &arg.annotation {
if runtime_annotation {
self.visit_runtime_annotation(expr);
self.visit_runtime_evaluated_annotation(expr);
} else {
self.visit_annotation(expr);
};
}
}
for expr in returns {
if runtime_annotation {
self.visit_runtime_annotation(expr);
self.visit_runtime_evaluated_annotation(expr);
} else {
self.visit_annotation(expr);
};
Expand Down Expand Up @@ -2036,38 +2032,53 @@ where
value,
..
}) => {
// If we're in a class or module scope, then the annotation needs to be
// available at runtime.
// See: https://docs.python.org/3/reference/simple_stmts.html#annotated-assignment-statements
let runtime_annotation = if self.semantic.future_annotations() {
if self.semantic.scope().kind.is_class() {
let baseclasses = &self
.settings
.flake8_type_checking
.runtime_evaluated_base_classes;
let decorators = &self
.settings
.flake8_type_checking
.runtime_evaluated_decorators;
flake8_type_checking::helpers::runtime_evaluated(
enum AnnotationKind {
RuntimeRequired,
RuntimeEvaluated,
TypingOnly,
}

fn annotation_kind(model: &SemanticModel, settings: &Settings) -> AnnotationKind {
// If the annotation is in a class, and that class is marked as
// runtime-evaluated, treat the annotation as runtime-required.
if model.scope().kind.is_class() {
let baseclasses =
&settings.flake8_type_checking.runtime_evaluated_base_classes;
let decorators =
&settings.flake8_type_checking.runtime_evaluated_decorators;
if flake8_type_checking::helpers::runtime_required(
baseclasses,
decorators,
&self.semantic,
)
} else {
false
model,
) {
return AnnotationKind::RuntimeRequired;
}
}
} else {
matches!(
self.semantic.scope().kind,
ScopeKind::Class(_) | ScopeKind::Module
)
};

if runtime_annotation {
self.visit_runtime_annotation(annotation);
} else {
self.visit_annotation(annotation);
// If `__future__` annotations are enabled, then annotations are never evaluated
// at runtime, so we can treat them as typing-only.
if model.future_annotations() {
return AnnotationKind::TypingOnly;
}

// Otherwise, if we're in a class or module scope, then the annotation needs to
// be available at runtime.
// See: https://docs.python.org/3/reference/simple_stmts.html#annotated-assignment-statements
if matches!(model.scope().kind, ScopeKind::Class(_) | ScopeKind::Module) {
return AnnotationKind::RuntimeEvaluated;
}

AnnotationKind::TypingOnly
}

match annotation_kind(&self.semantic, self.settings) {
AnnotationKind::RuntimeRequired => {
self.visit_runtime_required_annotation(annotation);
}
AnnotationKind::RuntimeEvaluated => {
self.visit_runtime_evaluated_annotation(annotation);
}
AnnotationKind::TypingOnly => self.visit_annotation(annotation),
}
if let Some(expr) = value {
if self.semantic.match_typing_expr(annotation, "TypeAlias") {
Expand Down Expand Up @@ -4306,10 +4317,18 @@ impl<'a> Checker<'a> {
self.semantic.flags = snapshot;
}

/// Visit an [`Expr`], and treat it as a runtime-evaluated type annotation.
fn visit_runtime_evaluated_annotation(&mut self, expr: &'a Expr) {
let snapshot = self.semantic.flags;
self.semantic.flags |= SemanticModelFlags::RUNTIME_EVALUATED_ANNOTATION;
self.visit_type_definition(expr);
self.semantic.flags = snapshot;
}

/// Visit an [`Expr`], and treat it as a runtime-required type annotation.
fn visit_runtime_annotation(&mut self, expr: &'a Expr) {
fn visit_runtime_required_annotation(&mut self, expr: &'a Expr) {
let snapshot = self.semantic.flags;
self.semantic.flags |= SemanticModelFlags::RUNTIME_ANNOTATION;
self.semantic.flags |= SemanticModelFlags::RUNTIME_REQUIRED_ANNOTATION;
self.visit_type_definition(expr);
self.semantic.flags = snapshot;
}
Expand Down Expand Up @@ -4824,8 +4843,7 @@ impl<'a> Checker<'a> {
for (name, range) in exports {
if let Some(binding_id) = self.semantic.global_scope().get(name) {
// Mark anything referenced in `__all__` as used.
self.semantic
.add_global_reference(binding_id, range, ExecutionContext::Runtime);
self.semantic.add_global_reference(binding_id, range);
} else {
if self.semantic.global_scope().uses_star_imports() {
if self.enabled(Rule::UndefinedLocalWithImportStarUsage) {
Expand Down
21 changes: 15 additions & 6 deletions crates/ruff/src/rules/flake8_type_checking/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,40 @@ pub(crate) fn is_valid_runtime_import(binding: &Binding, semantic: &SemanticMode
binding.context.is_runtime()
&& binding
.references()
.any(|reference_id| semantic.reference(reference_id).context().is_runtime())
.map(|reference_id| semantic.reference(reference_id))
.any(|reference| {
// This is like: typing context _or_ a runtime-required type annotation (since
// we're willing to quote it).
!(reference.in_type_checking_block()
|| reference.in_typing_only_annotation()
|| reference.in_runtime_evaluated_annotation()
|| reference.in_complex_string_type_definition()
|| reference.in_simple_string_type_definition())
})
} else {
false
}
}

pub(crate) fn runtime_evaluated(
pub(crate) fn runtime_required(
base_classes: &[String],
decorators: &[String],
semantic: &SemanticModel,
) -> bool {
if !base_classes.is_empty() {
if runtime_evaluated_base_class(base_classes, semantic) {
if runtime_required_base_class(base_classes, semantic) {
return true;
}
}
if !decorators.is_empty() {
if runtime_evaluated_decorators(decorators, semantic) {
if runtime_required_decorators(decorators, semantic) {
return true;
}
}
false
}

fn runtime_evaluated_base_class(base_classes: &[String], semantic: &SemanticModel) -> bool {
fn runtime_required_base_class(base_classes: &[String], semantic: &SemanticModel) -> bool {
if let ScopeKind::Class(ast::StmtClassDef { bases, .. }) = &semantic.scope().kind {
for base in bases {
if let Some(call_path) = semantic.resolve_call_path(base) {
Expand All @@ -52,7 +61,7 @@ fn runtime_evaluated_base_class(base_classes: &[String], semantic: &SemanticMode
false
}

fn runtime_evaluated_decorators(decorators: &[String], semantic: &SemanticModel) -> bool {
fn runtime_required_decorators(decorators: &[String], semantic: &SemanticModel) -> bool {
if let ScopeKind::Class(ast::StmtClassDef { decorator_list, .. }) = &semantic.scope().kind {
for decorator in decorator_list {
if let Some(call_path) = semantic.resolve_call_path(map_callable(&decorator.expression))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use anyhow::Result;
use ruff_text_size::TextRange;
use rustc_hash::FxHashMap;

use ruff_diagnostics::{AutofixKind, Diagnostic, DiagnosticKind, Fix, Violation};
use ruff_diagnostics::{AutofixKind, Diagnostic, DiagnosticKind, Edit, Fix, Violation};
use ruff_macros::{derive_message_formats, violation};
use ruff_python_semantic::{Binding, NodeId, ResolvedReferenceId, Scope};

Expand Down Expand Up @@ -228,13 +228,19 @@ pub(crate) fn typing_only_runtime_import(
};

if binding.context.is_runtime()
&& binding.references().all(|reference_id| {
checker
.semantic()
.reference(reference_id)
.context()
.is_typing()
})
&& binding
.references()
.map(|reference_id| checker.semantic().reference(reference_id))
.all(|reference| {
// All references should be in a typing context _or_ a runtime-evaluated
// annotation (as opposed to a runtime-required annotation), which we can
// quote.
reference.in_type_checking_block()
|| reference.in_typing_only_annotation()
|| reference.in_runtime_evaluated_annotation()
|| reference.in_complex_string_type_definition()
|| reference.in_simple_string_type_definition()
})
{
// Extract the module base and level from the full name.
// Ex) `foo.bar.baz` -> `foo`, `0`
Expand Down Expand Up @@ -278,6 +284,7 @@ pub(crate) fn typing_only_runtime_import(
let import = Import {
qualified_name,
reference_id,
binding,
range: binding.range,
parent_range: binding.parent_range(checker.semantic()),
};
Expand Down Expand Up @@ -356,6 +363,8 @@ pub(crate) fn typing_only_runtime_import(
struct Import<'a> {
/// The qualified name of the import (e.g., `typing.List` for `from typing import List`).
qualified_name: &'a str,
/// The binding for the imported symbol.
binding: &'a Binding<'a>,
/// The first reference to the imported symbol.
reference_id: ResolvedReferenceId,
/// The trimmed range of the import (e.g., `List` in `from typing import List`).
Expand Down Expand Up @@ -449,8 +458,28 @@ fn fix_imports(checker: &Checker, stmt_id: NodeId, imports: &[Import]) -> Result
checker.semantic(),
)?;

Ok(
Fix::suggested_edits(remove_import_edit, add_import_edit.into_edits())
.isolate(checker.isolation(parent)),
// Step 3) Quote any runtime usages of the referenced symbol.
let quote_reference_edits = imports.iter().flat_map(|Import { binding, .. }| {
binding.references.iter().filter_map(|reference_id| {
let reference = checker.semantic().reference(*reference_id);
if reference.in_runtime_evaluated_annotation() {
Some(Edit::range_replacement(
// TODO(charlie): escape any quotes in the reference.
format!("\"{}\"", checker.locator().slice(reference.range())),
reference.range(),
))
} else {
None
}
})
});

Ok(Fix::suggested_edits(
remove_import_edit,
add_import_edit
.into_edits()
.into_iter()
.chain(quote_reference_edits),
)
.isolate(checker.isolation(parent)))
}
Original file line number Diff line number Diff line change
Expand Up @@ -248,5 +248,35 @@ TCH002.py:172:24: TCH002 [*] Move third-party import `module.Member` into a type
172 |- from module import Member
173 176 |
174 177 | x: Member = 1
175 178 |

TCH002.py:178:24: TCH002 [*] Move third-party import `pandas.DataFrame` into a type-checking block
|
177 | def f():
178 | from pandas import DataFrame
| ^^^^^^^^^ TCH002
179 |
180 | def baz() -> DataFrame:
|
= help: Move into type-checking block

Suggested fix
1 1 | """Tests to determine accurate detection of typing-only imports."""
2 |+from typing import TYPE_CHECKING
3 |+
4 |+if TYPE_CHECKING:
5 |+ from pandas import DataFrame
2 6 |
3 7 |
4 8 | def f():
--------------------------------------------------------------------------------
175 179 |
176 180 |
177 181 | def f():
178 |- from pandas import DataFrame
179 182 |
180 |- def baz() -> DataFrame:
183 |+ def baz() -> 'DataFrame':
181 184 | ...


Loading

0 comments on commit 270814f

Please sign in to comment.