Skip to content

Jax equivalent of tf.scatter_nd #3658

Answered by shoyer
jpuigcerver asked this question in General
Discussion options

You must be logged in to vote

lax.scatter exists, but is indeed rather complex. For cases where you would use tf.scatter_nd, we recommend using indexed update functions or the equivalent syntactic sugar using the .at property.

The API for indexed updates is very similar but with a different axis order matching NumPy's advanced indexing. To reproduce scatter_nd in JAX you could use:

import jax.numpy as jnp

def scatter_nd(indices, updates, shape):
    zeros = jnp.zeros(shape, updates.dtype)
    key = tuple(jnp.moveaxis(indices, -1, 0))
    return zeros.at[key].add(updates)

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@jpuigcerver
Comment options

Answer selected by mattjj
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants