diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index 12e795ecf582..1de86ca84c19 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -457,8 +457,10 @@ jax.numpy.linalg eigvalsh inv lstsq + matrix_norm matrix_power matrix_rank + matrix_transpose multi_dot norm outer @@ -469,6 +471,8 @@ jax.numpy.linalg svd tensorinv tensorsolve + vector_norm + vecdot JAX Array --------- diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index 4e4d114f1543..850ffb8377b1 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -699,3 +699,52 @@ def outer(x1: ArrayLike, x2: ArrayLike, /) -> Array: if x1.ndim != 1 or x2.ndim != 1: raise ValueError(f"Input arrays must be one-dimensional, but they are {x1.ndim=} {x2.ndim=}") return x1[:, None] * x2[None, :] + + +@_wraps(getattr(np.linalg, "matrix_norm", None)) +def matrix_norm(x: ArrayLike, /, *, keepdims: bool = False, ord: str = 'fro') -> Array: + """ + Computes the matrix norm of a matrix (or a stack of matrices) x. + """ + check_arraylike('jnp.linalg.matrix_norm', x) + return norm(x, ord=ord, keepdims=keepdims, axis=(-2, -1)) + + +@_wraps(getattr(np.linalg, "matrix_transpose", None)) +def matrix_transpose(x: ArrayLike, /) -> Array: + """Transposes a matrix (or a stack of matrices) x.""" + check_arraylike('jnp.linalg.matrix_transpose', x) + x_arr = jnp.asarray(x) + ndim = x_arr.ndim + if ndim < 2: + raise ValueError(f"matrix_transpose requres at least 2 dimensions; got {ndim=}") + return jax.lax.transpose(x_arr, (*range(ndim - 2), ndim - 1, ndim - 2)) + + +@_wraps(getattr(np.linalg, "vector_norm", None)) +def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = False, + ord: int | str = 2) -> Array: + """Computes the vector norm of a vector (or batch of vectors) x.""" + check_arraylike('jnp.linalg.vector_norm', x) + if axis is None: + result = norm(jnp.ravel(x), ord=ord) + if keepdims: + result = lax.expand_dims(result, range(jnp.ndim(x))) + return result + return norm(x, axis=axis, keepdims=keepdims, ord=ord) + + +@_wraps(getattr(np.linalg, "vecdot", None)) +def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1) -> Array: + """Computes the (vector) dot product of two arrays.""" + check_arraylike("jnp.linalg.vecdot", x1, x2) + x1_arr, x2_arr = jnp.asarray(x1), jnp.asarray(x2) + rank = max(x1_arr.ndim, x2_arr.ndim) + x1_arr = jax.lax.broadcast_to_rank(x1_arr, rank) + x2_arr = jax.lax.broadcast_to_rank(x2_arr, rank) + if x1_arr.shape[axis] != x2_arr.shape[axis]: + raise ValueError("x1 and x2 must have the same size along specified axis.") + x1_arr = jax.numpy.moveaxis(x1_arr, axis, -1) + x2_arr = jax.numpy.moveaxis(x2_arr, axis, -1) + # TODO(jakevdp): call lax.dot_general directly + return jax.numpy.matmul(x1_arr[..., None, :], x2_arr[..., None])[..., 0, 0] diff --git a/jax/experimental/array_api/_linear_algebra_functions.py b/jax/experimental/array_api/_linear_algebra_functions.py index 154d963a765d..1a144773f44a 100644 --- a/jax/experimental/array_api/_linear_algebra_functions.py +++ b/jax/experimental/array_api/_linear_algebra_functions.py @@ -91,7 +91,7 @@ def matrix_norm(x, /, *, keepdims=False, ord='fro'): """ Computes the matrix norm of a matrix (or a stack of matrices) x. """ - return jax.numpy.linalg.norm(x, ord=ord, keepdims=keepdims, axis=(-1, -2)) + return jax.numpy.linalg.matrix_norm(x, ord=ord, keepdims=keepdims) def matrix_power(x, n, /): """ @@ -107,9 +107,7 @@ def matrix_rank(x, /, *, rtol=None): def matrix_transpose(x, /): """Transposes a matrix (or a stack of matrices) x.""" - if x.ndim < 2: - raise ValueError(f"matrix_transpose requres at least 2 dimensions; got {x.ndim=}") - return jax.lax.transpose(x, (*range(x.ndim - 2), x.ndim - 1, x.ndim - 2)) + return jax.numpy.linalg.matrix_transpose(x) def outer(x1, x2, /): """ @@ -177,16 +175,8 @@ def trace(x, /, *, offset=0, dtype=None): def vecdot(x1, x2, /, *, axis=-1): """Computes the (vector) dot product of two arrays.""" - rank = max(x1.ndim, x2.ndim) - x1 = jax.lax.broadcast_to_rank(x1, rank) - x2 = jax.lax.broadcast_to_rank(x2, rank) - if x1.shape[axis] != x2.shape[axis]: - raise ValueError("x1 and x2 must have the same size along specified axis.") - x1, x2 = jax.numpy.broadcast_arrays(x1, x2) - x1 = jax.numpy.moveaxis(x1, axis, -1) - x2 = jax.numpy.moveaxis(x2, axis, -1) - return jax.numpy.matmul(x1[..., None, :], x2[..., None])[..., 0, 0] + return jax.numpy.linalg.vecdot(x1, x2, axis=axis) def vector_norm(x, /, *, axis=None, keepdims=False, ord=2): """Computes the vector norm of a vector (or batch of vectors) x.""" - return jax.numpy.linalg.norm(x, axis=axis, keepdims=keepdims, ord=ord) + return jax.numpy.linalg.vector_norm(x, axis=axis, keepdims=keepdims, ord=ord) diff --git a/jax/experimental/array_api/linalg.py b/jax/experimental/array_api/linalg.py index 30b531f502bc..49c93c5b1908 100644 --- a/jax/experimental/array_api/linalg.py +++ b/jax/experimental/array_api/linalg.py @@ -20,7 +20,6 @@ eigh as eigh, eigvalsh as eigvalsh, inv as inv, - jax as jax, matmul as matmul, matrix_norm as matrix_norm, matrix_power as matrix_power, diff --git a/jax/numpy/linalg.py b/jax/numpy/linalg.py index 42536822ce5a..a5ee94c8a075 100644 --- a/jax/numpy/linalg.py +++ b/jax/numpy/linalg.py @@ -25,8 +25,10 @@ eigvalsh as eigvalsh, inv as inv, lstsq as lstsq, + matrix_norm as matrix_norm, matrix_power as matrix_power, matrix_rank as matrix_rank, + matrix_transpose as matrix_transpose, norm as norm, outer as outer, pinv as pinv, @@ -34,6 +36,8 @@ slogdet as slogdet, solve as solve, svd as svd, + vector_norm as vector_norm, + vecdot as vecdot, ) from jax._src.third_party.numpy.linalg import ( cond as cond, diff --git a/tests/array_api_test.py b/tests/array_api_test.py index 97fc682398e4..0d4893e4939e 100644 --- a/tests/array_api_test.py +++ b/tests/array_api_test.py @@ -175,7 +175,6 @@ 'eigh', 'eigvalsh', 'inv', - 'jax', 'matmul', 'matrix_norm', 'matrix_power', diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 1f73471a307a..948349cfe035 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -612,6 +612,70 @@ def testStringInfNorm(self): with self.assertRaisesRegex(err, msg): jnp.linalg.norm(jnp.array([1.0, 2.0, 3.0]), ord="inf") + @jtu.sample_product( + shape=[(2, 3), (4, 2, 3), (2, 3, 4, 5)], + dtype=float_types + complex_types, + keepdims=[True, False], + ord=[1, -1, 2, -2, np.inf, -np.inf, 'fro', 'nuc'], + ) + def testMatrixNorm(self, shape, dtype, keepdims, ord): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + if jtu.numpy_version() < (2, 0, 0): + np_fn = partial(np.linalg.norm, ord=ord, keepdims=keepdims, axis=(-2, -1)) + else: + np_fn = partial(np.linalg.matrix_norm, ord=ord, keepdims=keepdims) + np_fn = jtu.promote_like_jnp(np_fn, inexact=True) + jnp_fn = partial(jnp.linalg.matrix_norm, ord=ord, keepdims=keepdims) + self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-3) + self._CompileAndCheck(jnp_fn, args_maker) + + @jtu.sample_product( + shape=[(3,), (3, 4), (2, 3, 4, 5)], + dtype=float_types + complex_types, + keepdims=[True, False], + axis=[0, None], + ord=[1, -1, 2, -2, np.inf, -np.inf], + ) + def testVectorNorm(self, shape, dtype, keepdims, axis, ord): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + if jtu.numpy_version() < (2, 0, 0): + def np_fn(x, ord=ord, keepdims=keepdims, axis=axis): + if axis is None: + result = np_fn(x.ravel(), ord, keepdims=False, axis=0) + return np.reshape(result, (1,) * x.ndim) if keepdims else result + return np.linalg.norm(x, ord=ord, keepdims=keepdims, axis=axis) + else: + np_fn = partial(np.linalg.vector_norm, ord=ord, keepdims=keepdims, axis=axis) + np_fn = jtu.promote_like_jnp(np_fn, inexact=True) + jnp_fn = partial(jnp.linalg.vector_norm, ord=ord, keepdims=keepdims, axis=axis) + self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-3) + self._CompileAndCheck(jnp_fn, args_maker) + + @jtu.sample_product( + [dict(lhs_shape=(2, 3, 4), rhs_shape=(1, 4), axis=-1), + dict(lhs_shape=(2, 3, 4), rhs_shape=(2, 1, 1), axis=0), + dict(lhs_shape=(2, 3, 4), rhs_shape=(3, 4), axis=1)], + dtype=float_types + complex_types, + ) + def testVecDot(self, lhs_shape, rhs_shape, axis, dtype): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] + if jtu.numpy_version() < (2, 0, 0): + def np_fn(x, y, axis=axis): + x, y = np.broadcast_arrays(x, y) + x = np.moveaxis(x, axis, -1) + y = np.moveaxis(y, axis, -1) + return np.matmul(x[..., None, :], y[..., None])[..., 0, 0] + else: + np_fn = partial(np.linalg.vecdot, axis=axis) + np_fn = jtu.promote_like_jnp(np_fn, inexact=True) + jnp_fn = partial(jnp.linalg.vecdot, axis=axis) + self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-3) + self._CompileAndCheck(jnp_fn, args_maker) + + @jtu.sample_product( [ dict(m=m, n=n, full_matrices=full_matrices, hermitian=hermitian) @@ -1905,6 +1969,21 @@ def testSchurBatching(self, shape, dtype): Ts, Ss = vmap(lax.linalg.schur)(args) self.assertAllClose(reconstruct(Ss, Ts), args, atol=1e-4) + @jtu.sample_product( + shape=[(2, 3), (2, 3, 4), (2, 3, 4, 5)], + dtype=float_types + complex_types, + ) + def testMatrixTranspose(self, shape, dtype): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + jnp_fun = jnp.linalg.matrix_transpose + if jtu.numpy_version() < (2, 0, 0): + np_fun = lambda x: np.swapaxes(x, -1, -2) + else: + np_fun = np.linalg.matrix_transpose + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())