Skip to content

Commit

Permalink
Use the current minimum jaxlib version for type checking on the CI
Browse files Browse the repository at this point in the history
  • Loading branch information
superbobry committed Oct 7, 2024
1 parent 4cb3a6d commit 20c7310
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ repos:
- id: mypy
files: (jax/|tests/typing_test\.py)
exclude: jax/_src/basearray.py|jax/numpy/__init__.py # Use pyi instead
additional_dependencies: [types-requests==2.31.0, jaxlib==0.4.31, ml_dtypes==0.3.2, numpy==1.26.3, scipy==1.11.4]
additional_dependencies: [types-requests==2.31.0, jaxlib==0.4.34, ml_dtypes==0.3.2, numpy==1.26.3, scipy==1.11.4]
args: [--config=pyproject.toml]

- repo: https://github.com/mwouts/jupytext
Expand Down
3 changes: 1 addition & 2 deletions jax/_src/pallas/mosaic/error_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,7 @@ def parse_location_string(location_string: str) -> tuple[str, list[RawFrame]]:
def traceback_from_raw_frames(frames: list[RawFrame]) -> types.TracebackType:
"""Constructs a traceback from a list of RawFrame objects."""
xla_frames = [
xla_client.Frame(frame.filename, frame.func_name, -1, frame.lineno
) # type: ignore [call-arg]
xla_client.Frame(frame.filename, frame.func_name, -1, frame.lineno)
for frame in frames
]
return xla_client.Traceback.traceback_from_frames(xla_frames)
2 changes: 1 addition & 1 deletion jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def cache_miss(*args, **kwargs):
use_resource_env=jit_info.use_resource_env)
cpp_pjit_f = xc._xla.pjit(
fun_name(fun), fun, cache_miss, jit_info.static_argnums,
jit_info.static_argnames, cache_key, tree_util.dispatch_registry, # type: ignore
jit_info.static_argnames, cache_key, tree_util.dispatch_registry,
pxla.cc_shard_arg,
_get_cpp_global_cache(cache_key.contains_explicit_attributes))

Expand Down

0 comments on commit 20c7310

Please sign in to comment.