diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c89aa934d95d..1e2127b310ed 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -36,7 +36,7 @@ repos: - id: mypy files: (jax/|tests/typing_test\.py) exclude: jax/_src/basearray.py|jax/numpy/__init__.py # Use pyi instead - additional_dependencies: [types-requests==2.31.0, jaxlib==0.4.31, ml_dtypes==0.3.2, numpy==1.26.3, scipy==1.11.4] + additional_dependencies: [types-requests==2.31.0, jaxlib] args: [--config=pyproject.toml] - repo: https://github.com/mwouts/jupytext diff --git a/jax/_src/pallas/mosaic/error_handling.py b/jax/_src/pallas/mosaic/error_handling.py index f8231f5b24b6..fcc304055453 100644 --- a/jax/_src/pallas/mosaic/error_handling.py +++ b/jax/_src/pallas/mosaic/error_handling.py @@ -151,8 +151,7 @@ def parse_location_string(location_string: str) -> tuple[str, list[RawFrame]]: def traceback_from_raw_frames(frames: list[RawFrame]) -> types.TracebackType: """Constructs a traceback from a list of RawFrame objects.""" xla_frames = [ - xla_client.Frame(frame.filename, frame.func_name, -1, frame.lineno - ) # type: ignore [call-arg] + xla_client.Frame(frame.filename, frame.func_name, -1, frame.lineno) for frame in frames ] return xla_client.Traceback.traceback_from_frames(xla_frames) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 6b0f71579d99..f1baf48f6857 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -361,7 +361,7 @@ def cache_miss(*args, **kwargs): use_resource_env=jit_info.use_resource_env) cpp_pjit_f = xc._xla.pjit( fun_name(fun), fun, cache_miss, jit_info.static_argnums, - jit_info.static_argnames, cache_key, tree_util.dispatch_registry, # type: ignore + jit_info.static_argnames, cache_key, tree_util.dispatch_registry, pxla.cc_shard_arg, _get_cpp_global_cache(cache_key.contains_explicit_attributes))