Skip to content

Commit

Permalink
jnp.linalg: add matrix_norm, matrix_transpose, vector_norm, vector_tr…
Browse files Browse the repository at this point in the history
…anspose

These have been added upstream to numpy.linalg in NumPy 2.0, as part of the Array API standard.
  • Loading branch information
jakevdp committed Dec 15, 2023
1 parent 9462aec commit 3d343c5
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 16 deletions.
4 changes: 4 additions & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -457,8 +457,10 @@ jax.numpy.linalg
eigvalsh
inv
lstsq
matrix_norm
matrix_power
matrix_rank
matrix_transpose
multi_dot
norm
outer
Expand All @@ -469,6 +471,8 @@ jax.numpy.linalg
svd
tensorinv
tensorsolve
vector_norm
vecdot

JAX Array
---------
Expand Down
49 changes: 49 additions & 0 deletions jax/_src/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,3 +699,52 @@ def outer(x1: ArrayLike, x2: ArrayLike, /) -> Array:
if x1.ndim != 1 or x2.ndim != 1:
raise ValueError(f"Input arrays must be one-dimensional, but they are {x1.ndim=} {x2.ndim=}")
return x1[:, None] * x2[None, :]


@_wraps(getattr(np.linalg, "matrix_norm", None))
def matrix_norm(x: ArrayLike, /, *, keepdims: bool = False, ord: str = 'fro') -> Array:
"""
Computes the matrix norm of a matrix (or a stack of matrices) x.
"""
check_arraylike('jnp.linalg.matrix_norm', x)
return norm(x, ord=ord, keepdims=keepdims, axis=(-2, -1))


@_wraps(getattr(np.linalg, "matrix_transpose", None))
def matrix_transpose(x: ArrayLike, /) -> Array:
"""Transposes a matrix (or a stack of matrices) x."""
check_arraylike('jnp.linalg.matrix_transpose', x)
x_arr = jnp.asarray(x)
ndim = x_arr.ndim
if ndim < 2:
raise ValueError(f"matrix_transpose requres at least 2 dimensions; got {ndim=}")
return jax.lax.transpose(x_arr, (*range(ndim - 2), ndim - 1, ndim - 2))


@_wraps(getattr(np.linalg, "vector_norm", None))
def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = False,
ord: int | str = 2) -> Array:
"""Computes the vector norm of a vector (or batch of vectors) x."""
check_arraylike('jnp.linalg.vector_norm', x)
if axis is None:
result = jax.numpy.linalg.norm(lax.ravel(x), ord=ord)
if keepdims:
result = lax.expand_dims(result, range(jnp.ndim(x)))
return result
return jax.numpy.linalg.norm(x, axis=axis, keepdims=keepdims, ord=ord)


@_wraps(getattr(np.linalg, "vecdot", None))
def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1) -> Array:
"""Computes the (vector) dot product of two arrays."""
check_arraylike("jnp.linalg.vecdot", x1, x2)
x1_arr, x2_arr = jnp.asarray(x1), jnp.asarray(x2)
rank = max(x1_arr.ndim, x2_arr.ndim)
x1_arr = jax.lax.broadcast_to_rank(x1_arr, rank)
x2_arr = jax.lax.broadcast_to_rank(x2_arr, rank)
if x1_arr.shape[axis] != x2_arr.shape[axis]:
raise ValueError("x1 and x2 must have the same size along specified axis.")
x1_arr = jax.numpy.moveaxis(x1_arr, axis, -1)
x2_arr = jax.numpy.moveaxis(x2_arr, axis, -1)
# TODO(jakevdp): call lax.dot_general directly
return jax.numpy.matmul(x1_arr[..., None, :], x2_arr[..., None])[..., 0, 0]
18 changes: 4 additions & 14 deletions jax/experimental/array_api/_linear_algebra_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def matrix_norm(x, /, *, keepdims=False, ord='fro'):
"""
Computes the matrix norm of a matrix (or a stack of matrices) x.
"""
return jax.numpy.linalg.norm(x, ord=ord, keepdims=keepdims, axis=(-1, -2))
return jax.numpy.linalg.matrix_norm(x, ord=ord, keepdims=keepdims)

def matrix_power(x, n, /):
"""
Expand All @@ -107,9 +107,7 @@ def matrix_rank(x, /, *, rtol=None):

def matrix_transpose(x, /):
"""Transposes a matrix (or a stack of matrices) x."""
if x.ndim < 2:
raise ValueError(f"matrix_transpose requres at least 2 dimensions; got {x.ndim=}")
return jax.lax.transpose(x, (*range(x.ndim - 2), x.ndim - 1, x.ndim - 2))
return jax.numpy.linalg.matrix_transpose(x)

def outer(x1, x2, /):
"""
Expand Down Expand Up @@ -177,16 +175,8 @@ def trace(x, /, *, offset=0, dtype=None):

def vecdot(x1, x2, /, *, axis=-1):
"""Computes the (vector) dot product of two arrays."""
rank = max(x1.ndim, x2.ndim)
x1 = jax.lax.broadcast_to_rank(x1, rank)
x2 = jax.lax.broadcast_to_rank(x2, rank)
if x1.shape[axis] != x2.shape[axis]:
raise ValueError("x1 and x2 must have the same size along specified axis.")
x1, x2 = jax.numpy.broadcast_arrays(x1, x2)
x1 = jax.numpy.moveaxis(x1, axis, -1)
x2 = jax.numpy.moveaxis(x2, axis, -1)
return jax.numpy.matmul(x1[..., None, :], x2[..., None])[..., 0, 0]
return jax.numpy.linalg.vecdot(x1, x2, axis=axis)

def vector_norm(x, /, *, axis=None, keepdims=False, ord=2):
"""Computes the vector norm of a vector (or batch of vectors) x."""
return jax.numpy.linalg.norm(x, axis=axis, keepdims=keepdims, ord=ord)
return jax.numpy.linalg.vector_norm(x, axis=axis, keepdims=keepdims, ord=ord)
1 change: 0 additions & 1 deletion jax/experimental/array_api/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
eigh as eigh,
eigvalsh as eigvalsh,
inv as inv,
jax as jax,
matmul as matmul,
matrix_norm as matrix_norm,
matrix_power as matrix_power,
Expand Down
4 changes: 4 additions & 0 deletions jax/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,19 @@
eigvalsh as eigvalsh,
inv as inv,
lstsq as lstsq,
matrix_norm as matrix_norm,
matrix_power as matrix_power,
matrix_rank as matrix_rank,
matrix_transpose as matrix_transpose,
norm as norm,
outer as outer,
pinv as pinv,
qr as qr,
slogdet as slogdet,
solve as solve,
svd as svd,
vector_norm as vector_norm,
vecdot as vecdot,
)
from jax._src.third_party.numpy.linalg import (
cond as cond,
Expand Down
1 change: 0 additions & 1 deletion tests/array_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@
'eigh',
'eigvalsh',
'inv',
'jax',
'matmul',
'matrix_norm',
'matrix_power',
Expand Down
79 changes: 79 additions & 0 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,70 @@ def testStringInfNorm(self):
with self.assertRaisesRegex(err, msg):
jnp.linalg.norm(jnp.array([1.0, 2.0, 3.0]), ord="inf")

@jtu.sample_product(
shape=[(2, 3), (4, 2, 3), (2, 3, 4, 5)],
dtype=float_types + complex_types,
keepdims=[True, False],
ord=[1, -1, 2, -2, np.inf, -np.inf, 'fro', 'nuc'],
)
def testMatrixNorm(self, shape, dtype, keepdims, ord):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
if jtu.numpy_version() < (2, 0, 0):
np_fn = partial(np.linalg.norm, ord=ord, keepdims=keepdims, axis=(-2, -1))
else:
np_fn = partial(np.linalg.matrix_norm, ord=ord, keepdims=keepdims)
np_fn = jtu.promote_like_jnp(np_fn, inexact=True)
jnp_fn = partial(jnp.linalg.matrix_norm, ord=ord, keepdims=keepdims)
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-3)
self._CompileAndCheck(jnp_fn, args_maker)

@jtu.sample_product(
shape=[(3,), (3, 4), (2, 3, 4, 5)],
dtype=float_types + complex_types,
keepdims=[True, False],
axis=[0, None],
ord=[1, -1, 2, -2, np.inf, -np.inf],
)
def testVectorNorm(self, shape, dtype, keepdims, axis, ord):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
if jtu.numpy_version() < (2, 0, 0):
def np_fn(x, ord=ord, keepdims=keepdims, axis=axis):
if axis is None:
result = np_fn(x.ravel(), ord, keepdims=False, axis=0)
return np.reshape(result, (1,) * x.ndim) if keepdims else result
return np.linalg.norm(x, ord=ord, keepdims=keepdims, axis=axis)
else:
np_fn = partial(np.linalg.vector_norm, ord=ord, keepdims=keepdims, axis=axis)
np_fn = jtu.promote_like_jnp(np_fn, inexact=True)
jnp_fn = partial(jnp.linalg.vector_norm, ord=ord, keepdims=keepdims, axis=axis)
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-3)
self._CompileAndCheck(jnp_fn, args_maker)

@jtu.sample_product(
[dict(lhs_shape=(2, 3, 4), rhs_shape=(1, 4), axis=-1),
dict(lhs_shape=(2, 3, 4), rhs_shape=(2, 1, 1), axis=0),
dict(lhs_shape=(2, 3, 4), rhs_shape=(3, 4), axis=1)],
dtype=float_types + complex_types,
)
def testVecDot(self, lhs_shape, rhs_shape, axis, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
if jtu.numpy_version() < (2, 0, 0):
def np_fn(x, y, axis=axis):
x, y = np.broadcast_arrays(x, y)
x = np.moveaxis(x, axis, -1)
y = np.moveaxis(y, axis, -1)
return np.matmul(x[..., None, :], y[..., None])[..., 0, 0]
else:
np_fn = partial(np.linalg.vecdot, axis=axis)
np_fn = jtu.promote_like_jnp(np_fn, inexact=True)
jnp_fn = partial(jnp.linalg.vecdot, axis=axis)
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-3)
self._CompileAndCheck(jnp_fn, args_maker)


@jtu.sample_product(
[
dict(m=m, n=n, full_matrices=full_matrices, hermitian=hermitian)
Expand Down Expand Up @@ -1905,6 +1969,21 @@ def testSchurBatching(self, shape, dtype):
Ts, Ss = vmap(lax.linalg.schur)(args)
self.assertAllClose(reconstruct(Ss, Ts), args, atol=1e-4)

@jtu.sample_product(
shape=[(2, 3), (2, 3, 4), (2, 3, 4, 5)],
dtype=float_types + complex_types,
)
def testMatrixTranspose(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
jnp_fun = jnp.linalg.matrix_transpose
if jtu.numpy_version() < (2, 0, 0):
np_fun = lambda x: np.swapaxes(x, -1, -2)
else:
np_fun = np.linalg.matrix_transpose
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 3d343c5

Please sign in to comment.