diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 000000000..c8bd05613 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,35 @@ +name: Lint + +on: + push: + branches: + - main + pull_request: + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + cancel-in-progress: true + +jobs: + lintrunner: + name: lintrunner + + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v3 + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: "3.11" + - name: Install Lintrunner + run: | + pip install lintrunner + lintrunner init + - name: Run lintrunner on all files - Linux + run: | + set +e + if ! lintrunner -v --force-color --all-files --tee-json=lint.json; then + echo "" + echo -e "\e[1m\e[36mYou can reproduce these results locally by using \`lintrunner -m main\`.\e[0m" + exit 1 + fi diff --git a/.lintrunner.toml b/.lintrunner.toml new file mode 100644 index 000000000..c551cb732 --- /dev/null +++ b/.lintrunner.toml @@ -0,0 +1,20 @@ +merge_base_with = "origin/main" + +[[linter]] +code = 'RUFF' +include_patterns = ['test/smoke_test/*.py'] +command = [ + 'python3', + 'tools/linter/adapters/ruff_linter.py', + '--config=pyproject.toml', + '--show-disable', + '--', + '@{{PATHSFILE}}' +] +init_command = [ + 'python3', + 'tools/linter/adapters/pip_init.py', + '--dry-run={{DRYRUN}}', + 'ruff==0.0.290', +] +is_formatter = true diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..efa884a07 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,23 @@ +[tool.ruff] +target-version = "py38" +line-length = 120 +select = [ + "B", + "C4", + "G", + "E", + "F", + "SIM1", + "W", + # Not included in flake8 + "UP", + "PERF", + "PGH004", + "PIE807", + "PIE810", + "PLE", + "PLR1722", # use sys exit + "PLW3301", # nested min max + "RUF017", + "TRY302", +] diff --git a/test/smoke_test/smoke_test.py b/test/smoke_test/smoke_test.py index ca44b0369..8ae1d1c51 100644 --- a/test/smoke_test/smoke_test.py +++ b/test/smoke_test/smoke_test.py @@ -1,10 +1,8 @@ import os import re import sys -from pathlib import Path import argparse import torch -import platform import importlib import subprocess import torch._dynamo @@ -41,7 +39,7 @@ class Net(nn.Module): def __init__(self): - super(Net, self).__init__() + super().__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) self.fc1 = nn.Linear(9216, 1) @@ -69,7 +67,7 @@ def check_version(package: str) -> None: def check_nightly_binaries_date(package: str) -> None: - from datetime import datetime, timedelta + from datetime import datetime format_dt = '%Y%m%d' date_t_str = re.findall("dev\\d+", torch.__version__) @@ -177,11 +175,11 @@ def smoke_test_linalg() -> None: print("Testing smoke_test_linalg") A = torch.randn(5, 3) U, S, Vh = torch.linalg.svd(A, full_matrices=False) - U.shape, S.shape, Vh.shape + assert U.shape == A.shape and S.shape == torch.Size([3]) and Vh.shape == torch.Size([3, 3]) torch.dist(A, U @ torch.diag(S) @ Vh) U, S, Vh = torch.linalg.svd(A) - U.shape, S.shape, Vh.shape + assert U.shape == A.shape and S.shape == torch.Size([3]) and Vh.shape == torch.Size([3, 3]) torch.dist(A, U[:, :3] @ torch.diag(S) @ Vh) A = torch.randn(7, 5, 3) @@ -234,9 +232,9 @@ def smoke_test_modules(): smoke_test_command, stderr=subprocess.STDOUT, shell=True, universal_newlines=True) except subprocess.CalledProcessError as exc: - raise RuntimeError(f"Module {module['name']} FAIL: {exc.returncode} Output: {exc.output}") + raise RuntimeError(f"Module {module['name']} FAIL: {exc.returncode} Output: {exc.output}") from exc else: - print("Output: \n{}\n".format(output)) + print(f"Output: \n{output}\n") def main() -> None: diff --git a/tools/linter/adapters/pip_init.py b/tools/linter/adapters/pip_init.py new file mode 100644 index 000000000..f177a920d --- /dev/null +++ b/tools/linter/adapters/pip_init.py @@ -0,0 +1,83 @@ +""" +Initializer script that installs stuff to pip. +""" +import argparse +import logging +import os +import subprocess +import sys +import time + +from typing import List + + +def run_command(args: List[str]) -> "subprocess.CompletedProcess[bytes]": + logging.debug("$ %s", " ".join(args)) + start_time = time.monotonic() + try: + return subprocess.run(args, check=True) + finally: + end_time = time.monotonic() + logging.debug("took %dms", (end_time - start_time) * 1000) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="pip initializer") + parser.add_argument( + "packages", + nargs="+", + help="pip packages to install", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="verbose logging", + ) + parser.add_argument( + "--dry-run", help="do not install anything, just print what would be done." + ) + parser.add_argument( + "--no-black-binary", + help="do not use pre-compiled binaries from pip for black.", + action="store_true", + ) + + args = parser.parse_args() + + logging.basicConfig( + format="<%(threadName)s:%(levelname)s> %(message)s", + level=logging.NOTSET if args.verbose else logging.DEBUG, + stream=sys.stderr, + ) + + pip_args = ["pip3", "install"] + + # If we are in a global install, use `--user` to install so that you do not + # need root access in order to initialize linters. + # + # However, `pip install --user` interacts poorly with virtualenvs (see: + # https://bit.ly/3vD4kvl) and conda (see: https://bit.ly/3KG7ZfU). So in + # these cases perform a regular installation. + in_conda = os.environ.get("CONDA_PREFIX") is not None + in_virtualenv = os.environ.get("VIRTUAL_ENV") is not None + if not in_conda and not in_virtualenv: + pip_args.append("--user") + + pip_args.extend(args.packages) + + for package in args.packages: + package_name, _, version = package.partition("=") + if version == "": + raise RuntimeError( + "Package {package_name} did not have a version specified. " + "Please specify a version to produce a consistent linting experience." + ) + if args.no_black_binary and "black" in package_name: + pip_args.append(f"--no-binary={package_name}") + + dry_run = args.dry_run == "1" + if dry_run: + print(f"Would have run: {pip_args}") + sys.exit(0) + + run_command(pip_args) diff --git a/tools/linter/adapters/ruff_linter.py b/tools/linter/adapters/ruff_linter.py new file mode 100644 index 000000000..451834aa7 --- /dev/null +++ b/tools/linter/adapters/ruff_linter.py @@ -0,0 +1,462 @@ +"""Adapter for https://github.com/charliermarsh/ruff.""" + +from __future__ import annotations + +import argparse +import concurrent.futures +import dataclasses +import enum +import json +import logging +import os +import subprocess +import sys +import time +from typing import Any, BinaryIO + +LINTER_CODE = "RUFF" +IS_WINDOWS: bool = os.name == "nt" + + +def eprint(*args: Any, **kwargs: Any) -> None: + """Print to stderr.""" + print(*args, file=sys.stderr, flush=True, **kwargs) + + +class LintSeverity(str, enum.Enum): + """Severity of a lint message.""" + + ERROR = "error" + WARNING = "warning" + ADVICE = "advice" + DISABLED = "disabled" + + +@dataclasses.dataclass(frozen=True) +class LintMessage: + """A lint message defined by https://docs.rs/lintrunner/latest/lintrunner/lint_message/struct.LintMessage.html.""" + + path: str | None + line: int | None + char: int | None + code: str + severity: LintSeverity + name: str + original: str | None + replacement: str | None + description: str | None + + def asdict(self) -> dict[str, Any]: + return dataclasses.asdict(self) + + def display(self) -> None: + """Print to stdout for lintrunner to consume.""" + print(json.dumps(self.asdict()), flush=True) + + +def as_posix(name: str) -> str: + return name.replace("\\", "/") if IS_WINDOWS else name + + +def _run_command( + args: list[str], + *, + timeout: int | None, + stdin: BinaryIO | None, + input: bytes | None, + check: bool, + cwd: os.PathLike[Any] | None, +) -> subprocess.CompletedProcess[bytes]: + logging.debug("$ %s", " ".join(args)) + start_time = time.monotonic() + try: + if input is not None: + return subprocess.run( + args, + capture_output=True, + shell=False, + input=input, + timeout=timeout, + check=check, + cwd=cwd, + ) + + return subprocess.run( + args, + stdin=stdin, + capture_output=True, + shell=False, + timeout=timeout, + check=check, + cwd=cwd, + ) + finally: + end_time = time.monotonic() + logging.debug("took %dms", (end_time - start_time) * 1000) + + +def run_command( + args: list[str], + *, + retries: int = 0, + timeout: int | None = None, + stdin: BinaryIO | None = None, + input: bytes | None = None, + check: bool = False, + cwd: os.PathLike[Any] | None = None, +) -> subprocess.CompletedProcess[bytes]: + remaining_retries = retries + while True: + try: + return _run_command( + args, timeout=timeout, stdin=stdin, input=input, check=check, cwd=cwd + ) + except subprocess.TimeoutExpired as err: + if remaining_retries == 0: + raise err + remaining_retries -= 1 + logging.warning( + "(%s/%s) Retrying because command failed with: %r", + retries - remaining_retries, + retries, + err, + ) + time.sleep(1) + + +def add_default_options(parser: argparse.ArgumentParser) -> None: + """Add default options to a parser. + + This should be called the last in the chain of add_argument calls. + """ + parser.add_argument( + "--retries", + type=int, + default=3, + help="number of times to retry if the linter times out.", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="verbose logging", + ) + parser.add_argument( + "filenames", + nargs="+", + help="paths to lint", + ) + + +def explain_rule(code: str) -> str: + proc = run_command( + ["ruff", "rule", "--format=json", code], + check=True, + ) + rule = json.loads(str(proc.stdout, "utf-8").strip()) + return f"\n{rule['linter']}: {rule['summary']}" + + +def get_issue_severity(code: str) -> LintSeverity: + # "B901": `return x` inside a generator + # "B902": Invalid first argument to a method + # "B903": __slots__ efficiency + # "B950": Line too long + # "C4": Flake8 Comprehensions + # "C9": Cyclomatic complexity + # "E2": PEP8 horizontal whitespace "errors" + # "E3": PEP8 blank line "errors" + # "E5": PEP8 line length "errors" + # "T400": type checking Notes + # "T49": internal type checker errors or unmatched messages + if any( + code.startswith(x) + for x in ( + "B9", + "C4", + "C9", + "E2", + "E3", + "E5", + "T400", + "T49", + "PLC", + "PLR", + ) + ): + return LintSeverity.ADVICE + + # "F821": Undefined name + # "E999": syntax error + if any(code.startswith(x) for x in ("F821", "E999", "PLE")): + return LintSeverity.ERROR + + # "F": PyFlakes Error + # "B": flake8-bugbear Error + # "E": PEP8 "Error" + # "W": PEP8 Warning + # possibly other plugins... + return LintSeverity.WARNING + + +def format_lint_message( + message: str, code: str, rules: dict[str, str], show_disable: bool +) -> str: + if rules: + message += f".\n{rules.get(code) or ''}" + message += ".\nSee https://beta.ruff.rs/docs/rules/" + if show_disable: + message += f".\n\nTo disable, use ` # noqa: {code}`" + return message + + +def check_files( + filenames: list[str], + severities: dict[str, LintSeverity], + *, + config: str | None, + retries: int, + timeout: int, + explain: bool, + show_disable: bool, +) -> list[LintMessage]: + try: + proc = run_command( + [ + sys.executable, + "-m", + "ruff", + "--exit-zero", + "--quiet", + "--format=json", + *([f"--config={config}"] if config else []), + *filenames, + ], + retries=retries, + timeout=timeout, + check=True, + ) + except (OSError, subprocess.CalledProcessError) as err: + return [ + LintMessage( + path=None, + line=None, + char=None, + code=LINTER_CODE, + severity=LintSeverity.ERROR, + name="command-failed", + original=None, + replacement=None, + description=( + f"Failed due to {err.__class__.__name__}:\n{err}" + if not isinstance(err, subprocess.CalledProcessError) + else ( + f"COMMAND (exit code {err.returncode})\n" + f"{' '.join(as_posix(x) for x in err.cmd)}\n\n" + f"STDERR\n{err.stderr.decode('utf-8').strip() or '(empty)'}\n\n" + f"STDOUT\n{err.stdout.decode('utf-8').strip() or '(empty)'}" + ) + ), + ) + ] + + stdout = str(proc.stdout, "utf-8").strip() + vulnerabilities = json.loads(stdout) + + if explain: + all_codes = {v["code"] for v in vulnerabilities} + rules = {code: explain_rule(code) for code in all_codes} + else: + rules = {} + + return [ + LintMessage( + path=vuln["filename"], + name=vuln["code"], + description=( + format_lint_message( + vuln["message"], + vuln["code"], + rules, + show_disable, + ) + ), + line=int(vuln["location"]["row"]), + char=int(vuln["location"]["column"]), + code=LINTER_CODE, + severity=severities.get(vuln["code"], get_issue_severity(vuln["code"])), + original=None, + replacement=None, + ) + for vuln in vulnerabilities + ] + + +def check_file_for_fixes( + filename: str, + *, + config: str | None, + retries: int, + timeout: int, +) -> list[LintMessage]: + try: + with open(filename, "rb") as f: + original = f.read() + with open(filename, "rb") as f: + proc_fix = run_command( + [ + sys.executable, + "-m", + "ruff", + "--fix-only", + "--exit-zero", + *([f"--config={config}"] if config else []), + "--stdin-filename", + filename, + "-", + ], + stdin=f, + retries=retries, + timeout=timeout, + check=True, + ) + except (OSError, subprocess.CalledProcessError) as err: + return [ + LintMessage( + path=None, + line=None, + char=None, + code=LINTER_CODE, + severity=LintSeverity.ERROR, + name="command-failed", + original=None, + replacement=None, + description=( + f"Failed due to {err.__class__.__name__}:\n{err}" + if not isinstance(err, subprocess.CalledProcessError) + else ( + f"COMMAND (exit code {err.returncode})\n" + f"{' '.join(as_posix(x) for x in err.cmd)}\n\n" + f"STDERR\n{err.stderr.decode('utf-8').strip() or '(empty)'}\n\n" + f"STDOUT\n{err.stdout.decode('utf-8').strip() or '(empty)'}" + ) + ), + ) + ] + + replacement = proc_fix.stdout + if original == replacement: + return [] + + return [ + LintMessage( + path=filename, + name="format", + description="Run `lintrunner -a` to apply this patch.", + line=None, + char=None, + code=LINTER_CODE, + severity=LintSeverity.WARNING, + original=original.decode("utf-8"), + replacement=replacement.decode("utf-8"), + ) + ] + + +def main() -> None: + parser = argparse.ArgumentParser( + description=f"Ruff linter. Linter code: {LINTER_CODE}. Use with RUFF-FIX to auto-fix issues.", + fromfile_prefix_chars="@", + ) + parser.add_argument( + "--config", + default=None, + help="Path to the `pyproject.toml` or `ruff.toml` file to use for configuration", + ) + parser.add_argument( + "--explain", + action="store_true", + help="Explain a rule", + ) + parser.add_argument( + "--show-disable", + action="store_true", + help="Show how to disable a lint message", + ) + parser.add_argument( + "--timeout", + default=90, + type=int, + help="Seconds to wait for ruff", + ) + parser.add_argument( + "--severity", + action="append", + help="map code to severity (e.g. `F401:advice`). This option can be used multiple times.", + ) + parser.add_argument( + "--no-fix", + action="store_true", + help="Do not suggest fixes", + ) + add_default_options(parser) + args = parser.parse_args() + + logging.basicConfig( + format="<%(threadName)s:%(levelname)s> %(message)s", + level=logging.NOTSET + if args.verbose + else logging.DEBUG + if len(args.filenames) < 1000 + else logging.INFO, + stream=sys.stderr, + ) + + severities: dict[str, LintSeverity] = {} + if args.severity: + for severity in args.severity: + parts = severity.split(":", 1) + assert len(parts) == 2, f"invalid severity `{severity}`" + severities[parts[0]] = LintSeverity(parts[1]) + + lint_messages = check_files( + args.filenames, + severities=severities, + config=args.config, + retries=args.retries, + timeout=args.timeout, + explain=args.explain, + show_disable=args.show_disable, + ) + for lint_message in lint_messages: + lint_message.display() + + if args.no_fix or not lint_messages: + # If we're not fixing, we can exit early + return + + files_with_lints = {lint.path for lint in lint_messages if lint.path is not None} + with concurrent.futures.ThreadPoolExecutor( + max_workers=os.cpu_count(), + thread_name_prefix="Thread", + ) as executor: + futures = { + executor.submit( + check_file_for_fixes, + path, + config=args.config, + retries=args.retries, + timeout=args.timeout, + ): path + for path in files_with_lints + } + for future in concurrent.futures.as_completed(futures): + try: + for lint_message in future.result(): + lint_message.display() + except Exception: # Catch all exceptions for lintrunner + logging.critical('Failed at "%s".', futures[future]) + raise + + +if __name__ == "__main__": + main()