Skip to content

Commit

Permalink
feat: faster rotation decomposition + refactor axis angle functions (#…
Browse files Browse the repository at this point in the history
…182)

* feat: faster rotation decomposition + refactor axis angle functions

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
lgrcia and pre-commit-ci[bot] committed Jun 14, 2024
1 parent 3254661 commit 50a673b
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 11 deletions.
97 changes: 90 additions & 7 deletions src/jaxoplanet/experimental/starry/rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import jax
import jax.numpy as jnp

from jaxoplanet.types import Array
from jaxoplanet.utils import get_dtype_eps


Expand Down Expand Up @@ -67,7 +68,19 @@ def do_dot(M):
return do_dot


def right_project_axis_angle(inc, obl, theta, theta_z):
def full_rotation_axis_angle(inc: float, obl: float, theta: float, theta_z: float):
"""Return the axis-angle representation of the full rotation of the
spherical harmonic map
Args:
inc (float): map inclination
obl (float): map obliquity
theta (float): rotation angle about the map z-axis
theta_z (float): rotation angle about the sky y-axis
Returns:
tuple: x, y, z, angle
"""
f = 0.5 * math.sqrt(2)
si = jnp.sin(inc / 2)
ci = jnp.cos(inc / 2)
Expand All @@ -94,14 +107,84 @@ def right_project_axis_angle(inc, obl, theta, theta_z):
return axis_x, axis_y, axis_z, angle


def right_project(ydeg, inc, obl, theta, theta_z, x):
axis_x, axis_y, axis_z, angle = right_project_axis_angle(inc, obl, theta, theta_z)
return dot_rotation_matrix(ydeg, axis_x, axis_y, axis_z, angle)(x)
def sky_projection_axis_angle(inc: float, obl: float):
"""Return the axis-angle representation of the partial rotation of the
map due to inclination and obliquity
Args:
inc (float): map inclination
obl (float): map obliquity
Returns:
tuple: x, y, z, angle
"""
co = jnp.cos(obl / 2)
so = jnp.sin(obl / 2)
ci = jnp.cos(inc / 2)
si = jnp.sin(inc / 2)

denominator = jnp.sqrt(1 - ci**2 * co**2)

axis_x = si * co
axis_y = si * so
axis_z = -so * ci

angle = 2 * jnp.arccos(ci * co)

arg = jnp.linalg.norm(jnp.array([axis_x, axis_y, axis_z]))
axis_x = jnp.where(arg > 0, axis_x / denominator, 1.0)
axis_y = jnp.where(arg > 0, axis_y / denominator, 0.0)
axis_z = jnp.where(arg > 0, axis_z / denominator, 0.0)

return axis_x, axis_y, axis_z, angle


def left_project(
ydeg: int, inc: float, obl: float, theta: float, theta_z: float, y: Array
):
"""R @ y
Args:
ydeg (int): degree of the spherical harmonic map
inc (float): map inclination
obl (float): map obliquity
theta (float): rotation angle about the map z-axis
theta_z (float): rotation angle about the sky y-axis
x (Array): spherical harmonic map coefficients
Returns:
Array: rotated spherical harmonic map coefficients
"""
axis_x, axis_y, axis_z, angle = sky_projection_axis_angle(inc, obl)
y = dot_rotation_matrix(ydeg, 1.0, None, None, -0.5 * jnp.pi)(y)
y = dot_rotation_matrix(ydeg, None, None, 1.0, -theta)(y)
y = dot_rotation_matrix(ydeg, axis_x, axis_y, axis_z, angle)(y)
y = dot_rotation_matrix(ydeg, None, None, 1.0, -theta_z)(y)
return y


def left_project(ydeg, inc, obl, theta, theta_z, x):
axis_x, axis_y, axis_z, angle = right_project_axis_angle(inc, obl, theta, theta_z)
return dot_rotation_matrix(ydeg, -axis_x, -axis_y, -axis_z, angle)(x)
def right_project(
ydeg: int, inc: float, obl: float, theta: float, theta_z: float, y: Array
):
"""y @ R
Args:
ydeg (int): degree of the spherical harmonic map
inc (float): map inclination
obl (float): map obliquity
theta (float): rotation angle about the map z-axis
theta_z (float): rotation angle about the sky y-axis
x (Array): spherical harmonic map coefficients
Returns:
Array: rotated spherical harmonic map coefficients
"""
axis_x, axis_y, axis_z, angle = sky_projection_axis_angle(inc, obl)
y = dot_rotation_matrix(ydeg, None, None, 1.0, theta_z)(y)
y = dot_rotation_matrix(ydeg, -axis_x, -axis_y, -axis_z, angle)(y)
y = dot_rotation_matrix(ydeg, None, None, 1.0, theta)(y)
y = dot_rotation_matrix(ydeg, 1.0, None, None, 0.5 * jnp.pi)(y)
return y


@partial(jax.jit, static_argnums=(0,))
Expand Down
4 changes: 2 additions & 2 deletions src/jaxoplanet/experimental/starry/surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from jaxoplanet.experimental.starry.basis import A1, U0, poly_basis
from jaxoplanet.experimental.starry.pijk import Pijk
from jaxoplanet.experimental.starry.rotation import (
full_rotation_axis_angle,
left_project,
right_project_axis_angle,
)
from jaxoplanet.experimental.starry.utils import ortho_grid
from jaxoplanet.experimental.starry.ylm import Ylm
Expand Down Expand Up @@ -158,7 +158,7 @@ def intensity(self, lat: float, lon: float):
y = jnp.sin(lat) * jnp.sin(lon)
z = jnp.cos(lat) * jnp.ones_like(x)

axis = right_project_axis_angle(self.inc - jnp.pi / 2, self.obl, 0.0, 0.0)
axis = full_rotation_axis_angle(self.inc - jnp.pi / 2, self.obl, 0.0, 0.0)
axis = jnp.array(axis[0:3]) * axis[-1]
rotation = Rotation.from_rotvec(axis)
x, y, z = rotation.apply(jnp.array([x, y, z]).T).T
Expand Down
4 changes: 2 additions & 2 deletions tests/experimental/starry/rotation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

from jaxoplanet.experimental.starry.rotation import (
dot_rotation_matrix,
full_rotation_axis_angle,
left_project,
right_project,
right_project_axis_angle,
)
from jaxoplanet.test_utils import assert_allclose

Expand Down Expand Up @@ -62,7 +62,7 @@ def test_right_project_axis_angle_edge_case():
theta = theta_z = obl = 0.0
inc = jnp.pi / 2

x = jnp.array(right_project_axis_angle(inc, obl, theta, theta_z))
x = jnp.array(full_rotation_axis_angle(inc, obl, theta, theta_z))
assert jnp.all(jnp.isfinite(x))


Expand Down

0 comments on commit 50a673b

Please sign in to comment.