Skip to content

Commit

Permalink
Upgrade remaining sources to Python 3.9
Browse files Browse the repository at this point in the history
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
  • Loading branch information
superbobry committed Dec 12, 2023
1 parent b077483 commit 0db6083
Show file tree
Hide file tree
Showing 139 changed files with 1,336 additions and 1,090 deletions.
18 changes: 10 additions & 8 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Sequence
import functools
from functools import partial
import logging
from typing import Any, Callable, Optional, Union
from typing import Any, Callable
import types

import numpy as np
Expand Down Expand Up @@ -131,8 +133,8 @@ def policy(prim, *args, **params):

@api_boundary
def checkpoint(fun: Callable, *, prevent_cse: bool = True,
policy: Optional[Callable[..., bool]] = None,
static_argnums: Union[int, tuple[int, ...]] = (),
policy: Callable[..., bool] | None = None,
static_argnums: int | tuple[int, ...] = (),
) -> Callable:
"""Make ``fun`` recompute internal linearization points when differentiated.
Expand Down Expand Up @@ -574,8 +576,8 @@ def remat_transpose(reduce_axes, out_cts, *in_primals, jaxpr, **params):
ad.reducing_transposes[remat_p] = remat_transpose

# TODO(mattjj): move this to ad.py
def transpose_jaxpr(jaxpr: core.ClosedJaxpr, in_linear: Union[bool, Sequence[bool]],
out_zeros: Union[bool, Sequence[bool]],
def transpose_jaxpr(jaxpr: core.ClosedJaxpr, in_linear: bool | Sequence[bool],
out_zeros: bool | Sequence[bool],
reduce_axes: Sequence[core.AxisName],
) -> tuple[core.ClosedJaxpr, list[bool]]:
if type(in_linear) is bool:
Expand Down Expand Up @@ -639,7 +641,7 @@ def remat_vmap(spmd_axis_name, axis_size, axis_name, main_type, args, dims, *,

# TODO(mattjj,sharadmv): de-duplicate with pe.dce_jaxpr_call_rule
def remat_dce(used_outputs: list[bool], eqn: core.JaxprEqn
) -> tuple[list[bool], Optional[core.JaxprEqn]]:
) -> tuple[list[bool], core.JaxprEqn | None]:
new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs)
new_params = dict(eqn.params, jaxpr=new_jaxpr)
if not any(used_inputs) and not any(used_outputs) and not new_jaxpr.effects:
Expand Down Expand Up @@ -779,8 +781,8 @@ def checkpoint_wrapper(
*,
concrete: bool = False,
prevent_cse: bool = True,
static_argnums: Union[int, tuple[int, ...]] = (),
policy: Optional[Callable[..., bool]] = None,
static_argnums: int | tuple[int, ...] = (),
policy: Callable[..., bool] | None = None,
) -> Callable:
if concrete:
msg = ("The 'concrete' option to jax.checkpoint / jax.remat is deprecated; "
Expand Down
26 changes: 14 additions & 12 deletions jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Iterable, Sequence
import inspect
import operator
from functools import partial
from typing import Any, Callable, Optional, Union
from typing import Any, Callable
import warnings

import numpy as np
Expand All @@ -39,7 +41,7 @@

map = safe_map

def _ensure_index(x: Any) -> Union[int, tuple[int, ...]]:
def _ensure_index(x: Any) -> int | tuple[int, ...]:
"""Ensure x is either an index or a tuple of indices."""
x = core.concrete_or_error(None, x, "expected a static index or sequence of indices.")
try:
Expand All @@ -60,7 +62,7 @@ def _ensure_str(x: str) -> str:
raise TypeError(f"argument is not a string: {x}")
return x

def _ensure_str_tuple(x: Union[str, Iterable[str]]) -> tuple[str, ...]:
def _ensure_str_tuple(x: str | Iterable[str]) -> tuple[str, ...]:
"""Convert x to a tuple of strings."""
if isinstance(x, str):
return (x,)
Expand Down Expand Up @@ -97,7 +99,7 @@ def apply_flat_fun_nokwargs(fun, io_tree, py_args):

def flattened_fun_in_tree(
fn: lu.WrappedFun
) -> Optional[tuple[PyTreeDef, Callable[[], PyTreeDef], bool]]:
) -> tuple[PyTreeDef, Callable[[], PyTreeDef], bool] | None:
# This implementation relies on internal details of linear_util.py's
# WrappedFun, but it's for the worthy cause of better user error messages.
# It can fail (i.e. return None) if its WrappedFun argument is not transformed
Expand Down Expand Up @@ -473,8 +475,8 @@ def check_callable(fun):

def infer_argnums_and_argnames(
sig: inspect.Signature,
argnums: Union[int, Iterable[int], None],
argnames: Union[str, Iterable[str], None],
argnums: int | Iterable[int] | None,
argnames: str | Iterable[str] | None,
) -> tuple[tuple[int, ...], tuple[str, ...]]:
"""Infer missing argnums and argnames for a function with inspect."""
if argnums is None and argnames is None:
Expand Down Expand Up @@ -612,15 +614,15 @@ def api_hook(fun, tag: str):

def debug_info(traced_for: str, fun: Callable, args: tuple[Any],
kwargs: dict[str, Any], static_argnums: tuple[int, ...],
static_argnames: tuple[str, ...]) -> Optional[TracingDebugInfo]:
static_argnames: tuple[str, ...]) -> TracingDebugInfo | None:
"""Try to build trace-time debug info for fun when applied to args/kwargs."""
src = fun_sourceinfo(fun)
arg_names = _arg_names(fun, args, kwargs, static_argnums, static_argnames)
if src is None or arg_names is None: return None
return TracingDebugInfo(traced_for, src, arg_names, None)

# TODO(mattjj): make this function internal to this module
def fun_sourceinfo(fun: Callable) -> Optional[str]:
def fun_sourceinfo(fun: Callable) -> str | None:
while isinstance(fun, partial):
fun = fun.func
fun = inspect.unwrap(fun)
Expand All @@ -632,7 +634,7 @@ def fun_sourceinfo(fun: Callable) -> Optional[str]:
return None

def _arg_names(fn, args, kwargs, static_argnums, static_argnames,
) -> Optional[tuple[str, ...]]:
) -> tuple[str, ...] | None:
static = object()
static_argnums_ = _ensure_inbounds(True, len(args), static_argnums)
static_argnames_ = set(static_argnames)
Expand All @@ -651,8 +653,8 @@ def result_paths(*args, **kwargs):
ans = yield args, kwargs
yield ans, [keystr(path) for path, _ in generate_key_paths(ans)]

def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: Optional[TracingDebugInfo],
result_paths: Optional[tuple[Optional[str], ...]] = None,
def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: TracingDebugInfo | None,
result_paths: tuple[str | None, ...] | None = None,
) -> core.Jaxpr:
"""Add debug info to jaxpr, given trace-time debug info and result paths."""
if trace_debug is None:
Expand All @@ -665,7 +667,7 @@ def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: Optional[TracingDebugInfo],
trace_debug.arg_names, tuple(result_paths))
return jaxpr.replace(debug_info=debug_info)

def debug_info_final(f: lu.WrappedFun, dbg: Optional[TracingDebugInfo],
def debug_info_final(f: lu.WrappedFun, dbg: TracingDebugInfo | None,
res_paths: Callable[[], tuple[str, ...]]) -> lu.WrappedFun:
"Attach trace-time debug info and result paths lazy thunk to an lu.WrappedFun"
if dbg is None: return f
Expand Down
4 changes: 3 additions & 1 deletion jax/_src/basearray.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

# Note that type annotations for this file are defined in basearray.pyi

from __future__ import annotations

import abc
import numpy as np
from typing import Any, Union
Expand Down Expand Up @@ -73,7 +75,7 @@ def shape(self) -> tuple[int, ...]:

# Documentation for sharding-related methods and properties defined on ArrayImpl:
@abc.abstractmethod
def addressable_data(self, index: int) -> "Array":
def addressable_data(self, index: int) -> Array:
"""Return an array of the addressable data at a particular index."""

@property
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/checkify.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import dataclasses
import functools
import itertools as it
from typing import Union, Callable, TypeVar, Any
from typing import Callable, TypeVar, Any, Union

import numpy as np

Expand Down
7 changes: 4 additions & 3 deletions jax/_src/clusters/cloud_tpu_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import os
import re
import socket
import time
from typing import Optional
from jax._src import clusters
from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm

Expand Down Expand Up @@ -99,7 +100,7 @@ def get_process_id(cls) -> int:
return int(get_metadata('agent-worker-number'))

@classmethod
def get_local_process_id(cls) -> Optional[int]:
def get_local_process_id(cls) -> int | None:
return None

class MultisliceGceTpuCluster(clusters.ClusterEnv):
Expand Down Expand Up @@ -147,7 +148,7 @@ def get_process_id(cls) -> int:
return process_id_in_slice + slice_id * processes_per_slice

@classmethod
def get_local_process_id(cls) -> Optional[int]:
def get_local_process_id(cls) -> int | None:
return None

@staticmethod
Expand Down
19 changes: 10 additions & 9 deletions jax/_src/clusters/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Sequence
import logging
from typing import Optional
from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm

logger = logging.getLogger(__name__)
Expand All @@ -29,7 +30,7 @@ class ClusterEnv:
:class:`ClusterEnv` subclasses are automatically detected when imported.
"""

_cluster_types: list[type['ClusterEnv']] = []
_cluster_types: list[type[ClusterEnv]] = []

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
Expand All @@ -38,12 +39,12 @@ def __init_subclass__(cls, **kwargs):
@classmethod
# pytype: disable=bad-return-type
def auto_detect_unset_distributed_params(cls,
coordinator_address: Optional[str],
num_processes: Optional[int],
process_id: Optional[int],
local_device_ids: Optional[Sequence[int]]
) -> tuple[Optional[str], Optional[int], Optional[int],
Optional[Sequence[int]]]:
coordinator_address: str | None,
num_processes: int | None,
process_id: int | None,
local_device_ids: Sequence[int] | None
) -> tuple[str | None, int | None, int | None,
Sequence[int] | None]:
if all(p is not None for p in (coordinator_address, num_processes,
process_id, local_device_ids)):
return (coordinator_address, num_processes, process_id,
Expand Down Expand Up @@ -100,7 +101,7 @@ def get_process_id(cls) -> int:
raise NotImplementedError("ClusterEnv subclasses must implement get_process_id")

@classmethod
def get_local_process_id(cls) -> Optional[int]:
def get_local_process_id(cls) -> int | None:
""" Get index of current process inside a host.
The method is only useful to support single device per process.
Expand Down
5 changes: 3 additions & 2 deletions jax/_src/clusters/ompi_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import os
import re
from typing import Optional
from jax._src import clusters

# OMPI_MCA_orte_hnp_uri exists only when processes are launched via mpirun or mpiexec
Expand Down Expand Up @@ -55,5 +56,5 @@ def get_process_id(cls) -> int:
return int(os.environ[_PROCESS_ID])

@classmethod
def get_local_process_id(cls) -> Optional[int]:
def get_local_process_id(cls) -> int | None:
return int(os.environ[_LOCAL_PROCESS_ID])
5 changes: 3 additions & 2 deletions jax/_src/clusters/slurm_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import os
from typing import Optional
from jax._src import clusters

_JOBID_PARAM = 'SLURM_JOB_ID'
Expand Down Expand Up @@ -58,5 +59,5 @@ def get_process_id(cls) -> int:
return int(os.environ[_PROCESS_ID])

@classmethod
def get_local_process_id(cls) -> Optional[int]:
def get_local_process_id(cls) -> int | None:
return int(os.environ[_LOCAL_PROCESS_ID])
9 changes: 5 additions & 4 deletions jax/_src/compilation_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import logging
import threading
from typing import Optional
import zlib

import numpy as np
Expand All @@ -35,7 +36,7 @@

logger = logging.getLogger(__name__)

_cache: Optional[CacheInterface] = None
_cache: CacheInterface | None = None

_cache_initialized: bool = False

Expand Down Expand Up @@ -102,7 +103,7 @@ def _initialize_cache() -> None:
logger.debug("Initialized persistent compilation cache at %s", path)


def _get_cache() -> Optional[CacheInterface]:
def _get_cache() -> CacheInterface | None:
# TODO(b/289098047): consider making this an API and changing the callers of
# get_executable_and_time() and put_executable_and_time() to call get_cache()
# and passing the result to them.
Expand All @@ -113,7 +114,7 @@ def _get_cache() -> Optional[CacheInterface]:

def get_executable_and_time(
cache_key: str, compile_options, backend
) -> tuple[Optional[xla_client.LoadedExecutable], Optional[int]]:
) -> tuple[xla_client.LoadedExecutable | None, int | None]:
"""Returns the cached executable and its compilation time if present, or None
otherwise.
"""
Expand Down
12 changes: 6 additions & 6 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.
from __future__ import annotations

import collections
from collections import namedtuple
import collections # noqa: F401
from collections import defaultdict, namedtuple
from collections.abc import Generator, Hashable, Iterable, Iterator, Sequence
from contextlib import contextmanager
from dataclasses import dataclass
Expand All @@ -28,8 +28,8 @@
from operator import attrgetter
import threading
import types
from typing import (Any, Callable, ClassVar, DefaultDict, Generic, NamedTuple,
TypeVar, Union, cast, overload)
from typing import (Any, Callable, ClassVar, Generic, NamedTuple, TypeVar,
cast, overload, Union)
import warnings
from weakref import ref

Expand Down Expand Up @@ -3031,10 +3031,10 @@ class JaxprPpSettings(NamedTuple):
# A JaxprPpContext allows us to globally uniquify variable names within nested
# Jaxprs.
class JaxprPpContext:
var_ids: DefaultDict[Var, int]
var_ids: defaultdict[Var, int]

def __init__(self):
self.var_ids = collections.defaultdict(it.count().__next__, {})
self.var_ids = defaultdict(it.count().__next__, {})


def pp_var(v: Var, context: JaxprPpContext) -> str:
Expand Down
6 changes: 4 additions & 2 deletions jax/_src/custom_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import functools
import operator
from typing import Callable, Optional
from typing import Callable

from jax import lax
from jax._src import api
Expand Down Expand Up @@ -47,7 +49,7 @@
@custom_api_util.register_custom_decorator_type
class custom_vmap:
fun: Callable
vmap_rule: Optional[Callable]
vmap_rule: Callable | None

def __init__(self, fun: Callable):
functools.update_wrapper(self, fun)
Expand Down
Loading

0 comments on commit 0db6083

Please sign in to comment.