Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix interpolate() #7

Open
RyannDaGreat opened this issue Oct 3, 2022 · 0 comments
Open

Fix interpolate() #7

RyannDaGreat opened this issue Oct 3, 2022 · 0 comments

Comments

@RyannDaGreat
Copy link

RyannDaGreat commented Oct 3, 2022

Hi! I found an error in interpolate(). It will break when fed numpy arrays, because the variable inputs_are_torch will not have been defined when we reach the line with if inputs_are_torch:

I would make a pull request, but I'm not sure how (I've already forked this repo for another purpose, and won't be merging that because I've made changes that are irrelevant to this repo in it. Is there a way to fork twice?)

def interpolate(t, v0, v1, DOT_THRESHOLD=0.9995):
    """Helper function to (spherically) interpolate two arrays v1 v2.
    
    Taken from: https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355
    """

    if not isinstance(v0, np.ndarray):
        inputs_are_torch = True
        input_device = v0.device
        v0 = v0.cpu().numpy()
        v1 = v1.cpu().numpy()

    dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
    if np.abs(dot) > DOT_THRESHOLD:
        v2 = (1 - t) * v0 + t * v1
    else:
        theta_0 = np.arccos(dot)
        sin_theta_0 = np.sin(theta_0)
        theta_t = theta_0 * t
        sin_theta_t = np.sin(theta_t)
        s0 = np.sin(theta_0 - theta_t) / sin_theta_0
        s1 = sin_theta_t / sin_theta_0
        v2 = s0 * v0 + s1 * v1

    if inputs_are_torch:
        v2 = torch.from_numpy(v2).to(input_device)

    return v2

Should become

def interpolate(t, v0, v1, DOT_THRESHOLD=0.9995):
    """Helper function to (spherically) interpolate two arrays v1 v2.
    
    Taken from: https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355
    """

    inputs_are_torch = False
    if not isinstance(v0, np.ndarray):
        inputs_are_torch = True
        input_device = v0.device
        v0 = v0.cpu().numpy()
        v1 = v1.cpu().numpy()

    dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
    if np.abs(dot) > DOT_THRESHOLD:
        v2 = (1 - t) * v0 + t * v1
    else:
        theta_0 = np.arccos(dot)
        sin_theta_0 = np.sin(theta_0)
        theta_t = theta_0 * t
        sin_theta_t = np.sin(theta_t)
        s0 = np.sin(theta_0 - theta_t) / sin_theta_0
        s1 = sin_theta_t / sin_theta_0
        v2 = s0 * v0 + s1 * v1

    if inputs_are_torch:
        v2 = torch.from_numpy(v2).to(input_device)

    return v2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant