From d60a077d4498d9c735553700b2467213e290bbbb Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 11 Dec 2023 13:59:29 +0000 Subject: [PATCH] Upgrade remaining sources to Python 3.9 This PR is a follow up to #18881. The changes were generated by adding from __future__ import annotations to the files which did not already have them and running pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py --- jax/_src/ad_checkpoint.py | 18 +- jax/_src/api_util.py | 26 +-- jax/_src/basearray.py | 4 +- jax/_src/checkify.py | 2 +- jax/_src/clusters/cloud_tpu_cluster.py | 7 +- jax/_src/clusters/cluster.py | 19 ++- jax/_src/clusters/ompi_cluster.py | 5 +- jax/_src/clusters/slurm_cluster.py | 5 +- jax/_src/compilation_cache.py | 9 +- jax/_src/core.py | 12 +- jax/_src/custom_batching.py | 6 +- jax/_src/custom_derivatives.py | 12 +- jax/_src/custom_transpose.py | 6 +- jax/_src/debugging.py | 10 +- jax/_src/distributed.py | 28 ++-- jax/_src/environment_info.py | 7 +- jax/_src/image/scale.py | 10 +- .../export_back_compat_test_util.py | 18 +- jax/_src/internal_test_util/lax_test_util.py | 6 +- jax/_src/internal_test_util/test_harnesses.py | 20 ++- jax/_src/interpreters/ad.py | 8 +- jax/_src/interpreters/batching.py | 5 +- jax/_src/interpreters/mlir.py | 4 +- jax/_src/interpreters/partial_eval.py | 33 ++-- jax/_src/interpreters/pxla.py | 8 +- jax/_src/interpreters/xla.py | 17 +- jax/_src/jaxpr_util.py | 32 ++-- jax/_src/lax/control_flow/common.py | 8 +- jax/_src/lax/control_flow/for_loop.py | 10 +- jax/_src/lax/convolution.py | 35 ++-- jax/_src/lax/fft.py | 5 +- jax/_src/lax/lax.py | 12 +- jax/_src/lax/linalg.py | 12 +- jax/_src/lax/other.py | 26 +-- jax/_src/lax/parallel.py | 5 +- jax/_src/lax/qdwh.py | 5 +- jax/_src/lax/slicing.py | 50 +++--- jax/_src/lax/svd.py | 12 +- jax/_src/lax/windowed_reductions.py | 32 ++-- jax/_src/lib/__init__.py | 6 +- jax/_src/maps.py | 27 +-- jax/_src/monitoring.py | 13 +- jax/_src/nn/functions.py | 40 ++--- jax/_src/nn/initializers.py | 48 +++--- jax/_src/numpy/array_methods.py | 14 +- jax/_src/numpy/fft.py | 89 +++++----- jax/_src/numpy/index_tricks.py | 10 +- jax/_src/numpy/lax_numpy.py | 3 +- jax/_src/numpy/linalg.py | 29 ++-- jax/_src/numpy/polynomial.py | 10 +- jax/_src/numpy/reductions.py | 156 +++++++++--------- jax/_src/numpy/setops.py | 20 ++- jax/_src/numpy/ufunc_api.py | 44 ++--- jax/_src/numpy/ufuncs.py | 8 +- jax/_src/op_shardings.py | 8 +- jax/_src/ops/scatter.py | 40 ++--- jax/_src/ops/special.py | 16 +- jax/_src/pjit.py | 44 ++--- jax/_src/pretty_printer.py | 34 ++-- jax/_src/profiler.py | 12 +- jax/_src/random.py | 3 +- jax/_src/scipy/fft.py | 25 +-- jax/_src/scipy/linalg.py | 81 ++++----- jax/_src/scipy/optimize/_lbfgs.py | 27 +-- jax/_src/scipy/optimize/bfgs.py | 23 +-- jax/_src/scipy/optimize/line_search.py | 74 +++++---- jax/_src/scipy/optimize/minimize.py | 20 ++- jax/_src/scipy/signal.py | 36 ++-- jax/_src/scipy/spatial/transform.py | 8 +- jax/_src/scipy/special.py | 10 +- jax/_src/scipy/stats/_core.py | 7 +- jax/_src/sharding_specs.py | 10 +- jax/_src/source_info_util.py | 26 +-- jax/_src/state/primitives.py | 3 +- jax/_src/state/types.py | 1 + jax/_src/third_party/scipy/signal_helper.py | 6 +- jax/_src/traceback_util.py | 6 +- jax/_src/util.py | 14 +- jax/_src/xla_bridge.py | 50 +++--- jax/collect_profile.py | 7 +- jax/example_libraries/optimizers.py | 6 +- .../array_api/_manipulation_functions.py | 14 +- .../array_serialization/serialization.py | 24 +-- jax/experimental/custom_partitioning.py | 10 +- jax/experimental/export/export.py | 38 +++-- jax/experimental/export/serialization.py | 3 +- .../export/serialization_generated.py | 22 +-- jax/experimental/export/shape_poly.py | 72 ++++---- jax/experimental/host_callback.py | 7 +- jax/experimental/jax2tf/call_tf.py | 5 +- jax/experimental/jax2tf/examples/mnist_lib.py | 5 +- .../jax2tf/examples/saved_model_lib.py | 6 +- jax/experimental/jax2tf/impl_no_xla.py | 10 +- jax/experimental/jax2tf/jax2tf.py | 60 +++---- .../jax2tf/tests/back_compat_tf_test.py | 6 +- .../jax2tf/tests/cross_compilation_check.py | 5 +- .../tests/flax_models/bilstm_classifier.py | 28 ++-- .../tests/flax_models/transformer_lm1b.py | 6 +- .../tests/flax_models/transformer_nlp_seq.py | 6 +- .../tests/flax_models/transformer_wmt.py | 6 +- .../jax2tf/tests/jax2tf_limitations.py | 18 +- .../jax2tf/tests/model_harness.py | 6 +- .../jax2tf/tests/shape_poly_test.py | 36 ++-- jax/experimental/jax2tf/tests/tf_test_util.py | 9 +- jax/experimental/key_reuse/_common.py | 8 +- jax/experimental/key_reuse/_core.py | 6 +- jax/experimental/key_reuse/_forwarding.py | 12 +- jax/experimental/key_reuse/_simple.py | 10 +- jax/experimental/mesh_utils.py | 8 +- jax/experimental/multihost_utils.py | 4 +- jax/experimental/pallas/ops/layer_norm.py | 14 +- jax/experimental/pallas/ops/rms_norm.py | 14 +- jax/experimental/serialize_executable.py | 4 +- jax/experimental/shard_map.py | 2 +- jax/experimental/sparse/ad.py | 12 +- jax/experimental/sparse/api.py | 9 +- jax/experimental/sparse/linalg.py | 10 +- jax/experimental/sparse/test_util.py | 10 +- jax/experimental/topologies.py | 4 +- jaxlib/gpu_prng.py | 6 +- jaxlib/hlo_helpers.py | 22 +-- jaxlib/mosaic/python/apply_vector_layout.py | 41 ++--- tests/api_test.py | 5 +- tests/batching_test.py | 14 +- tests/export_harnesses_multi_platform_test.py | 6 +- tests/host_callback_test.py | 6 +- tests/lax_numpy_indexing_test.py | 5 +- tests/lax_numpy_test.py | 7 +- tests/lax_vmap_test.py | 7 +- tests/pallas/indexing_test.py | 8 +- tests/pmap_test.py | 5 +- tests/random_test.py | 8 +- tests/shape_poly_test.py | 19 ++- tests/shard_map_test.py | 18 +- tests/state_test.py | 14 +- tests/tree_util_test.py | 2 +- tests/typing_test.py | 9 +- tests/xmap_test.py | 8 +- 138 files changed, 1326 insertions(+), 1076 deletions(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 44a31b19fce1..996a80fce0b3 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from collections.abc import Sequence import functools from functools import partial import logging -from typing import Any, Callable, Optional, Union +from typing import Any, Callable import types import numpy as np @@ -131,8 +133,8 @@ def policy(prim, *args, **params): @api_boundary def checkpoint(fun: Callable, *, prevent_cse: bool = True, - policy: Optional[Callable[..., bool]] = None, - static_argnums: Union[int, tuple[int, ...]] = (), + policy: Callable[..., bool] | None = None, + static_argnums: int | tuple[int, ...] = (), ) -> Callable: """Make ``fun`` recompute internal linearization points when differentiated. @@ -574,8 +576,8 @@ def remat_transpose(reduce_axes, out_cts, *in_primals, jaxpr, **params): ad.reducing_transposes[remat_p] = remat_transpose # TODO(mattjj): move this to ad.py -def transpose_jaxpr(jaxpr: core.ClosedJaxpr, in_linear: Union[bool, Sequence[bool]], - out_zeros: Union[bool, Sequence[bool]], +def transpose_jaxpr(jaxpr: core.ClosedJaxpr, in_linear: bool | Sequence[bool], + out_zeros: bool | Sequence[bool], reduce_axes: Sequence[core.AxisName], ) -> tuple[core.ClosedJaxpr, list[bool]]: if type(in_linear) is bool: @@ -639,7 +641,7 @@ def remat_vmap(spmd_axis_name, axis_size, axis_name, main_type, args, dims, *, # TODO(mattjj,sharadmv): de-duplicate with pe.dce_jaxpr_call_rule def remat_dce(used_outputs: list[bool], eqn: core.JaxprEqn - ) -> tuple[list[bool], Optional[core.JaxprEqn]]: + ) -> tuple[list[bool], core.JaxprEqn | None]: new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs) new_params = dict(eqn.params, jaxpr=new_jaxpr) if not any(used_inputs) and not any(used_outputs) and not new_jaxpr.effects: @@ -779,8 +781,8 @@ def checkpoint_wrapper( *, concrete: bool = False, prevent_cse: bool = True, - static_argnums: Union[int, tuple[int, ...]] = (), - policy: Optional[Callable[..., bool]] = None, + static_argnums: int | tuple[int, ...] = (), + policy: Callable[..., bool] | None = None, ) -> Callable: if concrete: msg = ("The 'concrete' option to jax.checkpoint / jax.remat is deprecated; " diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 5f5e631005a9..c65270f54101 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from collections.abc import Iterable, Sequence import inspect import operator from functools import partial -from typing import Any, Callable, Optional, Union +from typing import Any, Callable import warnings import numpy as np @@ -39,7 +41,7 @@ map = safe_map -def _ensure_index(x: Any) -> Union[int, tuple[int, ...]]: +def _ensure_index(x: Any) -> int | tuple[int, ...]: """Ensure x is either an index or a tuple of indices.""" x = core.concrete_or_error(None, x, "expected a static index or sequence of indices.") try: @@ -60,7 +62,7 @@ def _ensure_str(x: str) -> str: raise TypeError(f"argument is not a string: {x}") return x -def _ensure_str_tuple(x: Union[str, Iterable[str]]) -> tuple[str, ...]: +def _ensure_str_tuple(x: str | Iterable[str]) -> tuple[str, ...]: """Convert x to a tuple of strings.""" if isinstance(x, str): return (x,) @@ -97,7 +99,7 @@ def apply_flat_fun_nokwargs(fun, io_tree, py_args): def flattened_fun_in_tree( fn: lu.WrappedFun - ) -> Optional[tuple[PyTreeDef, Callable[[], PyTreeDef], bool]]: + ) -> tuple[PyTreeDef, Callable[[], PyTreeDef], bool] | None: # This implementation relies on internal details of linear_util.py's # WrappedFun, but it's for the worthy cause of better user error messages. # It can fail (i.e. return None) if its WrappedFun argument is not transformed @@ -473,8 +475,8 @@ def check_callable(fun): def infer_argnums_and_argnames( sig: inspect.Signature, - argnums: Union[int, Iterable[int], None], - argnames: Union[str, Iterable[str], None], + argnums: int | Iterable[int] | None, + argnames: str | Iterable[str] | None, ) -> tuple[tuple[int, ...], tuple[str, ...]]: """Infer missing argnums and argnames for a function with inspect.""" if argnums is None and argnames is None: @@ -612,7 +614,7 @@ def api_hook(fun, tag: str): def debug_info(traced_for: str, fun: Callable, args: tuple[Any], kwargs: dict[str, Any], static_argnums: tuple[int, ...], - static_argnames: tuple[str, ...]) -> Optional[TracingDebugInfo]: + static_argnames: tuple[str, ...]) -> TracingDebugInfo | None: """Try to build trace-time debug info for fun when applied to args/kwargs.""" src = fun_sourceinfo(fun) arg_names = _arg_names(fun, args, kwargs, static_argnums, static_argnames) @@ -620,7 +622,7 @@ def debug_info(traced_for: str, fun: Callable, args: tuple[Any], return TracingDebugInfo(traced_for, src, arg_names, None) # TODO(mattjj): make this function internal to this module -def fun_sourceinfo(fun: Callable) -> Optional[str]: +def fun_sourceinfo(fun: Callable) -> str | None: while isinstance(fun, partial): fun = fun.func fun = inspect.unwrap(fun) @@ -632,7 +634,7 @@ def fun_sourceinfo(fun: Callable) -> Optional[str]: return None def _arg_names(fn, args, kwargs, static_argnums, static_argnames, - ) -> Optional[tuple[str, ...]]: + ) -> tuple[str, ...] | None: static = object() static_argnums_ = _ensure_inbounds(True, len(args), static_argnums) static_argnames_ = set(static_argnames) @@ -651,8 +653,8 @@ def result_paths(*args, **kwargs): ans = yield args, kwargs yield ans, [keystr(path) for path, _ in generate_key_paths(ans)] -def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: Optional[TracingDebugInfo], - result_paths: Optional[tuple[Optional[str], ...]] = None, +def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: TracingDebugInfo | None, + result_paths: tuple[str | None, ...] | None = None, ) -> core.Jaxpr: """Add debug info to jaxpr, given trace-time debug info and result paths.""" if trace_debug is None: @@ -665,7 +667,7 @@ def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: Optional[TracingDebugInfo], trace_debug.arg_names, tuple(result_paths)) return jaxpr.replace(debug_info=debug_info) -def debug_info_final(f: lu.WrappedFun, dbg: Optional[TracingDebugInfo], +def debug_info_final(f: lu.WrappedFun, dbg: TracingDebugInfo | None, res_paths: Callable[[], tuple[str, ...]]) -> lu.WrappedFun: "Attach trace-time debug info and result paths lazy thunk to an lu.WrappedFun" if dbg is None: return f diff --git a/jax/_src/basearray.py b/jax/_src/basearray.py index ea881a723a1d..cc33d753b899 100644 --- a/jax/_src/basearray.py +++ b/jax/_src/basearray.py @@ -14,6 +14,8 @@ # Note that type annotations for this file are defined in basearray.pyi +from __future__ import annotations + import abc import numpy as np from typing import Any, Union @@ -73,7 +75,7 @@ def shape(self) -> tuple[int, ...]: # Documentation for sharding-related methods and properties defined on ArrayImpl: @abc.abstractmethod - def addressable_data(self, index: int) -> "Array": + def addressable_data(self, index: int) -> Array: """Return an array of the addressable data at a particular index.""" @property diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index e151ea3f2c31..62f881b7ae18 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -17,7 +17,7 @@ import dataclasses import functools import itertools as it -from typing import Union, Callable, TypeVar, Any +from typing import Callable, TypeVar, Any, Union import numpy as np diff --git a/jax/_src/clusters/cloud_tpu_cluster.py b/jax/_src/clusters/cloud_tpu_cluster.py index bf0406c63a2d..a978cf4beff7 100644 --- a/jax/_src/clusters/cloud_tpu_cluster.py +++ b/jax/_src/clusters/cloud_tpu_cluster.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import re import socket import time -from typing import Optional from jax._src import clusters from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm @@ -99,7 +100,7 @@ def get_process_id(cls) -> int: return int(get_metadata('agent-worker-number')) @classmethod - def get_local_process_id(cls) -> Optional[int]: + def get_local_process_id(cls) -> int | None: return None class MultisliceGceTpuCluster(clusters.ClusterEnv): @@ -147,7 +148,7 @@ def get_process_id(cls) -> int: return process_id_in_slice + slice_id * processes_per_slice @classmethod - def get_local_process_id(cls) -> Optional[int]: + def get_local_process_id(cls) -> int | None: return None @staticmethod diff --git a/jax/_src/clusters/cluster.py b/jax/_src/clusters/cluster.py index 5a2f0e774c4f..fb566276a862 100644 --- a/jax/_src/clusters/cluster.py +++ b/jax/_src/clusters/cluster.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from collections.abc import Sequence import logging -from typing import Optional from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm logger = logging.getLogger(__name__) @@ -29,7 +30,7 @@ class ClusterEnv: :class:`ClusterEnv` subclasses are automatically detected when imported. """ - _cluster_types: list[type['ClusterEnv']] = [] + _cluster_types: list[type[ClusterEnv]] = [] def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) @@ -38,12 +39,12 @@ def __init_subclass__(cls, **kwargs): @classmethod # pytype: disable=bad-return-type def auto_detect_unset_distributed_params(cls, - coordinator_address: Optional[str], - num_processes: Optional[int], - process_id: Optional[int], - local_device_ids: Optional[Sequence[int]] - ) -> tuple[Optional[str], Optional[int], Optional[int], - Optional[Sequence[int]]]: + coordinator_address: str | None, + num_processes: int | None, + process_id: int | None, + local_device_ids: Sequence[int] | None + ) -> tuple[str | None, int | None, int | None, + Sequence[int] | None]: if all(p is not None for p in (coordinator_address, num_processes, process_id, local_device_ids)): return (coordinator_address, num_processes, process_id, @@ -100,7 +101,7 @@ def get_process_id(cls) -> int: raise NotImplementedError("ClusterEnv subclasses must implement get_process_id") @classmethod - def get_local_process_id(cls) -> Optional[int]: + def get_local_process_id(cls) -> int | None: """ Get index of current process inside a host. The method is only useful to support single device per process. diff --git a/jax/_src/clusters/ompi_cluster.py b/jax/_src/clusters/ompi_cluster.py index 27e118e56b7b..908af28a027b 100644 --- a/jax/_src/clusters/ompi_cluster.py +++ b/jax/_src/clusters/ompi_cluster.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import re -from typing import Optional from jax._src import clusters # OMPI_MCA_orte_hnp_uri exists only when processes are launched via mpirun or mpiexec @@ -55,5 +56,5 @@ def get_process_id(cls) -> int: return int(os.environ[_PROCESS_ID]) @classmethod - def get_local_process_id(cls) -> Optional[int]: + def get_local_process_id(cls) -> int | None: return int(os.environ[_LOCAL_PROCESS_ID]) diff --git a/jax/_src/clusters/slurm_cluster.py b/jax/_src/clusters/slurm_cluster.py index 2c31458ee77c..5edacb4f5d7a 100644 --- a/jax/_src/clusters/slurm_cluster.py +++ b/jax/_src/clusters/slurm_cluster.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os -from typing import Optional from jax._src import clusters _JOBID_PARAM = 'SLURM_JOB_ID' @@ -58,5 +59,5 @@ def get_process_id(cls) -> int: return int(os.environ[_PROCESS_ID]) @classmethod - def get_local_process_id(cls) -> Optional[int]: + def get_local_process_id(cls) -> int | None: return int(os.environ[_LOCAL_PROCESS_ID]) diff --git a/jax/_src/compilation_cache.py b/jax/_src/compilation_cache.py index 7c38fca1f554..0f465e94e17f 100644 --- a/jax/_src/compilation_cache.py +++ b/jax/_src/compilation_cache.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging import threading -from typing import Optional import zlib import numpy as np @@ -35,7 +36,7 @@ logger = logging.getLogger(__name__) -_cache: Optional[CacheInterface] = None +_cache: CacheInterface | None = None _cache_initialized: bool = False @@ -102,7 +103,7 @@ def _initialize_cache() -> None: logger.debug("Initialized persistent compilation cache at %s", path) -def _get_cache() -> Optional[CacheInterface]: +def _get_cache() -> CacheInterface | None: # TODO(b/289098047): consider making this an API and changing the callers of # get_executable_and_time() and put_executable_and_time() to call get_cache() # and passing the result to them. @@ -113,7 +114,7 @@ def _get_cache() -> Optional[CacheInterface]: def get_executable_and_time( cache_key: str, compile_options, backend -) -> tuple[Optional[xla_client.LoadedExecutable], Optional[int]]: +) -> tuple[xla_client.LoadedExecutable | None, int | None]: """Returns the cached executable and its compilation time if present, or None otherwise. """ diff --git a/jax/_src/core.py b/jax/_src/core.py index c2e586a2ff9e..154f1512412c 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -13,8 +13,8 @@ # limitations under the License. from __future__ import annotations -import collections -from collections import namedtuple +import collections # noqa: F401 +from collections import defaultdict, namedtuple from collections.abc import Generator, Hashable, Iterable, Iterator, Sequence from contextlib import contextmanager from dataclasses import dataclass @@ -28,8 +28,8 @@ from operator import attrgetter import threading import types -from typing import (Any, Callable, ClassVar, DefaultDict, Generic, NamedTuple, - TypeVar, Union, cast, overload) +from typing import (Any, Callable, ClassVar, Generic, NamedTuple, TypeVar, + cast, overload, Union) import warnings from weakref import ref @@ -3031,10 +3031,10 @@ class JaxprPpSettings(NamedTuple): # A JaxprPpContext allows us to globally uniquify variable names within nested # Jaxprs. class JaxprPpContext: - var_ids: DefaultDict[Var, int] + var_ids: defaultdict[Var, int] def __init__(self): - self.var_ids = collections.defaultdict(it.count().__next__, {}) + self.var_ids = defaultdict(it.count().__next__, {}) def pp_var(v: Var, context: JaxprPpContext) -> str: diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py index 34d3ca9fac02..c8a998712898 100644 --- a/jax/_src/custom_batching.py +++ b/jax/_src/custom_batching.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import functools import operator -from typing import Callable, Optional +from typing import Callable from jax import lax from jax._src import api @@ -47,7 +49,7 @@ @custom_api_util.register_custom_decorator_type class custom_vmap: fun: Callable - vmap_rule: Optional[Callable] + vmap_rule: Callable | None def __init__(self, fun: Callable): functools.update_wrapper(self, fun) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index d8ee8ec398ce..ce7317a2c307 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from collections.abc import Sequence import dataclasses from functools import update_wrapper, reduce, partial import inspect -from typing import Any, Callable, Generic, Optional, TypeVar +from typing import Any, Callable, Generic, TypeVar from jax._src import config from jax._src import core @@ -136,7 +138,7 @@ def f_jvp(primals, tangents): """ fun: Callable[..., ReturnValue] nondiff_argnums: tuple[int, ...] - jvp: Optional[Callable[..., tuple[ReturnValue, ReturnValue]]] = None + jvp: Callable[..., tuple[ReturnValue, ReturnValue]] | None = None symbolic_zeros: bool = False def __init__(self, @@ -194,7 +196,7 @@ def f_jvp(primals, tangents): self.symbolic_zeros = symbolic_zeros return jvp - def defjvps(self, *jvps: Optional[Callable[..., ReturnValue]]): + def defjvps(self, *jvps: Callable[..., ReturnValue] | None): """Convenience wrapper for defining JVPs for each argument separately. This convenience wrapper cannot be used together with ``nondiff_argnums``. @@ -493,8 +495,8 @@ def __init__(self, update_wrapper(self, fun) self.fun = fun self.nondiff_argnums = nondiff_argnums - self.fwd: Optional[Callable[..., tuple[ReturnValue, Any]]] = None - self.bwd: Optional[Callable[..., tuple[Any, ...]]] = None + self.fwd: Callable[..., tuple[ReturnValue, Any]] | None = None + self.bwd: Callable[..., tuple[Any, ...]] | None = None self.symbolic_zeros = False __getattr__ = custom_api_util.forward_attr diff --git a/jax/_src/custom_transpose.py b/jax/_src/custom_transpose.py index 48097aae6d25..210d40632c37 100644 --- a/jax/_src/custom_transpose.py +++ b/jax/_src/custom_transpose.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import functools -from typing import Any, Callable, Optional +from typing import Any, Callable from jax._src import ad_util from jax._src import api_util @@ -65,7 +67,7 @@ def transformation_with_aux( @custom_api_util.register_custom_decorator_type class custom_transpose: fun: Callable - transpose: Optional[Callable] = None + transpose: Callable | None = None def __init__(self, fun: Callable): functools.update_wrapper(self, fun) diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index ccf536ea27b5..17a9fb02813c 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -13,11 +13,13 @@ # limitations under the License. """Module for JAX debugging primitives and related functionality.""" +from __future__ import annotations + from collections.abc import Sequence import functools import string import sys -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Union import weakref import numpy as np @@ -434,7 +436,7 @@ def _slice_to_chunk_idx(size: int, slc: slice) -> int: assert size % slice_size == 0 return slc.start // slice_size -def _raise_to_slice(slc: Union[slice, int]): +def _raise_to_slice(slc: slice | int): if isinstance(slc, int): return slice(slc, slc + 1) return slc @@ -465,7 +467,7 @@ def make_color_iter(color_map, num_rows, num_cols): def visualize_sharding(shape: Sequence[int], sharding: Sharding, *, use_color: bool = True, scale: float = 1., min_width: int = 9, max_width: int = 80, - color_map: Optional[ColorMap] = None): + color_map: ColorMap | None = None): """Visualizes a ``Sharding`` using ``rich``.""" if not RICH_ENABLED: raise ValueError("`visualize_sharding` requires `rich` to be installed.") @@ -491,7 +493,7 @@ def visualize_sharding(shape: Sequence[int], sharding: Sharding, *, device_indices_map = sharding.devices_indices_map(tuple(shape)) slices: dict[tuple[int, ...], set[int]] = {} - heights: dict[tuple[int, ...], Optional[float]] = {} + heights: dict[tuple[int, ...], float | None] = {} widths: dict[tuple[int, ...], float] = {} for i, (dev, slcs) in enumerate(device_indices_map.items()): diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index 49828b7e8d1a..af9a80856532 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import atexit from collections.abc import Sequence import logging import os -from typing import Any, Optional, Union +from typing import Any from jax._src import clusters from jax._src import config @@ -29,16 +31,16 @@ class State: process_id: int = 0 num_processes: int = 1 - service: Optional[Any] = None - client: Optional[Any] = None - preemption_sync_manager: Optional[Any] = None - coordinator_address: Optional[str] = None + service: Any | None = None + client: Any | None = None + preemption_sync_manager: Any | None = None + coordinator_address: str | None = None def initialize(self, - coordinator_address: Optional[str] = None, - num_processes: Optional[int] = None, - process_id: Optional[int] = None, - local_device_ids: Optional[Union[int, Sequence[int]]] = None, + coordinator_address: str | None = None, + num_processes: int | None = None, + process_id: int | None = None, + local_device_ids: int | Sequence[int] | None = None, initialization_timeout: int = 300): coordinator_address = (coordinator_address or os.environ.get('JAX_COORDINATOR_ADDRESS', None)) @@ -108,10 +110,10 @@ def initialize_preemption_sync_manager(self): global_state = State() -def initialize(coordinator_address: Optional[str] = None, - num_processes: Optional[int] = None, - process_id: Optional[int] = None, - local_device_ids: Optional[Union[int, Sequence[int]]] = None, +def initialize(coordinator_address: str | None = None, + num_processes: int | None = None, + process_id: int | None = None, + local_device_ids: int | Sequence[int] | None = None, initialization_timeout: int = 300): """Initializes the JAX distributed system. diff --git a/jax/_src/environment_info.py b/jax/_src/environment_info.py index edf5933eaabd..3fc881ee57cd 100644 --- a/jax/_src/environment_info.py +++ b/jax/_src/environment_info.py @@ -12,24 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import subprocess import sys import textwrap -from typing import Optional, Union from jax import version from jax._src import lib from jax._src import xla_bridge import numpy as np -def try_nvidia_smi() -> Optional[str]: +def try_nvidia_smi() -> str | None: try: return subprocess.check_output(['nvidia-smi']).decode() except Exception: return None -def print_environment_info(return_string: bool = False) -> Union[None, str]: +def print_environment_info(return_string: bool = False) -> str | None: """Returns a string containing local environment & JAX installation information. This is useful information to include when asking a question or filing a bug. diff --git a/jax/_src/image/scale.py b/jax/_src/image/scale.py index 7e3c0bda890c..1ef1db916f73 100644 --- a/jax/_src/image/scale.py +++ b/jax/_src/image/scale.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from collections.abc import Sequence from functools import partial import enum -from typing import Callable, Union +from typing import Callable import numpy as np @@ -171,7 +173,7 @@ def from_string(s: str): def scale_and_translate(image, shape: core.Shape, spatial_dims: Sequence[int], scale, translation, - method: Union[str, ResizeMethod], + method: str | ResizeMethod, antialias: bool = True, precision=lax.Precision.HIGHEST): """Apply a scale and translation to an image. @@ -269,7 +271,7 @@ def _resize_nearest(x, output_shape: core.Shape): @partial(jit, static_argnums=(1, 2, 3, 4)) -def _resize(image, shape: core.Shape, method: Union[str, ResizeMethod], +def _resize(image, shape: core.Shape, method: str | ResizeMethod, antialias: bool, precision): if len(shape) != image.ndim: msg = ('shape must have length equal to the number of dimensions of x; ' @@ -295,7 +297,7 @@ def _resize(image, shape: core.Shape, method: Union[str, ResizeMethod], antialias, precision) -def resize(image, shape: core.Shape, method: Union[str, ResizeMethod], +def resize(image, shape: core.Shape, method: str | ResizeMethod, antialias: bool = True, precision = lax.Precision.HIGHEST): """Image resize. diff --git a/jax/_src/internal_test_util/export_back_compat_test_util.py b/jax/_src/internal_test_util/export_back_compat_test_util.py index 613b80f7ed1a..fbef4b05cc6e 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_util.py +++ b/jax/_src/internal_test_util/export_back_compat_test_util.py @@ -65,13 +65,15 @@ def func(...): ... """ +from __future__ import annotations + from collections.abc import Iterable, Sequence import dataclasses import datetime import os import re import sys -from typing import Any, Callable, Optional +from typing import Any, Callable from absl import logging @@ -163,12 +165,12 @@ def load_testdata_nested(self, testdata_nest) -> Iterable[CompatTestData]: def run_one_test(self, func: Callable[..., jax.Array], data: CompatTestData, - polymorphic_shapes: Optional[Sequence[str]] = None, - rtol: Optional[float] = None, - atol: Optional[float] = None, + polymorphic_shapes: Sequence[str] | None = None, + rtol: float | None = None, + atol: float | None = None, allow_unstable_custom_call_targets: Sequence[str] = (), - check_results: Optional[Callable[..., None]] = None, - expect_current_custom_calls: Optional[Sequence[str]] = None): + check_results: Callable[..., None] | None = None, + expect_current_custom_calls: Sequence[str] | None = None): """Run one compatibility test. Args: @@ -271,7 +273,7 @@ def run_current(self, func: Callable, data: CompatTestData): def serialize(self, func: Callable, data: CompatTestData, *, - polymorphic_shapes: Optional[Sequence[str]] = None, + polymorphic_shapes: Sequence[str] | None = None, allow_unstable_custom_call_targets: Sequence[str] = () ) -> tuple[bytes, str, int, int]: """Serializes the test function. @@ -303,7 +305,7 @@ def serialize(self, return serialized, module_str, module_version, nr_devices def run_serialized(self, data: CompatTestData, - polymorphic_shapes: Optional[Sequence[str]] = None): + polymorphic_shapes: Sequence[str] | None = None): args_specs = export.args_specs(data.inputs, polymorphic_shapes) def ndarray_to_aval(a: np.ndarray) -> core.ShapedArray: return core.ShapedArray(a.shape, a.dtype) diff --git a/jax/_src/internal_test_util/lax_test_util.py b/jax/_src/internal_test_util/lax_test_util.py index 100dffb097fb..40056f3a6dc0 100644 --- a/jax/_src/internal_test_util/lax_test_util.py +++ b/jax/_src/internal_test_util/lax_test_util.py @@ -17,9 +17,11 @@ # only, and may be changed or removed at any time and without any deprecation # cycle. +from __future__ import annotations + import collections import itertools -from typing import Optional, cast +from typing import cast from jax import lax from jax._src import dtypes @@ -355,7 +357,7 @@ def lax_ops(): def all_bdims(*shapes): - bdims = (itertools.chain([cast(Optional[int], None)], + bdims = (itertools.chain([cast(int | None, None)], range(len(shape) + 1)) for shape in shapes) return (t for t in itertools.product(*bdims) if not all(e is None for e in t)) diff --git a/jax/_src/internal_test_util/test_harnesses.py b/jax/_src/internal_test_util/test_harnesses.py index 09ff46837a28..a9aa7d385281 100644 --- a/jax/_src/internal_test_util/test_harnesses.py +++ b/jax/_src/internal_test_util/test_harnesses.py @@ -36,11 +36,13 @@ to fail. A Limitation is specific to a harness. """ +from __future__ import annotations + from collections.abc import Iterable, Sequence import operator import os from functools import partial -from typing import Any, Callable, Optional, NamedTuple, Union +from typing import Any, Callable, NamedTuple, Union from absl import testing import numpy as np @@ -151,7 +153,7 @@ class Harness: dtype: DType # A set of limitations describing the cases that are not supported or # partially implemented in JAX for this harness. - jax_unimplemented: Sequence["Limitation"] + jax_unimplemented: Sequence[Limitation] rng_factory: Callable # Carry some arbitrary parameters that the test can access. params: dict[str, Any] @@ -164,7 +166,7 @@ def __init__(self, *, dtype, rng_factory=jtu.rand_default, - jax_unimplemented: Sequence["Limitation"] = (), + jax_unimplemented: Sequence[Limitation] = (), **params): """See class docstring.""" self.group_name = jtu.sanitize_test_name(group_name) @@ -224,7 +226,7 @@ def filter(self, device_under_test: str, *, include_jax_unimpl: bool = False, - one_containing: Optional[str] = None) -> bool: + one_containing: str | None = None) -> bool: if not include_jax_unimpl: if any( device_under_test in l.devices @@ -285,7 +287,7 @@ def define( *, dtype, rng_factory=jtu.rand_default, - jax_unimplemented: Sequence["Limitation"] = (), + jax_unimplemented: Sequence[Limitation] = (), **params): """Defines a harness and stores it in `all_harnesses`. See Harness.""" group_name = str(group_name) @@ -312,7 +314,7 @@ def __init__( description: str, *, enabled: bool = True, - devices: Union[str, Sequence[str]] = ("cpu", "gpu", "tpu"), + devices: str | Sequence[str] = ("cpu", "gpu", "tpu"), dtypes: Sequence[DType] = (), skip_run: bool = False, ): @@ -354,8 +356,8 @@ def __str__(self): __repr__ = __str__ def filter(self, - device: Optional[str] = None, - dtype: Optional[DType] = None) -> bool: + device: str | None = None, + dtype: DType | None = None) -> bool: """Check that a limitation is enabled for the given dtype and device.""" return (self.enabled and (not self.dtypes or dtype is None or dtype in self.dtypes) and @@ -364,7 +366,7 @@ def filter(self, def parameterized(harnesses: Iterable[Harness], *, - one_containing: Optional[str] = None, + one_containing: str | None = None, include_jax_unimpl: bool = False): """Decorator for tests. diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index a92aa16edbd9..c5a743931263 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from collections.abc import Sequence import contextlib import functools import itertools as it from functools import partial -from typing import Any, Callable, Optional, Union +from typing import Any, Callable import jax from jax._src import config @@ -46,7 +48,7 @@ def identity(x): return x def _update_annotation( f: lu.WrappedFun, - orig_type: Optional[tuple[tuple[core.AbstractValue, bool], ...]], + orig_type: tuple[tuple[core.AbstractValue, bool], ...] | None, explicit_nonzeros: list[bool] ) -> lu.WrappedFun: if orig_type is None: @@ -694,7 +696,7 @@ def unmap_zero(zero, in_axis): def jvp_jaxpr(jaxpr: core.ClosedJaxpr, nonzeros: Sequence[bool], - instantiate: Union[bool, Sequence[bool]] + instantiate: bool | Sequence[bool] ) -> tuple[core.ClosedJaxpr, list[bool]]: if type(instantiate) is bool: instantiate = (instantiate,) * len(jaxpr.out_avals) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index aa33bb4dd726..969d0aaec491 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -966,7 +966,10 @@ def _matchaxis_symbolic_zeros(axis_name, sz, name, src, dst, x, sum_match=False) ### utilities for defining primitives' batching rules -BatchingRule = Callable[..., tuple[Any, Union[None, int, tuple[Union[None, int], ...]]]] +BatchingRule = Callable[ + ..., + tuple[Any, Union[int, None, tuple[Union[int, None], ...]]] +] primitive_batchers : dict[core.Primitive, BatchingRule] = {} axis_primitive_batchers: dict[core.Primitive, Callable] = {} spmd_axis_primitive_batchers: dict[core.Primitive, Callable] = {} diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index a628fe2f40e4..41c71cf97e22 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -72,7 +72,7 @@ def dense_int_elements(xs) -> ir.DenseIntElementsAttr: return ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64)) -def dense_int_array(xs) -> Union[ir.DenseIntElementsAttr, ir.DenseI64ArrayAttr]: +def dense_int_array(xs) -> ir.DenseIntElementsAttr | ir.DenseI64ArrayAttr: if hlo.get_api_version() < 5: return dense_int_elements(xs) return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64)) @@ -947,7 +947,7 @@ class TokenSet: primitives. A `TokenSet` encapsulates a set of HLO tokens that will be used by the lowering rules. """ - _tokens: typing.OrderedDict[core.Effect, Token] + _tokens: collections.OrderedDict[core.Effect, Token] def __init__(self, *args, **kwargs): self._tokens = collections.OrderedDict(*args, **kwargs) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 0fae2a4a8120..fb5de2dfeb3e 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -20,8 +20,7 @@ import inspect import itertools as it import operator as op -from typing import (Any, Callable, NamedTuple, Optional, - Union) +from typing import Any, Callable, NamedTuple, Union from weakref import ref import numpy as np @@ -670,8 +669,9 @@ def abstract_eval_fun(fun, *avals, debug_info=None, **params): return avals_out -JaxprTracerRecipe = Union['JaxprEqnRecipe', 'LambdaBinding', 'FreeVar', - 'ConstVar', Literal] +JaxprTracerRecipe = Union[ + 'JaxprEqnRecipe', 'LambdaBinding', 'FreeVar', 'ConstVar', Literal, +] class JaxprTracer(Tracer): __slots__ = ['pval', 'recipe'] @@ -1330,7 +1330,7 @@ def ensure_enum(case: bool | RematCases) -> RematCases: # * a list of Var instances representing residuals to be added (i.e. to be # plumbed as outputs of the 'known' side jaxpr and added as input binders to # the 'unknown' jaxpr). -PartialEvalCustomResult = tuple[Optional[JaxprEqn], Optional[JaxprEqn], +PartialEvalCustomResult = tuple[Union[JaxprEqn, None], Union[JaxprEqn, None], Sequence[bool], Sequence[bool], list[Var]] PartialEvalCustomRule = Callable[ [Callable[..., RematCases_], Sequence[bool], Sequence[bool], JaxprEqn], @@ -1566,7 +1566,8 @@ def has_effects(e: JaxprEqn) -> bool: return new_jaxpr, used_inputs -DCERule = Callable[[list[bool], JaxprEqn], tuple[list[bool], Optional[JaxprEqn]]] +DCERule = Callable[[list[bool], JaxprEqn], + tuple[list[bool], Union[JaxprEqn, None]]] def _default_dce_rule( used_outs: list[bool], eqn: JaxprEqn @@ -1844,12 +1845,16 @@ def apply_var_sub(a: Atom) -> Atom: jaxpr_effects, jaxpr.debug_info) return new_jaxpr, new_constvals -ConstFoldRule = Callable[[list[Optional[Any]], JaxprEqn], - tuple[list[Optional[Any]], Optional[JaxprEqn]]] +ConstFoldRule = Callable[ + [list[Union[Any, None]], JaxprEqn], + tuple[list[Union[Any, None]], Union[JaxprEqn, None]], +] const_fold_rules: dict[Primitive, ConstFoldRule] = {} -ForwardingRule = Callable[[JaxprEqn], - tuple[list[Optional[Var]], Optional[JaxprEqn]]] +ForwardingRule = Callable[ + [JaxprEqn], + tuple[list[Union[Var, None]], Union[JaxprEqn, None]] +] forwarding_rules: dict[Primitive, ForwardingRule] = {} @@ -2377,8 +2382,12 @@ def trace_to_jaxpr_final2( AbstractedAxisName = Hashable -AbstractedAxesSpec = Union[dict[int, AbstractedAxisName], - tuple[AbstractedAxisName, ...]] +AbstractedAxesSpec = Union[ + dict[int, AbstractedAxisName], + tuple[AbstractedAxisName, ...], +] + + def infer_lambda_input_type( axes_specs: Sequence[AbstractedAxesSpec] | None, args: Sequence[Any] diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 3a9e5f2dfcab..c7c368945a97 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -25,7 +25,7 @@ import logging import math import threading -from typing import (Any, Callable, NamedTuple, Optional, Union, cast, TypeVar) +from typing import Any, Callable, NamedTuple, TypeVar, Union, cast from collections.abc import Iterator import warnings @@ -1609,7 +1609,9 @@ class DeviceAssignmentMismatchError(Exception): ShardingInfo = tuple[ Union[sharding_impls.XLACompatibleSharding, UnspecifiedValue, AUTO], - MismatchType, Optional[Any]] # Any is dispatch.SourceInfo to avoid circular imports + MismatchType, + Union[Any, None], # Any is dispatch.SourceInfo to avoid circular imports +] def _get_default_device() -> xc.Device: @@ -1870,7 +1872,7 @@ def are_all_shardings_default_mem_kind(da_object, shardings): return False return True -MaybeLayout = Sequence[Optional[Union[XLACompatibleLayout, LayoutRequest]]] +MaybeLayout = Sequence[Union[XLACompatibleLayout, LayoutRequest, None]] class AllArgsInfo(NamedTuple): diff --git a/jax/_src/interpreters/xla.py b/jax/_src/interpreters/xla.py index 22d256edc7da..f30fc5b5d6fb 100644 --- a/jax/_src/interpreters/xla.py +++ b/jax/_src/interpreters/xla.py @@ -14,6 +14,8 @@ # Lowering of jaxprs into XLA (HLO) computations. +from __future__ import annotations + from collections import defaultdict from collections.abc import Sequence import dataclasses @@ -21,7 +23,7 @@ from functools import partial import itertools as it import operator -from typing import Any, Callable, Optional, Protocol, Union +from typing import Any, Callable, Protocol, Union import numpy as np @@ -78,9 +80,8 @@ def parameter(builder, num, shape, name=None, replicated=None): # arbitrary tuple nesting, but JAX only uses one level of tupling (and our type # checkers don't support recursive types), so we only represent one level of # nesting in this type definition. -SpatialSharding = Union[Shape, - None, - tuple[Optional[Shape], ...]] +SpatialSharding = Union[Shape, None, tuple[Union[Shape, None], ...]] + def sharding_to_proto(sharding: SpatialSharding): """Converts a SpatialSharding to an OpSharding. @@ -204,7 +205,7 @@ def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray: (t, partial(_make_abstract_python_scalar, t)) for t in _scalar_types) -def primitive_subcomputation(platform: str, axis_env: 'AxisEnv', +def primitive_subcomputation(platform: str, axis_env: AxisEnv, prim: core.Primitive, avals_in: Sequence[core.AbstractValue], avals_out: Sequence[core.AbstractValue], @@ -237,9 +238,9 @@ class TranslationContext: builder: xc.XlaBuilder # TODO(phawkins): make platform non-optional. We should always be translating # with a specific platform in mind. - platform: Optional[str] + platform: str | None axis_env: AxisEnv - name_stack: Union[str, source_info_util.NameStack] + name_stack: str | source_info_util.NameStack def replace(self, **kw): return dataclasses.replace(self, **kw) @@ -272,7 +273,7 @@ def register_initial_style_primitive(prim: core.Primitive): initial_style_primitives.add(prim) def register_translation(prim: core.Primitive, rule: TranslationRule, *, - platform: Optional[str] = None) -> None: + platform: str | None = None) -> None: if platform is None: _translations[prim] = rule else: diff --git a/jax/_src/jaxpr_util.py b/jax/_src/jaxpr_util.py index d7b30e878f87..b9a7763f596d 100644 --- a/jax/_src/jaxpr_util.py +++ b/jax/_src/jaxpr_util.py @@ -14,12 +14,14 @@ """Utilities for the Jaxpr IR.""" -import collections +from __future__ import annotations + +from collections import Counter, defaultdict import gzip import itertools import json import types -from typing import Any, Callable, DefaultDict, Optional +from typing import Any, Callable, Union from jax._src import core from jax._src import util @@ -37,7 +39,7 @@ def all_eqns(jaxpr: core.Jaxpr): yield from all_eqns(subjaxpr) def collect_eqns(jaxpr: core.Jaxpr, key: Callable): - d = collections.defaultdict(list) + d = defaultdict(list) for _, eqn in all_eqns(jaxpr): d[key(eqn)].append(eqn) return dict(d) @@ -68,7 +70,7 @@ def key(eqn): return source_info_util.summarize(eqn.source_info) return histogram(jaxpr, key) -MaybeEqn = Optional[core.JaxprEqn] +MaybeEqn = Union[core.JaxprEqn, None] def var_defs_and_refs(jaxpr: core.Jaxpr): defs: dict[core.Var, MaybeEqn] = {} @@ -128,19 +130,19 @@ def print_histogram(histogram: dict[Any, int]): def _pprof_profile( - profile: dict[tuple[Optional[xla_client.Traceback], core.Primitive], int] + profile: dict[tuple[xla_client.Traceback | None, core.Primitive], int] ) -> bytes: """Converts a profile into a compressed pprof protocol buffer. The input profile is a map from (traceback, primitive) pairs to counts. """ - s: DefaultDict[str, int] - func: DefaultDict[types.CodeType, int] - loc: DefaultDict[tuple[types.CodeType, int], int] + s: defaultdict[str, int] + func: defaultdict[types.CodeType, int] + loc: defaultdict[tuple[types.CodeType, int], int] - s = collections.defaultdict(itertools.count(1).__next__) - func = collections.defaultdict(itertools.count(1).__next__) - loc = collections.defaultdict(itertools.count(1).__next__) + s = defaultdict(itertools.count(1).__next__) + func = defaultdict(itertools.count(1).__next__) + loc = defaultdict(itertools.count(1).__next__) s[""] = 0 primitive_key = s["primitive"] samples = [] @@ -201,8 +203,8 @@ def pprof_equation_profile(jaxpr: core.Jaxpr) -> bytes: A gzip-compressed pprof Profile protocol buffer, suitable for passing to pprof tool for visualization. """ - d: DefaultDict[tuple[Optional[xla_client.Traceback], core.Primitive], int] - d = collections.defaultdict(int) - for _, eqn in all_eqns(jaxpr): - d[(eqn.source_info.traceback, eqn.primitive)] += 1 + d = Counter( + (eqn.source_info.traceback, eqn.primitive) + for _, eqn in all_eqns(jaxpr) + ) return _pprof_profile(d) diff --git a/jax/_src/lax/control_flow/common.py b/jax/_src/lax/control_flow/common.py index 52ab4871e720..7dbc062f503b 100644 --- a/jax/_src/lax/control_flow/common.py +++ b/jax/_src/lax/control_flow/common.py @@ -13,10 +13,12 @@ # limitations under the License. """Module for the common control flow utilities.""" +from __future__ import annotations + from collections.abc import Sequence import os from functools import partial -from typing import Any, Callable, Optional +from typing import Any, Callable from jax._src import core from jax._src import linear_util as lu @@ -52,7 +54,7 @@ def _typecheck_param(prim, param, name, msg_required, pred): @weakref_lru_cache def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals, - primitive_name: Optional[str] = None): + primitive_name: str | None = None): wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) debug = pe.debug_info(fun, in_tree, out_tree, False, primitive_name or "") @@ -61,7 +63,7 @@ def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals, @weakref_lru_cache def _initial_style_jaxpr(fun: Callable, in_tree, in_avals, - primitive_name: Optional[str] = None): + primitive_name: str | None = None): jaxpr, consts, out_tree = _initial_style_open_jaxpr( fun, in_tree, in_avals, primitive_name) closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) diff --git a/jax/_src/lax/control_flow/for_loop.py b/jax/_src/lax/control_flow/for_loop.py index d4635a8660b7..632ca77282cf 100644 --- a/jax/_src/lax/control_flow/for_loop.py +++ b/jax/_src/lax/control_flow/for_loop.py @@ -13,10 +13,12 @@ # limitations under the License. """Module for the `for_loop` primitive.""" +from __future__ import annotations + from collections.abc import Sequence import functools import operator -from typing import Any, Callable, Generic, Optional, TypeVar, Union +from typing import Any, Callable, Generic, TypeVar import jax.numpy as jnp from jax import lax @@ -100,7 +102,7 @@ def _trace_to_jaxpr_with_refs(f, state_tree: PyTreeDef, f, state_avals) return jaxpr, consts, out_tree_thunk() -def for_loop(nsteps: Union[int, Sequence[int]], +def for_loop(nsteps: int | Sequence[int], body: Callable[[Array, Ref[S]], None], init_state: S, *, reverse: bool = False, unroll: int = 1) -> S: """A for-loop combinator that allows read/write semantics in the loop body. @@ -176,7 +178,7 @@ def wrapped_body(i, refs): def scan(f: Callable[[Carry, X], tuple[Carry, Y]], init: Carry, xs: X, - length: Optional[int] = None, + length: int | None = None, reverse: bool = False, unroll: int = 1) -> tuple[Carry, Y]: if not callable(f): @@ -254,7 +256,7 @@ def _for_abstract_eval(*avals, jaxpr, **__): def _for_discharge_rule(in_avals, _, *args: Any, jaxpr: core.Jaxpr, reverse: bool, which_linear: Sequence[bool], nsteps: int, unroll: int - ) -> tuple[Sequence[Optional[Any]], Sequence[Any]]: + ) -> tuple[Sequence[Any | None], Sequence[Any]]: out_vals = for_p.bind(*args, jaxpr=jaxpr, reverse=reverse, which_linear=which_linear, nsteps=nsteps, unroll=unroll) diff --git a/jax/_src/lax/convolution.py b/jax/_src/lax/convolution.py index 10064285a753..56b1a2643db4 100644 --- a/jax/_src/lax/convolution.py +++ b/jax/_src/lax/convolution.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from collections.abc import Sequence from functools import partial import operator -from typing import NamedTuple, Optional, Union +from typing import NamedTuple, Union import numpy as np @@ -46,17 +48,20 @@ class ConvDimensionNumbers(NamedTuple): out_spec: Sequence[int] ConvGeneralDilatedDimensionNumbers = Union[ - None, ConvDimensionNumbers, tuple[str, str, str]] + tuple[str, str, str], + ConvDimensionNumbers, + None, +] def conv_general_dilated( lhs: Array, rhs: Array, window_strides: Sequence[int], - padding: Union[str, Sequence[tuple[int, int]]], - lhs_dilation: Optional[Sequence[int]] = None, - rhs_dilation: Optional[Sequence[int]] = None, + padding: str | Sequence[tuple[int, int]], + lhs_dilation: Sequence[int] | None = None, + rhs_dilation: Sequence[int] | None = None, dimension_numbers: ConvGeneralDilatedDimensionNumbers = None, feature_group_count: int = 1, batch_group_count: int = 1, precision: lax.PrecisionLike = None, - preferred_element_type: Optional[DTypeLike] = None) -> Array: + preferred_element_type: DTypeLike | None = None) -> Array: """General n-dimensional convolution operator, with optional dilation. Wraps XLA's `Conv @@ -168,7 +173,7 @@ def conv_general_dilated( def conv(lhs: Array, rhs: Array, window_strides: Sequence[int], padding: str, precision: lax.PrecisionLike = None, - preferred_element_type: Optional[DTypeLike] = None) -> Array: + preferred_element_type: DTypeLike | None = None) -> Array: """Convenience wrapper around `conv_general_dilated`. Args: @@ -194,11 +199,11 @@ def conv(lhs: Array, rhs: Array, window_strides: Sequence[int], def conv_with_general_padding(lhs: Array, rhs: Array, window_strides: Sequence[int], - padding: Union[str, Sequence[tuple[int, int]]], - lhs_dilation: Optional[Sequence[int]], - rhs_dilation: Optional[Sequence[int]], + padding: str | Sequence[tuple[int, int]], + lhs_dilation: Sequence[int] | None, + rhs_dilation: Sequence[int] | None, precision: lax.PrecisionLike = None, - preferred_element_type: Optional[DTypeLike] = None) -> Array: + preferred_element_type: DTypeLike | None = None) -> Array: """Convenience wrapper around `conv_general_dilated`. Args: @@ -266,12 +271,12 @@ def _flip_axes(x, axes): def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int], - padding: Union[str, Sequence[tuple[int, int]]], - rhs_dilation: Optional[Sequence[int]] = None, + padding: str | Sequence[tuple[int, int]], + rhs_dilation: Sequence[int] | None = None, dimension_numbers: ConvGeneralDilatedDimensionNumbers = None, transpose_kernel: bool = False, precision: lax.PrecisionLike = None, - preferred_element_type: Optional[DTypeLike] = None) -> Array: + preferred_element_type: DTypeLike | None = None) -> Array: """Convenience wrapper for calculating the N-d convolution "transpose". This function directly calculates a fractionally strided conv rather than @@ -325,7 +330,7 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int], k_shape = np.take(rhs.shape, dn.rhs_spec) k_sdims = k_shape[2:] # type: ignore[index] # Calculate correct output shape given padding and strides. - pads: Union[str, Sequence[tuple[int, int]]] + pads: str | Sequence[tuple[int, int]] if isinstance(padding, str) and padding in {'SAME', 'VALID'}: if rhs_dilation is None: rhs_dilation = (1,) * (rhs.ndim - 2) diff --git a/jax/_src/lax/fft.py b/jax/_src/lax/fft.py index 69d7c2156af9..3cefd45b7e85 100644 --- a/jax/_src/lax/fft.py +++ b/jax/_src/lax/fft.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from collections.abc import Sequence from functools import partial import math -from typing import Union import numpy as np @@ -51,7 +52,7 @@ def _str_to_fft_type(s: str) -> xla_client.FftType: raise ValueError(f"Unknown FFT type '{s}'") @partial(jit, static_argnums=(1, 2)) -def fft(x, fft_type: Union[xla_client.FftType, str], fft_lengths: Sequence[int]): +def fft(x, fft_type: xla_client.FftType | str, fft_lengths: Sequence[int]): if isinstance(fft_type, str): typ = _str_to_fft_type(fft_type) elif isinstance(fft_type, xla_client.FftType): diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index ab94d9b4225b..220218d00eb4 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -22,8 +22,7 @@ import itertools import math import operator -from typing import (Any, Callable, TypeVar, Union, - cast as type_cast, overload) +from typing import Any, Callable, TypeVar, Union, cast as type_cast, overload import warnings import numpy as np @@ -673,8 +672,13 @@ def __str__(self) -> str: PrecisionType = Precision -PrecisionLike = Union[None, str, PrecisionType, tuple[str, str], - tuple[PrecisionType, PrecisionType]] +PrecisionLike = Union[ + str, + PrecisionType, + tuple[str, str], + tuple[PrecisionType, PrecisionType], + None, +] def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 3cece6a47a8c..8c1a06492b13 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import inspect import functools from functools import partial import math -from typing import cast, Any, Callable, Literal, Optional, TypeVar, Union, overload +from typing import cast, Any, Callable, Literal, TypeVar, overload import warnings import numpy as np @@ -167,7 +169,7 @@ def eigh( lower: bool = True, symmetrize_input: bool = True, sort_eigenvalues: bool = True, - subset_by_index: Optional[tuple[int, int]] = None, + subset_by_index: tuple[int, int] | None = None, ) -> tuple[Array, Array]: r"""Eigendecomposition of a Hermitian matrix. @@ -303,11 +305,11 @@ def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: Literal[True]) def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: Literal[False]) -> Array: ... @overload -def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: bool = True) -> Union[Array, tuple[Array, Array, Array]]: ... +def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: bool = True) -> Array | tuple[Array, Array, Array]: ... # TODO: Add `max_qdwh_iterations` to the function signature for TPU SVD. @_warn_on_positional_kwargs -def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: bool = True) -> Union[Array, tuple[Array, Array, Array]]: +def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: bool = True) -> Array | tuple[Array, Array, Array]: """Singular value decomposition. Returns the singular values if compute_uv is False, otherwise returns a triple @@ -2116,7 +2118,7 @@ def tridiagonal_solve(dl: Array, d: Array, du: Array, b: Array) -> Array: def schur(x: ArrayLike, *, compute_schur_vectors: bool = True, sort_eig_vals: bool = False, - select_callable: Optional[Callable[..., Any]] = None) -> tuple[Array, Array]: + select_callable: Callable[..., Any] | None = None) -> tuple[Array, Array]: return schur_p.bind( x, compute_schur_vectors=compute_schur_vectors, diff --git a/jax/_src/lax/other.py b/jax/_src/lax/other.py index 63ecc0fd038b..8ebb36ff7f06 100644 --- a/jax/_src/lax/other.py +++ b/jax/_src/lax/other.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from collections.abc import Sequence import math -from typing import Any, Optional, Union, cast as type_cast +from typing import Any, cast as type_cast import jax from jax._src.numpy import lax_numpy as jnp @@ -27,12 +29,12 @@ def conv_general_dilated_patches( lhs: jax.typing.ArrayLike, filter_shape: Sequence[int], window_strides: Sequence[int], - padding: Union[str, Sequence[tuple[int, int]]], - lhs_dilation: Optional[Sequence[int]] = None, - rhs_dilation: Optional[Sequence[int]] = None, - dimension_numbers: Optional[convolution.ConvGeneralDilatedDimensionNumbers] = None, - precision: Optional[lax.PrecisionType] = None, - preferred_element_type: Optional[DType] = None, + padding: str | Sequence[tuple[int, int]], + lhs_dilation: Sequence[int] | None = None, + rhs_dilation: Sequence[int] | None = None, + dimension_numbers: convolution.ConvGeneralDilatedDimensionNumbers | None = None, + precision: lax.PrecisionType | None = None, + preferred_element_type: DType | None = None, ) -> jax.Array: """Extract patches subject to the receptive field of `conv_general_dilated`. @@ -123,11 +125,11 @@ def conv_general_dilated_local( lhs: jax.typing.ArrayLike, rhs: jax.typing.ArrayLike, window_strides: Sequence[int], - padding: Union[str, Sequence[tuple[int, int]]], + padding: str | Sequence[tuple[int, int]], filter_shape: Sequence[int], - lhs_dilation: Optional[Sequence[int]] = None, - rhs_dilation: Optional[Sequence[int]] = None, - dimension_numbers: Optional[convolution.ConvGeneralDilatedDimensionNumbers] = None, + lhs_dilation: Sequence[int] | None = None, + rhs_dilation: Sequence[int] | None = None, + dimension_numbers: convolution.ConvGeneralDilatedDimensionNumbers | None = None, precision: lax.PrecisionLike = None ) -> jax.Array: """General n-dimensional unshared convolution operator with optional dilation. @@ -202,7 +204,7 @@ def conv_general_dilated_local( c_precision = lax.canonicalize_precision(precision) lhs_precision = type_cast( - Optional[lax.PrecisionType], + lax.PrecisionType | None, (c_precision[0] if (isinstance(c_precision, tuple) and len(c_precision) == 2) else c_precision)) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index d29d8bdb6429..278e1d070225 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -15,12 +15,13 @@ Parallelization primitives. """ +from __future__ import annotations + from collections.abc import Sequence from functools import partial import itertools import math import string -from typing import Union import numpy as np @@ -606,7 +607,7 @@ def parse_args(self): return arg_specs -def pgather(src, idx, axes: Union[int, AxisName]): +def pgather(src, idx, axes: int | AxisName): """Uses the last positional axis of idx to index into src's axes.""" if not isinstance(axes, (tuple, list)): axes = (axes,) diff --git a/jax/_src/lax/qdwh.py b/jax/_src/lax/qdwh.py index 049f7604ad8c..92c41db22224 100644 --- a/jax/_src/lax/qdwh.py +++ b/jax/_src/lax/qdwh.py @@ -24,8 +24,9 @@ https://epubs.siam.org/doi/abs/10.1137/090774999 """ +from __future__ import annotations + import functools -from typing import Optional import jax import jax.numpy as jnp @@ -196,7 +197,7 @@ def false_fn(u): # TODO: Add pivoting. @functools.partial(jax.jit, static_argnames=('is_hermitian',)) def qdwh(x, *, is_hermitian=False, max_iterations=None, eps=None, - dynamic_shape: Optional[tuple[int, int]] = None): + dynamic_shape: tuple[int, int] | None = None): """QR-based dynamically weighted Halley iteration for polar decomposition. Args: diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 56240fe0adbd..b158d6c6b194 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from collections.abc import Sequence import enum import operator from functools import partial import math -from typing import Callable, NamedTuple, Optional, Union +from typing import Callable, NamedTuple import weakref import numpy as np @@ -54,7 +56,7 @@ def slice(operand: ArrayLike, start_indices: Sequence[int], limit_indices: Sequence[int], - strides: Optional[Sequence[int]] = None) -> Array: + strides: Sequence[int] | None = None) -> Array: """Wraps XLA's `Slice `_ operator. @@ -109,8 +111,8 @@ def slice(operand: ArrayLike, start_indices: Sequence[int], def dynamic_slice( - operand: Union[Array, np.ndarray], - start_indices: Union[Union[Array, np.ndarray], Sequence[ArrayLike]], + operand: Array | np.ndarray, + start_indices: Array | np.ndarray | Sequence[ArrayLike], slice_sizes: Shape, ) -> Array: """Wraps XLA's `DynamicSlice @@ -166,8 +168,8 @@ def dynamic_slice( slice_sizes=tuple(static_sizes)) -def dynamic_update_slice(operand: Union[Array, np.ndarray], update: ArrayLike, - start_indices: Union[Array, Sequence[ArrayLike]]) -> Array: +def dynamic_update_slice(operand: Array | np.ndarray, update: ArrayLike, + start_indices: Array | Sequence[ArrayLike]) -> Array: """Wraps XLA's `DynamicUpdateSlice `_ operator. @@ -268,7 +270,7 @@ class GatherScatterMode(enum.Enum): PROMISE_IN_BOUNDS = enum.auto() @staticmethod - def from_any(s: Optional[Union[str, 'GatherScatterMode']]): + def from_any(s: str | GatherScatterMode | None): if isinstance(s, GatherScatterMode): return s if s == "clip": @@ -287,7 +289,7 @@ def gather(operand: ArrayLike, start_indices: ArrayLike, *, unique_indices: bool = False, indices_are_sorted: bool = False, - mode: Optional[Union[str, GatherScatterMode]] = None, + mode: str | GatherScatterMode | None = None, fill_value = None) -> Array: """Gather operator. @@ -382,7 +384,7 @@ def scatter_add( operand: ArrayLike, scatter_indices: ArrayLike, updates: ArrayLike, dimension_numbers: ScatterDimensionNumbers, *, indices_are_sorted: bool = False, unique_indices: bool = False, - mode: Optional[Union[str, GatherScatterMode]] = None) -> Array: + mode: str | GatherScatterMode | None = None) -> Array: """Scatter-add operator. Wraps `XLA's Scatter operator @@ -429,7 +431,7 @@ def scatter_mul( operand: ArrayLike, scatter_indices: ArrayLike, updates: ArrayLike, dimension_numbers: ScatterDimensionNumbers, *, indices_are_sorted: bool = False, unique_indices: bool = False, - mode: Optional[Union[str, GatherScatterMode]] = None) -> Array: + mode: str | GatherScatterMode | None = None) -> Array: """Scatter-multiply operator. Wraps `XLA's Scatter operator @@ -476,7 +478,7 @@ def scatter_min( operand: ArrayLike, scatter_indices: ArrayLike, updates: ArrayLike, dimension_numbers: ScatterDimensionNumbers, *, indices_are_sorted: bool = False, unique_indices: bool = False, - mode: Optional[Union[str, GatherScatterMode]] = None) -> Array: + mode: str | GatherScatterMode | None = None) -> Array: """Scatter-min operator. Wraps `XLA's Scatter operator @@ -523,7 +525,7 @@ def scatter_max( operand: ArrayLike, scatter_indices: ArrayLike, updates: ArrayLike, dimension_numbers: ScatterDimensionNumbers, *, indices_are_sorted: bool = False, unique_indices: bool = False, - mode: Optional[Union[str, GatherScatterMode]] = None) -> Array: + mode: str | GatherScatterMode | None = None) -> Array: """Scatter-max operator. Wraps `XLA's Scatter operator @@ -575,7 +577,7 @@ def scatter_apply( dimension_numbers: ScatterDimensionNumbers, *, update_shape: Shape = (), indices_are_sorted: bool = False, unique_indices: bool = False, - mode: Optional[Union[str, GatherScatterMode]] = None) -> Array: + mode: str | GatherScatterMode | None = None) -> Array: """Scatter-apply operator. Wraps `XLA's Scatter operator @@ -637,7 +639,7 @@ def scatter( operand: ArrayLike, scatter_indices: ArrayLike, updates: ArrayLike, dimension_numbers: ScatterDimensionNumbers, *, indices_are_sorted: bool = False, unique_indices: bool = False, - mode: Optional[Union[str, GatherScatterMode]] = None) -> Array: + mode: str | GatherScatterMode | None = None) -> Array: """Scatter-update operator. Wraps `XLA's Scatter operator @@ -700,8 +702,8 @@ def index_take(src: Array, idxs: Array, axes: Sequence[int]) -> Array: ### convenience wrappers around traceables -def slice_in_dim(operand: Union[Array, np.ndarray], start_index: Optional[int], - limit_index: Optional[int], +def slice_in_dim(operand: Array | np.ndarray, start_index: int | None, + limit_index: int | None, stride: int = 1, axis: int = 0) -> Array: """Convenience wrapper around :func:`lax.slice` applying to only one dimension. @@ -775,7 +777,7 @@ def slice_in_dim(operand: Union[Array, np.ndarray], start_index: Optional[int], return slice(operand, start_indices, limit_indices, strides) -def index_in_dim(operand: Union[Array, np.ndarray], index: int, axis: int = 0, +def index_in_dim(operand: Array | np.ndarray, index: int, axis: int = 0, keepdims: bool = True) -> Array: """Convenience wrapper around :func:`lax.slice` to perform int indexing. @@ -835,7 +837,7 @@ def index_in_dim(operand: Union[Array, np.ndarray], index: int, axis: int = 0, return lax.squeeze(result, (axis,)) -def dynamic_slice_in_dim(operand: Union[Array, np.ndarray], +def dynamic_slice_in_dim(operand: Array | np.ndarray, start_index: ArrayLike, slice_size: int, axis: int = 0) -> Array: """Convenience wrapper around :func:`lax.dynamic_slice` applied to one dimension. @@ -893,8 +895,8 @@ def dynamic_slice_in_dim(operand: Union[Array, np.ndarray], return dynamic_slice(operand, start_indices, slice_sizes) -def dynamic_index_in_dim(operand: Union[Array, np.ndarray], - index: Union[int, Array], +def dynamic_index_in_dim(operand: Array | np.ndarray, + index: int | Array, axis: int = 0, keepdims: bool = True) -> Array: """Convenience wrapper around dynamic_slice to perform int indexing. @@ -945,7 +947,7 @@ def dynamic_index_in_dim(operand: Union[Array, np.ndarray], return lax.squeeze(result, (axis,)) -def dynamic_update_slice_in_dim(operand: Union[Array, np.ndarray], +def dynamic_update_slice_in_dim(operand: Array | np.ndarray, update: ArrayLike, start_index: ArrayLike, axis: int) -> Array: """Convenience wrapper around :func:`dynamic_update_slice` to update @@ -1007,7 +1009,7 @@ def dynamic_update_slice_in_dim(operand: Union[Array, np.ndarray], return dynamic_update_slice(operand, update, start_indices) -def dynamic_update_index_in_dim(operand: Union[Array, np.ndarray], +def dynamic_update_index_in_dim(operand: Array | np.ndarray, update: ArrayLike, index: ArrayLike, axis: int) -> Array: """Convenience wrapper around :func:`dynamic_update_slice` to update a slice @@ -2566,8 +2568,8 @@ def _scatter(operand_part, updates_part): def _dynamic_slice_indices( - operand: Union[Array, np.ndarray], - start_indices: Union[Union[Array, np.ndarray], Sequence[ArrayLike]] + operand: Array | np.ndarray, + start_indices: Array | np.ndarray | Sequence[ArrayLike] ) -> list[ArrayLike]: # Normalize the start_indices w.r.t. operand.shape if len(start_indices) != operand.ndim: diff --git a/jax/_src/lax/svd.py b/jax/_src/lax/svd.py index f35b1ec0ede0..1e3472e257b9 100644 --- a/jax/_src/lax/svd.py +++ b/jax/_src/lax/svd.py @@ -33,9 +33,11 @@ https://epubs.siam.org/doi/abs/10.1137/090774999 """ +from __future__ import annotations + from collections.abc import Sequence import functools -from typing import Any, Union +from typing import Any import jax import jax.numpy as jnp @@ -46,7 +48,7 @@ @functools.partial(jax.jit, static_argnums=(2, 3)) def _constant_svd( a: Any, return_nan: bool, full_matrices: bool, compute_uv: bool = True -) -> Union[Any, Sequence[Any]]: +) -> Any | Sequence[Any]: """SVD on matrix of all zeros.""" m, n = a.shape k = min(m, n) @@ -91,7 +93,7 @@ def _constant_svd( @functools.partial(jax.jit, static_argnums=(1, 2, 3)) def _svd_tall_and_square_input( a: Any, hermitian: bool, compute_uv: bool, max_iterations: int -) -> Union[Any, Sequence[Any]]: +) -> Any | Sequence[Any]: """Singular value decomposition for m x n matrix and m >= n. Args: @@ -151,7 +153,7 @@ def _qdwh_svd(a: Any, full_matrices: bool, compute_uv: bool = True, hermitian: bool = False, - max_iterations: int = 10) -> Union[Any, Sequence[Any]]: + max_iterations: int = 10) -> Any | Sequence[Any]: """Singular value decomposition. Args: @@ -217,7 +219,7 @@ def svd(a: Any, full_matrices: bool, compute_uv: bool = True, hermitian: bool = False, - max_iterations: int = 10) -> Union[Any, Sequence[Any]]: + max_iterations: int = 10) -> Any | Sequence[Any]: """Singular value decomposition. Args: diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index ce9c1c34cf1c..853bcd762163 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from collections.abc import Sequence from functools import partial -from typing import Callable, Optional, Union +from typing import Callable import warnings import numpy as np @@ -44,9 +46,9 @@ def reduce_window(operand, init_value, computation: Callable, window_dimensions: core.Shape, window_strides: Sequence[int], - padding: Union[str, Sequence[tuple[int, int]]], - base_dilation: Optional[Sequence[int]] = None, - window_dilation: Optional[Sequence[int]] = None) -> Array: + padding: str | Sequence[tuple[int, int]], + base_dilation: Sequence[int] | None = None, + window_dilation: Sequence[int] | None = None) -> Array: """Wraps XLA's `ReduceWindowWithGeneralPadding `_ operator. @@ -94,7 +96,7 @@ def reduce_window(operand, init_value, computation: Callable, return tree_util.tree_unflatten(out_tree, out_flat) def _get_monoid_window_reducer(monoid_op: Callable, - xs: Sequence[Array]) -> Optional[Callable]: + xs: Sequence[Array]) -> Callable | None: if len(xs) != 1: return None x, = xs @@ -113,8 +115,8 @@ def _get_monoid_window_reducer(monoid_op: Callable, def _reduce_window_sum(operand: Array, window_dimensions: core.Shape, window_strides: Sequence[int], padding: Sequence[tuple[int, int]], - base_dilation: Optional[Sequence[int]] = None, - window_dilation: Optional[Sequence[int]] = None) -> Array: + base_dilation: Sequence[int] | None = None, + window_dilation: Sequence[int] | None = None) -> Array: if base_dilation is None: base_dilation = (1,) * len(window_dimensions) if window_dilation is None: @@ -128,8 +130,8 @@ def _reduce_window_sum(operand: Array, window_dimensions: core.Shape, def _reduce_window_prod(operand: Array, window_dimensions: core.Shape, window_strides: Sequence[int], padding: Sequence[tuple[int, int]], - base_dilation: Optional[Sequence[int]] = None, - window_dilation: Optional[Sequence[int]] = None) -> Array: + base_dilation: Sequence[int] | None = None, + window_dilation: Sequence[int] | None = None) -> Array: init_value = lax._const(operand, 1) jaxpr, consts = lax._reduction_jaxpr(lax.mul, lax._abstractify(init_value)) if base_dilation is None: @@ -147,8 +149,8 @@ def _reduce_window_prod(operand: Array, window_dimensions: core.Shape, def _reduce_window_max(operand: Array, window_dimensions: core.Shape, window_strides: Sequence[int], padding: Sequence[tuple[int, int]], - base_dilation: Optional[Sequence[int]] = None, - window_dilation: Optional[Sequence[int]] = None) -> Array: + base_dilation: Sequence[int] | None = None, + window_dilation: Sequence[int] | None = None) -> Array: if base_dilation is None: base_dilation = (1,) * len(window_dimensions) if window_dilation is None: @@ -162,8 +164,8 @@ def _reduce_window_max(operand: Array, window_dimensions: core.Shape, def _reduce_window_min(operand: Array, window_dimensions: core.Shape, window_strides: Sequence[int], padding: Sequence[tuple[int, int]], - base_dilation: Optional[Sequence[int]] = None, - window_dilation: Optional[Sequence[int]] = None) -> Array: + base_dilation: Sequence[int] | None = None, + window_dilation: Sequence[int] | None = None) -> Array: if base_dilation is None: base_dilation = (1,) * len(window_dimensions) if window_dilation is None: @@ -178,8 +180,8 @@ def _reduce_window_logaddexp( operand: Array, window_dimensions: core.Shape, window_strides: Sequence[int], padding: Sequence[tuple[int, int]], - base_dilation: Optional[Sequence[int]] = None, - window_dilation: Optional[Sequence[int]] = None) -> Array: + base_dilation: Sequence[int] | None = None, + window_dilation: Sequence[int] | None = None) -> Array: init_value = lax._const(operand, -np.inf) jaxpr, consts = lax._reduction_jaxpr(logaddexp, lax._abstractify(init_value)) if base_dilation is None: diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index e6a391c81033..07d6d568e2b1 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -15,11 +15,11 @@ # This module is largely a wrapper around `jaxlib` that performs version # checking on import. -import datetime +from __future__ import annotations + import gc import pathlib import re -from typing import Optional try: import jaxlib as jaxlib @@ -125,7 +125,7 @@ def _xla_gc_callback(*args): # TODO(rocm): check if we need the same for rocm. -def _cuda_path() -> Optional[str]: +def _cuda_path() -> str | None: _jaxlib_path = pathlib.Path(jaxlib.__file__).parent # If the pip package nvidia-cuda-nvcc-cu11 is installed, it should have # both of the things XLA looks for in the cuda path, namely bin/ptxas and diff --git a/jax/_src/maps.py b/jax/_src/maps.py index d11457b4b195..ade9fba87be7 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -12,14 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from collections import OrderedDict, abc from collections.abc import Iterable, Sequence, Mapping import contextlib from functools import wraps, partial, partialmethod, lru_cache import itertools as it import math -from typing import (Callable, Optional, Any, - NamedTuple, Union) +from typing import Callable, Any, NamedTuple, Union import numpy as np @@ -205,7 +206,7 @@ def fresh_resource_name(tag=None): # pytree instance for it, so that it is treated as a leaf. class AxisNamePos(FrozenDict): user_repr: str - expected_rank: Optional[int] = None + expected_rank: int | None = None def __init__(self, *args, user_repr, **kwargs): super().__init__(*args, **kwargs) @@ -278,10 +279,10 @@ def xmap(fun: Callable, in_axes, out_axes, *, - axis_sizes: Optional[Mapping[AxisName, int]] = None, - axis_resources: Optional[Mapping[AxisName, ResourceSet]] = None, - donate_argnums: Union[int, Sequence[int]] = (), - backend: Optional[str] = None) -> stages.Wrapped: + axis_sizes: Mapping[AxisName, int] | None = None, + axis_resources: Mapping[AxisName, ResourceSet] | None = None, + donate_argnums: int | Sequence[int] = (), + backend: str | None = None) -> stages.Wrapped: """Assign a positional signature to a program that uses named array axes. .. warning:: @@ -717,7 +718,7 @@ class EvaluationPlan(NamedTuple): physical_axis_resources: dict[AxisName, tuple[ResourceAxisName, ...]] loop_axis_resources: dict[AxisName, tuple[ResourceAxisName, ...]] axis_subst_dict: dict[AxisName, tuple[ResourceAxisName, ...]] - axis_vmap_size: dict[AxisName, Optional[int]] + axis_vmap_size: dict[AxisName, int | None] @property def axis_subst(self) -> core.AxisSubst: @@ -743,7 +744,7 @@ def from_axis_resources(cls, axis_resource_count = _get_axis_resource_count( axis_resources, resource_env) axis_subst_dict = dict(axis_resources) - axis_vmap_size: dict[AxisName, Optional[int]] = {} + axis_vmap_size: dict[AxisName, int | None] = {} for naxis, raxes in sorted(axis_resources.items(), key=lambda x: str(x[0])): num_resources = axis_resource_count[naxis] assert global_axis_sizes[naxis] % num_resources.nglobal == 0 @@ -1402,7 +1403,7 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes, # XXX: We modify mesh_in_axes and mesh_out_axes here def add_spmd_axes( flat_mesh_axes: Sequence[ArrayMapping], - flat_extra_axes: Optional[Sequence[Sequence[Sequence[MeshAxisName]]]]): + flat_extra_axes: Sequence[Sequence[Sequence[MeshAxisName]]] | None): if flat_extra_axes is None: return for axes, extra in zip(flat_mesh_axes, flat_extra_axes): @@ -1562,7 +1563,7 @@ def _insert_aval_axes(aval, axes: AxisNamePos, local_axis_sizes): class ResourceCount(NamedTuple): nglobal: int - nlocal: Optional[int] + nlocal: int | None distributed: bool def to_local(self, global_size): @@ -1673,7 +1674,7 @@ def _to_resource_axes(axes_specs: Sequence[AxisNamePos], for axes in axes_specs) -def _merge_leading_axis(x, axis: Optional[int]): +def _merge_leading_axis(x, axis: int | None): if axis is None: # We assume that the output does not vary along the leading axis return lax.index_in_dim(x, 0, axis=0, keepdims=False) @@ -1684,7 +1685,7 @@ def _merge_leading_axis(x, axis: Optional[int]): return x_moved.reshape(shape) -def _slice_tile(x, dim: Optional[int], i, n: int): +def _slice_tile(x, dim: int | None, i, n: int): """Selects an `i`th (out of `n`) tiles of `x` along `dim`.""" if dim is None: return x (tile_size, rem) = divmod(x.shape[dim], n) diff --git a/jax/_src/monitoring.py b/jax/_src/monitoring.py index e927c4a3df40..3b291de0061a 100644 --- a/jax/_src/monitoring.py +++ b/jax/_src/monitoring.py @@ -20,19 +20,22 @@ A typical listener callback is to send an event to a metrics collector for aggregation/exporting. """ -from typing import Protocol, Union + +from __future__ import annotations + +from typing import Protocol class EventListenerWithMetadata(Protocol): - def __call__(self, event: str, **kwargs: Union[str, int]) -> None: + def __call__(self, event: str, **kwargs: str | int) -> None: ... class EventDurationListenerWithMetadata(Protocol): def __call__(self, event: str, duration_secs: float, - **kwargs: Union[str, int]) -> None: + **kwargs: str | int) -> None: ... @@ -40,7 +43,7 @@ def __call__(self, event: str, duration_secs: float, _event_duration_secs_listeners: list[EventDurationListenerWithMetadata] = [] -def record_event(event: str, **kwargs: Union[str, int]) -> None: +def record_event(event: str, **kwargs: str | int) -> None: """Record an event. If **kwargs are specified, all of the named arguments have to be passed in the @@ -51,7 +54,7 @@ def record_event(event: str, **kwargs: Union[str, int]) -> None: def record_event_duration_secs(event: str, duration: float, - **kwargs: Union[str, int]) -> None: + **kwargs: str | int) -> None: """Record an event duration in seconds (float). If **kwargs are specified, all of the named arguments have to be passed in the diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 278c180f453a..a9cc7f088e3d 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -14,10 +14,12 @@ """Shared neural network activations and other functions.""" +from __future__ import annotations + from functools import partial import operator import numpy as np -from typing import Any, Optional, Union +from typing import Any import jax import jax.numpy as jnp @@ -401,9 +403,9 @@ def glu(x: ArrayLike, axis: int = -1) -> Array: @partial(jax.jit, static_argnames=("axis",)) def log_softmax(x: ArrayLike, - axis: Optional[Union[int, tuple[int, ...]]] = -1, - where: Optional[ArrayLike] = None, - initial: Optional[ArrayLike] = None) -> Array: + axis: int | tuple[int, ...] | None = -1, + where: ArrayLike | None = None, + initial: ArrayLike | None = None) -> Array: r"""Log-Softmax function. Computes the logarithm of the :code:`softmax` function, which rescales @@ -442,9 +444,9 @@ def log_softmax(x: ArrayLike, # TODO(phawkins): this jit was found to change numerics in a test. Debug this. #@partial(jax.jit, static_argnames=("axis",)) def softmax(x: ArrayLike, - axis: Optional[Union[int, tuple[int, ...]]] = -1, - where: Optional[ArrayLike] = None, - initial: Optional[ArrayLike] = None) -> Array: + axis: int | tuple[int, ...] | None = -1, + where: ArrayLike | None = None, + initial: ArrayLike | None = None) -> Array: r"""Softmax function. Computes the function which rescales elements to the range :math:`[0, 1]` @@ -480,9 +482,9 @@ def softmax(x: ArrayLike, @partial(jax.custom_jvp, nondiff_argnums=(1,)) def _softmax( x: ArrayLike, - axis: Optional[Union[int, tuple[int, ...]]] = -1, - where: Optional[ArrayLike] = None, - initial: Optional[ArrayLike] = None) -> Array: + axis: int | tuple[int, ...] | None = -1, + where: ArrayLike | None = None, + initial: ArrayLike | None = None) -> Array: x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True) unnormalized = jnp.exp(x - x_max) result = unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True) @@ -498,9 +500,9 @@ def _softmax_jvp(axis, primals, tangents): def _softmax_deprecated( x: ArrayLike, - axis: Optional[Union[int, tuple[int, ...]]] = -1, - where: Optional[ArrayLike] = None, - initial: Optional[ArrayLike] = None) -> Array: + axis: int | tuple[int, ...] | None = -1, + where: ArrayLike | None = None, + initial: ArrayLike | None = None) -> Array: x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True) unnormalized = jnp.exp(x - lax.stop_gradient(x_max)) result = unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True) @@ -511,11 +513,11 @@ def _softmax_deprecated( @partial(jax.jit, static_argnames=("axis",)) def standardize(x: ArrayLike, - axis: Optional[Union[int, tuple[int, ...]]] = -1, - mean: Optional[ArrayLike] = None, - variance: Optional[ArrayLike] = None, + axis: int | tuple[int, ...] | None = -1, + mean: ArrayLike | None = None, + variance: ArrayLike | None = None, epsilon: ArrayLike = 1e-5, - where: Optional[ArrayLike] = None) -> Array: + where: ArrayLike | None = None) -> Array: r"""Normalizes an array by subtracting ``mean`` and dividing by :math:`\sqrt{\mathrm{variance}}`.""" numpy_util.check_arraylike("standardize", x) numpy_util.check_arraylike_or_none("standardize", mean, variance, where) @@ -533,7 +535,7 @@ def standardize(x: ArrayLike, # TODO(slebedev): Change the type of `x` to `ArrayLike`. @partial(jax.jit, static_argnames=("num_classes", "dtype", "axis")) def _one_hot(x: Any, num_classes: int, *, - dtype: Any, axis: Union[int, AxisName]) -> Array: + dtype: Any, axis: int | AxisName) -> Array: num_classes = core.concrete_dim_or_error( num_classes, "The error arose in jax.nn.one_hot argument `num_classes`.") @@ -557,7 +559,7 @@ def _one_hot(x: Any, num_classes: int, *, # TODO(slebedev): Change the type of `x` to `ArrayLike`. def one_hot(x: Any, num_classes: int, *, - dtype: Any = jnp.float_, axis: Union[int, AxisName] = -1) -> Array: + dtype: Any = jnp.float_, axis: int | AxisName = -1) -> Array: """One-hot encodes the given indices. Each index in the input ``x`` is encoded as a vector of zeros of length diff --git a/jax/_src/nn/initializers.py b/jax/_src/nn/initializers.py index 3c966668eaf9..95c85e34e946 100644 --- a/jax/_src/nn/initializers.py +++ b/jax/_src/nn/initializers.py @@ -17,9 +17,11 @@ used in Keras and Sonnet. """ +from __future__ import annotations + from collections.abc import Sequence import math -from typing import Any, Literal, Protocol, Union +from typing import Any, Literal, Protocol import numpy as np @@ -194,9 +196,9 @@ def init(key: KeyArray, @export def _compute_fans(shape: core.NamedShape, - in_axis: Union[int, Sequence[int]] = -2, - out_axis: Union[int, Sequence[int]] = -1, - batch_axis: Union[int, Sequence[int]] = () + in_axis: int | Sequence[int] = -2, + out_axis: int | Sequence[int] = -1, + batch_axis: int | Sequence[int] = () ) -> tuple[Array, Array]: """ Compute effective input and output sizes for a linear or convolutional layer. @@ -226,7 +228,7 @@ def _compute_fans(shape: core.NamedShape, return fan_in, fan_out def _complex_uniform(key: KeyArray, - shape: Union[Sequence[int], core.NamedShape], + shape: Sequence[int] | core.NamedShape, dtype: DTypeLikeInexact) -> Array: """ Sample uniform random values within a disk on the complex plane, @@ -240,7 +242,7 @@ def _complex_uniform(key: KeyArray, return r * jnp.exp(1j * theta) def _complex_truncated_normal(key: KeyArray, upper: ArrayLike, - shape: Union[Sequence[int], core.NamedShape], + shape: Sequence[int] | core.NamedShape, dtype: DTypeLikeInexact) -> Array: """ Sample random values from a centered normal distribution on the complex plane, @@ -259,11 +261,11 @@ def _complex_truncated_normal(key: KeyArray, upper: ArrayLike, @export def variance_scaling( scale: RealNumeric, - mode: Union[Literal["fan_in"], Literal["fan_out"], Literal["fan_avg"]], - distribution: Union[Literal["truncated_normal"], Literal["normal"], - Literal["uniform"]], - in_axis: Union[int, Sequence[int]] = -2, - out_axis: Union[int, Sequence[int]] = -1, + mode: Literal["fan_in"] | Literal["fan_out"] | Literal["fan_avg"], + distribution: (Literal["truncated_normal"] | Literal["normal"] | + Literal["uniform"]), + in_axis: int | Sequence[int] = -2, + out_axis: int | Sequence[int] = -1, batch_axis: Sequence[int] = (), dtype: DTypeLikeInexact = jnp.float_ ) -> Initializer: @@ -345,8 +347,8 @@ def init(key: KeyArray, return init @export -def glorot_uniform(in_axis: Union[int, Sequence[int]] = -2, - out_axis: Union[int, Sequence[int]] = -1, +def glorot_uniform(in_axis: int | Sequence[int] = -2, + out_axis: int | Sequence[int] = -1, batch_axis: Sequence[int] = (), dtype: DTypeLikeInexact = jnp.float_) -> Initializer: """Builds a Glorot uniform initializer (aka Xavier uniform initializer). @@ -383,8 +385,8 @@ def glorot_uniform(in_axis: Union[int, Sequence[int]] = -2, xavier_uniform = glorot_uniform @export -def glorot_normal(in_axis: Union[int, Sequence[int]] = -2, - out_axis: Union[int, Sequence[int]] = -1, +def glorot_normal(in_axis: int | Sequence[int] = -2, + out_axis: int | Sequence[int] = -1, batch_axis: Sequence[int] = (), dtype: DTypeLikeInexact = jnp.float_) -> Initializer: """Builds a Glorot normal initializer (aka Xavier normal initializer). @@ -421,8 +423,8 @@ def glorot_normal(in_axis: Union[int, Sequence[int]] = -2, xavier_normal = glorot_normal @export -def lecun_uniform(in_axis: Union[int, Sequence[int]] = -2, - out_axis: Union[int, Sequence[int]] = -1, +def lecun_uniform(in_axis: int | Sequence[int] = -2, + out_axis: int | Sequence[int] = -1, batch_axis: Sequence[int] = (), dtype: DTypeLikeInexact = jnp.float_) -> Initializer: """Builds a Lecun uniform initializer. @@ -457,8 +459,8 @@ def lecun_uniform(in_axis: Union[int, Sequence[int]] = -2, out_axis=out_axis, batch_axis=batch_axis, dtype=dtype) @export -def lecun_normal(in_axis: Union[int, Sequence[int]] = -2, - out_axis: Union[int, Sequence[int]] = -1, +def lecun_normal(in_axis: int | Sequence[int] = -2, + out_axis: int | Sequence[int] = -1, batch_axis: Sequence[int] = (), dtype: DTypeLikeInexact = jnp.float_) -> Initializer: """Builds a Lecun normal initializer. @@ -493,8 +495,8 @@ def lecun_normal(in_axis: Union[int, Sequence[int]] = -2, out_axis=out_axis, batch_axis=batch_axis, dtype=dtype) @export -def he_uniform(in_axis: Union[int, Sequence[int]] = -2, - out_axis: Union[int, Sequence[int]] = -1, +def he_uniform(in_axis: int | Sequence[int] = -2, + out_axis: int | Sequence[int] = -1, batch_axis: Sequence[int] = (), dtype: DTypeLikeInexact = jnp.float_) -> Initializer: """Builds a He uniform initializer (aka Kaiming uniform initializer). @@ -531,8 +533,8 @@ def he_uniform(in_axis: Union[int, Sequence[int]] = -2, kaiming_uniform = he_uniform @export -def he_normal(in_axis: Union[int, Sequence[int]] = -2, - out_axis: Union[int, Sequence[int]] = -1, +def he_normal(in_axis: int | Sequence[int] = -2, + out_axis: int | Sequence[int] = -1, batch_axis: Sequence[int] = (), dtype: DTypeLikeInexact = jnp.float_) -> Initializer: """Builds a He normal initializer (aka Kaiming normal initializer). diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 98ceacf04bf3..7de5be2bc32c 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -19,12 +19,14 @@ This is done dynamically in order to avoid circular imports. """ +from __future__ import annotations + __all__ = ['register_jax_array_methods'] import abc from functools import partial, wraps import math -from typing import Any, Optional, Union +from typing import Any import numpy as np import jax @@ -89,7 +91,7 @@ def _itemsize(arr: ArrayLike) -> int: def _clip(number: ArrayLike, - min: Optional[ArrayLike] = None, max: Optional[ArrayLike] = None, + min: ArrayLike | None = None, max: ArrayLike | None = None, out: None = None) -> Array: """Return an array whose values are limited to a specified range. @@ -111,7 +113,7 @@ def _transpose(a: Array, *args: Any) -> Array: return lax_numpy.transpose(a, axis) -def _compute_newshape(a: ArrayLike, newshape: Union[DimSize, Shape]) -> Shape: +def _compute_newshape(a: ArrayLike, newshape: DimSize | Shape) -> Shape: """Fixes a -1 value in newshape, if present.""" orig_newshape = newshape # for error messages try: @@ -162,7 +164,7 @@ def _reshape(a: Array, *args: Any, order: str = "C") -> Array: raise ValueError(f"Unexpected value for 'order' argument: {order}.") -def _view(arr: Array, dtype: Optional[DTypeLike] = None, type: None = None) -> Array: +def _view(arr: Array, dtype: DTypeLike | None = None, type: None = None) -> Array: """Return a bitwise copy of the array, viewed as a new dtype. This is fuller-featured wrapper around :func:`jax.lax.bitcast_convert_type`. @@ -282,7 +284,7 @@ def _unimplemented_setitem(self, i, x): "https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html") raise TypeError(msg.format(type(self))) -def _operator_round(number: ArrayLike, ndigits: Optional[int] = None) -> Array: +def _operator_round(number: ArrayLike, ndigits: int | None = None) -> Array: out = lax_numpy.round(number, decimals=ndigits or 0) # If `ndigits` is None, for a builtin float round(7.5) returns an integer. return out.astype(int) if ndigits is None else out @@ -308,7 +310,7 @@ def __array_module__(self, types): def _compress_method(a: ArrayLike, condition: ArrayLike, - axis: Optional[int] = None, out: None = None) -> Array: + axis: int | None = None, out: None = None) -> Array: """Return selected slices of this array along given axis. Refer to :func:`jax.numpy.compress` for full documentation.""" diff --git a/jax/_src/numpy/fft.py b/jax/_src/numpy/fft.py index 97479fe69323..41b04c9d5778 100644 --- a/jax/_src/numpy/fft.py +++ b/jax/_src/numpy/fft.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from collections.abc import Sequence import operator -from typing import Optional, Union import numpy as np from jax import dtypes @@ -40,8 +41,8 @@ def _fft_norm(s: Array, func_name: str, norm: str) -> Array: def _fft_core(func_name: str, fft_type: xla_client.FftType, a: ArrayLike, - s: Optional[Shape], axes: Optional[Sequence[int]], - norm: Optional[str]) -> Array: + s: Shape | None, axes: Sequence[int] | None, + norm: str | None) -> Array: full_name = "jax.numpy.fft." + func_name check_arraylike(full_name, a) arr = jnp.asarray(a) @@ -105,34 +106,34 @@ def _fft_core(func_name: str, fft_type: xla_client.FftType, a: ArrayLike, @_wraps(np.fft.fftn) -def fftn(a: ArrayLike, s: Optional[Shape] = None, - axes: Optional[Sequence[int]] = None, - norm: Optional[str] = None) -> Array: +def fftn(a: ArrayLike, s: Shape | None = None, + axes: Sequence[int] | None = None, + norm: str | None = None) -> Array: return _fft_core('fftn', xla_client.FftType.FFT, a, s, axes, norm) @_wraps(np.fft.ifftn) -def ifftn(a: ArrayLike, s: Optional[Shape] = None, - axes: Optional[Sequence[int]] = None, - norm: Optional[str] = None) -> Array: +def ifftn(a: ArrayLike, s: Shape | None = None, + axes: Sequence[int] | None = None, + norm: str | None = None) -> Array: return _fft_core('ifftn', xla_client.FftType.IFFT, a, s, axes, norm) @_wraps(np.fft.rfftn) -def rfftn(a: ArrayLike, s: Optional[Shape] = None, - axes: Optional[Sequence[int]] = None, - norm: Optional[str] = None) -> Array: +def rfftn(a: ArrayLike, s: Shape | None = None, + axes: Sequence[int] | None = None, + norm: str | None = None) -> Array: return _fft_core('rfftn', xla_client.FftType.RFFT, a, s, axes, norm) @_wraps(np.fft.irfftn) -def irfftn(a: ArrayLike, s: Optional[Shape] = None, - axes: Optional[Sequence[int]] = None, - norm: Optional[str] = None) -> Array: +def irfftn(a: ArrayLike, s: Shape | None = None, + axes: Sequence[int] | None = None, + norm: str | None = None) -> Array: return _fft_core('irfftn', xla_client.FftType.IRFFT, a, s, axes, norm) -def _axis_check_1d(func_name: str, axis: Optional[int]): +def _axis_check_1d(func_name: str, axis: int | None): full_name = "jax.numpy.fft." + func_name if isinstance(axis, (list, tuple)): raise ValueError( @@ -141,8 +142,8 @@ def _axis_check_1d(func_name: str, axis: Optional[int]): ) def _fft_core_1d(func_name: str, fft_type: xla_client.FftType, - a: ArrayLike, n: Optional[int], axis: Optional[int], - norm: Optional[str]) -> Array: + a: ArrayLike, n: int | None, axis: int | None, + norm: str | None) -> Array: _axis_check_1d(func_name, axis) axes = None if axis is None else [axis] s = None if n is None else [n] @@ -150,32 +151,32 @@ def _fft_core_1d(func_name: str, fft_type: xla_client.FftType, @_wraps(np.fft.fft) -def fft(a: ArrayLike, n: Optional[int] = None, - axis: int = -1, norm: Optional[str] = None) -> Array: +def fft(a: ArrayLike, n: int | None = None, + axis: int = -1, norm: str | None = None) -> Array: return _fft_core_1d('fft', xla_client.FftType.FFT, a, n=n, axis=axis, norm=norm) @_wraps(np.fft.ifft) -def ifft(a: ArrayLike, n: Optional[int] = None, - axis: int = -1, norm: Optional[str] = None) -> Array: +def ifft(a: ArrayLike, n: int | None = None, + axis: int = -1, norm: str | None = None) -> Array: return _fft_core_1d('ifft', xla_client.FftType.IFFT, a, n=n, axis=axis, norm=norm) @_wraps(np.fft.rfft) -def rfft(a: ArrayLike, n: Optional[int] = None, - axis: int = -1, norm: Optional[str] = None) -> Array: +def rfft(a: ArrayLike, n: int | None = None, + axis: int = -1, norm: str | None = None) -> Array: return _fft_core_1d('rfft', xla_client.FftType.RFFT, a, n=n, axis=axis, norm=norm) @_wraps(np.fft.irfft) -def irfft(a: ArrayLike, n: Optional[int] = None, - axis: int = -1, norm: Optional[str] = None) -> Array: +def irfft(a: ArrayLike, n: int | None = None, + axis: int = -1, norm: str | None = None) -> Array: return _fft_core_1d('irfft', xla_client.FftType.IRFFT, a, n=n, axis=axis, norm=norm) @_wraps(np.fft.hfft) -def hfft(a: ArrayLike, n: Optional[int] = None, - axis: int = -1, norm: Optional[str] = None) -> Array: +def hfft(a: ArrayLike, n: int | None = None, + axis: int = -1, norm: str | None = None) -> Array: conj_a = ufuncs.conj(a) _axis_check_1d('hfft', axis) nn = (conj_a.shape[axis] - 1) * 2 if n is None else n @@ -183,8 +184,8 @@ def hfft(a: ArrayLike, n: Optional[int] = None, norm=norm) * nn @_wraps(np.fft.ihfft) -def ihfft(a: ArrayLike, n: Optional[int] = None, - axis: int = -1, norm: Optional[str] = None) -> Array: +def ihfft(a: ArrayLike, n: int | None = None, + axis: int = -1, norm: str | None = None) -> Array: _axis_check_1d('ihfft', axis) arr = jnp.asarray(a) nn = arr.shape[axis] if n is None else n @@ -194,8 +195,8 @@ def ihfft(a: ArrayLike, n: Optional[int] = None, def _fft_core_2d(func_name: str, fft_type: xla_client.FftType, a: ArrayLike, - s: Optional[Shape], axes: Sequence[int], - norm: Optional[str]) -> Array: + s: Shape | None, axes: Sequence[int], + norm: str | None) -> Array: full_name = "jax.numpy.fft." + func_name if len(axes) != 2: raise ValueError( @@ -206,26 +207,26 @@ def _fft_core_2d(func_name: str, fft_type: xla_client.FftType, a: ArrayLike, @_wraps(np.fft.fft2) -def fft2(a: ArrayLike, s: Optional[Shape] = None, axes: Sequence[int] = (-2,-1), - norm: Optional[str] = None) -> Array: +def fft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1), + norm: str | None = None) -> Array: return _fft_core_2d('fft2', xla_client.FftType.FFT, a, s=s, axes=axes, norm=norm) @_wraps(np.fft.ifft2) -def ifft2(a: ArrayLike, s: Optional[Shape] = None, axes: Sequence[int] = (-2,-1), - norm: Optional[str] = None) -> Array: +def ifft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1), + norm: str | None = None) -> Array: return _fft_core_2d('ifft2', xla_client.FftType.IFFT, a, s=s, axes=axes, norm=norm) @_wraps(np.fft.rfft2) -def rfft2(a: ArrayLike, s: Optional[Shape] = None, axes: Sequence[int] = (-2,-1), - norm: Optional[str] = None) -> Array: +def rfft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1), + norm: str | None = None) -> Array: return _fft_core_2d('rfft2', xla_client.FftType.RFFT, a, s=s, axes=axes, norm=norm) @_wraps(np.fft.irfft2) -def irfft2(a: ArrayLike, s: Optional[Shape] = None, axes: Sequence[int] = (-2,-1), - norm: Optional[str] = None) -> Array: +def irfft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1), + norm: str | None = None) -> Array: return _fft_core_2d('irfft2', xla_client.FftType.IRFFT, a, s=s, axes=axes, norm=norm) @@ -292,10 +293,10 @@ def rfftfreq(n: int, d: ArrayLike = 1.0, *, dtype=None) -> Array: @_wraps(np.fft.fftshift) -def fftshift(x: ArrayLike, axes: Union[None, int, Sequence[int]] = None) -> Array: +def fftshift(x: ArrayLike, axes: None | int | Sequence[int] = None) -> Array: check_arraylike("fftshift", x) x = jnp.asarray(x) - shift: Union[int, Sequence[int]] + shift: int | Sequence[int] if axes is None: axes = tuple(range(x.ndim)) shift = [dim // 2 for dim in x.shape] @@ -308,10 +309,10 @@ def fftshift(x: ArrayLike, axes: Union[None, int, Sequence[int]] = None) -> Arra @_wraps(np.fft.ifftshift) -def ifftshift(x: ArrayLike, axes: Union[None, int, Sequence[int]] = None) -> Array: +def ifftshift(x: ArrayLike, axes: None | int | Sequence[int] = None) -> Array: check_arraylike("ifftshift", x) x = jnp.asarray(x) - shift: Union[int, Sequence[int]] + shift: int | Sequence[int] if axes is None: axes = tuple(range(x.ndim)) shift = [-(dim // 2) for dim in x.shape] diff --git a/jax/_src/numpy/index_tricks.py b/jax/_src/numpy/index_tricks.py index c9aa12719287..e67237546279 100644 --- a/jax/_src/numpy/index_tricks.py +++ b/jax/_src/numpy/index_tricks.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import abc from collections.abc import Iterable from typing import Any, Union @@ -74,7 +76,7 @@ class _Mgrid: [0, 1, 2]]], dtype=int32) """ - def __getitem__(self, key: Union[slice, tuple[slice, ...]]) -> Array: + def __getitem__(self, key: slice | tuple[slice, ...]) -> Array: if isinstance(key, slice): return _make_1d_grid_from_slice(key, op_name="mgrid") output: Iterable[Array] = (_make_1d_grid_from_slice(k, op_name="mgrid") for k in key) @@ -118,8 +120,8 @@ class _Ogrid: """ def __getitem__( - self, key: Union[slice, tuple[slice, ...]] - ) -> Union[Array, list[Array]]: + self, key: slice | tuple[slice, ...] + ) -> Array | list[Array]: if isinstance(key, slice): return _make_1d_grid_from_slice(key, op_name="ogrid") output: Iterable[Array] = (_make_1d_grid_from_slice(k, op_name="ogrid") for k in key) @@ -141,7 +143,7 @@ class _AxisConcat(abc.ABC): trans1d: int op_name: str - def __getitem__(self, key: Union[_IndexType, tuple[_IndexType, ...]]) -> Array: + def __getitem__(self, key: _IndexType | tuple[_IndexType, ...]) -> Array: key_tup: tuple[_IndexType, ...] = key if isinstance(key, tuple) else (key,) params = [self.axis, self.ndmin, self.trans1d, -1] diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 3cf1f0052b63..3b7e8bc161a5 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -32,7 +32,8 @@ import math import operator import types -from typing import (overload, Any, Callable, Literal, NamedTuple, Protocol, TypeVar, Union) +from typing import (overload, Any, Callable, Literal, NamedTuple, Protocol, + TypeVar, Union) from textwrap import dedent as _dedent import warnings diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index 0934e97d35cc..bcfc8c18fc93 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations from functools import partial import numpy as np import textwrap import operator -from typing import Literal, Optional, Union, cast, overload +from typing import Literal, cast, overload import jax from jax import jit, custom_jvp @@ -61,12 +62,12 @@ def svd(a: ArrayLike, full_matrices: bool, compute_uv: Literal[False], hermitian: bool = False) -> Array: ... @overload def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True, - hermitian: bool = False) -> Union[Array, tuple[Array, Array, Array]]: ... + hermitian: bool = False) -> Array | tuple[Array, Array, Array]: ... @_wraps(np.linalg.svd) @partial(jit, static_argnames=('full_matrices', 'compute_uv', 'hermitian')) def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True, - hermitian: bool = False) -> Union[Array, tuple[Array, Array, Array]]: + hermitian: bool = False) -> Array | tuple[Array, Array, Array]: check_arraylike("jnp.linalg.svd", a) a, = promote_dtypes_inexact(jnp.asarray(a)) if hermitian: @@ -129,7 +130,7 @@ def matrix_power(a: ArrayLike, n: int) -> Array: @_wraps(np.linalg.matrix_rank) @jit -def matrix_rank(M: ArrayLike, tol: Optional[ArrayLike] = None) -> Array: +def matrix_rank(M: ArrayLike, tol: ArrayLike | None = None) -> Array: check_arraylike("jnp.linalg.matrix_rank", M) M, = promote_dtypes_inexact(jnp.asarray(M)) if M.ndim < 2: @@ -193,7 +194,7 @@ def _slogdet_qr(a: Array) -> tuple[Array, Array]: LU decomposition if ``None``. """)) @partial(jit, static_argnames=('method',)) -def slogdet(a: ArrayLike, *, method: Optional[str] = None) -> tuple[Array, Array]: +def slogdet(a: ArrayLike, *, method: str | None = None) -> tuple[Array, Array]: check_arraylike("jnp.linalg.slogdet", a) a, = promote_dtypes_inexact(jnp.asarray(a)) a_shape = jnp.shape(a) @@ -382,7 +383,7 @@ def eigvals(a: ArrayLike) -> Array: @_wraps(np.linalg.eigh) @partial(jit, static_argnames=('UPLO', 'symmetrize_input')) -def eigh(a: ArrayLike, UPLO: Optional[str] = None, +def eigh(a: ArrayLike, UPLO: str | None = None, symmetrize_input: bool = True) -> tuple[Array, Array]: check_arraylike("jnp.linalg.eigh", a) if UPLO is None or UPLO == "L": @@ -400,7 +401,7 @@ def eigh(a: ArrayLike, UPLO: Optional[str] = None, @_wraps(np.linalg.eigvalsh) @partial(jit, static_argnames=('UPLO',)) -def eigvalsh(a: ArrayLike, UPLO: Optional[str] = 'L') -> Array: +def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array: check_arraylike("jnp.linalg.eigvalsh", a) w, _ = eigh(a, UPLO) return w @@ -413,7 +414,7 @@ def eigvalsh(a: ArrayLike, UPLO: Optional[str] = 'L') -> Array: `10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps`. """)) @partial(jit, static_argnames=('hermitian',)) -def pinv(a: ArrayLike, rcond: Optional[ArrayLike] = None, +def pinv(a: ArrayLike, rcond: ArrayLike | None = None, hermitian: bool = False) -> Array: # Uses same algorithm as # https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979 @@ -481,8 +482,8 @@ def inv(a: ArrayLike) -> Array: @_wraps(np.linalg.norm) @partial(jit, static_argnames=('ord', 'axis', 'keepdims')) -def norm(x: ArrayLike, ord: Union[int, str, None] = None, - axis: Union[None, tuple[int, ...], int] = None, +def norm(x: ArrayLike, ord: int | str | None = None, + axis: None | tuple[int, ...] | int = None, keepdims: bool = False) -> Array: check_arraylike("jnp.linalg.norm", x) x, = promote_dtypes_inexact(jnp.asarray(x)) @@ -579,11 +580,11 @@ def norm(x: ArrayLike, ord: Union[int, str, None] = None, @overload def qr(a: ArrayLike, mode: Literal["r"]) -> Array: ... @overload -def qr(a: ArrayLike, mode: str = "reduced") -> Union[Array, tuple[Array, Array]]: ... +def qr(a: ArrayLike, mode: str = "reduced") -> Array | tuple[Array, Array]: ... @_wraps(np.linalg.qr) @partial(jit, static_argnames=('mode',)) -def qr(a: ArrayLike, mode: str = "reduced") -> Union[Array, tuple[Array, Array]]: +def qr(a: ArrayLike, mode: str = "reduced") -> Array | tuple[Array, Array]: check_arraylike("jnp.linalg.qr", a) a, = promote_dtypes_inexact(jnp.asarray(a)) if mode == "raw": @@ -611,7 +612,7 @@ def solve(a: ArrayLike, b: ArrayLike) -> Array: return lax_linalg._solve(a, b) -def _lstsq(a: ArrayLike, b: ArrayLike, rcond: Optional[float], *, +def _lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None, *, numpy_resid: bool = False) -> tuple[Array, Array, Array, Array]: # TODO: add lstsq to lax_linalg and implement this function via those wrappers. # TODO: add custom jvp rule for more robust lstsq differentiation @@ -671,7 +672,7 @@ def _lstsq(a: ArrayLike, b: ArrayLike, rcond: Optional[float], *, The lstsq function does not currently have a custom JVP rule, so the gradient is poorly behaved for some inputs, particularly for low-rank `a`. """)) -def lstsq(a: ArrayLike, b: ArrayLike, rcond: Optional[float] = None, *, +def lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None = None, *, numpy_resid: bool = False) -> tuple[Array, Array, Array, Array]: check_arraylike("jnp.linalg.lstsq", a, b) if numpy_resid: diff --git a/jax/_src/numpy/polynomial.py b/jax/_src/numpy/polynomial.py index c823ca9abafc..41603c9dcf1c 100644 --- a/jax/_src/numpy/polynomial.py +++ b/jax/_src/numpy/polynomial.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations from functools import partial import operator -from typing import Optional, Union import numpy as np @@ -108,9 +108,9 @@ def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array: """ @_wraps(np.polyfit, lax_description=_POLYFIT_DOC) @partial(jit, static_argnames=('deg', 'rcond', 'full', 'cov')) -def polyfit(x: Array, y: Array, deg: int, rcond: Optional[float] = None, - full: bool = False, w: Optional[Array] = None, cov: bool = False - ) -> Union[Array, tuple[Array, ...]]: +def polyfit(x: Array, y: Array, deg: int, rcond: float | None = None, + full: bool = False, w: Array | None = None, cov: bool = False + ) -> Array | tuple[Array, ...]: check_arraylike("polyfit", x, y) deg = core.concrete_or_error(int, deg, "deg must be int") order = deg + 1 @@ -244,7 +244,7 @@ def polyadd(a1: Array, a2: Array) -> Array: @_wraps(np.polyint) @partial(jit, static_argnames=('m',)) -def polyint(p: Array, m: int = 1, k: Optional[int] = None) -> Array: +def polyint(p: Array, m: int = 1, k: int | None = None) -> Array: m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyint") k = 0 if k is None else k check_arraylike("polyint", p, k) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index f4f034454852..6a87d373cc54 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import builtins from collections.abc import Sequence from functools import partial import math import operator -from typing import overload, Any, Callable, Literal, Optional, Protocol, Union +from typing import overload, Any, Callable, Literal, Protocol, Union import warnings import numpy as np @@ -41,7 +43,7 @@ _lax_const = lax_internal._const -Axis = Union[None, int, Sequence[int]] +Axis = Union[int, Sequence[int], None] def _isscalar(element: Any) -> bool: if hasattr(element, '__jax_array__'): @@ -67,13 +69,13 @@ def _upcast_f16(dtype: DTypeLike) -> DType: def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val: ArrayLike, *, has_identity: bool = True, - preproc: Optional[Callable[[ArrayLike], ArrayLike]] = None, - bool_op: Optional[ReductionOp] = None, + preproc: Callable[[ArrayLike], ArrayLike] | None = None, + bool_op: ReductionOp | None = None, upcast_f16_for_computation: bool = False, - axis: Axis = None, dtype: Optional[DTypeLike] = None, out: None = None, - keepdims: bool = False, initial: Optional[ArrayLike] = None, - where_: Optional[ArrayLike] = None, - parallel_reduce: Optional[Callable[..., Array]] = None, + axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, + keepdims: bool = False, initial: ArrayLike | None = None, + where_: ArrayLike | None = None, + parallel_reduce: Callable[..., Array] | None = None, promote_integers: bool = False) -> Array: bool_op = bool_op or op # Note: we must accept out=None as an argument, because numpy reductions delegate to @@ -203,9 +205,9 @@ def force(x): @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'promote_integers'), inline=True) -def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = None, +def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, - initial: Optional[ArrayLike] = None, where: Optional[ArrayLike] = None, + initial: ArrayLike | None = None, where: ArrayLike | None = None, promote_integers: bool = True) -> Array: return _reduction(a, "sum", np.sum, lax.add, 0, preproc=_cast_to_numeric, bool_op=lax.bitwise_or, upcast_f16_for_computation=True, @@ -214,18 +216,18 @@ def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = No promote_integers=promote_integers) @_wraps(np.sum, skip_params=['out'], extra_params=_PROMOTE_INTEGERS_DOC) -def sum(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = None, - out: None = None, keepdims: bool = False, initial: Optional[ArrayLike] = None, - where: Optional[ArrayLike] = None, promote_integers: bool = True) -> Array: +def sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None, promote_integers: bool = True) -> Array: return _reduce_sum(a, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where, promote_integers=promote_integers) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'promote_integers'), inline=True) -def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = None, +def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, - initial: Optional[ArrayLike] = None, where: Optional[ArrayLike] = None, + initial: ArrayLike | None = None, where: ArrayLike | None = None, promote_integers: bool = True) -> Array: return _reduction(a, "prod", np.prod, lax.mul, 1, preproc=_cast_to_numeric, bool_op=lax.bitwise_and, upcast_f16_for_computation=True, @@ -233,9 +235,9 @@ def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = N initial=initial, where_=where, promote_integers=promote_integers) @_wraps(np.prod, skip_params=['out'], extra_params=_PROMOTE_INTEGERS_DOC) -def prod(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = None, +def prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, - initial: Optional[ArrayLike] = None, where: Optional[ArrayLike] = None, + initial: ArrayLike | None = None, where: ArrayLike | None = None, promote_integers: bool = True) -> Array: return _reduce_prod(a, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where, @@ -244,62 +246,62 @@ def prod(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = None, @partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True) def _reduce_max(a: ArrayLike, axis: Axis = None, out: None = None, - keepdims: bool = False, initial: Optional[ArrayLike] = None, - where: Optional[ArrayLike] = None) -> Array: + keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None) -> Array: return _reduction(a, "max", np.max, lax.max, -np.inf, has_identity=False, axis=axis, out=out, keepdims=keepdims, initial=initial, where_=where, parallel_reduce=lax.pmax) @_wraps(np.max, skip_params=['out']) def max(a: ArrayLike, axis: Axis = None, out: None = None, - keepdims: bool = False, initial: Optional[ArrayLike] = None, - where: Optional[ArrayLike] = None) -> Array: + keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None) -> Array: return _reduce_max(a, axis=_ensure_optional_axes(axis), out=out, keepdims=keepdims, initial=initial, where=where) @partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True) def _reduce_min(a: ArrayLike, axis: Axis = None, out: None = None, - keepdims: bool = False, initial: Optional[ArrayLike] = None, - where: Optional[ArrayLike] = None) -> Array: + keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None) -> Array: return _reduction(a, "min", np.min, lax.min, np.inf, has_identity=False, axis=axis, out=out, keepdims=keepdims, initial=initial, where_=where, parallel_reduce=lax.pmin) @_wraps(np.min, skip_params=['out']) def min(a: ArrayLike, axis: Axis = None, out: None = None, - keepdims: bool = False, initial: Optional[ArrayLike] = None, - where: Optional[ArrayLike] = None) -> Array: + keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None) -> Array: return _reduce_min(a, axis=_ensure_optional_axes(axis), out=out, keepdims=keepdims, initial=initial, where=where) @partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True) def _reduce_all(a: ArrayLike, axis: Axis = None, out: None = None, - keepdims: bool = False, *, where: Optional[ArrayLike] = None) -> Array: + keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: return _reduction(a, "all", np.all, lax.bitwise_and, True, preproc=_cast_to_bool, axis=axis, out=out, keepdims=keepdims, where_=where) @_wraps(np.all, skip_params=['out']) def all(a: ArrayLike, axis: Axis = None, out: None = None, - keepdims: bool = False, *, where: Optional[ArrayLike] = None) -> Array: + keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: return _reduce_all(a, axis=_ensure_optional_axes(axis), out=out, keepdims=keepdims, where=where) @partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True) def _reduce_any(a: ArrayLike, axis: Axis = None, out: None = None, - keepdims: bool = False, *, where: Optional[ArrayLike] = None) -> Array: + keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: return _reduction(a, "any", np.any, lax.bitwise_or, False, preproc=_cast_to_bool, axis=axis, out=out, keepdims=keepdims, where_=where) @_wraps(np.any, skip_params=['out']) def any(a: ArrayLike, axis: Axis = None, out: None = None, - keepdims: bool = False, *, where: Optional[ArrayLike] = None) -> Array: + keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: return _reduce_any(a, axis=_ensure_optional_axes(axis), out=out, keepdims=keepdims, where=where) amin = min amax = max -def _axis_size(a: ArrayLike, axis: Union[int, Sequence[int]]): +def _axis_size(a: ArrayLike, axis: int | Sequence[int]): if not isinstance(axis, (tuple, list)): axis_seq: Sequence[int] = (axis,) # type: ignore[assignment] else: @@ -311,17 +313,17 @@ def _axis_size(a: ArrayLike, axis: Union[int, Sequence[int]]): return size @_wraps(np.mean, skip_params=['out']) -def mean(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = None, +def mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, *, - where: Optional[ArrayLike] = None) -> Array: + where: ArrayLike | None = None) -> Array: return _mean(a, _ensure_optional_axes(axis), dtype, out, keepdims, where=where) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'), inline=True) -def _mean(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = None, +def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, *, upcast_f16_for_computation: bool = True, - where: Optional[ArrayLike] = None) -> Array: + where: ArrayLike | None = None) -> Array: check_arraylike("mean", a) if out is not None: raise NotImplementedError("The 'out' argument to jnp.mean is not supported.") @@ -351,22 +353,22 @@ def _mean(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = None, ).astype(result_dtype) @overload -def average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = None, +def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, returned: Literal[False] = False, keepdims: bool = False) -> Array: ... @overload -def average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = None, *, +def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, *, returned: Literal[True], keepdims: bool = False) -> Array: ... @overload -def average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = None, - returned: bool = False, keepdims: bool = False) -> Union[Array, tuple[Array, Array]]: ... +def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, + returned: bool = False, keepdims: bool = False) -> Array | tuple[Array, Array]: ... @_wraps(np.average) -def average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = None, - returned: bool = False, keepdims: bool = False) -> Union[Array, tuple[Array, Array]]: +def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, + returned: bool = False, keepdims: bool = False) -> Array | tuple[Array, Array]: return _average(a, _ensure_optional_axes(axis), weights, returned, keepdims) @partial(api.jit, static_argnames=('axis', 'returned', 'keepdims'), inline=True) -def _average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = None, - returned: bool = False, keepdims: bool = False) -> Union[Array, tuple[Array, Array]]: +def _average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, + returned: bool = False, keepdims: bool = False) -> Array | tuple[Array, Array]: if weights is None: # Treat all weights as 1 check_arraylike("average", a) a, = promote_dtypes_inexact(a) @@ -420,16 +422,16 @@ def _average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = Non @_wraps(np.var, skip_params=['out']) -def var(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = None, +def var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, *, - where: Optional[ArrayLike] = None) -> Array: + where: ArrayLike | None = None) -> Array: return _var(a, _ensure_optional_axes(axis), dtype, out, ddof, keepdims, where=where) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) -def _var(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = None, +def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, *, - where: Optional[ArrayLike] = None) -> Array: + where: ArrayLike | None = None) -> Array: check_arraylike("var", a) dtypes.check_user_dtype_supported(dtype, "var") if out is not None: @@ -459,7 +461,7 @@ def _var(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = None, return lax.div(result, normalizer).astype(dtype) -def _var_promote_types(a_dtype: DTypeLike, dtype: Optional[DTypeLike]) -> tuple[DType, DType]: +def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike | None) -> tuple[DType, DType]: if dtype: if (not dtypes.issubdtype(dtype, np.complexfloating) and dtypes.issubdtype(a_dtype, np.complexfloating)): @@ -481,16 +483,16 @@ def _var_promote_types(a_dtype: DTypeLike, dtype: Optional[DTypeLike]) -> tuple[ @_wraps(np.std, skip_params=['out']) -def std(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = None, +def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, *, - where: Optional[ArrayLike] = None) -> Array: + where: ArrayLike | None = None) -> Array: return _std(a, _ensure_optional_axes(axis), dtype, out, ddof, keepdims, where=where) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) -def _std(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = None, +def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, *, - where: Optional[ArrayLike] = None) -> Array: + where: ArrayLike | None = None) -> Array: check_arraylike("std", a) dtypes.check_user_dtype_supported(dtype, "std") if dtype is not None and not dtypes.issubdtype(dtype, np.inexact): @@ -543,8 +545,8 @@ def _nan_reduction(a: ArrayLike, name: str, jnp_reduction: Callable[..., Array], @_wraps(np.nanmin, skip_params=['out']) @partial(api.jit, static_argnames=('axis', 'keepdims')) def nanmin(a: ArrayLike, axis: Axis = None, out: None = None, - keepdims: bool = False, initial: Optional[ArrayLike] = None, - where: Optional[ArrayLike] = None) -> Array: + keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None) -> Array: return _nan_reduction(a, 'nanmin', min, np.inf, nan_if_all_nan=initial is None, axis=axis, out=out, keepdims=keepdims, initial=initial, where=where) @@ -552,17 +554,17 @@ def nanmin(a: ArrayLike, axis: Axis = None, out: None = None, @_wraps(np.nanmax, skip_params=['out']) @partial(api.jit, static_argnames=('axis', 'keepdims')) def nanmax(a: ArrayLike, axis: Axis = None, out: None = None, - keepdims: bool = False, initial: Optional[ArrayLike] = None, - where: Optional[ArrayLike] = None) -> Array: + keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None) -> Array: return _nan_reduction(a, 'nanmax', max, -np.inf, nan_if_all_nan=initial is None, axis=axis, out=out, keepdims=keepdims, initial=initial, where=where) @_wraps(np.nansum, skip_params=['out']) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) -def nansum(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = None, out: None = None, - keepdims: bool = False, initial: Optional[ArrayLike] = None, - where: Optional[ArrayLike] = None) -> Array: +def nansum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, + keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None) -> Array: dtypes.check_user_dtype_supported(dtype, "nanprod") return _nan_reduction(a, 'nansum', sum, 0, nan_if_all_nan=False, axis=axis, dtype=dtype, out=out, keepdims=keepdims, @@ -574,9 +576,9 @@ def nansum(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = None, o @_wraps(np.nanprod, skip_params=['out']) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) -def nanprod(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = None, out: None = None, - keepdims: bool = False, initial: Optional[ArrayLike] = None, - where: Optional[ArrayLike] = None) -> Array: +def nanprod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, + keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None) -> Array: dtypes.check_user_dtype_supported(dtype, "nanprod") return _nan_reduction(a, 'nanprod', prod, 1, nan_if_all_nan=False, axis=axis, dtype=dtype, out=out, keepdims=keepdims, @@ -584,8 +586,8 @@ def nanprod(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = None, @_wraps(np.nanmean, skip_params=['out']) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) -def nanmean(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = None, out: None = None, - keepdims: bool = False, where: Optional[ArrayLike] = None) -> Array: +def nanmean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, + keepdims: bool = False, where: ArrayLike | None = None) -> Array: check_arraylike("nanmean", a) if out is not None: raise NotImplementedError("The 'out' argument to jnp.nanmean is not supported.") @@ -604,9 +606,9 @@ def nanmean(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = None, @_wraps(np.nanvar, skip_params=['out']) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) -def nanvar(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = None, out: None = None, +def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, - where: Optional[ArrayLike] = None) -> Array: + where: ArrayLike | None = None) -> Array: check_arraylike("nanvar", a) dtypes.check_user_dtype_supported(dtype, "nanvar") if out is not None: @@ -635,9 +637,9 @@ def nanvar(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = None, o @_wraps(np.nanstd, skip_params=['out']) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) -def nanstd(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = None, out: None = None, +def nanstd(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, - where: Optional[ArrayLike] = None) -> Array: + where: ArrayLike | None = None) -> Array: check_arraylike("nanstd", a) dtypes.check_user_dtype_supported(dtype, "nanstd") if out is not None: @@ -647,7 +649,7 @@ def nanstd(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = None, o class CumulativeReduction(Protocol): def __call__(self, a: ArrayLike, axis: Axis = None, - dtype: Optional[DTypeLike] = None, out: None = None) -> Array: ... + dtype: DTypeLike | None = None, out: None = None) -> Array: ... # TODO(jakevdp): should we change these semantics to match those of numpy? @@ -661,12 +663,12 @@ def _make_cumulative_reduction(np_reduction: Any, reduction: Callable[..., Array @_wraps(np_reduction, skip_params=['out'], lax_description=CUML_REDUCTION_LAX_DESCRIPTION) def cumulative_reduction(a: ArrayLike, axis: Axis = None, - dtype: Optional[DTypeLike] = None, out: None = None) -> Array: + dtype: DTypeLike | None = None, out: None = None) -> Array: return _cumulative_reduction(a, _ensure_optional_axes(axis), dtype, out) @partial(api.jit, static_argnames=('axis', 'dtype')) def _cumulative_reduction(a: ArrayLike, axis: Axis = None, - dtype: Optional[DTypeLike] = None, out: None = None) -> Array: + dtype: DTypeLike | None = None, out: None = None) -> Array: check_arraylike(np_reduction.__name__, a) if out is not None: raise NotImplementedError(f"The 'out' argument to jnp.{np_reduction.__name__} " @@ -706,7 +708,7 @@ def _cumulative_reduction(a: ArrayLike, axis: Axis = None, @_wraps(np.quantile, skip_params=['out', 'overwrite_input']) @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) -def quantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = None, +def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", keepdims: bool = False, interpolation: None = None) -> Array: check_arraylike("quantile", a, q) @@ -722,7 +724,7 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, tuple[int, .. @_wraps(np.nanquantile, skip_params=['out', 'overwrite_input']) @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) -def nanquantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = None, +def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", keepdims: bool = False, interpolation: None = None) -> Array: check_arraylike("nanquantile", a, q) @@ -735,7 +737,7 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, tuple[int, "Use 'method=' instead.", DeprecationWarning) return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, interpolation or method, keepdims, True) -def _quantile(a: Array, q: Array, axis: Optional[Union[int, tuple[int, ...]]], +def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, interpolation: str, keepdims: bool, squash_nans: bool) -> Array: if interpolation not in ["linear", "lower", "higher", "midpoint", "nearest"]: raise ValueError("interpolation can only be 'linear', 'lower', 'higher', " @@ -860,7 +862,7 @@ def _quantile(a: Array, q: Array, axis: Optional[Union[int, tuple[int, ...]]], @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def percentile(a: ArrayLike, q: ArrayLike, - axis: Optional[Union[int, tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", keepdims: bool = False, interpolation: None = None) -> Array: check_arraylike("percentile", a, q) @@ -872,7 +874,7 @@ def percentile(a: ArrayLike, q: ArrayLike, @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def nanpercentile(a: ArrayLike, q: ArrayLike, - axis: Optional[Union[int, tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", keepdims: bool = False, interpolation: None = None) -> Array: check_arraylike("nanpercentile", a, q) @@ -883,7 +885,7 @@ def nanpercentile(a: ArrayLike, q: ArrayLike, @_wraps(np.median, skip_params=['out', 'overwrite_input']) @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims')) -def median(a: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = None, +def median(a: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, keepdims: bool = False) -> Array: check_arraylike("median", a) @@ -892,7 +894,7 @@ def median(a: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = None, @_wraps(np.nanmedian, skip_params=['out', 'overwrite_input']) @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims')) -def nanmedian(a: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = None, +def nanmedian(a: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, keepdims: bool = False) -> Array: check_arraylike("nanmedian", a) diff --git a/jax/_src/numpy/setops.py b/jax/_src/numpy/setops.py index 475f7af5aade..28e42a036bbe 100644 --- a/jax/_src/numpy/setops.py +++ b/jax/_src/numpy/setops.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from functools import partial import math import operator from textwrap import dedent as _dedent -from typing import Optional, Union, cast +from typing import cast import numpy as np @@ -79,7 +81,7 @@ def _in1d(ar1: ArrayLike, ar2: ArrayLike, invert: bool) -> Array: When ``size`` is specified and there are fewer than the indicated number of elements, the remaining elements will be filled with ``fill_value``, which defaults to zero.""")) def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, - *, size: Optional[int] = None, fill_value: Optional[ArrayLike] = None) -> Array: + *, size: int | None = None, fill_value: ArrayLike | None = None) -> Array: check_arraylike("setdiff1d", ar1, ar2) if size is None: ar1 = core.concrete_or_error(None, ar1, "The error arose in setdiff1d()") @@ -117,7 +119,7 @@ def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, remaining elements will be filled with ``fill_value``, which defaults to the minimum value of the union.""")) def union1d(ar1: ArrayLike, ar2: ArrayLike, - *, size: Optional[int] = None, fill_value: Optional[ArrayLike] = None) -> Array: + *, size: int | None = None, fill_value: ArrayLike | None = None) -> Array: check_arraylike("union1d", ar1, ar2) if size is None: ar1 = core.concrete_or_error(None, ar1, "The error arose in union1d()") @@ -175,7 +177,7 @@ def _intersect1d_sorted_mask(ar1: ArrayLike, ar2: ArrayLike, return_indices: boo @_wraps(np.intersect1d) def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, - return_indices: bool = False) -> Union[Array, tuple[Array, Array, Array]]: + return_indices: bool = False) -> Array | tuple[Array, Array, Array]: check_arraylike("intersect1d", ar1, ar2) ar1 = core.concrete_or_error(None, ar1, "The error arose in intersect1d()") ar2 = core.concrete_or_error(None, ar2, "The error arose in intersect1d()") @@ -253,9 +255,9 @@ def _unique_sorted_mask(ar: Array, axis: int) -> tuple[Array, Array, Array]: return aux, mask, perm def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bool = False, - return_counts: bool = False, size: Optional[int] = None, - fill_value: Optional[ArrayLike] = None, return_true_size: bool = False - ) -> Union[Array, tuple[Array, ...]]: + return_counts: bool = False, size: int | None = None, + fill_value: ArrayLike | None = None, return_true_size: bool = False + ) -> Array | tuple[Array, ...]: """ Find the unique elements of an array along a particular axis. """ @@ -325,8 +327,8 @@ def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bo remaining elements will be filled with ``fill_value``. The default is the minimum value along the specified axis of the input.""")) def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = False, - return_counts: bool = False, axis: Optional[int] = None, - *, size: Optional[int] = None, fill_value: Optional[ArrayLike] = None): + return_counts: bool = False, axis: int | None = None, + *, size: int | None = None, fill_value: ArrayLike | None = None): check_arraylike("unique", ar) if size is None: ar = core.concrete_or_error(None, ar, diff --git a/jax/_src/numpy/ufunc_api.py b/jax/_src/numpy/ufunc_api.py index ae287b106254..7bc2a59b713a 100644 --- a/jax/_src/numpy/ufunc_api.py +++ b/jax/_src/numpy/ufunc_api.py @@ -14,15 +14,12 @@ """Tools to create numpy-style ufuncs.""" -_AT_INPLACE_WARNING = """\ -Because JAX arrays are immutable, jnp.ufunc.at() cannot operate inplace like -np.ufunc.at(). Instead, you can pass inplace=False and capture the result; e.g. ->>> arr = jnp.add.at(arr, ind, val, inplace=False) -""" +from __future__ import annotations + from functools import partial import math import operator -from typing import Any, Callable, Optional +from typing import Any, Callable import jax from jax._src.typing import Array, ArrayLike, DTypeLike @@ -36,7 +33,14 @@ import numpy as np -def get_if_single_primitive(fun: Callable[..., Any], *args: Any) -> Optional[jax.core.Primitive]: +_AT_INPLACE_WARNING = """\ +Because JAX arrays are immutable, jnp.ufunc.at() cannot operate inplace like +np.ufunc.at(). Instead, you can pass inplace=False and capture the result; e.g. +>>> arr = jnp.add.at(arr, ind, val, inplace=False) +""" + + +def get_if_single_primitive(fun: Callable[..., Any], *args: Any) -> jax.core.Primitive | None: """ If fun(*args) lowers to a single primitive with inputs and outputs matching function inputs and outputs, return that primitive. Otherwise return None. @@ -78,8 +82,8 @@ class ufunc: """ def __init__(self, func: Callable[..., Any], /, nin: int, nout: int, *, - name: Optional[str] = None, - nargs: Optional[int] = None, + name: str | None = None, + nargs: int | None = None, identity: Any = None, update_doc=False): # We want ufunc instances to work properly when marked as static, # and for this reason it's important that their properties not be @@ -129,9 +133,9 @@ def __call__(self, *args: ArrayLike, @_wraps(np.ufunc.reduce, module="numpy.ufunc") @partial(jax.jit, static_argnames=['self', 'axis', 'dtype', 'out', 'keepdims']) - def reduce(self, a: ArrayLike, axis: int = 0, dtype: Optional[DTypeLike] = None, - out: None = None, keepdims: bool = False, initial: Optional[ArrayLike] = None, - where: Optional[ArrayLike] = None) -> Array: + def reduce(self, a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None) -> Array: check_arraylike(f"{self.__name__}.reduce", a) if self.nin != 2: raise ValueError("reduce only supported for binary ufuncs") @@ -155,9 +159,9 @@ def reduce(self, a: ArrayLike, axis: int = 0, dtype: Optional[DTypeLike] = None, reducer = _primitive_reducers.get(primitive, self._reduce_via_scan) return reducer(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where) - def _reduce_via_scan(self, arr: ArrayLike, axis: int = 0, dtype: Optional[DTypeLike] = None, - keepdims: bool = False, initial: Optional[ArrayLike] = None, - where: Optional[ArrayLike] = None) -> Array: + def _reduce_via_scan(self, arr: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, + keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None) -> Array: assert self.nin == 2 and self.nout == 1 arr = lax_internal.asarray(arr) if initial is None: @@ -217,7 +221,7 @@ def body_fun(i, val): @_wraps(np.ufunc.accumulate, module="numpy.ufunc") @partial(jax.jit, static_argnames=['self', 'axis', 'dtype']) - def accumulate(self, a: ArrayLike, axis: int = 0, dtype: Optional[DTypeLike] = None, + def accumulate(self, a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, out: None = None) -> Array: if self.nin != 2: raise ValueError("accumulate only supported for binary ufuncs") @@ -233,7 +237,7 @@ def accumulate(self, a: ArrayLike, axis: int = 0, dtype: Optional[DTypeLike] = N return accumulator(a, axis=axis, dtype=dtype) def _accumulate_via_scan(self, arr: ArrayLike, axis: int = 0, - dtype: Optional[DTypeLike] = None) -> Array: + dtype: DTypeLike | None = None) -> Array: assert self.nin == 2 and self.nout == 1 check_arraylike(f"{self.__name__}.accumulate", arr) arr = lax_internal.asarray(arr) @@ -255,7 +259,7 @@ def scan_fun(carry, _): @_wraps(np.ufunc.accumulate, module="numpy.ufunc") @partial(jax.jit, static_argnums=[0], static_argnames=['inplace']) - def at(self, a: ArrayLike, indices: Any, b: Optional[ArrayLike] = None, /, *, + def at(self, a: ArrayLike, indices: Any, b: ArrayLike | None = None, /, *, inplace: bool = True) -> Array: if inplace: raise NotImplementedError(_AT_INPLACE_WARNING) @@ -295,7 +299,7 @@ def scan_fun(carry, x): @_wraps(np.ufunc.reduceat, module="numpy.ufunc") @partial(jax.jit, static_argnames=['self', 'axis', 'dtype']) def reduceat(self, a: ArrayLike, indices: Any, axis: int = 0, - dtype: Optional[DTypeLike] = None, out: None = None) -> Array: + dtype: DTypeLike | None = None, out: None = None) -> Array: if self.nin != 2: raise ValueError("reduceat only supported for binary ufuncs") if self.nout != 1: @@ -305,7 +309,7 @@ def reduceat(self, a: ArrayLike, indices: Any, axis: int = 0, return self._reduceat_via_scan(a, indices, axis=axis, dtype=dtype) def _reduceat_via_scan(self, a: ArrayLike, indices: Any, axis: int = 0, - dtype: Optional[DTypeLike] = None) -> Array: + dtype: DTypeLike | None = None) -> Array: check_arraylike(f"{self.__name__}.reduceat", a, indices) a = lax_internal.asarray(a) idx_tuple = _eliminate_deprecated_list_indexing(indices) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 8bf217977ea4..cb6da41929ad 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -16,10 +16,12 @@ Implements ufuncs for jax.numpy. """ +from __future__ import annotations + from functools import partial import operator from textwrap import dedent -from typing import Any, Callable, Union, overload +from typing import Any, Callable, overload import numpy as np @@ -124,9 +126,9 @@ def _logical_op(np_op: Callable[..., Any], bitwise_op: UnOp) -> UnOp: ... @overload def _logical_op(np_op: Callable[..., Any], bitwise_op: BinOp) -> BinOp: ... @overload -def _logical_op(np_op: Callable[..., Any], bitwise_op: Union[UnOp, BinOp]) -> Union[UnOp, BinOp]: ... +def _logical_op(np_op: Callable[..., Any], bitwise_op: UnOp | BinOp) -> UnOp | BinOp: ... -def _logical_op(np_op: Callable[..., Any], bitwise_op: Union[UnOp, BinOp]) -> Union[UnOp, BinOp]: +def _logical_op(np_op: Callable[..., Any], bitwise_op: UnOp | BinOp) -> UnOp | BinOp: @_wraps(np_op, update_doc=False, module='numpy') @partial(jit, inline=True) def op(*args): diff --git a/jax/_src/op_shardings.py b/jax/_src/op_shardings.py index 05ecfca8d82f..74d4a6320645 100644 --- a/jax/_src/op_shardings.py +++ b/jax/_src/op_shardings.py @@ -13,6 +13,8 @@ # limitations under the License. """Sharding utilities""" +from __future__ import annotations + from collections.abc import Sequence import itertools from typing import Union @@ -43,7 +45,7 @@ def get_num_ways_dim_sharded( return partitions, num_replicas -def is_op_sharding_replicated(op: Union[xc.OpSharding, xc.HloSharding]) -> bool: +def is_op_sharding_replicated(op: xc.OpSharding | xc.HloSharding) -> bool: if isinstance(op, xc.OpSharding): op = xc.HloSharding.from_proto(op) if op.num_devices() == 1: @@ -51,8 +53,8 @@ def is_op_sharding_replicated(op: Union[xc.OpSharding, xc.HloSharding]) -> bool: return op.is_replicated() # type: ignore -def are_op_shardings_equal(op1: Union[xc.OpSharding, xc.HloSharding], - op2: Union[xc.OpSharding, xc.HloSharding]) -> bool: +def are_op_shardings_equal(op1: xc.OpSharding | xc.HloSharding, + op2: xc.OpSharding | xc.HloSharding) -> bool: if id(op1) == id(op2): return True if is_op_sharding_replicated(op1) and is_op_sharding_replicated(op2): diff --git a/jax/_src/ops/scatter.py b/jax/_src/ops/scatter.py index 9a2a2fa34626..c0a24667034d 100644 --- a/jax/_src/ops/scatter.py +++ b/jax/_src/ops/scatter.py @@ -14,9 +14,11 @@ # Helpers for indexed updates. +from __future__ import annotations + from collections.abc import Sequence import sys -from typing import Callable, Optional, Union +from typing import Callable, Union import warnings import numpy as np @@ -36,9 +38,9 @@ if sys.version_info >= (3, 10): from types import EllipsisType - SingleIndex = Union[None, int, slice, Sequence[int], Array, EllipsisType] + SingleIndex = int | slice | Sequence[int] | Array | EllipsisType | None else: - SingleIndex = Union[None, int, slice, Sequence[int], Array] + SingleIndex = Union[int, slice, Sequence[int], Array, None] Index = Union[SingleIndex, tuple[SingleIndex, ...]] Scalar = Union[complex, float, int, np.number] @@ -158,12 +160,12 @@ def _segment_update(name: str, data: ArrayLike, segment_ids: ArrayLike, scatter_op: Callable, - num_segments: Optional[int] = None, + num_segments: int | None = None, indices_are_sorted: bool = False, unique_indices: bool = False, - bucket_size: Optional[int] = None, - reducer: Optional[Callable] = None, - mode: Optional[lax.GatherScatterMode] = None) -> Array: + bucket_size: int | None = None, + reducer: Callable | None = None, + mode: lax.GatherScatterMode | None = None) -> Array: check_arraylike(name, data, segment_ids) mode = lax.GatherScatterMode.FILL_OR_DROP if mode is None else mode data = jnp.asarray(data) @@ -198,11 +200,11 @@ def _segment_update(name: str, def segment_sum(data: ArrayLike, segment_ids: ArrayLike, - num_segments: Optional[int] = None, + num_segments: int | None = None, indices_are_sorted: bool = False, unique_indices: bool = False, - bucket_size: Optional[int] = None, - mode: Optional[lax.GatherScatterMode] = None) -> Array: + bucket_size: int | None = None, + mode: lax.GatherScatterMode | None = None) -> Array: """Computes the sum within segments of an array. Similar to TensorFlow's `segment_sum @@ -253,11 +255,11 @@ def segment_sum(data: ArrayLike, def segment_prod(data: ArrayLike, segment_ids: ArrayLike, - num_segments: Optional[int] = None, + num_segments: int | None = None, indices_are_sorted: bool = False, unique_indices: bool = False, - bucket_size: Optional[int] = None, - mode: Optional[lax.GatherScatterMode] = None) -> Array: + bucket_size: int | None = None, + mode: lax.GatherScatterMode | None = None) -> Array: """Computes the product within segments of an array. Similar to TensorFlow's `segment_prod @@ -309,11 +311,11 @@ def segment_prod(data: ArrayLike, def segment_max(data: ArrayLike, segment_ids: ArrayLike, - num_segments: Optional[int] = None, + num_segments: int | None = None, indices_are_sorted: bool = False, unique_indices: bool = False, - bucket_size: Optional[int] = None, - mode: Optional[lax.GatherScatterMode] = None) -> Array: + bucket_size: int | None = None, + mode: lax.GatherScatterMode | None = None) -> Array: """Computes the maximum within segments of an array. Similar to TensorFlow's `segment_max @@ -364,11 +366,11 @@ def segment_max(data: ArrayLike, def segment_min(data: ArrayLike, segment_ids: ArrayLike, - num_segments: Optional[int] = None, + num_segments: int | None = None, indices_are_sorted: bool = False, unique_indices: bool = False, - bucket_size: Optional[int] = None, - mode: Optional[lax.GatherScatterMode] = None) -> Array: + bucket_size: int | None = None, + mode: lax.GatherScatterMode | None = None) -> Array: """Computes the minimum within segments of an array. Similar to TensorFlow's `segment_min diff --git a/jax/_src/ops/special.py b/jax/_src/ops/special.py index 545844ad3e01..525e492ba0fb 100644 --- a/jax/_src/ops/special.py +++ b/jax/_src/ops/special.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import overload, Literal, Optional, Union +from __future__ import annotations + +from typing import overload, Literal import jax from jax import lax @@ -27,19 +29,19 @@ # unnecessary scipy dependencies. @overload -def logsumexp(a: ArrayLike, axis: Axis = None, b: Optional[ArrayLike] = None, +def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, keepdims: bool = False, return_sign: Literal[False] = False) -> Array: ... @overload -def logsumexp(a: ArrayLike, axis: Axis = None, b: Optional[ArrayLike] = None, +def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, keepdims: bool = False, *, return_sign: Literal[True]) -> tuple[Array, Array]: ... @overload -def logsumexp(a: ArrayLike, axis: Axis = None, b: Optional[ArrayLike] = None, - keepdims: bool = False, return_sign: bool = False) -> Union[Array, tuple[Array, Array]]: ... +def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, + keepdims: bool = False, return_sign: bool = False) -> Array | tuple[Array, Array]: ... -def logsumexp(a: ArrayLike, axis: Axis = None, b: Optional[ArrayLike] = None, - keepdims: bool = False, return_sign: bool = False) -> Union[Array, tuple[Array, Array]]: +def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, + keepdims: bool = False, return_sign: bool = False) -> Array | tuple[Array, Array]: r"""Log-sum-exp reduction. Computes diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 95cac8afc392..a8e29031ef81 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from collections.abc import Sequence, Iterable import dataclasses from functools import partial, lru_cache @@ -19,7 +21,7 @@ import itertools as it import logging import weakref -from typing import Callable, Union, cast, Optional, NamedTuple, Any +from typing import Callable, cast, NamedTuple, Any, Union import threading import warnings @@ -384,12 +386,12 @@ class PjitInfo(NamedTuple): static_argnames: tuple[str, ...] donate_argnums: tuple[int, ...] donate_argnames: tuple[str, ...] - device: Optional[xc.Device] - backend: Optional[str] + device: xc.Device | None + backend: str | None keep_unused: bool inline: bool resource_env: Any - abstracted_axes: Optional[Any] + abstracted_axes: Any | None in_layouts: Any # pytree[XlaCompatibleLayout] | None out_layouts: Any # pytree[XlaCompatibleLayout] | None @@ -554,7 +556,7 @@ def _extract_implicit_args( return [x for x, (_, e) in zip(args, in_type) if not e] # type: ignore def _flat_axes_specs(abstracted_axes, *args, **kwargs - ) -> Optional[list[pe.AbstractedAxesSpec]]: + ) -> list[pe.AbstractedAxesSpec] | None: if abstracted_axes is None: return None if kwargs: raise NotImplementedError def ax_leaf(l): @@ -569,15 +571,15 @@ def pjit( fun: Callable, in_shardings=UNSPECIFIED, out_shardings=UNSPECIFIED, - static_argnums: Union[int, Sequence[int], None] = None, - static_argnames: Union[str, Iterable[str], None] = None, - donate_argnums: Union[int, Sequence[int], None] = None, - donate_argnames: Union[str, Iterable[str], None] = None, + static_argnums: int | Sequence[int] | None = None, + static_argnames: str | Iterable[str] | None = None, + donate_argnums: int | Sequence[int] | None = None, + donate_argnames: str | Iterable[str] | None = None, keep_unused: bool = False, - device: Optional[xc.Device] = None, - backend: Optional[str] = None, + device: xc.Device | None = None, + backend: str | None = None, inline: bool = False, - abstracted_axes: Optional[Any] = None, + abstracted_axes: Any | None = None, ) -> stages.Wrapped: """Makes ``fun`` compiled and automatically partitioned across multiple devices. @@ -1001,7 +1003,7 @@ def _pjit_jaxpr(fun, out_shardings_thunk, out_layouts_thunk, in_type, debug_info def pjit_check_aval_sharding( - shardings, flat_avals, names: Optional[tuple[str, ...]], + shardings, flat_avals, names: tuple[str, ...] | None, what_aval: str, allow_uneven_sharding: bool): new_names = [''] * len(shardings) if names is None else names for aval, s, name in zip(flat_avals, shardings, new_names): @@ -1045,7 +1047,7 @@ def pjit_check_aval_sharding( def _resolve_in_shardings( args, pjit_in_shardings: Sequence[PjitSharding], out_shardings: Sequence[PjitSharding], - pjit_mesh: Optional[pxla.Mesh]) -> Sequence[PjitSharding]: + pjit_mesh: pxla.Mesh | None) -> Sequence[PjitSharding]: # If True, means that device or backend is set by the user on pjit and it # has the same semantics as device_put i.e. doesn't matter which device the # arg is on, reshard it to the device mentioned. So don't do any of the @@ -1253,7 +1255,7 @@ class SameDeviceAssignmentTuple: # device_assignment is Optional because shardings can contain `AUTO` and in # that case `mesh` is compulsory to be used. So in that case # `_pjit_lower_cached` cache, resource_env will check against the devices. - device_assignment: Optional[XLADeviceAssignment] + device_assignment: XLADeviceAssignment | None def __hash__(self): shardings_hash = tuple( @@ -1303,8 +1305,8 @@ def _pjit_lower_cached( inline: bool, *, lowering_parameters: mlir.LoweringParameters, - in_layouts: Optional[pxla.MaybeLayout] = None, - out_layouts: Optional[pxla.MaybeLayout] = None): + in_layouts: pxla.MaybeLayout | None = None, + out_layouts: pxla.MaybeLayout | None = None): in_shardings: tuple[PjitShardingMinusUnspecified, ...] = cast( tuple[PjitShardingMinusUnspecified, ...], sdat_in_shardings.shardings) out_shardings: tuple[PjitSharding, ...] = sdat_out_shardings.shardings @@ -1504,7 +1506,7 @@ def _pjit_batcher(insert_axis, spmd_axis_name, pxla.spmd_primitive_batchers[pjit_p] = partial(_pjit_batcher, True, None) def _pjit_batcher_for_sharding( - s: Union[GSPMDSharding, UnspecifiedValue], + s: GSPMDSharding | UnspecifiedValue, dim: int, val: tuple[str, ...], mesh, ndim: int): if is_unspecified(s): return s @@ -1574,7 +1576,7 @@ def _filter_zeros(is_nz_l, l): @weakref_lru_cache def _known_jaxpr_fwd(known_jaxpr: core.ClosedJaxpr, - in_fwd: tuple[Optional[int]]) -> core.ClosedJaxpr: + in_fwd: tuple[int | None]) -> core.ClosedJaxpr: updated_jaxpr = known_jaxpr.jaxpr.replace( outvars=[x for x, i in zip(known_jaxpr.jaxpr.outvars, in_fwd) if i is None]) @@ -1766,7 +1768,7 @@ def _dce_jaxpr_pjit( def dce_jaxpr_pjit_rule(used_outputs: list[bool], eqn: core.JaxprEqn - ) -> tuple[list[bool], Optional[core.JaxprEqn]]: + ) -> tuple[list[bool], core.JaxprEqn | None]: dced_jaxpr, used_inputs = _dce_jaxpr_pjit( eqn.params['jaxpr'], tuple(used_outputs)) @@ -2047,7 +2049,7 @@ def get_unconstrained_dims(sharding: NamedSharding): def _fast_path_get_device_assignment( - shardings: Iterable[PjitSharding]) -> Optional[XLADeviceAssignment]: + shardings: Iterable[PjitSharding]) -> XLADeviceAssignment | None: da = None for i in shardings: if is_unspecified(i): diff --git a/jax/_src/pretty_printer.py b/jax/_src/pretty_printer.py index 7a1d840530c4..b046afdfcec2 100644 --- a/jax/_src/pretty_printer.py +++ b/jax/_src/pretty_printer.py @@ -25,12 +25,14 @@ # Annotations. https://ayazhafiz.com/articles/21/strictly-annotated # +from __future__ import annotations + import abc from collections.abc import Sequence import enum from functools import partial import sys -from typing import NamedTuple, Optional, Union +from typing import NamedTuple from jax._src import config @@ -67,7 +69,7 @@ def _can_use_color() -> bool: class Doc(abc.ABC): __slots__ = () - def format(self, width: int = 80, use_color: Optional[bool] = None, + def format(self, width: int = 80, use_color: bool | None = None, annotation_prefix=" # ") -> str: if use_color is None: use_color = CAN_USE_COLOR and _PPRINT_USE_COLOR.value @@ -77,7 +79,7 @@ def format(self, width: int = 80, use_color: Optional[bool] = None, def __str__(self): return self.format() - def __add__(self, other: 'Doc') -> 'Doc': + def __add__(self, other: Doc) -> Doc: return concat([self, other]) class _NilDoc(Doc): @@ -88,9 +90,9 @@ def __repr__(self): return "nil" class _TextDoc(Doc): __slots__ = ("text", "annotation") text: str - annotation: Optional[str] + annotation: str | None - def __init__(self, text: str, annotation: Optional[str] = None): + def __init__(self, text: str, annotation: str | None = None): assert isinstance(text, str), text assert annotation is None or isinstance(annotation, str), annotation self.text = text @@ -151,14 +153,14 @@ def __repr__(self): return f"nest({self.n, self.child})" class _ColorDoc(Doc): __slots__ = ("foreground", "background", "intensity", "child") - foreground: Optional[Color] - background: Optional[Color] - intensity: Optional[Intensity] + foreground: Color | None + background: Color | None + intensity: Intensity | None child: Doc - def __init__(self, child: Doc, *, foreground: Optional[Color] = None, - background: Optional[Color] = None, - intensity: Optional[Intensity] = None): + def __init__(self, child: Doc, *, foreground: Color | None = None, + background: Color | None = None, + intensity: Intensity | None = None): assert isinstance(child, Doc), child self.child = child self.foreground = foreground @@ -243,7 +245,7 @@ class _State(NamedTuple): class _Line(NamedTuple): text: str width: int - annotations: Union[Optional[str], list[str]] + annotations: str | None | list[str] def _update_color(use_color: bool, state: _ColorState, update: _ColorState @@ -359,7 +361,7 @@ def nil() -> Doc: """An empty document.""" return _nil -def text(s: str, annotation: Optional[str] = None) -> Doc: +def text(s: str, annotation: str | None = None) -> Doc: """Literal text.""" return _TextDoc(s, annotation) @@ -391,9 +393,9 @@ def nest(n: int, doc: Doc) -> Doc: return _NestDoc(n, doc) -def color(doc: Doc, *, foreground: Optional[Color] = None, - background: Optional[Color] = None, - intensity: Optional[Intensity] = None): +def color(doc: Doc, *, foreground: Color | None = None, + background: Color | None = None, + intensity: Intensity | None = None): """ANSI colors. Overrides the foreground/background/intensity of the text for the child doc. diff --git a/jax/_src/profiler.py b/jax/_src/profiler.py index 76442a561e68..0b013b786dd5 100644 --- a/jax/_src/profiler.py +++ b/jax/_src/profiler.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from contextlib import contextmanager from functools import wraps import glob @@ -23,7 +25,7 @@ import socketserver import threading -from typing import Callable, Optional +from typing import Callable from jax._src import traceback_util traceback_util.register_exclusion(__file__) @@ -31,7 +33,7 @@ from jax._src import xla_bridge from jax._src.lib import xla_client -_profiler_server: Optional[xla_client.profiler.ProfilerServer] = None +_profiler_server: xla_client.profiler.ProfilerServer | None = None logger = logging.getLogger(__name__) @@ -300,7 +302,7 @@ def __init__(self, name: str, **kwargs): super().__init__(name, _r=1, **kwargs) -def annotate_function(func: Callable, name: Optional[str] = None, +def annotate_function(func: Callable, name: str | None = None, **decorator_kwargs): """Decorator that generates a trace event for the execution of a function. @@ -336,7 +338,7 @@ def wrapper(*args, **kwargs): return wrapper -def device_memory_profile(backend: Optional[str] = None) -> bytes: +def device_memory_profile(backend: str | None = None) -> bytes: """Captures a JAX device memory profile as ``pprof``-format protocol buffer. A device memory profile is a snapshot of the state of memory, that describes the JAX @@ -364,7 +366,7 @@ def device_memory_profile(backend: Optional[str] = None) -> bytes: return xla_client.heap_profile(xla_bridge.get_backend(backend)) -def save_device_memory_profile(filename, backend: Optional[str] = None) -> None: +def save_device_memory_profile(filename, backend: str | None = None) -> None: """Collects a device memory profile and writes it to a file. :func:`save_device_memory_profile` is a convenience wrapper around :func:`device_memory_profile` diff --git a/jax/_src/random.py b/jax/_src/random.py index fa23fecb2916..c57761db7107 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -14,13 +14,12 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Hashable, Sequence from functools import partial import math from operator import index import typing from typing import Union -from collections.abc import Hashable import warnings import numpy as np diff --git a/jax/_src/scipy/fft.py b/jax/_src/scipy/fft.py index d1c0a9623c57..6c2d3cd5a53a 100644 --- a/jax/_src/scipy/fft.py +++ b/jax/_src/scipy/fft.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from collections.abc import Sequence from functools import partial import math -from typing import Optional import scipy.fft as osp_fft from jax import lax @@ -42,8 +43,8 @@ def _dct_ortho_norm(out: Array, axis: int) -> Array: # John Makhoul: A Fast Cosine Transform in One and Two Dimensions (1980) @_wraps(osp_fft.dct) -def dct(x: Array, type: int = 2, n: Optional[int] = None, - axis: int = -1, norm: Optional[str] = None) -> Array: +def dct(x: Array, type: int = 2, n: int | None = None, + axis: int = -1, norm: str | None = None) -> Array: if type != 2: raise NotImplementedError('Only DCT type 2 is implemented.') @@ -64,7 +65,7 @@ def dct(x: Array, type: int = 2, n: Optional[int] = None, return out -def _dct2(x: Array, axes: Sequence[int], norm: Optional[str]) -> Array: +def _dct2(x: Array, axes: Sequence[int], norm: str | None) -> Array: axis1, axis2 = map(partial(canonicalize_axis, num_dims=x.ndim), axes) N1, N2 = x.shape[axis1], x.shape[axis2] v = _dct_interleave(_dct_interleave(x, axis1), axis2) @@ -82,9 +83,9 @@ def _dct2(x: Array, axes: Sequence[int], norm: Optional[str]) -> Array: @_wraps(osp_fft.dctn) def dctn(x: Array, type: int = 2, - s: Optional[Sequence[int]]=None, - axes: Optional[Sequence[int]] = None, - norm: Optional[str] = None) -> Array: + s: Sequence[int] | None=None, + axes: Sequence[int] | None = None, + norm: str | None = None) -> Array: if type != 2: raise NotImplementedError('Only DCT type 2 is implemented.') @@ -109,8 +110,8 @@ def dctn(x: Array, type: int = 2, @_wraps(osp_fft.dct) -def idct(x: Array, type: int = 2, n: Optional[int] = None, - axis: int = -1, norm: Optional[str] = None) -> Array: +def idct(x: Array, type: int = 2, n: int | None = None, + axis: int = -1, norm: str | None = None) -> Array: if type != 2: raise NotImplementedError('Only DCT type 2 is implemented.') @@ -140,9 +141,9 @@ def idct(x: Array, type: int = 2, n: Optional[int] = None, @_wraps(osp_fft.idctn) def idctn(x: Array, type: int = 2, - s: Optional[Sequence[int]]=None, - axes: Optional[Sequence[int]] = None, - norm: Optional[str] = None) -> Array: + s: Sequence[int] | None=None, + axes: Sequence[int] | None = None, + norm: str | None = None) -> Array: if type != 2: raise NotImplementedError('Only DCT type 2 is implemented.') diff --git a/jax/_src/scipy/linalg.py b/jax/_src/scipy/linalg.py index a7f93f9fce00..6c40ad890f54 100644 --- a/jax/_src/scipy/linalg.py +++ b/jax/_src/scipy/linalg.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations from functools import partial import numpy as np import scipy.linalg import textwrap -from typing import cast, overload, Any, Literal, Optional, Union +from typing import cast, overload, Any, Literal import jax import jax.numpy as jnp @@ -84,10 +85,10 @@ def _svd(x: ArrayLike, *, full_matrices: bool, compute_uv: Literal[True]) -> tup def _svd(x: ArrayLike, *, full_matrices: bool, compute_uv: Literal[False]) -> Array: ... @overload -def _svd(x: ArrayLike, *, full_matrices: bool, compute_uv: bool) -> Union[Array, tuple[Array, Array, Array]]: ... +def _svd(x: ArrayLike, *, full_matrices: bool, compute_uv: bool) -> Array | tuple[Array, Array, Array]: ... @partial(jit, static_argnames=('full_matrices', 'compute_uv')) -def _svd(a: ArrayLike, *, full_matrices: bool, compute_uv: bool) -> Union[Array, tuple[Array, Array, Array]]: +def _svd(a: ArrayLike, *, full_matrices: bool, compute_uv: bool) -> Array | tuple[Array, Array, Array]: a, = promote_dtypes_inexact(jnp.asarray(a)) return lax_linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv) @@ -109,13 +110,13 @@ def svd(a: ArrayLike, full_matrices: bool = True, *, compute_uv: Literal[False], @overload def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True, overwrite_a: bool = False, check_finite: bool = True, - lapack_driver: str = 'gesdd') -> Union[Array, tuple[Array, Array, Array]]: ... + lapack_driver: str = 'gesdd') -> Array | tuple[Array, Array, Array]: ... @_wraps(scipy.linalg.svd, lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite', 'lapack_driver')) def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True, overwrite_a: bool = False, check_finite: bool = True, - lapack_driver: str = 'gesdd') -> Union[Array, tuple[Array, Array, Array]]: + lapack_driver: str = 'gesdd') -> Array | tuple[Array, Array, Array]: del overwrite_a, check_finite, lapack_driver # unused return _svd(a, full_matrices=full_matrices, compute_uv=compute_uv) @@ -127,20 +128,20 @@ def det(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> A @overload -def _eigh(a: ArrayLike, b: Optional[ArrayLike], lower: bool, eigvals_only: Literal[True], +def _eigh(a: ArrayLike, b: ArrayLike | None, lower: bool, eigvals_only: Literal[True], eigvals: None, type: int) -> Array: ... @overload -def _eigh(a: ArrayLike, b: Optional[ArrayLike], lower: bool, eigvals_only: Literal[False], +def _eigh(a: ArrayLike, b: ArrayLike | None, lower: bool, eigvals_only: Literal[False], eigvals: None, type: int) -> tuple[Array, Array]: ... @overload -def _eigh(a: ArrayLike, b: Optional[ArrayLike], lower: bool, eigvals_only: bool, - eigvals: None, type: int) -> Union[Array, tuple[Array, Array]]: ... +def _eigh(a: ArrayLike, b: ArrayLike | None, lower: bool, eigvals_only: bool, + eigvals: None, type: int) -> Array | tuple[Array, Array]: ... @partial(jit, static_argnames=('lower', 'eigvals_only', 'eigvals', 'type')) -def _eigh(a: ArrayLike, b: Optional[ArrayLike], lower: bool, eigvals_only: bool, - eigvals: None, type: int) -> Union[Array, tuple[Array, Array]]: +def _eigh(a: ArrayLike, b: ArrayLike | None, lower: bool, eigvals_only: bool, + eigvals: None, type: int) -> Array | tuple[Array, Array]: if b is not None: raise NotImplementedError("Only the b=None case of eigh is implemented") if type != 1: @@ -158,36 +159,36 @@ def _eigh(a: ArrayLike, b: Optional[ArrayLike], lower: bool, eigvals_only: bool, return w, v @overload -def eigh(a: ArrayLike, b: Optional[ArrayLike] = None, lower: bool = True, +def eigh(a: ArrayLike, b: ArrayLike | None = None, lower: bool = True, eigvals_only: Literal[False] = False, overwrite_a: bool = False, overwrite_b: bool = False, turbo: bool = True, eigvals: None = None, type: int = 1, check_finite: bool = True) -> tuple[Array, Array]: ... @overload -def eigh(a: ArrayLike, b: Optional[ArrayLike] = None, lower: bool = True, *, +def eigh(a: ArrayLike, b: ArrayLike | None = None, lower: bool = True, *, eigvals_only: Literal[True], overwrite_a: bool = False, overwrite_b: bool = False, turbo: bool = True, eigvals: None = None, type: int = 1, check_finite: bool = True) -> Array: ... @overload -def eigh(a: ArrayLike, b: Optional[ArrayLike], lower: bool, +def eigh(a: ArrayLike, b: ArrayLike | None, lower: bool, eigvals_only: Literal[True], overwrite_a: bool = False, overwrite_b: bool = False, turbo: bool = True, eigvals: None = None, type: int = 1, check_finite: bool = True) -> Array: ... @overload -def eigh(a: ArrayLike, b: Optional[ArrayLike] = None, lower: bool = True, +def eigh(a: ArrayLike, b: ArrayLike | None = None, lower: bool = True, eigvals_only: bool = False, overwrite_a: bool = False, overwrite_b: bool = False, turbo: bool = True, eigvals: None = None, - type: int = 1, check_finite: bool = True) -> Union[Array, tuple[Array, Array]]: ... + type: int = 1, check_finite: bool = True) -> Array | tuple[Array, Array]: ... @_wraps(scipy.linalg.eigh, lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'overwrite_b', 'turbo', 'check_finite')) -def eigh(a: ArrayLike, b: Optional[ArrayLike] = None, lower: bool = True, +def eigh(a: ArrayLike, b: ArrayLike | None = None, lower: bool = True, eigvals_only: bool = False, overwrite_a: bool = False, overwrite_b: bool = False, turbo: bool = True, eigvals: None = None, - type: int = 1, check_finite: bool = True) -> Union[Array, tuple[Array, Array]]: + type: int = 1, check_finite: bool = True) -> Array | tuple[Array, Array]: del overwrite_a, overwrite_b, turbo, check_finite # unused return _eigh(a, b, lower, eigvals_only, eigvals, type) @@ -239,10 +240,10 @@ def _lu(a: ArrayLike, permute_l: Literal[True]) -> tuple[Array, Array]: ... def _lu(a: ArrayLike, permute_l: Literal[False]) -> tuple[Array, Array, Array]: ... @overload -def _lu(a: ArrayLike, permute_l: bool) -> Union[tuple[Array, Array], tuple[Array, Array, Array]]: ... +def _lu(a: ArrayLike, permute_l: bool) -> tuple[Array, Array] | tuple[Array, Array, Array]: ... @partial(jit, static_argnums=(1,)) -def _lu(a: ArrayLike, permute_l: bool) -> Union[tuple[Array, Array], tuple[Array, Array, Array]]: +def _lu(a: ArrayLike, permute_l: bool) -> tuple[Array, Array] | tuple[Array, Array, Array]: a, = promote_dtypes_inexact(jnp.asarray(a)) lu, _, permutation = lax_linalg.lu(a) dtype = lax.dtype(a) @@ -266,13 +267,13 @@ def lu(a: ArrayLike, permute_l: Literal[True], overwrite_a: bool = False, @overload def lu(a: ArrayLike, permute_l: bool = False, overwrite_a: bool = False, - check_finite: bool = True) -> Union[tuple[Array, Array], tuple[Array, Array, Array]]: ... + check_finite: bool = True) -> tuple[Array, Array] | tuple[Array, Array, Array]: ... @_wraps(scipy.linalg.lu, update_doc=False, lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite')) @partial(jit, static_argnames=('permute_l', 'overwrite_a', 'check_finite')) def lu(a: ArrayLike, permute_l: bool = False, overwrite_a: bool = False, - check_finite: bool = True) -> Union[tuple[Array, Array], tuple[Array, Array, Array]]: + check_finite: bool = True) -> tuple[Array, Array] | tuple[Array, Array, Array]: del overwrite_a, check_finite # unused return _lu(a, permute_l) @@ -283,10 +284,10 @@ def _qr(a: ArrayLike, mode: Literal["r"], pivoting: bool) -> tuple[Array]: ... def _qr(a: ArrayLike, mode: Literal["full", "economic"], pivoting: bool) -> tuple[Array, Array]: ... @overload -def _qr(a: ArrayLike, mode: str, pivoting: bool) -> Union[tuple[Array], tuple[Array, Array]]: ... +def _qr(a: ArrayLike, mode: str, pivoting: bool) -> tuple[Array] | tuple[Array, Array]: ... @partial(jit, static_argnames=('mode', 'pivoting')) -def _qr(a: ArrayLike, mode: str, pivoting: bool) -> Union[tuple[Array], tuple[Array, Array]]: +def _qr(a: ArrayLike, mode: str, pivoting: bool) -> tuple[Array] | tuple[Array, Array]: if pivoting: raise NotImplementedError( "The pivoting=True case of qr is not implemented.") @@ -317,12 +318,12 @@ def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Lit @overload def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = "full", - pivoting: bool = False, check_finite: bool = True) -> Union[tuple[Array], tuple[Array, Array]]: ... + pivoting: bool = False, check_finite: bool = True) -> tuple[Array] | tuple[Array, Array]: ... @_wraps(scipy.linalg.qr, lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite', 'lwork')) def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = "full", - pivoting: bool = False, check_finite: bool = True) -> Union[tuple[Array], tuple[Array, Array]]: + pivoting: bool = False, check_finite: bool = True) -> tuple[Array] | tuple[Array, Array]: del overwrite_a, lwork, check_finite # unused return _qr(a, mode, pivoting) @@ -364,7 +365,7 @@ def solve(a: ArrayLike, b: ArrayLike, lower: bool = False, return _solve(a, b, assume_a, lower) @partial(jit, static_argnames=('trans', 'lower', 'unit_diagonal')) -def _solve_triangular(a: ArrayLike, b: ArrayLike, trans: Union[int, str], +def _solve_triangular(a: ArrayLike, b: ArrayLike, trans: int | str, lower: bool, unit_diagonal: bool) -> Array: if trans == 0 or trans == "N": transpose_a, conjugate_a = False, False @@ -392,7 +393,7 @@ def _solve_triangular(a: ArrayLike, b: ArrayLike, trans: Union[int, str], @_wraps(scipy.linalg.solve_triangular, lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_b', 'debug', 'check_finite')) -def solve_triangular(a: ArrayLike, b: ArrayLike, trans: Union[int, str] = 0, lower: bool = False, +def solve_triangular(a: ArrayLike, b: ArrayLike, trans: int | str = 0, lower: bool = False, unit_diagonal: bool = False, overwrite_b: bool = False, debug: Any = None, check_finite: bool = True) -> Array: del overwrite_b, debug, check_finite # unused @@ -570,21 +571,21 @@ def _pade13(A: Array) -> tuple[Array, Array]: """) @overload -def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: Optional[str] = None, +def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, compute_expm: Literal[True] = True) -> tuple[Array, Array]: ... @overload -def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: Optional[str] = None, +def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, compute_expm: Literal[False]) -> Array: ... @overload -def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: Optional[str] = None, - compute_expm: bool = True) -> Union[Array, tuple[Array, Array]]: ... +def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, + compute_expm: bool = True) -> Array | tuple[Array, Array]: ... @_wraps(scipy.linalg.expm_frechet, lax_description=_expm_frechet_description) @partial(jit, static_argnames=('method', 'compute_expm')) -def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: Optional[str] = None, - compute_expm: bool = True) -> Union[Array, tuple[Array, Array]]: +def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, + compute_expm: bool = True) -> Array | tuple[Array, Array]: A_arr = jnp.asarray(A) E_arr = jnp.asarray(E) if A_arr.ndim != 2 or A_arr.shape[0] != A_arr.shape[1]: @@ -631,8 +632,8 @@ def block_diag(*arrs: ArrayLike) -> Array: @_wraps(scipy.linalg.eigh_tridiagonal) @partial(jit, static_argnames=("eigvals_only", "select", "select_range")) def eigh_tridiagonal(d: ArrayLike, e: ArrayLike, *, eigvals_only: bool = False, - select: str = 'a', select_range: Optional[tuple[float, float]] = None, - tol: Optional[float] = None) -> Array: + select: str = 'a', select_range: tuple[float, float] | None = None, + tol: float | None = None) -> Array: if not eigvals_only: raise NotImplementedError("Calculation of eigenvectors is not implemented") @@ -786,8 +787,8 @@ def body(args): @partial(jit, static_argnames=('side', 'method')) @jax.default_matmul_precision("float32") -def polar(a: ArrayLike, side: str = 'right', *, method: str = 'qdwh', eps: Optional[float] = None, - max_iterations: Optional[int] = None) -> tuple[Array, Array]: +def polar(a: ArrayLike, side: str = 'right', *, method: str = 'qdwh', eps: float | None = None, + max_iterations: int | None = None) -> tuple[Array, Array]: r"""Computes the polar decomposition. Given the :math:`m \times n` matrix :math:`a`, returns the factors of the polar @@ -999,7 +1000,7 @@ def hessenberg(a: ArrayLike, *, calc_q: Literal[True], overwrite_a: bool = False @_wraps(scipy.linalg.hessenberg, lax_description=_no_overwrite_and_chkfinite_doc) @partial(jit, static_argnames=('calc_q', 'check_finite', 'overwrite_a')) def hessenberg(a: ArrayLike, *, calc_q: bool = False, overwrite_a: bool = False, - check_finite: bool = True) -> Union[Array, tuple[Array, Array]]: + check_finite: bool = True) -> Array | tuple[Array, Array]: del overwrite_a, check_finite n = jnp.shape(a)[-1] if n == 0: @@ -1020,7 +1021,7 @@ def hessenberg(a: ArrayLike, *, calc_q: bool = False, overwrite_a: bool = False, return h @_wraps(scipy.linalg.toeplitz) -def toeplitz(c: ArrayLike, r: Optional[ArrayLike] = None) -> Array: +def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array: if r is None: check_arraylike("toeplitz", c) r = jnp.conjugate(jnp.asarray(c)) diff --git a/jax/_src/scipy/optimize/_lbfgs.py b/jax/_src/scipy/optimize/_lbfgs.py index b50c54ec1700..a1719d5280ec 100644 --- a/jax/_src/scipy/optimize/_lbfgs.py +++ b/jax/_src/scipy/optimize/_lbfgs.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """The Limited-Memory Broyden-Fletcher-Goldfarb-Shanno minimization algorithm.""" -from typing import Callable, NamedTuple, Optional, Union + +from __future__ import annotations + +from typing import Callable, NamedTuple from functools import partial import jax @@ -50,32 +53,32 @@ class LBFGSResults(NamedTuple): 5 = line search failed ls_status: integer describing the end status of the last line search """ - converged: Union[bool, Array] - failed: Union[bool, Array] - k: Union[int, Array] - nfev: Union[int, Array] - ngev: Union[int, Array] + converged: bool | Array + failed: bool | Array + k: int | Array + nfev: int | Array + ngev: int | Array x_k: Array f_k: Array g_k: Array s_history: Array y_history: Array rho_history: Array - gamma: Union[float, Array] - status: Union[int, Array] - ls_status: Union[int, Array] + gamma: float | Array + status: int | Array + ls_status: int | Array def _minimize_lbfgs( fun: Callable, x0: Array, - maxiter: Optional[float] = None, + maxiter: float | None = None, norm=jnp.inf, maxcor: int = 10, ftol: float = 2.220446049250313e-09, gtol: float = 1e-05, - maxfun: Optional[float] = None, - maxgrad: Optional[float] = None, + maxfun: float | None = None, + maxgrad: float | None = None, maxls: int = 20, ): """ diff --git a/jax/_src/scipy/optimize/bfgs.py b/jax/_src/scipy/optimize/bfgs.py index de2177cbecd6..b6fd9f9dda17 100644 --- a/jax/_src/scipy/optimize/bfgs.py +++ b/jax/_src/scipy/optimize/bfgs.py @@ -12,8 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. """The Broyden-Fletcher-Goldfarb-Shanno minimization algorithm.""" + +from __future__ import annotations + from functools import partial -from typing import Callable, NamedTuple, Optional, Union +from typing import Callable, NamedTuple import jax import jax.numpy as jnp @@ -45,19 +48,19 @@ class _BFGSResults(NamedTuple): line_search_status: int describing line search end state (only means something if line search fails). """ - converged: Union[bool, jax.Array] - failed: Union[bool, jax.Array] - k: Union[int, jax.Array] - nfev: Union[int, jax.Array] - ngev: Union[int, jax.Array] - nhev: Union[int, jax.Array] + converged: bool | jax.Array + failed: bool | jax.Array + k: int | jax.Array + nfev: int | jax.Array + ngev: int | jax.Array + nhev: int | jax.Array x_k: jax.Array f_k: jax.Array g_k: jax.Array H_k: jax.Array old_old_fval: jax.Array - status: Union[int, jax.Array] - line_search_status: Union[int, jax.Array] + status: int | jax.Array + line_search_status: int | jax.Array _dot = partial(jnp.dot, precision=lax.Precision.HIGHEST) @@ -67,7 +70,7 @@ class _BFGSResults(NamedTuple): def minimize_bfgs( fun: Callable, x0: jax.Array, - maxiter: Optional[int] = None, + maxiter: int | None = None, norm=jnp.inf, gtol: float = 1e-5, line_search_maxiter: int = 10, diff --git a/jax/_src/scipy/optimize/line_search.py b/jax/_src/scipy/optimize/line_search.py index ae337bbe5c68..078d23d97a96 100644 --- a/jax/_src/scipy/optimize/line_search.py +++ b/jax/_src/scipy/optimize/line_search.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import NamedTuple, Union +from __future__ import annotations + +from typing import NamedTuple from functools import partial from jax._src.numpy.util import promote_dtypes_inexact @@ -57,23 +59,23 @@ def _binary_replace(replace_bit, original_dict, new_dict, keys=None): class _ZoomState(NamedTuple): - done: Union[bool, jax.Array] - failed: Union[bool, jax.Array] - j: Union[int, jax.Array] - a_lo: Union[float, jax.Array] - phi_lo: Union[float, jax.Array] - dphi_lo: Union[float, jax.Array] - a_hi: Union[float, jax.Array] - phi_hi: Union[float, jax.Array] - dphi_hi: Union[float, jax.Array] - a_rec: Union[float, jax.Array] - phi_rec: Union[float, jax.Array] - a_star: Union[float, jax.Array] - phi_star: Union[float, jax.Array] - dphi_star: Union[float, jax.Array] - g_star: Union[float, jax.Array] - nfev: Union[int, jax.Array] - ngev: Union[int, jax.Array] + done: bool | jax.Array + failed: bool | jax.Array + j: int | jax.Array + a_lo: float | jax.Array + phi_lo: float | jax.Array + dphi_lo: float | jax.Array + a_hi: float | jax.Array + phi_hi: float | jax.Array + dphi_hi: float | jax.Array + a_rec: float | jax.Array + phi_rec: float | jax.Array + a_star: float | jax.Array + phi_star: float | jax.Array + dphi_star: float | jax.Array + g_star: float | jax.Array + nfev: int | jax.Array + ngev: int | jax.Array def _zoom(restricted_func_and_grad, wolfe_one, wolfe_two, a_lo, phi_lo, @@ -213,17 +215,17 @@ def body(state): class _LineSearchState(NamedTuple): - done: Union[bool, jax.Array] - failed: Union[bool, jax.Array] - i: Union[int, jax.Array] - a_i1: Union[float, jax.Array] - phi_i1: Union[float, jax.Array] - dphi_i1: Union[float, jax.Array] - nfev: Union[int, jax.Array] - ngev: Union[int, jax.Array] - a_star: Union[float, jax.Array] - phi_star: Union[float, jax.Array] - dphi_star: Union[float, jax.Array] + done: bool | jax.Array + failed: bool | jax.Array + i: int | jax.Array + a_i1: float | jax.Array + phi_i1: float | jax.Array + dphi_i1: float | jax.Array + nfev: int | jax.Array + ngev: int | jax.Array + a_star: float | jax.Array + phi_star: float | jax.Array + dphi_star: float | jax.Array g_star: jax.Array @@ -241,15 +243,15 @@ class _LineSearchResults(NamedTuple): g_k: final gradient value status: integer end status """ - failed: Union[bool, jax.Array] - nit: Union[int, jax.Array] - nfev: Union[int, jax.Array] - ngev: Union[int, jax.Array] - k: Union[int, jax.Array] - a_k: Union[int, jax.Array] + failed: bool | jax.Array + nit: int | jax.Array + nfev: int | jax.Array + ngev: int | jax.Array + k: int | jax.Array + a_k: int | jax.Array f_k: jax.Array g_k: jax.Array - status: Union[bool, jax.Array] + status: bool | jax.Array def line_search(f, xk, pk, old_fval=None, old_old_fval=None, gfk=None, c1=1e-4, diff --git a/jax/_src/scipy/optimize/minimize.py b/jax/_src/scipy/optimize/minimize.py index 7118a52c6a98..830f1228424a 100644 --- a/jax/_src/scipy/optimize/minimize.py +++ b/jax/_src/scipy/optimize/minimize.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from collections.abc import Mapping -from typing import Any, Callable, Optional, Union +from typing import Any, Callable import jax from jax._src.scipy.optimize.bfgs import minimize_bfgs @@ -39,14 +41,14 @@ class OptimizeResults(NamedTuple): nit: integer number of iterations of the optimization algorithm. """ x: jax.Array - success: Union[bool, jax.Array] - status: Union[int, jax.Array] + success: bool | jax.Array + status: int | jax.Array fun: jax.Array jac: jax.Array - hess_inv: Optional[jax.Array] - nfev: Union[int, jax.Array] - njev: Union[int, jax.Array] - nit: Union[int, jax.Array] + hess_inv: jax.Array | None + nfev: int | jax.Array + njev: int | jax.Array + nit: int | jax.Array def minimize( @@ -55,8 +57,8 @@ def minimize( args: tuple = (), *, method: str, - tol: Optional[float] = None, - options: Optional[Mapping[str, Any]] = None, + tol: float | None = None, + options: Mapping[str, Any] | None = None, ) -> OptimizeResults: """Minimization of scalar function of one or more variables. diff --git a/jax/_src/scipy/signal.py b/jax/_src/scipy/signal.py index f3d918e25e17..48b7e6e300be 100644 --- a/jax/_src/scipy/signal.py +++ b/jax/_src/scipy/signal.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from collections.abc import Sequence from functools import partial import math import operator -from typing import Callable, Optional, Union +from typing import Callable import warnings import numpy as np @@ -40,7 +42,7 @@ @_wraps(osp_signal.fftconvolve) def fftconvolve(in1: ArrayLike, in2: ArrayLike, mode: str = "full", - axes: Optional[Sequence[int]] = None) -> Array: + axes: Sequence[int] | None = None) -> Array: check_arraylike('fftconvolve', in1, in2) in1, in2 = promote_dtypes_inexact(in1, in2) if in1.ndim != in2.ndim: @@ -221,7 +223,7 @@ def detrend(data: ArrayLike, axis: int = -1, type: str = 'linear', bp: int = 0, def _fft_helper(x: Array, win: Array, detrend_func: Callable[[Array], Array], - nperseg: int, noverlap: int, nfft: Optional[int], sides: str) -> Array: + nperseg: int, noverlap: int, nfft: int | None, sides: str) -> Array: """Calculate windowed FFT in the same way the original SciPy does. """ if x.dtype.kind == 'i': @@ -284,12 +286,12 @@ def odd_ext(x: Array, n: int, axis: int = -1) -> Array: return ext -def _spectral_helper(x: Array, y: Optional[ArrayLike], fs: ArrayLike = 1.0, - window: str = 'hann', nperseg: Optional[int] = None, - noverlap: Optional[int] = None, nfft: Optional[int] = None, - detrend_type: Union[bool, str, Callable[[Array], Array]] = 'constant', +def _spectral_helper(x: Array, y: ArrayLike | None, fs: ArrayLike = 1.0, + window: str = 'hann', nperseg: int | None = None, + noverlap: int | None = None, nfft: int | None = None, + detrend_type: bool | str | Callable[[Array], Array] = 'constant', return_onesided: bool = True, scaling: str = 'density', - axis: int = -1, mode: str = 'psd', boundary: Optional[str] = None, + axis: int = -1, mode: str = 'psd', boundary: str | None = None, padded: bool = False) -> tuple[Array, Array, Array]: """LAX-backend implementation of `scipy.signal._spectral_helper`. @@ -499,8 +501,8 @@ def detrend_func(d): @_wraps(osp_signal.stft) def stft(x: Array, fs: ArrayLike = 1.0, window: str = 'hann', nperseg: int = 256, - noverlap: Optional[int] = None, nfft: Optional[int] = None, - detrend: bool = False, return_onesided: bool = True, boundary: Optional[str] = 'zeros', + noverlap: int | None = None, nfft: int | None = None, + detrend: bool = False, return_onesided: bool = True, boundary: str | None = 'zeros', padded: bool = True, axis: int = -1) -> tuple[Array, Array, Array]: return _spectral_helper(x, None, fs, window, nperseg, noverlap, nfft, detrend, return_onesided, @@ -517,9 +519,9 @@ def stft(x: Array, fs: ArrayLike = 1.0, window: str = 'hann', nperseg: int = 256 @_wraps(osp_signal.csd, lax_description=_csd_description) -def csd(x: Array, y: Optional[ArrayLike], fs: ArrayLike = 1.0, window: str = 'hann', - nperseg: Optional[int] = None, noverlap: Optional[int] = None, - nfft: Optional[int] = None, detrend: str = 'constant', +def csd(x: Array, y: ArrayLike | None, fs: ArrayLike = 1.0, window: str = 'hann', + nperseg: int | None = None, noverlap: int | None = None, + nfft: int | None = None, detrend: str = 'constant', return_onesided: bool = True, scaling: str = 'density', axis: int = -1, average: str = 'mean') -> tuple[Array, Array]: freqs, _, Pxy = _spectral_helper(x, y, fs, window, nperseg, noverlap, nfft, @@ -551,8 +553,8 @@ def csd(x: Array, y: Optional[ArrayLike], fs: ArrayLike = 1.0, window: str = 'ha @_wraps(osp_signal.welch) def welch(x: Array, fs: ArrayLike = 1.0, window: str = 'hann', - nperseg: Optional[int] = None, noverlap: Optional[int] = None, - nfft: Optional[int] = None, detrend: str = 'constant', + nperseg: int | None = None, noverlap: int | None = None, + nfft: int | None = None, detrend: str = 'constant', return_onesided: bool = True, scaling: str = 'density', axis: int = -1, average: str = 'mean') -> tuple[Array, Array]: freqs, Pxx = csd(x, None, fs=fs, window=window, nperseg=nperseg, @@ -613,8 +615,8 @@ def _overlap_and_add(x: Array, step_size: int) -> Array: @_wraps(osp_signal.istft) def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann', - nperseg: Optional[int] = None, noverlap: Optional[int] = None, - nfft: Optional[int] = None, input_onesided: bool = True, + nperseg: int | None = None, noverlap: int | None = None, + nfft: int | None = None, input_onesided: bool = True, boundary: bool = True, time_axis: int = -1, freq_axis: int = -2) -> tuple[Array, Array]: # Input validation diff --git a/jax/_src/scipy/spatial/transform.py b/jax/_src/scipy/spatial/transform.py index 6aea6bc496f0..c80e5c6c60ac 100644 --- a/jax/_src/scipy/spatial/transform.py +++ b/jax/_src/scipy/spatial/transform.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import functools import re import typing @@ -74,14 +76,14 @@ def from_rotvec(cls, rotvec: jax.Array, degrees: bool = False): return cls(_from_rotvec(rotvec, degrees)) @classmethod - def identity(cls, num: typing.Optional[int] = None, dtype=float): + def identity(cls, num: int | None = None, dtype=float): """Get identity rotation(s).""" assert num is None quat = jnp.array([0., 0., 0., 1.], dtype=dtype) return cls(quat) @classmethod - def random(cls, random_key: jax.Array, num: typing.Optional[int] = None): + def random(cls, random_key: jax.Array, num: int | None = None): """Generate uniformly distributed rotations.""" # Need to implement scipy.stats.special_ortho_group for this to work... raise NotImplementedError @@ -147,7 +149,7 @@ def magnitude(self) -> jax.Array: """Get the magnitude(s) of the rotation(s).""" return _magnitude(self.quat) - def mean(self, weights: typing.Optional[jax.Array] = None): + def mean(self, weights: jax.Array | None = None): """Get the mean of the rotations.""" w = jnp.ones(self.quat.shape[0], dtype=self.quat.dtype) if weights is None else jnp.asarray(weights, dtype=self.quat.dtype) if w.ndim != 1: diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index 70fce360a339..fd4566b98fa2 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from functools import partial import operator -from typing import cast, Any, Optional +from typing import cast, Any import numpy as np import scipy.special as osp_special @@ -267,7 +269,7 @@ def rel_entr( @custom_derivatives.custom_jvp @_wraps(osp_special.zeta, module='scipy.special') -def zeta(x: ArrayLike, q: Optional[ArrayLike] = None) -> Array: +def zeta(x: ArrayLike, q: ArrayLike | None = None) -> Array: if q is None: raise NotImplementedError( "Riemann zeta function not implemented; pass q != None to compute the Hurwitz Zeta function.") @@ -277,7 +279,7 @@ def zeta(x: ArrayLike, q: Optional[ArrayLike] = None) -> Array: # There is no general closed-form derivative for the zeta function, so we compute # derivatives via a series expansion -def _zeta_series_expansion(x: ArrayLike, q: Optional[ArrayLike] = None) -> Array: +def _zeta_series_expansion(x: ArrayLike, q: ArrayLike | None = None) -> Array: if q is None: raise NotImplementedError( "Riemann zeta function not implemented; pass q != None to compute the Hurwitz Zeta function.") @@ -1199,7 +1201,7 @@ def sph_harm(m: Array, n: Array, theta: Array, phi: Array, - n_max: Optional[int] = None) -> Array: + n_max: int | None = None) -> Array: r"""Computes the spherical harmonics. The JAX version has one extra argument `n_max`, the maximum value in `n`. diff --git a/jax/_src/scipy/stats/_core.py b/jax/_src/scipy/stats/_core.py index e83e7a8cb1f3..541f409dcaa7 100644 --- a/jax/_src/scipy/stats/_core.py +++ b/jax/_src/scipy/stats/_core.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from collections import namedtuple from functools import partial import math -from typing import Optional import jax import jax.numpy as jnp @@ -34,7 +35,7 @@ Currently the only supported nan_policy is 'propagate' """) @partial(jit, static_argnames=['axis', 'nan_policy', 'keepdims']) -def mode(a: ArrayLike, axis: Optional[int] = 0, nan_policy: str = "propagate", keepdims: bool = False) -> ModeResult: +def mode(a: ArrayLike, axis: int | None = 0, nan_policy: str = "propagate", keepdims: bool = False) -> ModeResult: check_arraylike("mode", a) x = jnp.atleast_1d(a) @@ -97,7 +98,7 @@ def rankdata( a: ArrayLike, method: str = "average", *, - axis: Optional[int] = None, + axis: int | None = None, nan_policy: str = "propagate", ) -> Array: diff --git a/jax/_src/sharding_specs.py b/jax/_src/sharding_specs.py index e01c67eb65a9..40a2b7df0d8b 100644 --- a/jax/_src/sharding_specs.py +++ b/jax/_src/sharding_specs.py @@ -27,12 +27,14 @@ # This encoding is assumed by various parts of the system, e.g. generating # replica groups for collective operations. +from __future__ import annotations + import collections from collections.abc import Mapping, Sequence import functools import itertools import math -from typing import Any, Optional, Union, cast +from typing import Any, Union, cast import numpy as np @@ -82,7 +84,7 @@ def get_logical_mesh_ids(mesh_shape): _MeshAxisName = Any def sharding_spec_sharding_proto( - self, special_axes: Optional[Mapping[int, OpShardingType]] = None + self, special_axes: Mapping[int, OpShardingType] | None = None ) -> xc.HloSharding: """Converts a ShardingSpec to an OpSharding proto. @@ -273,7 +275,7 @@ def new_mesh_sharding_specs(axis_sizes, axis_names): return functools.partial(make_sharding_spec, axis_sizes, mesh_axis_pos) def pmap_sharding_spec(nrep, axis_size, sharded_shape: Sequence[int], - map_axis: Optional[int]) -> ShardingSpec: + map_axis: int | None) -> ShardingSpec: """Sharding spec for arguments or results of a pmap. Args: nrep: number of local XLA replicas (product of local axis sizes) @@ -310,7 +312,7 @@ def shift_sharded_axis(a: MeshDimAssignment): def create_pmap_sharding_spec(shape: tuple[int, ...], sharded_dim: int = 0, - sharded_dim_size: Optional[int] = None): + sharded_dim_size: int | None = None): if sharded_dim is not None: sharded_shape = shape[:sharded_dim] + shape[sharded_dim+1:] if sharded_dim_size is None: diff --git a/jax/_src/source_info_util.py b/jax/_src/source_info_util.py index e428b1147bd3..f28e78a20a26 100644 --- a/jax/_src/source_info_util.py +++ b/jax/_src/source_info_util.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from collections.abc import Iterator import contextlib import dataclasses @@ -22,7 +24,7 @@ import sysconfig import threading import types -from typing import Optional, NamedTuple, Union +from typing import NamedTuple import jax.version from jax._src.lib import xla_client @@ -80,9 +82,9 @@ def wrap(self, stack: tuple[str, ...]) -> tuple[str, ...]: @dataclasses.dataclass(frozen=True) class NameStack: - stack: tuple[Union[Scope, Transform], ...] = () + stack: tuple[Scope | Transform, ...] = () - def extend(self, name: Union[tuple[str, ...], str]) -> 'NameStack': + def extend(self, name: tuple[str, ...] | str) -> NameStack: if not isinstance(name, tuple): name = (name,) scopes = tuple(map(Scope, name)) @@ -93,19 +95,19 @@ def wrap_name(self, name: str) -> str: return name return f'{self}/{name}' - def transform(self, transform_name: str) -> 'NameStack': + def transform(self, transform_name: str) -> NameStack: return NameStack((*self.stack, Transform(transform_name))) - def __getitem__(self, idx: slice) -> 'NameStack': + def __getitem__(self, idx: slice) -> NameStack: return NameStack(self.stack[idx]) def __len__(self): return len(self.stack) - def __add__(self, other: 'NameStack') -> 'NameStack': + def __add__(self, other: NameStack) -> NameStack: return NameStack(self.stack + other.stack) - def __radd__(self, other: 'NameStack') -> 'NameStack': + def __radd__(self, other: NameStack) -> NameStack: return NameStack(other.stack + self.stack) def __str__(self) -> str: @@ -123,11 +125,11 @@ def new_name_stack(name: str = '') -> NameStack: class SourceInfo(NamedTuple): - traceback: Optional[Traceback] + traceback: Traceback | None name_stack: NameStack - def replace(self, *, traceback: Optional[Traceback] = None, - name_stack: Optional[NameStack] = None) -> 'SourceInfo': + def replace(self, *, traceback: Traceback | None = None, + name_stack: NameStack | None = None) -> SourceInfo: return SourceInfo( self.traceback if traceback is None else traceback, self.name_stack if name_stack is None else name_stack @@ -172,7 +174,7 @@ def user_frames(source_info: SourceInfo) -> Iterator[Frame]: if is_user_filename(code[i].co_filename)) @functools.lru_cache(maxsize=64) -def user_frame(source_info: SourceInfo) -> Optional[Frame]: +def user_frame(source_info: SourceInfo) -> Frame | None: return next(user_frames(source_info), None) def _summarize_frame(frame: Frame) -> str: @@ -217,7 +219,7 @@ def has_user_context(e): return False @contextlib.contextmanager -def user_context(c: Optional[Traceback], *, name_stack: Optional[NameStack] = None): +def user_context(c: Traceback | None, *, name_stack: NameStack | None = None): prev = _source_info_context.context _source_info_context.context = _source_info_context.context.replace( traceback=c, name_stack=name_stack) diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index a32dcb77f533..b5463e71b460 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Module for state primitives.""" -from functools import partial +from __future__ import annotations +from functools import partial from typing import Any, Union import numpy as np diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index fda7ef10d99a..c2cb653a99e6 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Module for state types.""" + from __future__ import annotations from collections.abc import Sequence diff --git a/jax/_src/third_party/scipy/signal_helper.py b/jax/_src/third_party/scipy/signal_helper.py index d9099903d90e..5205c9079b61 100644 --- a/jax/_src/third_party/scipy/signal_helper.py +++ b/jax/_src/third_party/scipy/signal_helper.py @@ -1,14 +1,16 @@ """Utility functions adopted from scipy.signal.""" +from __future__ import annotations + import scipy.signal as osp_signal -from typing import Any, Optional, Union +from typing import Any import warnings import jax.numpy as jnp from jax._src.typing import Array, ArrayLike, DTypeLike -def _triage_segments(window: Union[ArrayLike, str, tuple[Any, ...]], nperseg: Optional[int], +def _triage_segments(window: ArrayLike | str | tuple[Any, ...], nperseg: int | None, input_length: int, dtype: DTypeLike) -> tuple[Array, int]: """ Parses window and nperseg arguments for spectrogram and _spectral_helper. diff --git a/jax/_src/traceback_util.py b/jax/_src/traceback_util.py index c0770dadefa4..5ab04cd9ab64 100644 --- a/jax/_src/traceback_util.py +++ b/jax/_src/traceback_util.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import functools import os import sys import traceback import types -from typing import Any, Callable, Optional, TypeVar, cast +from typing import Any, Callable, TypeVar, cast from jax._src import config from jax._src import util @@ -67,7 +69,7 @@ def _add_tracebackhide_to_hidden_frames(tb: types.TracebackType): if not include_frame(f): f.f_locals["__tracebackhide__"] = True -def filter_traceback(tb: types.TracebackType) -> Optional[types.TracebackType]: +def filter_traceback(tb: types.TracebackType) -> types.TracebackType | None: out = None # Scan the traceback and collect relevant frames. frames = list(traceback.walk_tb(tb)) diff --git a/jax/_src/util.py b/jax/_src/util.py index 23ec012a276a..f492da14154c 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -12,13 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from collections.abc import Iterable, Iterator, Sequence import functools from functools import partial import itertools as it import logging import operator -from typing import (Any, Callable, Generic, Optional, TypeVar, overload, TYPE_CHECKING, cast) +from typing import (Any, Callable, Generic, TypeVar, overload, TYPE_CHECKING, cast) import weakref import numpy as np @@ -144,7 +146,7 @@ def merge_lists(bs: Sequence[bool], l0: Sequence[T], l1: Sequence[T]) -> list[T] return out def subs_list( - subs: Sequence[Optional[int]], src: Sequence[T], base: Sequence[T], + subs: Sequence[int | None], src: Sequence[T], base: Sequence[T], ) -> list[T]: base_ = iter(base) out = [src[i] if i is not None else next(base_) for i in subs] @@ -153,7 +155,7 @@ def subs_list( return out def subs_list2( - subs1: Sequence[Optional[int]], subs2: Sequence[Optional[int]], + subs1: Sequence[int | None], subs2: Sequence[int | None], src1: Sequence[T], src2: Sequence[T], base: Sequence[T], ) -> list[T]: assert len(subs1) == len(subs2) @@ -374,8 +376,8 @@ def ceil_of_ratio(x, y): def wraps( wrapped: Callable, - namestr: Optional[str] = None, - docstr: Optional[str] = None, + namestr: str | None = None, + docstr: str | None = None, **kwargs, ) -> Callable[[T], T]: """ @@ -533,7 +535,7 @@ def __contains__(self, elt: T) -> bool: class HashableWrapper: x: Any - hash: Optional[int] + hash: int | None def __init__(self, x): self.x = x try: self.hash = hash(x) diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 8835afe11e96..29cc0302b71e 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -19,6 +19,8 @@ XLA. There are also a handful of related casting utilities. """ +from __future__ import annotations + from collections.abc import Mapping import dataclasses from functools import lru_cache, partial @@ -32,7 +34,7 @@ import platform as py_platform import sys import threading -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Union import warnings from jax._src import config @@ -47,7 +49,7 @@ logger = logging.getLogger(__name__) -jax_plugins: Optional[Any] +jax_plugins: Any | None try: import jax_plugins # type: ignore except ModuleNotFoundError: @@ -106,7 +108,7 @@ # Backends -def _get_tpu_library_path() -> Optional[str]: +def _get_tpu_library_path() -> str | None: path_from_env = os.getenv("TPU_LIBRARY_PATH") if path_from_env is not None: return path_from_env @@ -128,7 +130,7 @@ def _get_tpu_library_path() -> Optional[str]: return None -def tpu_client_timer_callback(timer_secs: float) -> Optional[xla_client.Client]: +def tpu_client_timer_callback(timer_secs: float) -> xla_client.Client | None: def _log_warning(): warnings.warn( f'TPU backend initialization is taking more than {timer_secs} seconds. ' @@ -154,7 +156,7 @@ def _log_warning(): # example, there could be multiple backends that provide the same kind of # device. -BackendFactory = Callable[[], Optional[xla_client.Client]] +BackendFactory = Callable[[], Union[xla_client.Client, None]] @dataclasses.dataclass class BackendRegistration: @@ -176,7 +178,7 @@ class BackendRegistration: experimental: bool = False _backend_factories: dict[str, BackendRegistration] = {} -_default_backend: Optional[xla_client.Client] = None +_default_backend: xla_client.Client | None = None _backends : dict[str, xla_client.Client] = {} _backend_errors : dict[str, str] = {} _backend_lock = threading.Lock() @@ -376,7 +378,7 @@ def _get_pjrt_plugin_names_and_library_paths( def _get_pjrt_plugin_config( json_path: str, ) -> tuple[ - str, Optional[Mapping[str, Union[str, int, list[int], float, bool]]] + str, Mapping[str, str | int | list[int] | float | bool] | None ]: """Gets PJRT plugin configuration from a json file. @@ -473,8 +475,8 @@ def register_plugin( plugin_name: str, *, priority: int = 400, - library_path: Optional[str] = None, - options: Optional[Mapping[str, Union[str, int, list[int], float, bool]]] = None, + library_path: str | None = None, + options: Mapping[str, str | int | list[int] | float | bool] | None = None, ) -> None: """Registers a backend factory for the PJRT plugin. @@ -771,7 +773,7 @@ def _init_backend(platform: str) -> xla_client.Client: def _get_backend_uncached( - platform: Union[None, str, xla_client.Client] = None + platform: None | str | xla_client.Client = None ) -> xla_client.Client: # TODO(mattjj,skyewm): remove this input polymorphism after we clean up how # 'backend' values are handled @@ -798,13 +800,13 @@ def _get_backend_uncached( @lru_cache(maxsize=None) # don't use util.memoize because there is no X64 dependence. def get_backend( - platform: Union[None, str, xla_client.Client] = None + platform: None | str | xla_client.Client = None ) -> xla_client.Client: return _get_backend_uncached(platform) def get_device_backend( - device: Optional[xla_client.Device] = None, + device: xla_client.Device | None = None, ) -> xla_client.Client: """Returns the Backend associated with `device`, or the default Backend.""" if device is not None: @@ -813,7 +815,7 @@ def get_device_backend( def device_count( - backend: Optional[Union[str, xla_client.Client]] = None + backend: str | xla_client.Client | None = None ) -> int: """Returns the total number of devices. @@ -835,14 +837,14 @@ def device_count( def local_device_count( - backend: Optional[Union[str, xla_client.Client]] = None + backend: str | xla_client.Client | None = None ) -> int: """Returns the number of devices addressable by this process.""" return int(get_backend(backend).local_device_count()) def devices( - backend: Optional[Union[str, xla_client.Client]] = None + backend: str | xla_client.Client | None = None ) -> list[xla_client.Device]: """Returns a list of all devices for a given backend. @@ -874,7 +876,7 @@ def default_backend() -> str: return get_backend(None).platform -def backend_pjrt_c_api_version(platform=None) -> Optional[tuple[int, int]]: +def backend_pjrt_c_api_version(platform=None) -> tuple[int, int] | None: """Returns the PJRT C API version of the backend. Returns None if the backend does not use PJRT C API. @@ -888,9 +890,9 @@ def backend_pjrt_c_api_version(platform=None) -> Optional[tuple[int, int]]: @lru_cache -def local_devices(process_index: Optional[int] = None, - backend: Optional[Union[str, xla_client.Client]] = None, - host_id: Optional[int] = None) -> list[xla_client.Device]: +def local_devices(process_index: int | None = None, + backend: str | xla_client.Client | None = None, + host_id: int | None = None) -> list[xla_client.Device]: """Like :py:func:`jax.devices`, but only returns devices local to a given process. If ``process_index`` is ``None``, returns devices local to this process. @@ -919,7 +921,7 @@ def local_devices(process_index: Optional[int] = None, def process_index( - backend: Optional[Union[str, xla_client.Client]] = None + backend: str | xla_client.Client | None = None ) -> int: """Returns the integer process index of this process. @@ -938,7 +940,7 @@ def process_index( # TODO: remove this sometime after jax 0.2.13 is released -def host_id(backend: Optional[Union[str, xla_client.Client]] = None) -> int: +def host_id(backend: str | xla_client.Client | None = None) -> int: warnings.warn( "jax.host_id has been renamed to jax.process_index. This alias " "will eventually be removed; please update your code.") @@ -947,14 +949,14 @@ def host_id(backend: Optional[Union[str, xla_client.Client]] = None) -> int: @lru_cache def process_count( - backend: Optional[Union[str, xla_client.Client]] = None + backend: str | xla_client.Client | None = None ) -> int: """Returns the number of JAX processes associated with the backend.""" return max(d.process_index for d in devices(backend)) + 1 # TODO: remove this sometime after jax 0.2.13 is released -def host_count(backend: Optional[Union[str, xla_client.Client]] = None) -> int: +def host_count(backend: str | xla_client.Client | None = None) -> int: warnings.warn( "jax.host_count has been renamed to jax.process_count. This alias " "will eventually be removed; please update your code.") @@ -963,7 +965,7 @@ def host_count(backend: Optional[Union[str, xla_client.Client]] = None) -> int: # TODO: remove this sometime after jax 0.2.13 is released def host_ids( - backend: Optional[Union[str, xla_client.Client]] = None + backend: str | xla_client.Client | None = None ) -> list[int]: warnings.warn( "jax.host_ids has been deprecated; please use range(jax.process_count()) " diff --git a/jax/collect_profile.py b/jax/collect_profile.py index 8a31b63d3910..a7777085ce90 100644 --- a/jax/collect_profile.py +++ b/jax/collect_profile.py @@ -11,14 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import annotations + import argparse import gzip import os import pathlib import tempfile -from typing import Optional - # pytype: disable=import-error from jax._src import profiler as jax_profiler try: @@ -65,7 +66,7 @@ help="Profiler Python tracer level", type=int) def collect_profile(port: int, duration_in_ms: int, host: str, - log_dir: Optional[str], host_tracer_level: int, + log_dir: str | None, host_tracer_level: int, device_tracer_level: int, python_tracer_level: int, no_perfetto_link: bool): options = profiler.ProfilerOptions( diff --git a/jax/example_libraries/optimizers.py b/jax/example_libraries/optimizers.py index 82fd5eb71c7f..15537e54e3f3 100644 --- a/jax/example_libraries/optimizers.py +++ b/jax/example_libraries/optimizers.py @@ -89,7 +89,9 @@ def step(step, opt_state): .. _Optax: https://github.com/deepmind/optax """ -from typing import Any, Callable, NamedTuple, Union +from __future__ import annotations + +from typing import Any, Callable, NamedTuple from collections import namedtuple import functools @@ -550,7 +552,7 @@ def schedule(i): return values[jnp.sum(i > boundaries)] return schedule -def make_schedule(scalar_or_schedule: Union[float, Schedule]) -> Schedule: +def make_schedule(scalar_or_schedule: float | Schedule) -> Schedule: if callable(scalar_or_schedule): return scalar_or_schedule elif jnp.ndim(scalar_or_schedule) == 0: diff --git a/jax/experimental/array_api/_manipulation_functions.py b/jax/experimental/array_api/_manipulation_functions.py index d405b846c0ac..c5ccb39b08dd 100644 --- a/jax/experimental/array_api/_manipulation_functions.py +++ b/jax/experimental/array_api/_manipulation_functions.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import operator from typing import List, Optional, Tuple, Union @@ -30,7 +32,7 @@ def broadcast_to(x: Array, /, shape: tuple[int]) -> Array: return jax.numpy.broadcast_to(x, shape=shape) -def concat(arrays: Union[tuple[Array, ...], list[Array]], /, *, axis: Optional[int] = 0) -> Array: +def concat(arrays: tuple[Array, ...] | list[Array], /, *, axis: int | None = 0) -> Array: """Joins a sequence of arrays along an existing axis.""" dtype = _result_type(*arrays) if axis is None: @@ -46,7 +48,7 @@ def expand_dims(x: Array, /, *, axis: int = 0) -> Array: return jax.numpy.expand_dims(x, axis=axis) -def flip(x: Array, /, *, axis: Optional[Union[int, tuple[int, ...]]] = None) -> Array: +def flip(x: Array, /, *, axis: int | tuple[int, ...] | None = None) -> Array: """Reverses the order of elements in an array along the given axis.""" return jax.numpy.flip(x, axis=axis) @@ -56,24 +58,24 @@ def permute_dims(x: Array, /, axes: tuple[int, ...]) -> Array: return jax.lax.transpose(x, axes) -def reshape(x: Array, /, shape: tuple[int, ...], *, copy: Optional[bool] = None) -> Array: +def reshape(x: Array, /, shape: tuple[int, ...], *, copy: bool | None = None) -> Array: """Reshapes an array without changing its data.""" del copy # unused return jax.numpy.reshape(x, shape) -def roll(x: Array, /, shift: Union[int, tuple[int]], *, axis: Optional[Union[int, tuple[int, ...]]] = None) -> Array: +def roll(x: Array, /, shift: int | tuple[int], *, axis: int | tuple[int, ...] | None = None) -> Array: """Rolls array elements along a specified axis.""" return jax.numpy.roll(x, shift=shift, axis=axis) -def squeeze(x: Array, /, axis: Union[int, tuple[int, ...]]) -> Array: +def squeeze(x: Array, /, axis: int | tuple[int, ...]) -> Array: """Removes singleton dimensions (axes) from x.""" dimensions = axis if isinstance(axis, tuple) else (axis,) return jax.lax.squeeze(x, dimensions=dimensions) -def stack(arrays: Union[tuple[Array, ...], list[Array]], /, *, axis: int = 0) -> Array: +def stack(arrays: tuple[Array, ...] | list[Array], /, *, axis: int = 0) -> Array: """Joins a sequence of arrays along a new axis.""" dtype = _result_type(*arrays) return jax.numpy.stack(arrays, axis=axis, dtype=dtype) diff --git a/jax/experimental/array_serialization/serialization.py b/jax/experimental/array_serialization/serialization.py index 6fabfc3dd6b8..357a9a6bf03b 100644 --- a/jax/experimental/array_serialization/serialization.py +++ b/jax/experimental/array_serialization/serialization.py @@ -13,6 +13,8 @@ # limitations under the License. """Array serialization and deserialization.""" +from __future__ import annotations + import abc import asyncio from collections.abc import Awaitable, Sequence @@ -257,10 +259,10 @@ def estimate_read_memory_footprint(t: ts.TensorStore, async def async_deserialize( in_sharding: sharding_impls.XLACompatibleSharding, - tensorstore_spec: Union[ts.Spec, dict[str, Any]], - global_shape: Optional[Sequence[int]] = None, + tensorstore_spec: ts.Spec | dict[str, Any], + global_shape: Sequence[int] | None = None, dtype=None, - byte_limiter: Optional[_LimitInFlightBytes] = None, + byte_limiter: _LimitInFlightBytes | None = None, context=TS_CONTEXT, assume_metadata: bool = False, ): @@ -310,8 +312,8 @@ async def cb(index: array.Index, device: jax.Device): def run_deserialization(shardings: Sequence[sharding.Sharding], tensorstore_specs: Sequence[dict[str, Any]], - global_shapes: Optional[Sequence[array.Shape]] = None, - dtypes: Optional[Sequence[typing.DTypeLike]] = None, + global_shapes: Sequence[array.Shape] | None = None, + dtypes: Sequence[typing.DTypeLike] | None = None, concurrent_gb: int = 32): concurrent_bytes = concurrent_gb * 10**9 @@ -392,8 +394,8 @@ def serialize(self, arrays, tensorstore_specs, *, @abc.abstractmethod def deserialize(self, shardings: Sequence[sharding.Sharding], tensorstore_specs: Sequence[dict[str, Any]], - global_shapes: Optional[Sequence[array.Shape]] = None, - dtypes: Optional[Sequence[typing.DTypeLike]] = None): + global_shapes: Sequence[array.Shape] | None = None, + dtypes: Sequence[typing.DTypeLike] | None = None): """Deserializes GDAs from TensorStore.""" @@ -548,8 +550,8 @@ def serialize_with_paths(self, arrays: Sequence[jax.Array], def deserialize(self, shardings: Sequence[sharding.Sharding], tensorstore_specs: Sequence[dict[str, Any]], - global_shapes: Optional[Sequence[array.Shape]] = None, - dtypes: Optional[Sequence[typing.DTypeLike]] = None, + global_shapes: Sequence[array.Shape] | None = None, + dtypes: Sequence[typing.DTypeLike] | None = None, concurrent_gb: int = 32): self.wait_until_finished() return run_deserialization(shardings, tensorstore_specs, @@ -558,8 +560,8 @@ def deserialize(self, shardings: Sequence[sharding.Sharding], def deserialize_with_paths( self, shardings: Sequence[sharding.Sharding], paths: Sequence[str], - global_shapes: Optional[Sequence[array.Shape]] = None, - dtypes: Optional[Sequence[typing.DTypeLike]] = None, + global_shapes: Sequence[array.Shape] | None = None, + dtypes: Sequence[typing.DTypeLike] | None = None, concurrent_gb: int = 32): tspecs = jax.tree_map(get_tensorstore_spec, paths) return self.deserialize(shardings, tspecs, global_shapes, dtypes, diff --git a/jax/experimental/custom_partitioning.py b/jax/experimental/custom_partitioning.py index 36604ae80ca4..6209a37d9e26 100644 --- a/jax/experimental/custom_partitioning.py +++ b/jax/experimental/custom_partitioning.py @@ -12,9 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import jax +from __future__ import annotations + import inspect from typing import Optional +import weakref + +import jax from jax._src import core from jax import tree_util from jax._src import linear_util as lu @@ -32,8 +36,6 @@ from jax._src.api_util import flatten_fun_nokwargs from jax._src.api_util import argnums_partial -import weakref - def _resolve_kwargs(fun, args, kwargs): ba = inspect.signature(fun).bind(*args, **kwargs) @@ -494,7 +496,7 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values, return mlir.lower_fun( core.jaxpr_as_fun(call), multiple_results=True)(ctx, *values) - def to_mesh_pspec_sharding(hlo_sharding: Optional[xc.HloSharding], ndim): + def to_mesh_pspec_sharding(hlo_sharding: xc.HloSharding | None, ndim): if hlo_sharding is None: return hlo_sharding if mesh.empty or not decode_shardings: diff --git a/jax/experimental/export/export.py b/jax/experimental/export/export.py index 9491e474e99e..cf6df4478890 100644 --- a/jax/experimental/export/export.py +++ b/jax/experimental/export/export.py @@ -15,6 +15,8 @@ """ +from __future__ import annotations + from collections.abc import Sequence import copy import dataclasses @@ -68,7 +70,7 @@ class DisabledSafetyCheck: _impl: str @classmethod - def platform(cls) -> "DisabledSafetyCheck": + def platform(cls) -> DisabledSafetyCheck: """Allows the execution platform to differ from the serialization platform. Has effect only on deserialization. @@ -76,7 +78,7 @@ def platform(cls) -> "DisabledSafetyCheck": return DisabledSafetyCheck("platform") @classmethod - def custom_call(cls, target_name: str) -> "DisabledSafetyCheck": + def custom_call(cls, target_name: str) -> DisabledSafetyCheck: """Allows the serialization of a call target not known to be stable. Has effect only on serialization. @@ -86,7 +88,7 @@ def custom_call(cls, target_name: str) -> "DisabledSafetyCheck": return DisabledSafetyCheck(f"custom_call:{target_name}") @classmethod - def shape_assertions(cls) -> "DisabledSafetyCheck": + def shape_assertions(cls) -> DisabledSafetyCheck: """Allows invocations with shapes that do not meet the constraints. Has effect on serialization (to suppress the generation of the assertions) @@ -94,7 +96,7 @@ def shape_assertions(cls) -> "DisabledSafetyCheck": """ return DisabledSafetyCheck("shape_assertions") - def is_custom_call(self) -> Optional[str]: + def is_custom_call(self) -> str | None: """Returns the custom call target allowed by this directive.""" m = re.match(r'custom_call:(.+)$', self._impl) return m.group(1) if m else None @@ -125,7 +127,7 @@ def __hash__(self) -> int: LoweringSharding = Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue] # None means unspecified sharding -Sharding = Optional[xla_client.HloSharding] +Sharding = Union[xla_client.HloSharding, None] @dataclasses.dataclass(frozen=True) class Exported: @@ -288,7 +290,7 @@ class Exported: module_kept_var_idx: tuple[int, ...] uses_shape_polymorphism: bool - _get_vjp: Optional[Callable[["Exported"], "Exported"]] + _get_vjp: Callable[[Exported], Exported] | None def mlir_module(self) -> ir.Module: return xla_client._xla.mlir.deserialize_portable_artifact(self.mlir_module_serialized) @@ -301,7 +303,7 @@ def __str__(self): def has_vjp(self) -> bool: return self._get_vjp is not None - def vjp(self) -> "Exported": + def vjp(self) -> Exported: """Gets the exported VJP. Returns None if not available, which can happen if the Exported has been @@ -316,9 +318,9 @@ def default_lowering_platform() -> str: return xb.canonicalize_platform(jax.default_backend()) def symbolic_shape( - shape_spec: Optional[str], + shape_spec: str | None, *, - like: Optional[Sequence[Optional[int]]] = None) -> Shape: + like: Sequence[int | None] | None = None) -> Shape: """Constructs a jax.ShapeDtypeStruct with polymorphic shapes. Args: @@ -338,7 +340,7 @@ def symbolic_shape( """ return shape_poly.symbolic_shape(shape_spec, like=like) -def shape_and_dtype_jax_array(a) -> tuple[Sequence[Optional[int]], DType]: +def shape_and_dtype_jax_array(a) -> tuple[Sequence[int | None], DType]: """Returns the shape and dtype of a jax.Array.""" aval = core.raise_to_shaped(core.get_aval(a)) return aval.shape, aval.dtype @@ -404,7 +406,7 @@ def _keep_main_tokens(serialization_version: int) -> bool: def export(fun_jax: Callable, *, - lowering_platforms: Optional[Sequence[str]] = None, + lowering_platforms: Sequence[str] | None = None, disabled_checks: Sequence[DisabledSafetyCheck] = (), ) -> Callable[..., Exported]: """Exports native serialization for a JAX function. @@ -965,10 +967,10 @@ def canonical_shardings( device_assignment: Sequence[jax.Device], in_shardings: Sequence[Sharding], out_shardings: Sequence[Sharding] - ) -> tuple[Union[pxla.UnspecifiedValue, - Sequence[sharding.XLACompatibleSharding]], - Union[pxla.UnspecifiedValue, - Sequence[sharding.XLACompatibleSharding]]]: + ) -> tuple[(pxla.UnspecifiedValue | + Sequence[sharding.XLACompatibleSharding]), + (pxla.UnspecifiedValue | + Sequence[sharding.XLACompatibleSharding])]: """Prepares canonical in_ and out_shardings for a pjit invocation. The pjit front-end is picky about what in- and out-shardings it accepts, @@ -980,8 +982,8 @@ def canonical_shardings( """ replicated_s = sharding.GSPMDSharding.get_replicated(device_assignment) def canonicalize( - ss: Sequence[Sharding]) -> Union[pxla.UnspecifiedValue, - Sequence[sharding.XLACompatibleSharding]]: + ss: Sequence[Sharding]) -> (pxla.UnspecifiedValue | + Sequence[sharding.XLACompatibleSharding]): if all(s is None for s in ss): return sharding_impls.UNSPECIFIED return tuple( @@ -1108,7 +1110,7 @@ def _call_exported_abstract_eval( assert len(in_avals) == len(exported.in_avals) # since the pytrees have the same structure # Check that the expected shapes match the actual ones for arg_idx, (exp_aval, actual_aval) in enumerate(zip(exported.in_avals, in_avals)): - def pp_arg_dim(dim_idx: Optional[int]) -> str: + def pp_arg_dim(dim_idx: int | None) -> str: return shape_poly.pretty_print_dimension_descriptor(exported.in_tree, arg_idx, dim_idx) if len(exp_aval.shape) != len(actual_aval.shape): diff --git a/jax/experimental/export/serialization.py b/jax/experimental/export/serialization.py index 5548e58258d6..f8dfe399df5d 100644 --- a/jax/experimental/export/serialization.py +++ b/jax/experimental/export/serialization.py @@ -14,7 +14,8 @@ # Serialization and deserialization of export.Exported -from typing import Callable, Sequence, TypeVar +from typing import Callable, TypeVar +from collections.abc import Sequence try: import flatbuffers diff --git a/jax/experimental/export/serialization_generated.py b/jax/experimental/export/serialization_generated.py index 21eb5a6ce9ba..02b00f7880bc 100644 --- a/jax/experimental/export/serialization_generated.py +++ b/jax/experimental/export/serialization_generated.py @@ -20,7 +20,7 @@ from flatbuffers.compat import import_numpy np = import_numpy() -class PyTreeDefKind(object): +class PyTreeDefKind: leaf = 0 none = 1 tuple = 2 @@ -28,12 +28,12 @@ class PyTreeDefKind(object): dict = 4 -class AbstractValueKind(object): +class AbstractValueKind: shapedArray = 0 abstractToken = 1 -class DType(object): +class DType: bool = 0 i8 = 1 i16 = 2 @@ -58,18 +58,18 @@ class DType(object): f8_e5m2fnuz = 21 -class ShardingKind(object): +class ShardingKind: unspecified = 0 hlo_sharding = 1 -class DisabledSafetyCheckKind(object): +class DisabledSafetyCheckKind: platform = 0 custom_call = 1 shape_assertions = 2 -class PyTreeDef(object): +class PyTreeDef: __slots__ = ['_tab'] @classmethod @@ -161,7 +161,7 @@ def PyTreeDefEnd(builder): -class AbstractValue(object): +class AbstractValue: __slots__ = ['_tab'] @classmethod @@ -233,7 +233,7 @@ def AbstractValueEnd(builder): -class Sharding(object): +class Sharding: __slots__ = ['_tab'] @classmethod @@ -302,7 +302,7 @@ def ShardingEnd(builder): -class Effect(object): +class Effect: __slots__ = ['_tab'] @classmethod @@ -338,7 +338,7 @@ def EffectEnd(builder): -class DisabledSafetyCheck(object): +class DisabledSafetyCheck: __slots__ = ['_tab'] @classmethod @@ -384,7 +384,7 @@ def DisabledSafetyCheckEnd(builder): -class Exported(object): +class Exported: __slots__ = ['_tab'] @classmethod diff --git a/jax/experimental/export/shape_poly.py b/jax/experimental/export/shape_poly.py index ee1813c8cee8..4d06ec319341 100644 --- a/jax/experimental/export/shape_poly.py +++ b/jax/experimental/export/shape_poly.py @@ -31,6 +31,8 @@ [README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md). """ +from __future__ import annotations + import collections from collections.abc import Iterable, Sequence import dataclasses @@ -111,9 +113,9 @@ class _DimAtom: MOD = "mod" NON_NEGATIVE = "non_negative" # The max of the operand and 0 - def __init__(self, *operands: '_DimExpr', - var: Optional[str] = None, - operation: Optional[str] = None): + def __init__(self, *operands: _DimExpr, + var: str | None = None, + operation: str | None = None): if var is not None: assert operation is None assert not operands @@ -124,10 +126,10 @@ def __init__(self, *operands: '_DimExpr', self.operands = operands @classmethod - def from_var(cls, v: str) -> '_DimAtom': + def from_var(cls, v: str) -> _DimAtom: return _DimAtom(var=v) - def to_var(self) -> Optional[str]: + def to_var(self) -> str | None: return self.var def get_vars(self) -> set[str]: @@ -141,7 +143,7 @@ def get_vars(self) -> set[str]: return acc @classmethod - def from_operation(cls, operation: str, *operands: '_DimExpr') -> '_DimAtom': + def from_operation(cls, operation: str, *operands: _DimExpr) -> _DimAtom: return _DimAtom(*operands, operation=operation) def __str__(self): @@ -161,7 +163,7 @@ def __eq__(self, other: Any): if self.var is not None: return self.var == other.var else: - def symbolic_equal(e1: '_DimExpr', e2: '_DimExpr') -> bool: + def symbolic_equal(e1: _DimExpr, e2: _DimExpr) -> bool: try: return e1 == e2 except InconclusiveDimensionOperation: @@ -170,7 +172,7 @@ def symbolic_equal(e1: '_DimExpr', e2: '_DimExpr') -> bool: all(symbolic_equal(self_o, other_o) for self_o, other_o in zip(self.operands, other.operands))) - def __lt__(self, other: '_DimAtom'): + def __lt__(self, other: _DimAtom): """ Comparison to another atom in graded reverse lexicographic order. Used only for determining a sorting order, does not relate to the @@ -266,14 +268,14 @@ def __str__(self): for key, exponent in sorted(self.items())) @classmethod - def from_var(cls, v: str) -> '_DimMon': + def from_var(cls, v: str) -> _DimMon: return _DimMon({_DimAtom.from_var(v): 1}) @classmethod def from_atom(clscls, a: _DimAtom, aexp: int): return _DimMon({a: aexp}) - def to_var(self) -> Optional[str]: + def to_var(self) -> str | None: """Extract the variable name "x", from a monomial "x". Return None, if the monomial is not a single variable.""" items = self.items() @@ -292,14 +294,14 @@ def get_vars(self) -> set[str]: return acc @classmethod - def from_operation(cls, operation: str, *operands: '_DimExpr') -> '_DimMon': + def from_operation(cls, operation: str, *operands: _DimExpr) -> _DimMon: return _DimMon({_DimAtom.from_operation(operation, *operands): 1}) @property def degree(self): return sum(self.values()) - def __lt__(self, other: '_DimMon'): + def __lt__(self, other: _DimMon): """ Comparison to another monomial in graded reverse lexicographic order. Used only for determining a sorting order, does not relate to the @@ -309,13 +311,13 @@ def __lt__(self, other: '_DimMon'): other_key = -other.degree, tuple(sorted(other)) return self_key > other_key - def mul(self, other: '_DimMon') -> '_DimMon': + def mul(self, other: _DimMon) -> _DimMon: """ Returns the product with another monomial. Example: (n^2*m) * n == n^3 * m. """ return _DimMon(collections.Counter(self) + collections.Counter(other)) - def divide(self, divisor: '_DimMon') -> '_DimMon': + def divide(self, divisor: _DimMon) -> _DimMon: """ Divides by another monomial. Raises a InconclusiveDimensionOperation if the result is not a monomial. @@ -432,14 +434,14 @@ def from_monomial(cls, mon: _DimMon, exp: int): return _DimExpr.normalize({mon: exp}) @classmethod - def from_var(cls, v: str) -> '_DimExpr': + def from_var(cls, v: str) -> _DimExpr: return _DimExpr({_DimMon.from_var(v): 1}) @classmethod - def from_operation(cls, operation: str, *operands: '_DimExpr') -> '_DimExpr': + def from_operation(cls, operation: str, *operands: _DimExpr) -> _DimExpr: return _DimExpr.from_monomial(_DimMon.from_operation(operation, *operands), 1) - def to_var(self) -> Optional[str]: + def to_var(self) -> str | None: """Extract the variable name "x", from a symbolic expression.""" items = self.monomials() if len(items) != 1: # type: ignore @@ -552,7 +554,7 @@ def __rsub__(self, other): return self.__jax_array__().__rsub__(other) return _ensure_poly(other, "sub").__sub__(self) - def __neg__(self) -> '_DimExpr': + def __neg__(self) -> _DimExpr: return _DimExpr({mon: -coeff for mon, coeff in self.monomials()}) def __mul__(self, other): @@ -648,7 +650,7 @@ def __lt__(self, other: DimSize): except InconclusiveDimensionOperation as e: raise self.inconclusive_comparison("<", other) from e - def divmod(self, divisor: "_DimExpr") -> tuple[DimSize, int]: + def divmod(self, divisor: _DimExpr) -> tuple[DimSize, int]: """ Floor division with remainder (divmod) generalized to polynomials. If the `divisor` is not a constant, the remainder must be 0. @@ -742,11 +744,11 @@ def evaluate(self, env: DimVarEnv): for mon, coeff in self.monomials()] return functools.reduce(_evaluate_add, terms) if len(terms) > 1 else terms[0] - def non_negative(self) -> "_DimExpr": + def non_negative(self) -> _DimExpr: return _DimExpr.from_operation(_DimAtom.NON_NEGATIVE, self) @staticmethod - def get_aval(dim: "_DimExpr"): + def get_aval(dim: _DimExpr): return core.dim_value_aval() def dimension_as_value(self): @@ -773,10 +775,10 @@ class _Decomposition: def _decompose_expr(e: _DimExpr, operation: str, *, - with_factor: Optional[int] = None, - with_exp: Optional[int] = None, - with_rest_monomial: Optional[Union[_DimExpr, int]] = None, - with_rest_expr: Optional[Union[_DimExpr, int]] = None, + with_factor: int | None = None, + with_exp: int | None = None, + with_rest_monomial: _DimExpr | int | None = None, + with_rest_expr: _DimExpr | int | None = None, ) -> Iterable[_Decomposition]: """Computes the decompositions of `e` into `_Decomposition`. @@ -994,9 +996,9 @@ def __str__(self): return "(" + ", ".join(["..." if d is ... else str(d) for d in self]) + ")" -def symbolic_shape(shape_spec: Union[str, PolyShape, None], +def symbolic_shape(shape_spec: str | PolyShape | None, *, - like: Optional[Sequence[Optional[int]]] = None + like: Sequence[int | None] | None = None ) -> Sequence[DimSize]: """Parses the shape polymorphic specification into a symbolic shape. @@ -1028,7 +1030,7 @@ def symbolic_shape(shape_spec: Union[str, PolyShape, None], class _Parser: def __init__(self, shape_spec: str, - like_shape: Optional[Sequence[Optional[int]]], + like_shape: Sequence[int | None] | None, shape_spec_repr: str): self.shape_spec = shape_spec self.shape_spec_repr = shape_spec_repr # For error messages @@ -1043,7 +1045,7 @@ def parse(self) -> Sequence[DimSize]: self.expect_token(tok, [tokenize.ENDMARKER]) return sh - def add_dim(self, expr: Optional[DimSize], tok: tokenize.TokenInfo): + def add_dim(self, expr: DimSize | None, tok: tokenize.TokenInfo): if expr is None: raise self.parse_err(tok, ("unexpected placeholder for unknown dimension; " @@ -1057,7 +1059,7 @@ def add_dim(self, expr: Optional[DimSize], tok: tokenize.TokenInfo): f"like={self.like_shape}")) self.dimensions.append(expr) - def parse_err(self, tok: Optional[tokenize.TokenInfo], detail: str) -> Exception: + def parse_err(self, tok: tokenize.TokenInfo | None, detail: str) -> Exception: msg = ( f"syntax error in symbolic shape {self.shape_spec_repr} " f"in dimension {len(self.dimensions)}: {detail}. ") @@ -1289,7 +1291,7 @@ class Comparator(Enum): right: DimSize # `error_message_pieces` is a list of strings and DimSize. The error message # is formed by evaluating the DimSize and concatenating the sequence. - error_message_pieces: Sequence[Union[str, DimSize]] + error_message_pieces: Sequence[str | DimSize] def check_statically(self, eval: CachingShapeEvaluator) -> None: """Evaluates a constraint statically.""" @@ -1307,7 +1309,7 @@ def check_statically(self, eval: CachingShapeEvaluator) -> None: if not ok: raise self.make_error(eval) - def compute(self, eval: CachingShapeEvaluator) -> Optional[jax.Array]: + def compute(self, eval: CachingShapeEvaluator) -> jax.Array | None: """Computes if the constraint is satisfied. If the constraint can be resolved statically returns None @@ -1383,7 +1385,7 @@ def __init__(self): def add_constraint(self, comp: ShapeConstraint.Comparator, left: DimSize, right: DimSize, - error_message_pieces: Sequence[Union[str, DimSize]]): + error_message_pieces: Sequence[str | DimSize]): c = ShapeConstraint(comp, left, right, error_message_pieces) self.constraints.append(c) @@ -1445,7 +1447,7 @@ def _cached_pretty_print_dimension_descriptor( def pretty_print_dimension_descriptor( args_kwargs_tree: tree_util.PyTreeDef, - flat_arg_idx: int, dim_idx: Optional[int]) -> str: + flat_arg_idx: int, dim_idx: int | None) -> str: arg_str = _cached_pretty_print_dimension_descriptor(args_kwargs_tree, flat_arg_idx) if dim_idx is not None: arg_str += f".shape[{dim_idx}]" @@ -1551,7 +1553,7 @@ def _solve_dim_equations( # Returns a shape environment and the shape constraints if it can solve all # dimension variables. Raises an exception if it cannot. shapeenv: DimVarEnv = {} - solution_error_message_pieces: list[Union[str, _DimExpr]] = [ + solution_error_message_pieces: list[str | _DimExpr] = [ " Obtained dimension variables: " ] # Error message describing the solution # Prepare error message piece describing the polymorphic shape specs diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index 89f9bd6dea49..856ab266c24a 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -494,6 +494,9 @@ def power3_with_cotangents(x): * Explore implementation with XLA CustomCall for CPU and GPU. """ + +from __future__ import annotations + import atexit from collections.abc import Sequence import functools @@ -1699,7 +1702,7 @@ class _CallbackHandlerData: initialized: bool on_exit: bool lock: threading.Lock - last_callback_exception: Optional[tuple[Exception, str]] + last_callback_exception: tuple[Exception, str] | None clients: tuple[XlaLocalClient, ...] devices: tuple[XlaDevice, ...] consumer_registry: dict[Callable, int] @@ -1818,7 +1821,7 @@ def exit_handler(): _callback_handler_data.initialized = True -def barrier_wait(logging_name: Optional[str] = None): +def barrier_wait(logging_name: str | None = None): """Blocks the calling thread until all current outfeed is processed. Waits until all callbacks from computations already running on all devices diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index 6da5af5a9d71..f1f58b3a7f17 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -22,6 +22,9 @@ https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax. """ + +from __future__ import annotations + from collections.abc import Sequence import functools from typing import Any, Callable, Optional @@ -271,7 +274,7 @@ def fix_float0(arg_jax, ct_arg_jax): return util.wraps(callable_tf)(make_call) -def check_tf_result(idx: int, r_tf: TfVal, r_aval: Optional[core.ShapedArray]) -> TfVal: +def check_tf_result(idx: int, r_tf: TfVal, r_aval: core.ShapedArray | None) -> TfVal: # Check that the TF function returns values of expected types. This # improves error reporting, preventing hard-to-diagnose errors downstream try: diff --git a/jax/experimental/jax2tf/examples/mnist_lib.py b/jax/experimental/jax2tf/examples/mnist_lib.py index c11c5c4e3f3a..d69f55278dc0 100644 --- a/jax/experimental/jax2tf/examples/mnist_lib.py +++ b/jax/experimental/jax2tf/examples/mnist_lib.py @@ -18,6 +18,9 @@ See README.md for how these are used. """ + +from __future__ import annotations + from collections.abc import Sequence import functools import logging @@ -288,7 +291,7 @@ def plot_images(ds, nr_rows: int, nr_cols: int, title: str, - inference_fn: Optional[Callable] = None): + inference_fn: Callable | None = None): """Plots a grid of images with their predictions. Params: diff --git a/jax/experimental/jax2tf/examples/saved_model_lib.py b/jax/experimental/jax2tf/examples/saved_model_lib.py index b44b36112c72..b3c94b5e153a 100644 --- a/jax/experimental/jax2tf/examples/saved_model_lib.py +++ b/jax/experimental/jax2tf/examples/saved_model_lib.py @@ -24,6 +24,8 @@ customize this function as needed. """ +from __future__ import annotations + from collections.abc import Sequence from typing import Any, Callable, Optional, Union @@ -37,11 +39,11 @@ def convert_and_save_model( model_dir: str, *, input_signatures: Sequence[tf.TensorSpec], - polymorphic_shapes: Optional[Union[str, jax2tf.PolyShape]] = None, + polymorphic_shapes: str | jax2tf.PolyShape | None = None, with_gradient: bool = False, enable_xla: bool = True, compile_model: bool = True, - saved_model_options: Optional[tf.saved_model.SaveOptions] = None): + saved_model_options: tf.saved_model.SaveOptions | None = None): """Convert a JAX function and saves a SavedModel. This is an example, we do not promise backwards compatibility for this code. diff --git a/jax/experimental/jax2tf/impl_no_xla.py b/jax/experimental/jax2tf/impl_no_xla.py index a85a7d06d251..b449aeb7011b 100644 --- a/jax/experimental/jax2tf/impl_no_xla.py +++ b/jax/experimental/jax2tf/impl_no_xla.py @@ -13,6 +13,8 @@ # limitations under the License. """Workarounds for jax2tf transforms when XLA is not linked in.""" +from __future__ import annotations + import builtins from collections.abc import Sequence import dataclasses @@ -250,8 +252,8 @@ def _conv_general_dilated( lhs, rhs, *, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers: lax.ConvDimensionNumbers, feature_group_count: int, batch_group_count: int, - precision: Optional[tuple[PrecisionType, PrecisionType]], - preferred_element_type: Optional[DType], + precision: tuple[PrecisionType, PrecisionType] | None, + preferred_element_type: DType | None, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): """Implementation of lax.conv_general_dilated_p using XlaConv.""" # In presence of shape polymorphism, lhs.shape and rhs.shape may contain @@ -360,8 +362,8 @@ def _conv_general_dilated( def _dot_general(lhs, rhs, *, dimension_numbers, - precision: Optional[tuple[PrecisionType, PrecisionType]], - preferred_element_type: Optional[DType], + precision: tuple[PrecisionType, PrecisionType] | None, + preferred_element_type: DType | None, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): """Implementation of lax.dot_general_p in terms of tf.linalg.einsum.""" diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index e8f6044fb27c..72502d153c7b 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -13,6 +13,8 @@ # limitations under the License. """Provides JAX and TensorFlow interoperation APIs.""" +from __future__ import annotations + from collections.abc import Iterable, Sequence from functools import partial import contextlib @@ -204,11 +206,11 @@ def __init__(self): # A dict collecting all tf concrete_functions called by stablehlo.custom_call # This is used only by native serialization (unlike all the other # thread-local state). - self.call_tf_concrete_function_list: Optional[list[Any]] = None + self.call_tf_concrete_function_list: list[Any] | None = None _thread_local_state = _ThreadLocalState() -def _get_current_name_stack() -> Union[NameStack, str]: +def _get_current_name_stack() -> NameStack | str: return source_info_util.current_name_stack() @contextlib.contextmanager @@ -223,7 +225,7 @@ def inside_call_tf(): def get_thread_local_state_call_tf_concrete_function_list() -> ( - Optional[list[Any]] + list[Any] | None ): return _thread_local_state.call_tf_concrete_function_list @@ -231,11 +233,11 @@ def get_thread_local_state_call_tf_concrete_function_list() -> ( @partial(api_util.api_hook, tag="jax2tf_convert") def convert(fun_jax: Callable, *, - polymorphic_shapes: Optional[str] = None, + polymorphic_shapes: str | None = None, with_gradient: bool = True, enable_xla: bool = True, - native_serialization: Union[bool, _DefaultNativeSerialization] = DEFAULT_NATIVE_SERIALIZATION, - native_serialization_platforms: Optional[Sequence[str]] = None, + native_serialization: bool | _DefaultNativeSerialization = DEFAULT_NATIVE_SERIALIZATION, + native_serialization_platforms: Sequence[str] | None = None, native_serialization_disabled_checks: Sequence[DisabledSafetyCheck] = (), ) -> Callable: """Allows calling a JAX function from a TensorFlow program. @@ -360,7 +362,7 @@ def converted_fun_tf(*args_tf: TfVal, **kwargs_tf: TfVal) -> TfVal: source_info_util.register_exclusion(os.path.dirname(tf.__file__)) _has_registered_tf_source_path = True - def shape_and_dtype_tf(a: TfVal) -> tuple[Sequence[Optional[int]], DType]: + def shape_and_dtype_tf(a: TfVal) -> tuple[Sequence[int | None], DType]: # The shape and JAX dtype for a TF argument tf_arg_shape = np.shape(a) # Fix the shape for TF1 @@ -476,7 +478,7 @@ def get_vjp_fun(self) -> tuple[Callable, class NativeSerializationImpl(SerializationImpl): def __init__(self, fun_jax, *, args_specs, kwargs_specs, - native_serialization_platforms: Optional[Sequence[str]], + native_serialization_platforms: Sequence[str] | None, native_serialization_disabled_checks: Sequence[DisabledSafetyCheck]): self.convert_kwargs = dict(native_serialization=True, native_serialization_platforms=native_serialization_platforms, @@ -773,7 +775,7 @@ def fix_out_ct(out_ct_tf, out_ct_aval: core.ShapedArray, out_tf: TfVal): return grad_fn_tf @contextlib.contextmanager -def _extended_name_stack(extra_name_stack: Optional[str]): +def _extended_name_stack(extra_name_stack: str | None): name_ctx = (source_info_util.extend_name_stack(extra_name_stack) if extra_name_stack else contextlib.nullcontext()) @@ -786,7 +788,7 @@ def _interpret_fun_jax( fun_jax: Callable, args_tf: Sequence[TfVal], args_avals: Sequence[core.ShapedArray], - extra_name_stack: Optional[str], + extra_name_stack: str | None, fresh_constant_cache: bool = False, ) -> tuple[tuple[TfVal, ...], tuple[core.ShapedArray, ...]]: with core.new_base_main(TensorFlowTrace) as main: # type: ignore @@ -943,7 +945,7 @@ def _call_wrapped_with_new_constant_cache(fun: lu.WrappedFun, def _convert_jax_impl(impl_jax: Callable, *, multiple_results=True, with_physical_avals=False, - extra_name_stack: Optional[str] = None) -> Callable: + extra_name_stack: str | None = None) -> Callable: """Convert the JAX implementation of a primitive. Args: @@ -997,7 +999,7 @@ def _interpret_subtrace(main: core.MainTrace, def _interpret_jaxpr(jaxpr: core.ClosedJaxpr, *args_tf: TfVal, - extra_name_stack: Optional[str], + extra_name_stack: str | None, fresh_constant_cache: bool = True) -> Sequence[TfVal]: """Evaluates a Jaxpr with tf.Tensor arguments. @@ -1032,7 +1034,7 @@ def _jax_physical_dtype(dtype): return _jax_physical_aval(core.ShapedArray((), dtype)).dtype -def _aval_to_tf_shape(aval: core.ShapedArray) -> tuple[Optional[int], ...]: +def _aval_to_tf_shape(aval: core.ShapedArray) -> tuple[int | None, ...]: """Generate a TF shape, possibly containing None for polymorphic dimensions.""" aval = _jax_physical_aval(aval) @@ -1068,7 +1070,7 @@ def _to_jax_dtype(tf_dtype): def _tfval_to_tensor_jax_dtype(val: TfVal, - jax_dtype: Optional[DType] = None, + jax_dtype: DType | None = None, memoize_constants=False) -> tuple[TfVal, DType]: """Converts a scalar, ndarray, or tf.Tensor to a tf.Tensor with proper type. @@ -1155,7 +1157,7 @@ def _assert_matching_abstract_shape(x: TfVal, shape: Sequence[shape_poly.DimSize """Asserts that shape matches x.shape in the known dimensions and has dimension polynomials elsewhere.""" # Ensures that the shape does not contain None; it should contain symbolic expressions. - def check_one(xd: Optional[int], sd: Any): + def check_one(xd: int | None, sd: Any): if core.is_constant_dim(sd): return xd == sd else: @@ -1190,7 +1192,7 @@ class TensorFlowTracer(core.Tracer): # _aval: core.ShapedArray __slots__ = ["val", "_aval"] - def __init__(self, trace: "TensorFlowTrace", val: TfVal, + def __init__(self, trace: TensorFlowTrace, val: TfVal, aval: core.AbstractValue): self._trace = trace self._aval = aval @@ -2035,8 +2037,8 @@ def _conv_general_dimension_numbers_proto(dimension_numbers): return proto -def _precision_config_proto(precision: Optional[tuple[PrecisionType, - PrecisionType]]): +def _precision_config_proto(precision: None | (tuple[PrecisionType, + PrecisionType])): """Convert an integer to an XLA.PrecisionConfig.""" if precision is None: return None @@ -2053,8 +2055,8 @@ def _conv_general_dilated(lhs, rhs, *, dimension_numbers: lax.ConvDimensionNumbers, feature_group_count: int, batch_group_count: int, - precision: Optional[tuple[PrecisionType, PrecisionType]], - preferred_element_type: Optional[DType], + precision: tuple[PrecisionType, PrecisionType] | None, + preferred_element_type: DType | None, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): """Implementation of lax.conv_general_dilated_p using XlaConv.""" @@ -2062,7 +2064,7 @@ def _conv_general_dilated(lhs, rhs, *, dnums_proto = _conv_general_dimension_numbers_proto(dimension_numbers) precision_config_proto = _precision_config_proto(precision) - def gen_conv(lhs, rhs, preferred_element_type: Optional[DType]): + def gen_conv(lhs, rhs, preferred_element_type: DType | None): tf_version = tuple(int(v) for v in tf.__version__.split(".")[:2]) if tf_version >= (2, 8): # TODO(necula): remove when 2.8.0 is the stable TF version (and supports @@ -2098,7 +2100,7 @@ def gen_conv(lhs, rhs, preferred_element_type: Optional[DType]): # Follow the lowering for complex convolutions from # lax._conv_general_dilated_translation. We can use the same conversion on all # platforms because on XLA:TPU the compiler does the same as a rewrite. - preferred_float_et: Optional[Any] + preferred_float_et: Any | None if np.issubdtype(_in_avals[0].dtype, np.complexfloating): if preferred_element_type is not None: # Convert complex dtype to types used for real and imaginary parts @@ -2122,8 +2124,8 @@ def gen_conv(lhs, rhs, preferred_element_type: Optional[DType]): def _dot_general(lhs, rhs, *, dimension_numbers, - precision: Optional[tuple[PrecisionType, PrecisionType]], - preferred_element_type: Optional[DType], + precision: tuple[PrecisionType, PrecisionType] | None, + preferred_element_type: DType | None, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): """Implementation of lax.dot_general_p in terms of tf.linalg.einsum.""" @@ -3381,7 +3383,7 @@ def _custom_lin(*args: TfVal, **_) -> Sequence[TfVal]: tf_impl[ad.custom_lin_p] = _custom_lin -PartitionsOrReplicated = Optional[tuple[int, ...]] +PartitionsOrReplicated = Union[tuple[int, ...], None] def split_to_logical_devices(tensor: TfVal, partition_dimensions: PartitionsOrReplicated): @@ -3414,13 +3416,13 @@ def split_to_logical_devices(tensor: TfVal, def _xla_compatible_sharding_to_hlo_sharding( s: sharding.XLACompatibleSharding, - aval: core.ShapedArray) -> Optional[xla_client.HloSharding]: + aval: core.ShapedArray) -> xla_client.HloSharding | None: if sharding_impls.is_unspecified(s): return None return s._to_xla_hlo_sharding(aval.ndim) # type: ignore[union-attr] def _shard_value(val: TfVal, - sd: Optional[xla_client.HloSharding], *, + sd: xla_client.HloSharding | None, *, skip_replicated_sharding: bool) -> TfVal: """Apply sharding to a TfVal.""" if sd is None: @@ -3476,7 +3478,7 @@ def _pjit(*args: TfVal, _out_aval: Sequence[core.ShapedArray]) -> TfVal: del donated_invars # Apply sharding annotation to the arguments - in_hlo_shardings: Sequence[Optional[xla_client.HloSharding]] = map( + in_hlo_shardings: Sequence[xla_client.HloSharding | None] = map( _xla_compatible_sharding_to_hlo_sharding, in_shardings, _in_avals) sharded_args: Sequence[TfVal] = tuple( map(partial(_shard_value, @@ -3485,7 +3487,7 @@ def _pjit(*args: TfVal, results = _interpret_jaxpr(jaxpr, *sharded_args, extra_name_stack=util.wrap_name(name, "pjit"), fresh_constant_cache=False) - out_hlo_shardings: Sequence[Optional[xla_client.HloSharding]] = map( + out_hlo_shardings: Sequence[xla_client.HloSharding | None] = map( _xla_compatible_sharding_to_hlo_sharding, out_shardings, _out_aval) sharded_results: Sequence[TfVal] = tuple( map(partial(_shard_value, diff --git a/jax/experimental/jax2tf/tests/back_compat_tf_test.py b/jax/experimental/jax2tf/tests/back_compat_tf_test.py index 7f7d13f0c92e..47c5c8360cf5 100644 --- a/jax/experimental/jax2tf/tests/back_compat_tf_test.py +++ b/jax/experimental/jax2tf/tests/back_compat_tf_test.py @@ -17,6 +17,8 @@ these tests. """ +from __future__ import annotations + import base64 from collections.abc import Sequence import io @@ -82,7 +84,7 @@ def serialize( self, func: Callable, data: bctu.CompatTestData, - polymorphic_shapes: Optional[Sequence[str]] = None, + polymorphic_shapes: Sequence[str] | None = None, allow_unstable_custom_call_targets: Sequence[str] = (), ): # We serialize as a tf.Graph @@ -125,7 +127,7 @@ def serialize( def run_serialized( self, data: bctu.CompatTestData, - polymorphic_shapes: Optional[Sequence[str]] = None, + polymorphic_shapes: Sequence[str] | None = None, ): root_dir = self.create_tempdir() deserialize_directory(data.mlir_module_serialized, root_dir) diff --git a/jax/experimental/jax2tf/tests/cross_compilation_check.py b/jax/experimental/jax2tf/tests/cross_compilation_check.py index 9b612809d262..63e8928ee371 100644 --- a/jax/experimental/jax2tf/tests/cross_compilation_check.py +++ b/jax/experimental/jax2tf/tests/cross_compilation_check.py @@ -23,6 +23,9 @@ currently saved file with the saved one. """ + +from __future__ import annotations + from collections.abc import Sequence import contextlib import dataclasses @@ -173,7 +176,7 @@ def write_and_check_harness(harness: primitive_harness.Harness, def write_and_check_harnesses(io: Io, save_directory: str, *, - filter_harness: Optional[Callable[[str], bool]] = None, + filter_harness: Callable[[str], bool] | None = None, for_platforms: Sequence[str] = ("cpu", "tpu"), verbose = False): logging.info("Writing and checking harnesses at %s", save_directory) diff --git a/jax/experimental/jax2tf/tests/flax_models/bilstm_classifier.py b/jax/experimental/jax2tf/tests/flax_models/bilstm_classifier.py index fdd16b642cff..9bb7466125c5 100644 --- a/jax/experimental/jax2tf/tests/flax_models/bilstm_classifier.py +++ b/jax/experimental/jax2tf/tests/flax_models/bilstm_classifier.py @@ -16,6 +16,8 @@ https://github.com/google/flax/tree/main/examples/sst2 """ +from __future__ import annotations + import functools from typing import Any, Callable, Optional @@ -88,10 +90,10 @@ class WordDropout(nn.Module): """ dropout_rate: float unk_idx: int - deterministic: Optional[bool] = None + deterministic: bool | None = None @nn.compact - def __call__(self, inputs: Array, deterministic: Optional[bool] = None): + def __call__(self, inputs: Array, deterministic: bool | None = None): deterministic = nn.module.merge_param( 'deterministic', self.deterministic, deterministic) if deterministic or self.dropout_rate == 0.: @@ -120,8 +122,8 @@ class Embedder(nn.Module): frozen: bool = False dropout_rate: float = 0. word_dropout_rate: float = 0. - unk_idx: Optional[int] = None - deterministic: Optional[bool] = None + unk_idx: int | None = None + deterministic: bool | None = None dtype: jnp.dtype = jnp.dtype('float32') def setup(self): @@ -137,7 +139,7 @@ def setup(self): unk_idx=self.unk_idx) def __call__(self, inputs: Array, - deterministic: Optional[bool] = None) -> Array: + deterministic: bool | None = None) -> Array: """Embeds the input sequences and applies word dropout and dropout. Args: @@ -222,14 +224,14 @@ class MLP(nn.Module): activation: Callable[..., Any] = nn.tanh dropout_rate: float = 0.0 output_bias: bool = False - deterministic: Optional[bool] = None + deterministic: bool | None = None def setup(self): self.intermediate_layer = nn.Dense(self.hidden_size) self.output_layer = nn.Dense(self.output_size, use_bias=self.output_bias) self.dropout_layer = nn.Dropout(rate=self.dropout_rate) - def __call__(self, inputs: Array, deterministic: Optional[bool] = None): + def __call__(self, inputs: Array, deterministic: bool | None = None): """Applies the MLP to the last dimension of the inputs. Args: @@ -306,7 +308,7 @@ class AttentionClassifier(nn.Module): hidden_size: int output_size: int dropout_rate: float = 0. - deterministic: Optional[bool] = None + deterministic: bool | None = None def setup(self): self.dropout_layer = nn.Dropout(rate=self.dropout_rate) @@ -319,7 +321,7 @@ def setup(self): dropout_rate=self.dropout_rate) def __call__(self, encoded_inputs: Array, lengths: Array, - deterministic: Optional[bool] = None) -> Array: + deterministic: bool | None = None) -> Array: """Applies model to the encoded inputs. Args: @@ -362,7 +364,7 @@ class TextClassifier(nn.Module): dropout_rate: float word_dropout_rate: float unk_idx: int = 1 - deterministic: Optional[bool] = None + deterministic: bool | None = None def setup(self): self.embedder = Embedder( @@ -378,14 +380,14 @@ def setup(self): dropout_rate=self.dropout_rate) def embed_token_ids(self, token_ids: Array, - deterministic: Optional[bool] = None) -> Array: + deterministic: bool | None = None) -> Array: deterministic = nn.module.merge_param( 'deterministic', self.deterministic, deterministic) return self.embedder(token_ids, deterministic=deterministic) def logits_from_embedded_inputs( self, embedded_inputs: Array, lengths: Array, - deterministic: Optional[bool] = None) -> Array: + deterministic: bool | None = None) -> Array: deterministic = nn.module.merge_param( 'deterministic', self.deterministic, deterministic) encoded_inputs = self.encoder(embedded_inputs, lengths) @@ -393,7 +395,7 @@ def logits_from_embedded_inputs( encoded_inputs, lengths, deterministic=deterministic) def __call__(self, token_ids: Array, lengths: Array, - deterministic: Optional[bool] = None) -> Array: + deterministic: bool | None = None) -> Array: """Embeds the token IDs, encodes them, and classifies with attention.""" embedded_inputs = self.embed_token_ids( token_ids, deterministic=deterministic) diff --git a/jax/experimental/jax2tf/tests/flax_models/transformer_lm1b.py b/jax/experimental/jax2tf/tests/flax_models/transformer_lm1b.py index 4ccdd3ab9e0f..334248219962 100644 --- a/jax/experimental/jax2tf/tests/flax_models/transformer_lm1b.py +++ b/jax/experimental/jax2tf/tests/flax_models/transformer_lm1b.py @@ -22,6 +22,8 @@ # pytype: disable=wrong-keyword-args # pytype: disable=attribute-error +from __future__ import annotations + from typing import Callable, Any, Optional from flax import linen as nn @@ -50,7 +52,7 @@ class TransformerConfig: decode: bool = False kernel_init: Callable = nn.initializers.xavier_uniform() bias_init: Callable = nn.initializers.normal(stddev=1e-6) - posemb_init: Optional[Callable] = None + posemb_init: Callable | None = None def shift_right(x, axis=1): @@ -173,7 +175,7 @@ class MlpBlock(nn.Module): out_dim: optionally specify out dimension. """ config: TransformerConfig - out_dim: Optional[int] = None + out_dim: int | None = None @nn.compact def __call__(self, inputs): diff --git a/jax/experimental/jax2tf/tests/flax_models/transformer_nlp_seq.py b/jax/experimental/jax2tf/tests/flax_models/transformer_nlp_seq.py index 5fd9a0d40dc5..04111e6c4d5b 100644 --- a/jax/experimental/jax2tf/tests/flax_models/transformer_nlp_seq.py +++ b/jax/experimental/jax2tf/tests/flax_models/transformer_nlp_seq.py @@ -16,6 +16,8 @@ https://github.com/google/flax/tree/main/examples/lm1b """ +from __future__ import annotations + from typing import Callable, Any, Optional from flax import linen as nn @@ -40,7 +42,7 @@ class TransformerConfig: attention_dropout_rate: float = 0.3 kernel_init: Callable = nn.initializers.xavier_uniform() bias_init: Callable = nn.initializers.normal(stddev=1e-6) - posemb_init: Optional[Callable] = None + posemb_init: Callable | None = None def sinusoidal_init(max_len=2048): @@ -117,7 +119,7 @@ class MlpBlock(nn.Module): out_dim: optionally specify out dimension. """ config: TransformerConfig - out_dim: Optional[int] = None + out_dim: int | None = None @nn.compact def __call__(self, inputs, deterministic=True): diff --git a/jax/experimental/jax2tf/tests/flax_models/transformer_wmt.py b/jax/experimental/jax2tf/tests/flax_models/transformer_wmt.py index b620bade14b4..58e50dacd914 100644 --- a/jax/experimental/jax2tf/tests/flax_models/transformer_wmt.py +++ b/jax/experimental/jax2tf/tests/flax_models/transformer_wmt.py @@ -22,6 +22,8 @@ # pytype: disable=wrong-keyword-args # pytype: disable=attribute-error +from __future__ import annotations + from typing import Callable, Any, Optional from flax import linen as nn @@ -51,7 +53,7 @@ class TransformerConfig: decode: bool = False kernel_init: Callable = nn.initializers.xavier_uniform() bias_init: Callable = nn.initializers.normal(stddev=1e-6) - posemb_init: Optional[Callable] = None + posemb_init: Callable | None = None def shift_right(x, axis=1): @@ -164,7 +166,7 @@ class MlpBlock(nn.Module): out_dim: optionally specify out dimension. """ config: TransformerConfig - out_dim: Optional[int] = None + out_dim: int | None = None @nn.compact def __call__(self, inputs): diff --git a/jax/experimental/jax2tf/tests/jax2tf_limitations.py b/jax/experimental/jax2tf/tests/jax2tf_limitations.py index c1d4c11ace52..1d69d454e256 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_limitations.py +++ b/jax/experimental/jax2tf/tests/jax2tf_limitations.py @@ -13,6 +13,8 @@ # limitations under the License. """See primitives_test docstring for how the Jax2TfLimitations are used.""" +from __future__ import annotations + from collections.abc import Sequence import itertools from typing import Any, Callable, Optional, Union @@ -43,7 +45,7 @@ def __init__( self, description: str, *, - devices: Union[str, Sequence[str]] = ("cpu", "gpu", "tpu"), + devices: str | Sequence[str] = ("cpu", "gpu", "tpu"), dtypes: Sequence[DType] = (), enabled: bool = True, # jax2tf specific @@ -52,7 +54,7 @@ def __init__( skip_tf_run=False, expect_tf_error: bool = True, skip_comparison=False, - custom_assert: Optional[Callable] = None, + custom_assert: Callable | None = None, tol=None): """See the test_harnesses.Limitation common arguments. @@ -92,8 +94,8 @@ def __init__( self.skip_comparison = skip_comparison def get_max_tolerance_limitation( - self, limitations: Sequence["Jax2TfLimitation"] - ) -> Optional["Jax2TfLimitation"]: + self, limitations: Sequence[Jax2TfLimitation] + ) -> Jax2TfLimitation | None: """Pick the tolerance limitation that establishes the maximum tolerance.""" # TODO: it would be best if the limitations with tolerance are mutually exclusive # and we don't have to compute the maximum @@ -108,9 +110,9 @@ def get_max_tolerance_limitation( def filter( # type: ignore[override] self, - dtype: Optional[DType] = None, - device: Optional[str] = None, - mode: Optional[str] = None) -> bool: + dtype: DType | None = None, + device: str | None = None, + mode: str | None = None) -> bool: """Checks if this limitation is enabled for dtype and device and mode.""" native_serialization_mask = ( Jax2TfLimitation.FOR_NATIVE @@ -122,7 +124,7 @@ def filter( # type: ignore[override] @classmethod def limitations_for_harness( - cls, harness: test_harnesses.Harness) -> Sequence["Jax2TfLimitation"]: + cls, harness: test_harnesses.Harness) -> Sequence[Jax2TfLimitation]: group_method = getattr(cls, harness.group_name, None) if harness.group_name in cls.harness_groups_no_limitations: assert group_method is None, ( diff --git a/jax/experimental/jax2tf/tests/model_harness.py b/jax/experimental/jax2tf/tests/model_harness.py index 5f29a7f8d2ae..9af7229c0530 100644 --- a/jax/experimental/jax2tf/tests/model_harness.py +++ b/jax/experimental/jax2tf/tests/model_harness.py @@ -13,6 +13,8 @@ # limitations under the License. """All the models to convert.""" +from __future__ import annotations + from collections.abc import Sequence import dataclasses import functools @@ -46,8 +48,8 @@ class ModelHarness: variables: dict[str, Any] inputs: Sequence[np.ndarray] rtol: float = 1e-4 - polymorphic_shapes: Optional[Sequence[Union[str, None]]] = None - tensor_spec: Optional[Sequence[tf.TensorSpec]] = None + polymorphic_shapes: Sequence[str | None] | None = None + tensor_spec: Sequence[tf.TensorSpec] | None = None def __post_init__(self): # When providing polymorphic shapes, tensor_spec should be provided as well. diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index 4635f0b2a501..cb7d395ed4fe 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -13,6 +13,8 @@ # limitations under the License. """Tests for the shape-polymorphic jax2tf conversion.""" +from __future__ import annotations + from collections.abc import Sequence import contextlib import math @@ -80,14 +82,14 @@ def __init__(self, fun: Callable, *, arg_descriptors: Sequence[test_harnesses.ArgDescriptor] = (), - polymorphic_shapes: Sequence[Optional[str]] = (), - input_signature: Optional[Sequence[tf.TensorSpec]] = None, - expected_output_signature: Optional[tf.TensorSpec] = None, + polymorphic_shapes: Sequence[str | None] = (), + input_signature: Sequence[tf.TensorSpec] | None = None, + expected_output_signature: tf.TensorSpec | None = None, enable_xla: bool = True, - expect_error: tuple[Optional[Any], Optional[str]] = (None, None), + expect_error: tuple[Any | None, str | None] = (None, None), skip_jax_run: bool = False, check_result: bool = True, - tol: Optional[float] = None, + tol: float | None = None, limitations: Sequence[Jax2TfLimitation] = (), override_jax_config_flags: dict[str, Any] = {}): """Args: @@ -129,7 +131,7 @@ def __init__(self, self.override_jax_config_flags = override_jax_config_flags # Replicate the harness for both enable and disable xla - def both_enable_and_disable_xla(self) -> tuple["PolyHarness", "PolyHarness"]: + def both_enable_and_disable_xla(self) -> tuple[PolyHarness, PolyHarness]: assert self.enable_xla other = PolyHarness(self.group_name, f"{self.name}_enable_xla_False", @@ -144,7 +146,7 @@ def both_enable_and_disable_xla(self) -> tuple["PolyHarness", "PolyHarness"]: self.name = f"{self.name}_enable_xla_True" return (self, other) - def run_test(self, tst: tf_test_util.JaxToTfTestCase) -> Optional[jax.Array]: + def run_test(self, tst: tf_test_util.JaxToTfTestCase) -> jax.Array | None: def log_message(extra: str): return f"[{tst._testMethodName}]: {extra}" @@ -243,10 +245,10 @@ def log_message(extra: str): def check_shape_poly(tst, f_jax: Callable, *, arg_descriptors: Sequence[test_harnesses.ArgDescriptor] = (), skip_jax_run: bool = False, - polymorphic_shapes: Sequence[Optional[str]] = (), - input_signature: Optional[Sequence[tf.TensorSpec]] = None, - expected_output_signature: Optional[tf.TensorSpec] = None, - expect_error=(None, None)) -> Optional[jax.Array]: + polymorphic_shapes: Sequence[str | None] = (), + input_signature: Sequence[tf.TensorSpec] | None = None, + expected_output_signature: tf.TensorSpec | None = None, + expect_error=(None, None)) -> jax.Array | None: # Makes and tests a harness. See PolyHarness documentation. h = PolyHarness("", "", f_jax, arg_descriptors=arg_descriptors, @@ -427,10 +429,10 @@ def f_jax(x, *, y): def test_arg_avals_non_native(self): """Test conversion of actual arguments to abstract values.""" - def check_avals(*, arg_shapes: Sequence[Sequence[Optional[int]]], - polymorphic_shapes: Sequence[Optional[Union[str, PS]]], - expected_avals: Optional[Sequence[core.ShapedArray]] = None, - expected_shapeenv: Optional[dict[str, int]] = None, + def check_avals(*, arg_shapes: Sequence[Sequence[int | None]], + polymorphic_shapes: Sequence[str | PS | None], + expected_avals: Sequence[core.ShapedArray] | None = None, + expected_shapeenv: dict[str, int] | None = None, eager_mode: bool = False): # Use eager mode only for when all arg_shapes are known, in order to # check expected_shapeenv. @@ -637,7 +639,7 @@ def conv_and_run(*, arg_shape: core.Shape, )), ]) def test_shape_constraints_errors(self, *, - shape, poly_spec: str, expect_error: Optional[str] = None): + shape, poly_spec: str, expect_error: str | None = None): def f_jax(x): # x: f32[a + 2*b, a, a + b + c] return 0. @@ -2622,7 +2624,7 @@ def _make_vmap_primitive_harnesses() -> Sequence[PolyHarness]: continue def make_batched_arg_descriptor( - ad: test_harnesses.ArgDescriptor) -> Optional[test_harnesses.ArgDescriptor]: + ad: test_harnesses.ArgDescriptor) -> test_harnesses.ArgDescriptor | None: if isinstance(ad, RandArg): return RandArg((batch_size,) + ad.shape, ad.dtype) elif isinstance(ad, CustomArg): diff --git a/jax/experimental/jax2tf/tests/tf_test_util.py b/jax/experimental/jax2tf/tests/tf_test_util.py index a27a2668a50c..9bfcd37fe7d6 100644 --- a/jax/experimental/jax2tf/tests/tf_test_util.py +++ b/jax/experimental/jax2tf/tests/tf_test_util.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from collections.abc import Sequence import contextlib import dataclasses -import functools import re import os from typing import Any, Callable, Optional @@ -91,8 +92,8 @@ def SaveAndLoadModel(model: tf.Module, return restored_model def SaveAndLoadFunction(f_tf: Callable, *, - input_signature: Optional[Sequence[tf.TensorSpec]] = None, - input_args: Optional[Sequence[Any]] = None, + input_signature: Sequence[tf.TensorSpec] | None = None, + input_args: Sequence[Any] | None = None, variables: Sequence[tf.Variable] = (), save_gradients=True) -> tuple[Callable, tf.train.Checkpoint]: # Roundtrip through saved model on disk. Return the Checkpoint also @@ -365,7 +366,7 @@ def log_message(extra): return result_jax, result_tf def TransformConvertAndCompare(self, func: Callable, arg, - transform: Optional[str]): + transform: str | None): """Like ConvertAndCompare but first applies a transformation. `func` must be a function from one argument to one result. `arg` is diff --git a/jax/experimental/key_reuse/_common.py b/jax/experimental/key_reuse/_common.py index f5aeb9eec226..cff87cf0e9d1 100644 --- a/jax/experimental/key_reuse/_common.py +++ b/jax/experimental/key_reuse/_common.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import NamedTuple, Union +from __future__ import annotations + +from typing import NamedTuple from jax import core from jax.interpreters import batching, mlir import numpy as np @@ -20,12 +22,12 @@ class Sink(NamedTuple): idx: int - mask: Union[bool, np.ndarray] = True + mask: bool | np.ndarray = True class Source(NamedTuple): idx: int - mask: Union[bool, np.ndarray] = True + mask: bool | np.ndarray = True class KeyReuseSignature(NamedTuple): diff --git a/jax/experimental/key_reuse/_core.py b/jax/experimental/key_reuse/_core.py index a6c1409c9045..f6a778161047 100644 --- a/jax/experimental/key_reuse/_core.py +++ b/jax/experimental/key_reuse/_core.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Union +from __future__ import annotations + +from typing import Any, Callable from jax import core from jax.experimental.key_reuse import _forwarding @@ -39,7 +41,7 @@ def check_key_reuse_jaxpr(jaxpr: core.Jaxpr, *, use_forwarding: bool = True): def get_jaxpr_type_signature( jaxpr: core.Jaxpr, *, - consumed_inputs: Optional[list[Union[bool, np.ndarray]]] = None, + consumed_inputs: list[bool | np.ndarray] | None = None, use_forwarding: bool = True, ) -> KeyReuseSignature: """Parse the jaxpr to determine key reuse signature""" diff --git a/jax/experimental/key_reuse/_forwarding.py b/jax/experimental/key_reuse/_forwarding.py index 4b4c682f2466..69b0a29b733b 100644 --- a/jax/experimental/key_reuse/_forwarding.py +++ b/jax/experimental/key_reuse/_forwarding.py @@ -16,7 +16,7 @@ from collections import defaultdict from functools import reduce -from typing import Any, Callable, NamedTuple, Optional, Union +from typing import Any, Callable, NamedTuple import jax from jax import core @@ -83,11 +83,11 @@ def is_key(var: core.Atom): def get_jaxpr_type_signature( jaxpr: core.Jaxpr, - consumed_inputs: Optional[list[Union[bool, np.ndarray]]] = None, - forwarded_inputs: Optional[dict[int, int]] = None, + consumed_inputs: list[bool | np.ndarray] | None = None, + forwarded_inputs: dict[int, int] | None = None, ) -> KeyReuseSignatureWithForwards: """Parse the jaxpr to determine key reuse signature""" - consumed: dict[core.Atom, Union[bool, np.ndarray]] = {} + consumed: dict[core.Atom, bool | np.ndarray] = {} forwards: dict[core.Atom, core.Atom] = {} # map forwarded outputs to inputs. def resolve_forwards(var: core.Atom) -> core.Atom: @@ -221,8 +221,8 @@ def _assert_consumed_value_key_type_signature(eqn, args_consumed): def _cond_key_type_signature(eqn, args_consumed): signatures = [get_jaxpr_type_signature(branch.jaxpr, consumed_inputs=args_consumed[1:]) for branch in eqn.params['branches']] - sinks = defaultdict(lambda: []) - sources = defaultdict(lambda: []) + sinks = defaultdict(list) + sources = defaultdict(list) for sig in signatures: for sink in sig.sinks: sinks[sink.idx].append(sink.mask) diff --git a/jax/experimental/key_reuse/_simple.py b/jax/experimental/key_reuse/_simple.py index b06630e35ce0..86b1a19842e5 100644 --- a/jax/experimental/key_reuse/_simple.py +++ b/jax/experimental/key_reuse/_simple.py @@ -16,7 +16,7 @@ from collections import defaultdict from functools import reduce -from typing import Any, Callable, NamedTuple, Optional, Union +from typing import Any, Callable import jax from jax import core @@ -75,10 +75,10 @@ def is_key(var: core.Atom): def get_jaxpr_type_signature( jaxpr: core.Jaxpr, - consumed_inputs: Optional[list[Union[bool, np.ndarray]]] = None, + consumed_inputs: list[bool | np.ndarray] | None = None, ) -> KeyReuseSignature: """Parse the jaxpr to determine key reuse signature""" - consumed: dict[core.Atom, Union[bool, np.ndarray]] = {} + consumed: dict[core.Atom, bool | np.ndarray] = {} def is_key(var: core.Atom): return hasattr(var.aval, "dtype") and jax.dtypes.issubdtype(var.aval.dtype, jax.dtypes.prng_key) @@ -196,8 +196,8 @@ def _assert_consumed_value_key_type_signature(eqn, args_consumed): def _cond_key_type_signature(eqn, args_consumed): signatures = [get_jaxpr_type_signature(branch.jaxpr, consumed_inputs=args_consumed[1:]) for branch in eqn.params['branches']] - sinks = defaultdict(lambda: []) - sources = defaultdict(lambda: []) + sinks = defaultdict(list) + sources = defaultdict(list) for sig in signatures: for sink in sig.sinks: sinks[sink.idx].append(sink.mask) diff --git a/jax/experimental/mesh_utils.py b/jax/experimental/mesh_utils.py index faf48a841f07..8a3b8e2823ce 100644 --- a/jax/experimental/mesh_utils.py +++ b/jax/experimental/mesh_utils.py @@ -14,6 +14,8 @@ # ============================================================================== """Utils for building a device mesh.""" +from __future__ import annotations + import collections from collections.abc import Sequence import itertools @@ -95,7 +97,7 @@ def _tpu_v2_v3_create_device_mesh( # return None; in that case, it will fall back to using the default logic. device_kind_handler_dict: dict[ str, - Callable[..., Optional[np.ndarray]], + Callable[..., np.ndarray | None], ] = { _TPU_V2: _tpu_v2_v3_create_device_mesh, _TPU_V3: _tpu_v2_v3_create_device_mesh, @@ -269,7 +271,7 @@ def _transpose_trick(physical_mesh: np.ndarray, def create_device_mesh( mesh_shape: Sequence[int], - devices: Optional[Sequence[Any]] = None, *, + devices: Sequence[Any] | None = None, *, contiguous_submeshes: bool = False) -> np.ndarray: """Creates a performant device mesh for jax.sharding.Mesh. @@ -323,7 +325,7 @@ def create_device_mesh( def create_hybrid_device_mesh(mesh_shape: Sequence[int], dcn_mesh_shape: Sequence[int], - devices: Optional[Sequence[Any]] = None, *, + devices: Sequence[Any] | None = None, *, process_is_granule: bool = False) -> np.ndarray: """Creates a device mesh for hybrid (e.g., ICI and DCN) parallelism. diff --git a/jax/experimental/multihost_utils.py b/jax/experimental/multihost_utils.py index 8b3fb735477b..8d60c21b415e 100644 --- a/jax/experimental/multihost_utils.py +++ b/jax/experimental/multihost_utils.py @@ -13,6 +13,8 @@ # limitations under the License. """Utilities for synchronizing and communication across multiple hosts.""" +from __future__ import annotations + from functools import partial, lru_cache from typing import Optional import zlib @@ -41,7 +43,7 @@ def _psum(x: Any) -> Any: return jax.tree_map(partial(jnp.sum, axis=0), x) -def broadcast_one_to_all(in_tree: Any, is_source: Optional[bool] = None) -> Any: +def broadcast_one_to_all(in_tree: Any, is_source: bool | None = None) -> Any: """Broadcast data from a source host (host 0 by default) to all other hosts. Args: diff --git a/jax/experimental/pallas/ops/layer_norm.py b/jax/experimental/pallas/ops/layer_norm.py index dbae0e6feedc..31c6fe43c876 100644 --- a/jax/experimental/pallas/ops/layer_norm.py +++ b/jax/experimental/pallas/ops/layer_norm.py @@ -14,6 +14,8 @@ """Module containing fused layer norm forward and backward pass.""" +from __future__ import annotations + import functools from typing import Optional @@ -69,8 +71,8 @@ def body(i, _): def layer_norm_forward( x, weight, bias, - num_warps: Optional[int] = None, - num_stages: Optional[int] = 3, + num_warps: int | None = None, + num_stages: int | None = 3, eps: float = 1e-5, backward_pass_impl: str = 'triton', interpret: bool = False): @@ -179,8 +181,8 @@ def body(i, acc_ref): def layer_norm_backward( - num_warps: Optional[int], - num_stages: Optional[int], + num_warps: int | None, + num_stages: int | None, eps: float, backward_pass_impl: str, interpret: bool, @@ -246,8 +248,8 @@ def layer_norm_backward( "interpret"]) def layer_norm( x, weight, bias, - num_warps: Optional[int] = None, - num_stages: Optional[int] = 3, + num_warps: int | None = None, + num_stages: int | None = 3, eps: float = 1e-5, backward_pass_impl: str = 'triton', interpret: bool = False): diff --git a/jax/experimental/pallas/ops/rms_norm.py b/jax/experimental/pallas/ops/rms_norm.py index a1b850af905e..f6a5bc6f71f7 100644 --- a/jax/experimental/pallas/ops/rms_norm.py +++ b/jax/experimental/pallas/ops/rms_norm.py @@ -14,6 +14,8 @@ """Module containing rms forward and backward pass.""" +from __future__ import annotations + import functools from typing import Optional @@ -58,8 +60,8 @@ def body(i, _): def rms_norm_forward( x, weight, bias, - num_warps: Optional[int] = None, - num_stages: Optional[int] = 3, + num_warps: int | None = None, + num_stages: int | None = 3, eps: float = 1e-5, backward_pass_impl: str = 'triton', interpret: bool = False): @@ -161,8 +163,8 @@ def body(i, acc_ref): def rms_norm_backward( - num_warps: Optional[int], - num_stages: Optional[int], + num_warps: int | None, + num_stages: int | None, eps: float, backward_pass_impl: str, interpret: bool, @@ -227,8 +229,8 @@ def rms_norm_backward( "interpret"]) def rms_norm( x, weight, bias, - num_warps: Optional[int] = None, - num_stages: Optional[int] = 3, + num_warps: int | None = None, + num_stages: int | None = 3, eps: float = 1e-5, backward_pass_impl: str = 'triton', interpret: bool = False): diff --git a/jax/experimental/serialize_executable.py b/jax/experimental/serialize_executable.py index 86c8d9cc9372..1c0614da516f 100644 --- a/jax/experimental/serialize_executable.py +++ b/jax/experimental/serialize_executable.py @@ -13,6 +13,8 @@ # limitations under the License. """Pickling support for precompiled binaries.""" +from __future__ import annotations + import pickle import io from typing import Optional, Union @@ -42,7 +44,7 @@ def serialize(compiled: jax.stages.Compiled): def deserialize_and_load(serialized, in_tree, out_tree, - backend: Optional[Union[str, xc.Client]] = None): + backend: str | xc.Client | None = None): """Constructs a jax.stages.Compiled from a serialized executable.""" if backend is None or isinstance(backend, str): diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 17deec81c0d2..daf96515b822 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -456,7 +456,7 @@ def _unshard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue # Type-checking -RepType = Optional[set[AxisName]] +RepType = Union[set[AxisName], None] def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_names, out_names, check_rep, rewrite, auto): diff --git a/jax/experimental/sparse/ad.py b/jax/experimental/sparse/ad.py index 649c1dda2437..910f3bf233cf 100644 --- a/jax/experimental/sparse/ad.py +++ b/jax/experimental/sparse/ad.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from collections.abc import Sequence import itertools from typing import Any, Callable, Union @@ -29,7 +31,7 @@ is_sparse = lambda x: isinstance(x, JAXSparse) -def flatten_fun_for_sparse_ad(fun, argnums: Union[int, tuple[int]], args: tuple[Any]): +def flatten_fun_for_sparse_ad(fun, argnums: int | tuple[int], args: tuple[Any]): argnums_tup = _ensure_index_tuple(argnums) assert all(0 <= argnum < len(args) for argnum in argnums_tup) @@ -71,7 +73,7 @@ def postprocess_gradients(grads_out): return fun_flat, argnums_flat, args_flat, postprocess_gradients -def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0, +def value_and_grad(fun: Callable, argnums: int | Sequence[int] = 0, has_aux=False, **kwargs) -> Callable[..., tuple[Any, Any]]: """Sparse-aware version of :func:`jax.value_and_grad` @@ -99,7 +101,7 @@ def value_and_grad_fun(*args, **kwargs): return value_and_grad_fun -def grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0, +def grad(fun: Callable, argnums: int | Sequence[int] = 0, has_aux=False, **kwargs) -> Callable: """Sparse-aware version of :func:`jax.grad` @@ -129,7 +131,7 @@ def grad_fun(*args, **kwargs): return grad_fun -def jacfwd(fun: Callable, argnums: Union[int, Sequence[int]] = 0, +def jacfwd(fun: Callable, argnums: int | Sequence[int] = 0, has_aux: bool = False, **kwargs) -> Callable: """Sparse-aware version of :func:`jax.jacfwd` @@ -152,7 +154,7 @@ def jacfwd_fun(*args, **kwargs): return jacfwd_fun -def jacrev(fun: Callable, argnums: Union[int, Sequence[int]] = 0, +def jacrev(fun: Callable, argnums: int | Sequence[int] = 0, has_aux: bool = False, **kwargs) -> Callable: """Sparse-aware version of :func:`jax.jacrev` diff --git a/jax/experimental/sparse/api.py b/jax/experimental/sparse/api.py index 62ac3bfabdbf..5e64e1e14910 100644 --- a/jax/experimental/sparse/api.py +++ b/jax/experimental/sparse/api.py @@ -29,6 +29,9 @@ Further down are some examples of potential high-level wrappers for sparse objects. (API should be considered unstable and subject to change). """ + +from __future__ import annotations + from functools import partial import operator from typing import Optional, Union @@ -56,7 +59,7 @@ todense_p = core.Primitive('todense') todense_p.multiple_results = False -def todense(arr: Union[JAXSparse, Array]) -> Array: +def todense(arr: JAXSparse | Array) -> Array: """Convert input to a dense matrix. If input is already dense, pass through.""" bufs, tree = tree_util.tree_flatten(arr) return todense_p.bind(*bufs, tree=tree) @@ -113,7 +116,7 @@ def _todense_batching_rule(batched_args, batch_dims, *, tree): _todense_impl, multiple_results=False)) -def empty(shape: Shape, dtype: Optional[DTypeLike]=None, index_dtype: DTypeLike = 'int32', +def empty(shape: Shape, dtype: DTypeLike | None=None, index_dtype: DTypeLike = 'int32', sparse_format: str = 'bcoo', **kwds) -> JAXSparse: """Create an empty sparse array. @@ -134,7 +137,7 @@ def empty(shape: Shape, dtype: Optional[DTypeLike]=None, index_dtype: DTypeLike return cls._empty(shape, dtype=dtype, index_dtype=index_dtype, **kwds) -def eye(N: int, M: Optional[int] = None, k: int = 0, dtype: Optional[DTypeLike] = None, +def eye(N: int, M: int | None = None, k: int = 0, dtype: DTypeLike | None = None, index_dtype: DTypeLike = 'int32', sparse_format: str = 'bcoo', **kwds) -> JAXSparse: """Create 2D sparse identity matrix. diff --git a/jax/experimental/sparse/linalg.py b/jax/experimental/sparse/linalg.py index 047043651f08..a23a491abf1e 100644 --- a/jax/experimental/sparse/linalg.py +++ b/jax/experimental/sparse/linalg.py @@ -14,6 +14,8 @@ """Sparse linear algebra routines.""" +from __future__ import annotations + from typing import Union, Callable import functools @@ -33,10 +35,10 @@ def lobpcg_standard( - A: Union[jax.Array, Callable[[jax.Array], jax.Array]], + A: jax.Array | Callable[[jax.Array], jax.Array], X: jax.Array, m: int = 100, - tol: Union[jax.Array, float, None] = None): + tol: jax.Array | float | None = None): """Compute the top-k standard eigenvalues using the LOBPCG routine. LOBPCG [1] stands for Locally Optimal Block Preconditioned Conjugate Gradient. @@ -106,7 +108,7 @@ def _lobpcg_standard_matrix( A: jax.Array, X: jax.Array, m: int, - tol: Union[jax.Array, float, None], + tol: jax.Array | float | None, debug: bool = False): """Computes lobpcg_standard(), possibly with debug diagnostics.""" return _lobpcg_standard_callable( @@ -117,7 +119,7 @@ def _lobpcg_standard_callable( A: Callable[[jax.Array], jax.Array], X: jax.Array, m: int, - tol: Union[jax.Array, float, None], + tol: jax.Array | float | None, debug: bool = False): """Supports generic lobpcg_standard() callable interface.""" diff --git a/jax/experimental/sparse/test_util.py b/jax/experimental/sparse/test_util.py index 3f4df1febece..52b704f5f52c 100644 --- a/jax/experimental/sparse/test_util.py +++ b/jax/experimental/sparse/test_util.py @@ -13,6 +13,8 @@ # limitations under the License. """Sparse test utilities.""" +from __future__ import annotations + from collections.abc import Iterable, Iterator, Sequence import functools import itertools @@ -145,8 +147,8 @@ def batched_args_maker(): def _rand_sparse(shape: Sequence[int], dtype: DTypeLike, *, rng: np.random.RandomState, rand_method: Callable[..., Any], - nse: Union[int, float], n_batch: int, n_dense: int, - sparse_format: str) -> Union[sparse.BCOO, sparse.BCSR]: + nse: int | float, n_batch: int, n_dense: int, + sparse_format: str) -> sparse.BCOO | sparse.BCSR: if sparse_format not in ['bcoo', 'bcsr']: raise ValueError(f"Sparse format {sparse_format} not supported.") @@ -186,7 +188,7 @@ def _rand_sparse(shape: Sequence[int], dtype: DTypeLike, *, def rand_bcoo(rng: np.random.RandomState, rand_method: Callable[..., Any]=jtu.rand_default, - nse: Union[int, float]=0.5, n_batch: int=0, n_dense: int=0): + nse: int | float=0.5, n_batch: int=0, n_dense: int=0): """Generates a random BCOO array.""" return functools.partial(_rand_sparse, rng=rng, rand_method=rand_method, nse=nse, n_batch=n_batch, n_dense=n_dense, @@ -194,7 +196,7 @@ def rand_bcoo(rng: np.random.RandomState, def rand_bcsr(rng: np.random.RandomState, rand_method: Callable[..., Any]=jtu.rand_default, - nse: Union[int, float]=0.5, n_batch: int=0, n_dense: int=0): + nse: int | float=0.5, n_batch: int=0, n_dense: int=0): """Generates a random BCSR array.""" return functools.partial(_rand_sparse, rng=rng, rand_method=rand_method, nse=nse, n_batch=n_batch, n_dense=n_dense, diff --git a/jax/experimental/topologies.py b/jax/experimental/topologies.py index ae04d162effc..9bf2993bdc12 100644 --- a/jax/experimental/topologies.py +++ b/jax/experimental/topologies.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import abc from collections.abc import Sequence from typing import Optional @@ -34,7 +36,7 @@ def get_attached_topology(platform=None) -> TopologyDescription: def get_topology_desc( - topology_name: str = "", platform: Optional[str] = None, **kwargs + topology_name: str = "", platform: str | None = None, **kwargs ) -> TopologyDescription: if platform == "tpu" or platform is None: return TopologyDescription( diff --git a/jaxlib/gpu_prng.py b/jaxlib/gpu_prng.py index 44cdb6b292e8..f1e0ed6a400a 100644 --- a/jaxlib/gpu_prng.py +++ b/jaxlib/gpu_prng.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import functools from functools import partial import importlib import itertools import operator -from typing import Optional, Union import jaxlib.mlir.ir as ir @@ -51,8 +51,8 @@ _prod = lambda xs: functools.reduce(operator.mul, xs, 1) def _threefry2x32_lowering(prng, platform, keys, data, - length: Optional[Union[int, ir.Value]] = None, - output_shape: Optional[ir.Value] = None): + length: int | ir.Value | None = None, + output_shape: ir.Value | None = None): """ThreeFry2x32 kernel for GPU. In presence of dynamic shapes, `length` is an `ir.Value` and `output_shape` diff --git a/jaxlib/hlo_helpers.py b/jaxlib/hlo_helpers.py index 1cd0d7ffd181..19979f61feb1 100644 --- a/jaxlib/hlo_helpers.py +++ b/jaxlib/hlo_helpers.py @@ -14,9 +14,11 @@ """A small library of helpers for use in jaxlib to build MLIR operations.""" +from __future__ import annotations + from collections.abc import Sequence from functools import partial -from typing import Callable, Optional, Union +from typing import Callable, Union import jaxlib.mlir.ir as ir import jaxlib.mlir.dialects.stablehlo as hlo @@ -58,7 +60,7 @@ def shape_dtype_to_ir_type(shape: Sequence[int], dtype) -> ir.Type: def mk_result_types_and_shapes( shape_type_pairs: Sequence[ShapeTypePair] -) -> tuple[list[ir.Type], Optional[list[ir.Value]]]: +) -> tuple[list[ir.Type], list[ir.Value] | None]: result_types: list[ir.Type] = [] result_shapes: list[ir.Value] = [] has_dynamic_shapes = any( @@ -76,7 +78,7 @@ def mk_result_types_and_shapes( result_shapes if has_dynamic_shapes else None) # TODO(necula): share this with mlir.shape_tensor -def shape_tensor(sizes: Sequence[Union[int, ir.Value]]) -> ir.Value: +def shape_tensor(sizes: Sequence[int | ir.Value]) -> ir.Value: int1d = shape_dtype_to_ir_type((1,), np.int32) i32_type = shape_dtype_to_ir_type((), np.int32) def dim_to_i32x1(d): @@ -108,7 +110,7 @@ def hlo_s32(x: int): def ensure_hlo_s32(x: DimensionSize): return hlo_s32(x) if isinstance(x, int) else x -def dense_int_array(xs) -> Union[ir.DenseIntElementsAttr, ir.DenseI64ArrayAttr]: +def dense_int_array(xs) -> ir.DenseIntElementsAttr | ir.DenseI64ArrayAttr: if hlo.get_api_version() < 5: return ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64)) return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64)) @@ -140,15 +142,15 @@ def custom_call( *, result_types: Sequence[ir.Type], operands: Sequence[ir.Value], - backend_config: Union[str, bytes, dict[str, ir.Attribute]] = "", + backend_config: str | bytes | dict[str, ir.Attribute] = "", has_side_effect: bool = False, - result_shapes: Optional[Sequence[ir.Value]] = None, + result_shapes: Sequence[ir.Value] | None = None, called_computations: Sequence[str] = (), api_version: int = 2, - operand_output_aliases: Optional[dict[int, int]] = None, - operand_layouts: Optional[Sequence[Sequence[int]]] = None, - result_layouts: Optional[Sequence[Sequence[int]]] = None, - extra_attributes: Optional[dict[str, ir.Attribute]] = None, + operand_output_aliases: dict[int, int] | None = None, + operand_layouts: Sequence[Sequence[int]] | None = None, + result_layouts: Sequence[Sequence[int]] | None = None, + extra_attributes: dict[str, ir.Attribute] | None = None, ) -> ir.Operation: """Helper function for building an hlo.CustomCall. diff --git a/jaxlib/mosaic/python/apply_vector_layout.py b/jaxlib/mosaic/python/apply_vector_layout.py index 20e60be011b0..c231ac39aae1 100644 --- a/jaxlib/mosaic/python/apply_vector_layout.py +++ b/jaxlib/mosaic/python/apply_vector_layout.py @@ -19,13 +19,16 @@ """ # mypy: ignore-errors + +from __future__ import annotations + import abc from collections.abc import Sequence import dataclasses import functools import math import re -from typing import Any, Callable, Literal, Optional, Union, overload +from typing import Any, Callable, Literal, Union, overload from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith @@ -139,7 +142,7 @@ class VectorLayout: bitwidth: int offsets: tuple[Offset, Offset] # Replication applies only within a tile. tiling: tuple[int, int] - implicit_dim: Optional[ImplicitDim] + implicit_dim: ImplicitDim | None def __post_init__(self): # TODO(b/275751535): Allow more bitwidths. @@ -252,8 +255,8 @@ def tile_array_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: else: raise AssertionError(f"Invalid implicit dim: {self.implicit_dim}") - def generalizes(self, other: "VectorLayout", - shape: Optional[tuple[int, ...]] = None) -> bool: + def generalizes(self, other: VectorLayout, + shape: tuple[int, ...] | None = None) -> bool: """Returns True if the other layout is a special case of this one. In here, other is considered "a special case" when the set of vector @@ -318,8 +321,8 @@ def generalizes(self, other: "VectorLayout", return False return True - def equivalent_to(self, other: "VectorLayout", - shape: Optional[tuple[int, ...]] = None) -> bool: + def equivalent_to(self, other: VectorLayout, + shape: tuple[int, ...] | None = None) -> bool: """Returns True if the two layouts are equivalent. That is, when all potential vector entries where the value can be stored @@ -338,8 +341,8 @@ def tile_data_bounds( self, full_shape: tuple[int, ...], ixs: tuple[int, ...], - allow_replicated: Union[bool, TargetTuple] = False, - ) -> "VRegDataBounds": + allow_replicated: bool | TargetTuple = False, + ) -> VRegDataBounds: """Returns the bounds of the given tile that hold useful data. Arguments: @@ -786,7 +789,7 @@ def get_sublane_mask(self) -> ir.Attribute: return ir.DenseBoolArrayAttr.get(mask) -Layout = Optional[VectorLayout] +Layout = Union[VectorLayout, None] PATTERN = re.compile( r'#tpu.vpad<"([0-9]+),{([*0-9]+),([*0-9]+)},\(([0-9]+),([0-9]+)\)(,-1|,-2)?">' @@ -1147,7 +1150,7 @@ def is_supported_reduced_sublanes_retile( def copy_one_sublane( src_vreg: ir.Value, src_sl_idx: int, - dst_vreg: Optional[ir.Value], + dst_vreg: ir.Value | None, dst_sl_idx: int, ) -> ir.Value: """Copy one sublane from a vreg to another vreg. @@ -1459,17 +1462,17 @@ class RewriteContext: func: func.FuncOp hardware_generation: int - def erase(self, op: Union[ir.Operation, ir.OpView]): + def erase(self, op: ir.Operation | ir.OpView): if isinstance(op, ir.OpView): op = op.operation op.erase() - def replace(self, old: Union[ir.Operation, ir.OpView], new: ValueLike): + def replace(self, old: ir.Operation | ir.OpView, new: ValueLike): self.replace_all_uses_with(old, new) self.erase(old) def replace_all_uses_with( - self, old: Union[ir.Operation, ir.OpView], new: ValueLike + self, old: ir.Operation | ir.OpView, new: ValueLike ): if isinstance(new, (ir.Operation, ir.OpView)): new = new.results @@ -1686,7 +1689,7 @@ def _arith_constant_rule(ctx: RewriteContext, op: arith.ConstantOp, # pylint: d def _elementwise_op_rule(factory, # pylint: disable=missing-function-docstring ctx: RewriteContext, op: ir.Operation, - layout_in: Union[Layout, tuple[Layout, ...]], + layout_in: Layout | tuple[Layout, ...], layout_out: Layout): if not isinstance(layout_in, tuple): layout_in = (layout_in,) @@ -1988,7 +1991,7 @@ def _scf_if_rule( # pylint: disable=missing-function-docstring ctx: RewriteContext, op: scf.IfOp, layout_in: None, # pylint: disable=unused-argument - layout_out: Union[Layout, tuple[Layout, ...]], + layout_out: Layout | tuple[Layout, ...], ): if len(op.results) == 1: layout_out = (layout_out,) @@ -2066,7 +2069,7 @@ def _scf_if_rule( # pylint: disable=missing-function-docstring def _scf_yield_rule( # pylint: disable=missing-function-docstring ctx: RewriteContext, op: scf.YieldOp, - layout_in: Union[Layout, tuple[Layout, ...]], + layout_in: Layout | tuple[Layout, ...], layout_out: None, # pylint: disable=unused-argument ): if not op.operands: @@ -2094,8 +2097,8 @@ def _scf_yield_rule( # pylint: disable=missing-function-docstring def _scf_for_rule( # pylint: disable=missing-function-docstring ctx: RewriteContext, op: scf.ForOp, - layout_in: Union[Layout, tuple[Layout, ...]], - layout_out: Union[Layout, tuple[Layout, ...]], + layout_in: Layout | tuple[Layout, ...], + layout_out: Layout | tuple[Layout, ...], ): # TODO(b/286175570) Support inputs and outputs in scf.for. assert layout_in and len(layout_in) == 3 @@ -3488,7 +3491,7 @@ def type_bitwidth(ty: ir.Type) -> int: raise NotImplementedError(ty) -def get_constant(ty: ir.Type, value: Union[int, float]) -> ir.Attribute: +def get_constant(ty: ir.Type, value: int | float) -> ir.Attribute: if ir.IntegerType.isinstance(ty): return ir.IntegerAttr.get(ty, value) elif ty == ir.IndexType.get(): diff --git a/tests/api_test.py b/tests/api_test.py index faabe83a7688..00758757e893 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import collections import collections.abc @@ -32,7 +33,7 @@ import subprocess import sys import types -from typing import Callable, NamedTuple, Optional +from typing import Callable, NamedTuple import unittest import weakref @@ -6226,7 +6227,7 @@ class DCETest(jtu.JaxTestCase): def assert_dce_result(self, jaxpr: core.Jaxpr, used_outputs: list[bool], expected_used_inputs: list[bool], - expected_num_eqns: Optional[int] = None, + expected_num_eqns: int | None = None, check_diff: bool = True): jaxpr_dce, used_inputs = pe.dce_jaxpr(jaxpr, used_outputs) core.check_jaxpr(jaxpr_dce) diff --git a/tests/batching_test.py b/tests/batching_test.py index e2376ee2eec7..3bcd4c5216cc 100644 --- a/tests/batching_test.py +++ b/tests/batching_test.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from contextlib import contextmanager from functools import partial import itertools as it -from typing import Any, Optional, Callable, Union, TypeVar +from typing import Any, Callable, TypeVar, Union import numpy as np from absl.testing import absltest @@ -1348,10 +1350,10 @@ def __repr__(self) -> str: return f'NamedArray(names={self.names}, data={self.data})' class NamedMapSpec: - name: Optional[str] - axis: Optional[int] + name: str | None + axis: int | None - def __init__(self, name: str, axis: Optional[int]): + def __init__(self, name: str, axis: int | None): assert (name is None) == (axis is None) self.name = name self.axis = axis @@ -1366,7 +1368,7 @@ def named_mul(x: NamedArray, y: NamedArray) -> NamedArray: lambda names, xs: NamedArray(names, xs[0])) -def named_to_elt(cont: Callable[[Array, Optional[int]], ArrayElt], +def named_to_elt(cont: Callable[[Array, int | None], ArrayElt], _: Int, val: NamedArray, spec: NamedMapSpec) -> NamedArray: if spec.name is None: return val @@ -1376,7 +1378,7 @@ def named_to_elt(cont: Callable[[Array, Optional[int]], ArrayElt], elt = cont(val.data, spec.axis) return NamedArray(elt_names, elt) -def named_from_elt(cont: Callable[[int, ArrayElt, Optional[int]], Array], +def named_from_elt(cont: Callable[[int, ArrayElt, int | None], Array], axis_size: int, elt: NamedArray, annotation: NamedMapSpec ) -> NamedArray: data = cont(axis_size, elt.data, annotation.axis) diff --git a/tests/export_harnesses_multi_platform_test.py b/tests/export_harnesses_multi_platform_test.py index 0ca617f878ee..406c092270f6 100644 --- a/tests/export_harnesses_multi_platform_test.py +++ b/tests/export_harnesses_multi_platform_test.py @@ -19,9 +19,11 @@ cross-platform lowering is tested in export_test.py. """ +from __future__ import annotations + import math import re -from typing import Callable, Optional +from typing import Callable from absl import logging from absl.testing import absltest @@ -127,7 +129,7 @@ def export_and_compare_to_native( *args: jax.Array, unimplemented_platforms: set[str] = set(), skip_run_on_platforms: set[str] = set(), - tol: Optional[float] = None): + tol: float | None = None): devices = [ d for d in self.__class__.devices diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py index c2acd4f974f2..6ffd07521119 100644 --- a/tests/host_callback_test.py +++ b/tests/host_callback_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from collections.abc import Sequence from functools import partial import itertools @@ -20,7 +22,7 @@ import re import threading import time -from typing import Callable, Optional +from typing import Callable import unittest from unittest import skip, SkipTest @@ -102,7 +104,7 @@ def fun1_equiv(a): # Numerical equivalent of fun1 def maybe_print(do_print: bool, arg, what: str, - tap_with_device: Optional[bool] = False, + tap_with_device: bool | None = False, device_index: int = 0): """Conditionally print on testing_string""" if do_print: diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index 5e2707f80d5b..55413edd492b 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import enum from functools import partial import itertools import typing -from typing import Any, Optional +from typing import Any from absl.testing import absltest from absl.testing import parameterized @@ -54,7 +55,7 @@ class IndexSpec(typing.NamedTuple): shape: tuple[int, ...] indexer: Any - out_shape: Optional[tuple[int, ...]] = None + out_shape: tuple[int, ...] | None = None def check_grads(f, args, order, atol=None, rtol=None, eps=None): diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 14065a1a5357..853f8d217919 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import collections from collections.abc import Iterator @@ -22,7 +23,7 @@ import itertools import math import platform -from typing import cast, Optional +from typing import cast import unittest from unittest import SkipTest @@ -3888,7 +3889,7 @@ def testUnpackbits(self, shape, dtype, axis, bitorder, count): [dict(shape=shape, axis=axis) for shape in [(3,), (3, 4), (3, 4, 5)] for axis in itertools.chain(range(-len(shape), len(shape)), - [cast(Optional[int], None)]) + [cast(int | None, None)]) ], index_shape=scalar_shapes + [(3,), (2, 1, 3)], dtype=all_dtypes, @@ -3940,7 +3941,7 @@ def testTakeOptionalArgs(self): filter(_shapes_are_broadcast_compatible, itertools.combinations_with_replacement(nonempty_nonscalar_array_shapes, 2))) for axis in itertools.chain(range(len(x_shape)), [-1], - [cast(Optional[int], None)]) + [cast(int | None, None)]) ], dtype=default_dtypes, index_dtype=int_dtypes, diff --git a/tests/lax_vmap_test.py b/tests/lax_vmap_test.py index b3fa18ed7c00..fae1ae32cc17 100644 --- a/tests/lax_vmap_test.py +++ b/tests/lax_vmap_test.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations from functools import partial import itertools import math -from typing import Optional, cast +from typing import cast import unittest from absl.testing import absltest @@ -73,8 +74,8 @@ def _CheckBatching(self, op, bdim_size, bdims, shapes, dtypes, rng, for lhs_shape in [(b * batch_group_count, i * feature_group_count, 6, 7)] for rhs_shape in [(j * batch_group_count * feature_group_count, i, 1, 2)]], [dict(lhs_bdim=lhs_bdim, rhs_bdim=rhs_bdim) - for lhs_bdim in itertools.chain([cast(Optional[int], None)], range(5)) - for rhs_bdim in itertools.chain([cast(Optional[int], None)], range(5)) + for lhs_bdim in itertools.chain([cast(int | None, None)], range(5)) + for rhs_bdim in itertools.chain([cast(int | None, None)], range(5)) if (lhs_bdim, rhs_bdim) != (None, None) ], [dict(dimension_numbers=dim_nums, perms=perms) diff --git a/tests/pallas/indexing_test.py b/tests/pallas/indexing_test.py index 5c4a9778c2bb..344774c0184f 100644 --- a/tests/pallas/indexing_test.py +++ b/tests/pallas/indexing_test.py @@ -13,7 +13,9 @@ # limitations under the License. """Tests for Pallas indexing logic and abstractions.""" -from typing import Union + +from __future__ import annotations + import unittest from absl.testing import absltest @@ -45,7 +47,7 @@ def int_indexer_strategy(dim) -> hps.SearchStrategy[int]: @hps.composite -def slice_indexer_strategy(draw, dim) -> Union[Slice, slice]: +def slice_indexer_strategy(draw, dim) -> Slice | slice: start = draw(int_indexer_strategy(dim)) size = draw(hps.integers(min_value=0, max_value=np.iinfo(np.int32).max)) return draw( @@ -64,7 +66,7 @@ def array_indexer_strategy(draw, shape) -> jax.Array: @hps.composite def indexer_strategy(draw, dim, int_indexer_shape - ) -> Union[int, Slice, jax.Array]: + ) -> int | Slice | jax.Array: return draw(hps.one_of( int_indexer_strategy(dim), slice_indexer_strategy(dim), diff --git a/tests/pmap_test.py b/tests/pmap_test.py index f79245f68654..6c277e099280 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations from concurrent.futures import ThreadPoolExecutor from functools import partial @@ -21,7 +22,7 @@ import os from random import shuffle import re -from typing import Optional, cast +from typing import cast import unittest from unittest import SkipTest import weakref @@ -60,7 +61,7 @@ compatible_shapes = [[(3,)], [(3, 4), (3, 1), (1, 4)], [(2, 3, 4), (2, 1, 4)]] def all_bdims(*shapes, pmap): - bdims = (it.chain([cast(Optional[int], None)], range(len(shape) + 1)) + bdims = (it.chain([cast(int | None, None)], range(len(shape) + 1)) for shape in shapes) return (t for t in it.product(*bdims) if not all(e is None for e in t)) diff --git a/tests/random_test.py b/tests/random_test.py index 52ce462e3e43..886f82814e1b 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import copy import enum from functools import partial import math from unittest import skipIf -from typing import Any, NamedTuple, Optional +from typing import Any, NamedTuple import zlib from absl.testing import absltest @@ -59,8 +61,8 @@ class RandomValuesCase(NamedTuple): params: dict expected: np.ndarray on_x64: OnX64 = OnX64.ALSO - atol: Optional[float] = None - rtol: Optional[float] = None + atol: float | None = None + rtol: float | None = None def _testname(self): if self.dtype is None: diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index c23bbbd44667..6df7a99bf20f 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -12,11 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests for the shape-polymorphic export.""" + +from __future__ import annotations + import enum from collections.abc import Sequence import itertools import math -from typing import Any, Callable, Optional +from typing import Any, Callable import unittest from absl import logging @@ -527,10 +530,10 @@ def __init__(self, fun: Callable[..., Any], *, arg_descriptors: Sequence[test_harnesses.ArgDescriptor] = (), - polymorphic_shapes: Sequence[Optional[str]] = (), - expect_error: Optional[tuple[Any, str]] = None, + polymorphic_shapes: Sequence[str | None] = (), + expect_error: tuple[Any, str] | None = None, check_result: bool = True, - tol: Optional[float] = None, + tol: float | None = None, limitations: Sequence[test_harnesses.Limitation] = (), override_jax_config_flags: dict[str, Any] = {}): """Args: @@ -561,7 +564,7 @@ def __init__(self, self.limitations = limitations self.override_jax_config_flags = override_jax_config_flags - def run_test(self, tst: jtu.JaxTestCase) -> Optional[jax.Array]: + def run_test(self, tst: jtu.JaxTestCase) -> jax.Array | None: def log_message(extra: str): return f"[{tst._testMethodName}]: {extra}" @@ -613,8 +616,8 @@ def log_message(extra: str): def check_shape_poly(tst, f_jax: Callable, *, arg_descriptors: Sequence[test_harnesses.ArgDescriptor] = (), - polymorphic_shapes: Sequence[Optional[str]] = (), - expect_error=None) -> Optional[jax.Array]: + polymorphic_shapes: Sequence[str | None] = (), + expect_error=None) -> jax.Array | None: # Builds a PolyHarness and runs the test. See PolyHarness documentation. h = PolyHarness("", "", f_jax, arg_descriptors=arg_descriptors, @@ -2290,7 +2293,7 @@ def _make_vmap_primitive_harnesses() -> Sequence[PolyHarness]: continue def make_batched_arg_descriptor( - ad: test_harnesses.ArgDescriptor) -> Optional[test_harnesses.ArgDescriptor]: + ad: test_harnesses.ArgDescriptor) -> test_harnesses.ArgDescriptor | None: if isinstance(ad, RandArg): return RandArg((batch_size,) + ad.shape, ad.dtype) elif isinstance(ad, CustomArg): diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index fbde159ec74b..e3f1d52b3970 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from collections.abc import Sequence, Iterable, Iterator, Generator from functools import partial import itertools as it @@ -19,7 +21,7 @@ import operator as op import os from types import SimpleNamespace -from typing import Any, NamedTuple, Callable, Optional, TypeVar, Union +from typing import Any, NamedTuple, Callable, TypeVar import unittest from absl.testing import absltest @@ -819,7 +821,7 @@ def g(rng): # TODO(mattjj): consider moving this method to be a helper in jtu def assert_dce_result(self, jaxpr: core.Jaxpr, used_outputs: list[bool], expected_used_inputs: list[bool], - expected_num_eqns: Optional[int] = None, + expected_num_eqns: int | None = None, check_diff: bool = True): jaxpr_dce, used_inputs = pe.dce_jaxpr(jaxpr, used_outputs) core.check_jaxpr(jaxpr_dce) @@ -1310,7 +1312,7 @@ class FunSpec(NamedTuple): num_inputs: int fun: Callable out_rep: Callable - valid_types: Optional[Callable] = None + valid_types: Callable | None = None fun_specs = [ FunSpec('id', 1, lambda x: x, lambda r: r), @@ -1476,8 +1478,8 @@ def dilate(mesh: Mesh, spec: P, shape: ShapeDtypeDuck) -> ShapeDtypeDuck: return jax.ShapeDtypeStruct(new_shape, shape.dtype) def make_out_specs( - mesh: MeshDuck, out_types: Union[ShapeDtypeDuck, Sequence[ShapeDtypeDuck]], - out_reps: Union[set[core.AxisName], Sequence[set[core.AxisName]]] + mesh: MeshDuck, out_types: ShapeDtypeDuck | Sequence[ShapeDtypeDuck], + out_reps: set[core.AxisName] | Sequence[set[core.AxisName]] ) -> Chooser: if type(out_types) is not tuple: out_spec = yield from make_out_spec(mesh, out_types, out_reps) # type: ignore @@ -1522,11 +1524,11 @@ def sample_shmap_batched(bdim_size: int) -> Chooser: return name + f'_vmap_{bdims}', bdims, *shmap_specs, batch_args, ref def all_bdims(*shapes: tuple[int, ...] - ) -> Iterator[Sequence[Optional[int]]]: + ) -> Iterator[Sequence[int | None]]: bdims = ((None, *range(len(shape) + 1)) for shape in shapes) return (t for t in it.product(*bdims) if not all(e is None for e in t)) -def batchify_arg(size: int, bdim: Optional[int], x: Arr) -> Arr: +def batchify_arg(size: int, bdim: int | None, x: Arr) -> Arr: if bdim is None: return x else: @@ -1534,7 +1536,7 @@ def batchify_arg(size: int, bdim: Optional[int], x: Arr) -> Arr: [1 if i != bdim else -1 for i in range(len(x.shape) + 1)]) return np.expand_dims(x, bdim) * iota -def args_slicer(args: Sequence[Arr], bdims: Sequence[Optional[int]] +def args_slicer(args: Sequence[Arr], bdims: Sequence[int | None] ) -> Callable[[int], Sequence[Arr]]: def slicer(x, bdim): if bdim is None: diff --git a/tests/state_test.py b/tests/state_test.py index 03e1f9d8bb5f..4570c7958ecf 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from collections.abc import Sequence from functools import partial import itertools as it -from typing import Any, Callable, NamedTuple, Optional, Union +from typing import Any, Callable, NamedTuple, Union from absl.testing import absltest from absl.testing import parameterized @@ -709,8 +711,8 @@ def index_params(draw): class VmappableIndexParam(NamedTuple): index_param: IndexParam - ref_bdim: Optional[int] - non_slice_idx_bdims: tuple[Optional[int], ...] + ref_bdim: int | None + non_slice_idx_bdims: tuple[int | None, ...] slice_bdim: int bat_ref_aval: shaped_array_ref bat_ref_shape: Shape @@ -719,7 +721,7 @@ class VmappableIndexParam(NamedTuple): bat_slice_aval: core.ShapedArray bat_slice_shape: Shape - def maybe_tuple_insert(t: tuple[Any, ...], idx: Optional[int], + def maybe_tuple_insert(t: tuple[Any, ...], idx: int | None, val: Any) -> tuple[Any, ...]: if idx is None: return t @@ -815,12 +817,12 @@ def set_vmap_params(draw): Indexer = tuple[Union[int, slice, np.ndarray]] def _unpack_idx(idx: Indexer - ) -> tuple[Sequence[Union[int, np.ndarray]], Sequence[bool]]: + ) -> tuple[Sequence[int | np.ndarray], Sequence[bool]]: indexed_dims = [type(i) != slice for i in idx] non_slice_idx = [i for i, b in zip(idx, indexed_dims) if b] return non_slice_idx, indexed_dims - def _pack_idx(non_slice_idx: Sequence[Union[int, np.ndarray]], + def _pack_idx(non_slice_idx: Sequence[int | np.ndarray], indexed_dims: Sequence[bool]) -> Indexer: idx_ = iter(non_slice_idx) idx = tuple(next(idx_) if b else slice(None) for b in indexed_dims) diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index 7706c3791aa3..cc2c21fc4981 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -579,7 +579,7 @@ def testFlattenDictKeyOrder(self): self.assertEqual(list(restored_d.keys()), ["a", "b", "c"]) def testFlattenDefaultDictKeyOrder(self): - d = collections.defaultdict(lambda: 0, + d = collections.defaultdict(int, {"b": 2, "a": 1, "c": {"b": 2, "a": 1}}) leaves, treedef = tree_util.tree_flatten(d) self.assertEqual(leaves, [1, 2, 1, 2]) diff --git a/tests/typing_test.py b/tests/typing_test.py index bb3eabe890f6..562c6c56d2d9 100644 --- a/tests/typing_test.py +++ b/tests/typing_test.py @@ -17,7 +17,10 @@ This test is meant to be both a runtime test and a static type annotation test, so it should be checked with pytype/mypy as well as being run with pytest. """ -from typing import Any, Optional, Union, TYPE_CHECKING + +from __future__ import annotations + +from typing import Any, TYPE_CHECKING import jax from jax._src import core @@ -100,7 +103,7 @@ def testArrayLike(self) -> None: self.assertArraysEqual(out9, jnp.float32(0)) def testArrayInstanceChecks(self): - def is_array(x: typing.ArrayLike) -> Union[bool, typing.Array]: + def is_array(x: typing.ArrayLike) -> bool | typing.Array: return isinstance(x, typing.Array) x = jnp.arange(5) @@ -114,7 +117,7 @@ def is_array(x: typing.ArrayLike) -> Union[bool, typing.Array]: def testAnnotations(self): # This test is mainly meant for static type checking: we want to ensure that # Tracer and ArrayImpl are valid as array.Array. - def f(x: Any) -> Optional[typing.Array]: + def f(x: Any) -> typing.Array | None: if isinstance(x, core.Tracer): return x elif isinstance(x, ArrayImpl): diff --git a/tests/xmap_test.py b/tests/xmap_test.py index 884b09084eaa..981a39d801ca 100644 --- a/tests/xmap_test.py +++ b/tests/xmap_test.py @@ -12,15 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from collections.abc import Generator, Iterator import functools import itertools as it +from itertools import product, permutations import math import os import re -from itertools import product, permutations -from typing import Union, Optional from unittest import SkipTest +from typing import Union import numpy as np from absl.testing import absltest @@ -1912,7 +1914,7 @@ def testVjpReduceAxesCollective(self): # lax.psum has the wrong transpose, so test with a corrected version for now @functools.partial(jax.custom_vjp, nondiff_argnums=(1,)) - def psum_idrev(x, axis_name: Optional[AxisNames] = None): + def psum_idrev(x, axis_name: AxisNames | None = None): if axis_name is None: return x return jax.lax.psum(x, axis_name)