Sharding Linear Operations Over Weight Out Dimensions

This post develops part of this document:

Sharding over the out dimension

In the previous post on Index Projection Functions, we developed affine projections for the dimension of a tensor-valued operation, assuming dimensions: , , , :

We’ll now consider the projection functions , , and ; and how we’ll handle batching over out dimensions:

The values of in the out dimension are independent of each other; each out value is computed using one column of and one value in ; and as a result the op can be cleanly and trivially sharded by chunking and :

By extending the space to index the dimension, we can express the index functions , , and coordinates in terms of the indexed coordinate, and the shapes in terms of the out dimension size.

We also cleanly get the property that coherent ranges in the index space correspond to coherent tensor ranges in the mapped coordinate space:

Sharding over the in dimension

Sharding over the dimension is more complex, as it requires sharding a reduce operation; which breaks our current block model; as a preview for a future post, we can see that this can be rewritten as a reduction: