(Almost) Differentiable Indexing of Tensors

Indexing tensors is at the core of numeros ML applications. In this post, I describe a strategy to index a tensor when we need to compute the gradient with respect to the index.

When indexing a tensor, we extract the portion of the tensor specified by the index. Using a programming-oriented notation, a[i]\mathbf a[i] will be the i-th element of the tensor a\mathbf a.

During differentiation, we compute the rate of change of a function with respect to its input (i.e. the derivative). In the case of indexing, this equates to How would my output look if I take a tiny step from my current index? The long story is that functions need to respect specific properties to be differentiable, and indexing does not respect them. The short one is that indices are integer values, so an infinitesimally small step from the current index would result in an ill-defined operation (which memory cell corresponds to i+dii+\mathrm{d}i?).

With these ideas in mind, let’s explore a trick to index a tensor in a differentiable way. To do this, we consider the mathematical description of the indexing operation:

f(a;w)=aTw=a0w0+...+aNwN              a,wRN, f(\mathbf a; \mathbf w) = \mathbf a^T \mathbf w = a_0 w_0 + ... + a_N w_N \; \; \; \; \; \; \; \mathbf a, \mathbf w \in \mathbb{R}^N, (1)

where a\mathbf a is the input tensor and w\mathbf w is the one-hot encoding of the index. That is, w\mathbf w is constrained (or quantized) to rows of the N×NN\times N identity matrix (INI_N). Therefore, for any w^RN\widehat{\mathbf w} \in \mathbb{R}^N,

f(a;w^)=aTQ(w^),              Q:RNRN, f(\mathbf a; \widehat{\mathbf w}) = \mathbf a^T \mathbf{Q}(\widehat{\mathbf w}), \; \; \; \; \; \; \; \mathbf{Q} : \mathbb{R}^N \to \mathbb{R}^N, (2)

where Q()\mathbf{Q}(\cdot) is the quantization operation that snaps w^\widehat{\mathbf w} to the closest row of INI_N. There are many alternatives for the function Q()\mathbf{Q}(\cdot), and they depend on the definition of “closest.” Let’s define

Q(x)=IN[argmaxx]. \mathbf{Q}(\mathbf x) = I_N[\arg\max \mathbf x]. (3)

That is, Q(x)\mathbf{Q}(\mathbf x) maps each vector x\mathbf x to the row of the identity matrix with the 11 located in the position of the highest entry of x\mathbf x. The rate of change of Q(x)\mathbf{Q}(\mathbf x) with respect to the entries of x\mathbf x will be zero because the output is constant for a given input (imagine a step function whose derivative is an impulse).

Let’s say we want to optimize an objective, L\mathcal{L}, using gradient descent. To do this, we compute the gradient of L\mathcal{L} with respect to the weights of our model (in our case, the indexing vector, w^\widehat{\mathbf w}) and update the weights after each iteration as:

w^t+1w^tηw^L. \widehat{\mathbf w}_{t+1} \leftarrow \widehat{\mathbf w}_t - \eta \nabla_{\widehat{\mathbf w}} \mathcal{L}. (4)

Since w^L=QLJ(Q)=0\nabla_{\widehat{\mathbf w}} \mathcal{L} = \nabla_{\mathbf Q} \mathcal{L}\cdot \mathbf{J} (\mathbf{Q}) = 0 we don’t update the weights. We can, however, use the Straight-Throught Estimation (STE) of the gradients. That is, we compute the gradient as if we did not quantize the vector

w^L=Q(w^)L, \nabla_{\widehat{\mathbf w}} \mathcal{L} = \nabla_{\mathbf{Q}(\widehat{\mathbf w})} \mathcal{L}, (5)

and use this gradient for the update in Eq. $5$.

In code, the indexing operation with STE would look like this for one-dimensional vectors:

1
2
def index_ste(T: Tensor, idx: Tensor) -> Tensor:
    return T + (torch.eye(T.shape[0])[T.argmax()] - T).detach()

The role of T in the return value is to instruct autograd to propagate the gradients through T without the need to specify a custom backward pass.