Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jnp.linalg: add matrix_norm, matrix_transpose, vector_norm, vector_transpose #19005

Merged
merged 1 commit into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -457,8 +457,10 @@ jax.numpy.linalg
eigvalsh
inv
lstsq
matrix_norm
matrix_power
matrix_rank
matrix_transpose
multi_dot
norm
outer
Expand All @@ -469,6 +471,8 @@ jax.numpy.linalg
svd
tensorinv
tensorsolve
vector_norm
vecdot

JAX Array
---------
Expand Down
49 changes: 49 additions & 0 deletions jax/_src/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,3 +699,52 @@ def outer(x1: ArrayLike, x2: ArrayLike, /) -> Array:
if x1.ndim != 1 or x2.ndim != 1:
raise ValueError(f"Input arrays must be one-dimensional, but they are {x1.ndim=} {x2.ndim=}")
return x1[:, None] * x2[None, :]


@_wraps(getattr(np.linalg, "matrix_norm", None))
def matrix_norm(x: ArrayLike, /, *, keepdims: bool = False, ord: str = 'fro') -> Array:
"""
Computes the matrix norm of a matrix (or a stack of matrices) x.
"""
check_arraylike('jnp.linalg.matrix_norm', x)
return norm(x, ord=ord, keepdims=keepdims, axis=(-2, -1))


@_wraps(getattr(np.linalg, "matrix_transpose", None))
def matrix_transpose(x: ArrayLike, /) -> Array:
"""Transposes a matrix (or a stack of matrices) x."""
check_arraylike('jnp.linalg.matrix_transpose', x)
x_arr = jnp.asarray(x)
ndim = x_arr.ndim
if ndim < 2:
raise ValueError(f"matrix_transpose requres at least 2 dimensions; got {ndim=}")
return jax.lax.transpose(x_arr, (*range(ndim - 2), ndim - 1, ndim - 2))


@_wraps(getattr(np.linalg, "vector_norm", None))
def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = False,
ord: int | str = 2) -> Array:
"""Computes the vector norm of a vector (or batch of vectors) x."""
check_arraylike('jnp.linalg.vector_norm', x)
if axis is None:
result = norm(jnp.ravel(x), ord=ord)
if keepdims:
result = lax.expand_dims(result, range(jnp.ndim(x)))
return result
return norm(x, axis=axis, keepdims=keepdims, ord=ord)


@_wraps(getattr(np.linalg, "vecdot", None))
def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1) -> Array:
"""Computes the (vector) dot product of two arrays."""
check_arraylike("jnp.linalg.vecdot", x1, x2)
x1_arr, x2_arr = jnp.asarray(x1), jnp.asarray(x2)
rank = max(x1_arr.ndim, x2_arr.ndim)
x1_arr = jax.lax.broadcast_to_rank(x1_arr, rank)
x2_arr = jax.lax.broadcast_to_rank(x2_arr, rank)
if x1_arr.shape[axis] != x2_arr.shape[axis]:
raise ValueError("x1 and x2 must have the same size along specified axis.")
x1_arr = jax.numpy.moveaxis(x1_arr, axis, -1)
x2_arr = jax.numpy.moveaxis(x2_arr, axis, -1)
# TODO(jakevdp): call lax.dot_general directly
return jax.numpy.matmul(x1_arr[..., None, :], x2_arr[..., None])[..., 0, 0]
18 changes: 4 additions & 14 deletions jax/experimental/array_api/_linear_algebra_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def matrix_norm(x, /, *, keepdims=False, ord='fro'):
"""
Computes the matrix norm of a matrix (or a stack of matrices) x.
"""
return jax.numpy.linalg.norm(x, ord=ord, keepdims=keepdims, axis=(-1, -2))
return jax.numpy.linalg.matrix_norm(x, ord=ord, keepdims=keepdims)

def matrix_power(x, n, /):
"""
Expand All @@ -107,9 +107,7 @@ def matrix_rank(x, /, *, rtol=None):

def matrix_transpose(x, /):
"""Transposes a matrix (or a stack of matrices) x."""
if x.ndim < 2:
raise ValueError(f"matrix_transpose requres at least 2 dimensions; got {x.ndim=}")
return jax.lax.transpose(x, (*range(x.ndim - 2), x.ndim - 1, x.ndim - 2))
return jax.numpy.linalg.matrix_transpose(x)

def outer(x1, x2, /):
"""
Expand Down Expand Up @@ -177,16 +175,8 @@ def trace(x, /, *, offset=0, dtype=None):

def vecdot(x1, x2, /, *, axis=-1):
"""Computes the (vector) dot product of two arrays."""
rank = max(x1.ndim, x2.ndim)
x1 = jax.lax.broadcast_to_rank(x1, rank)
x2 = jax.lax.broadcast_to_rank(x2, rank)
if x1.shape[axis] != x2.shape[axis]:
raise ValueError("x1 and x2 must have the same size along specified axis.")
x1, x2 = jax.numpy.broadcast_arrays(x1, x2)
x1 = jax.numpy.moveaxis(x1, axis, -1)
x2 = jax.numpy.moveaxis(x2, axis, -1)
return jax.numpy.matmul(x1[..., None, :], x2[..., None])[..., 0, 0]
return jax.numpy.linalg.vecdot(x1, x2, axis=axis)

def vector_norm(x, /, *, axis=None, keepdims=False, ord=2):
"""Computes the vector norm of a vector (or batch of vectors) x."""
return jax.numpy.linalg.norm(x, axis=axis, keepdims=keepdims, ord=ord)
return jax.numpy.linalg.vector_norm(x, axis=axis, keepdims=keepdims, ord=ord)
1 change: 0 additions & 1 deletion jax/experimental/array_api/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
eigh as eigh,
eigvalsh as eigvalsh,
inv as inv,
jax as jax,
matmul as matmul,
matrix_norm as matrix_norm,
matrix_power as matrix_power,
Expand Down
4 changes: 4 additions & 0 deletions jax/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,19 @@
eigvalsh as eigvalsh,
inv as inv,
lstsq as lstsq,
matrix_norm as matrix_norm,
matrix_power as matrix_power,
matrix_rank as matrix_rank,
matrix_transpose as matrix_transpose,
norm as norm,
outer as outer,
pinv as pinv,
qr as qr,
slogdet as slogdet,
solve as solve,
svd as svd,
vector_norm as vector_norm,
vecdot as vecdot,
)
from jax._src.third_party.numpy.linalg import (
cond as cond,
Expand Down
1 change: 0 additions & 1 deletion tests/array_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@
'eigh',
'eigvalsh',
'inv',
'jax',
'matmul',
'matrix_norm',
'matrix_power',
Expand Down
79 changes: 79 additions & 0 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,70 @@ def testStringInfNorm(self):
with self.assertRaisesRegex(err, msg):
jnp.linalg.norm(jnp.array([1.0, 2.0, 3.0]), ord="inf")

@jtu.sample_product(
shape=[(2, 3), (4, 2, 3), (2, 3, 4, 5)],
dtype=float_types + complex_types,
keepdims=[True, False],
ord=[1, -1, 2, -2, np.inf, -np.inf, 'fro', 'nuc'],
)
def testMatrixNorm(self, shape, dtype, keepdims, ord):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
if jtu.numpy_version() < (2, 0, 0):
np_fn = partial(np.linalg.norm, ord=ord, keepdims=keepdims, axis=(-2, -1))
else:
np_fn = partial(np.linalg.matrix_norm, ord=ord, keepdims=keepdims)
jnp_fn = partial(jnp.linalg.matrix_norm, ord=ord, keepdims=keepdims)
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-3)
self._CompileAndCheck(jnp_fn, args_maker)

@jtu.sample_product(
shape=[(3,), (3, 4), (2, 3, 4, 5)],
dtype=float_types + complex_types,
keepdims=[True, False],
axis=[0, None],
ord=[1, -1, 2, -2, np.inf, -np.inf],
)
def testVectorNorm(self, shape, dtype, keepdims, axis, ord):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
if jtu.numpy_version() < (2, 0, 0):
def np_fn(x, *, ord, keepdims, axis):
x = np.asarray(x)
if axis is None:
result = np_fn(x.ravel(), ord=ord, keepdims=False, axis=0)
return np.reshape(result, (1,) * x.ndim) if keepdims else result
return np.linalg.norm(x, ord=ord, keepdims=keepdims, axis=axis)
else:
np_fn = np.linalg.vector_norm
np_fn = partial(np_fn, ord=ord, keepdims=keepdims, axis=axis)
jnp_fn = partial(jnp.linalg.vector_norm, ord=ord, keepdims=keepdims, axis=axis)
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-3)
self._CompileAndCheck(jnp_fn, args_maker)

@jtu.sample_product(
[dict(lhs_shape=(2, 3, 4), rhs_shape=(1, 4), axis=-1),
dict(lhs_shape=(2, 3, 4), rhs_shape=(2, 1, 1), axis=0),
dict(lhs_shape=(2, 3, 4), rhs_shape=(3, 4), axis=1)],
dtype=float_types + complex_types,
)
def testVecDot(self, lhs_shape, rhs_shape, axis, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
if jtu.numpy_version() < (2, 0, 0):
def np_fn(x, y, axis=axis):
x, y = np.broadcast_arrays(x, y)
x = np.moveaxis(x, axis, -1)
y = np.moveaxis(y, axis, -1)
return np.matmul(x[..., None, :], y[..., None])[..., 0, 0]
else:
np_fn = partial(np.linalg.vecdot, axis=axis)
np_fn = jtu.promote_like_jnp(np_fn, inexact=True)
jnp_fn = partial(jnp.linalg.vecdot, axis=axis)
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-3)
self._CompileAndCheck(jnp_fn, args_maker)


@jtu.sample_product(
[
dict(m=m, n=n, full_matrices=full_matrices, hermitian=hermitian)
Expand Down Expand Up @@ -1905,6 +1969,21 @@ def testSchurBatching(self, shape, dtype):
Ts, Ss = vmap(lax.linalg.schur)(args)
self.assertAllClose(reconstruct(Ss, Ts), args, atol=1e-4)

@jtu.sample_product(
shape=[(2, 3), (2, 3, 4), (2, 3, 4, 5)],
dtype=float_types + complex_types,
)
def testMatrixTranspose(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
jnp_fun = jnp.linalg.matrix_transpose
if jtu.numpy_version() < (2, 0, 0):
np_fun = lambda x: np.swapaxes(x, -1, -2)
else:
np_fun = np.linalg.matrix_transpose
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())