Skip to content

Commit

Permalink
Merge pull request #111 from adtzlr/fix-eigh
Browse files Browse the repository at this point in the history
Fix missing `eigh` dual values
  • Loading branch information
adtzlr committed May 21, 2024
2 parents e6fe648 + 5902c8c commit c944b86
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/tensortrax/__about__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
tensorTRAX: Math on (Hyper-Dual) Tensors with Trailing Axes.
"""

__version__ = "0.22.0"
__version__ = "0.23.0"
34 changes: 31 additions & 3 deletions src/tensortrax/math/linalg/_linalg_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,27 +126,55 @@ def eigh(A, eps=np.sqrt(np.finfo(float).eps)):

# alpha = [0, 1, 2]
# beta = [(1, 2), (2, 0), (0, 1)]

alpha = np.arange(dim)
beta = [
np.concatenate([np.arange(a + 1, dim), np.arange(a)]) for a in np.arange(dim)
]

δN = []
ΔN = []
for α in alpha:
δNα = []
ΔNα = []
for β in beta[α]:
Mαβ = einsum("i...,j...->ij...", N[α], N[β])
δAαβ = einsum("ij...,ij...->...", Mαβ, δ(A))
ΔAαβ = einsum("ij...,ij...->...", Mαβ, Δ(A))
λαβ = λ[α] - λ[β]
δNα.append(1 / λαβ * N[β] * δAαβ)
ΔNα.append(1 / λαβ * N[β] * ΔAαβ)
δN.append(sum(δNα, axis=0))
ΔN.append(sum(ΔNα, axis=0))

ΔδN = []
for α in alpha:
ΔδNα = []
for β in beta[α]:
Mαβ = einsum("i...,j...->ij...", N[α], N[β])
δAαβ = einsum("ij...,ij...->...", Mαβ, δ(A))
ΔδAαβ = einsum("ij...,ij...->...", Mαβ, Δδ(A))
λαβ = λ[α] - λ[β]
Δλαβ = Δλ[α] - Δλ[β]
ΔδNα.append(
-(λαβ**-2) * Δλαβ * N[β] * δAαβ
+ 1 / λαβ * ΔN[β] * δAαβ
+ 1 / λαβ * N[β] * ΔδAαβ
)
ΔδN.append(sum(ΔδNα, axis=0))

δM = einsum("ai...,aj...->aij...", δN, N) + einsum("ai...,aj...->aij...", N, δN)
ΔM = einsum("ai...,aj...->aij...", ΔN, N) + einsum("ai...,aj...->aij...", N, ΔN)
Δδλ = einsum("aij...,ij...->a...", δM, Δ(A)) + einsum(
"aij...,ij...->a...", M, Δδ(A)
)

ΔδM = (
einsum("ai...,aj...->aij...", δN, ΔN)
+ einsum("ai...,aj...->aij...", ΔN, δN)
+ einsum("ai...,aj...->aij...", ΔδN, N)
+ einsum("ai...,aj...->aij...", N, ΔδN)
)

return (
Tensor(
x=λ,
Expand All @@ -158,8 +186,8 @@ def eigh(A, eps=np.sqrt(np.finfo(float).eps)):
Tensor(
x=M,
δx=δM,
Δx=δM * np.nan,
Δδx=δM * np.nan,
Δx=ΔM,
Δδx=ΔδM,
ntrax=A.ntrax,
),
)
Expand Down

0 comments on commit c944b86

Please sign in to comment.