Skip to content

Commit

Permalink
Merge pull request #110 from adtzlr/add-sqrtm
Browse files Browse the repository at this point in the history
Add matrix square root `math.linalg.sqrtm()`
  • Loading branch information
adtzlr committed May 18, 2024
2 parents 15d12f2 + 7a8d7fe commit c8e1ef6
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/tensortrax/__about__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
tensorTRAX: Math on (Hyper-Dual) Tensors with Trailing Axes.
"""

__version__ = "0.21.3"
__version__ = "0.22.0"
15 changes: 13 additions & 2 deletions src/tensortrax/math/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,17 @@
from ._linalg_array import det as _det
from ._linalg_array import inv as _inv
from ._linalg_array import pinv as _pinv
from ._linalg_tensor import det, eigh, eigvalsh, expm, inv, pinv
from ._linalg_tensor import det, eigh, eigvalsh, expm, inv, pinv, sqrtm

__all__ = ["_det", "_inv", "_pinv", "det", "eigh", "eigvalsh", "expm", "inv", "pinv"]
__all__ = [
"_det",
"_inv",
"_pinv",
"det",
"eigh",
"eigvalsh",
"expm",
"inv",
"pinv",
"sqrtm",
]
8 changes: 7 additions & 1 deletion src/tensortrax/math/linalg/_linalg_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np

from ..._tensor import Tensor, Δ, Δδ, einsum, f, matmul, δ
from .._math_tensor import exp, sum, transpose
from .._math_tensor import exp, sqrt, sum, transpose
from ..special._special_tensor import ddot
from . import _linalg_array as linalg

Expand Down Expand Up @@ -169,3 +169,9 @@ def expm(A):
"Compute the matrix exponential of a symmetric array."
λ, M = eigh(A)
return einsum("a...,aij...->ij...", exp(λ), M)


def sqrtm(A):
"Compute the matrix square root of a symmetric array."
λ, M = eigh(A)
return einsum("a...,aij...->ij...", sqrt(λ), M)
1 change: 1 addition & 0 deletions tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def test_math():
assert tm.linalg.eigvalsh(T).shape == (3,)

assert tm.linalg.expm(T).shape == (3, 3)
assert tm.linalg.sqrtm(T).shape == (3, 3)

assert tm.base.cross(F, F).shape == F.shape
assert tm.base.eye(F).shape == F.shape
Expand Down

0 comments on commit c8e1ef6

Please sign in to comment.