Tapestry: Sharding Matmul over the in dimension

This post develops part of this document:

Sharding , and , over the in dimension

In the previous posts on Index Projection Functions and Sharding Linear Operations over Weight Out Dimensions, we developed affine projection sharding over the and dimensions of a tensor-valued operaton, assuming dimensions: , , , :

To examine sharding over the dimension, we’ll need to focus on the nature of the matrix multiplication operation, and discuss and operations.

What’s important here is that, while is linearly shardable in its and dimensions, it contains an implicit reduce sum reduction operation in its dimension.

📝 Note: careful readers may note that there exists a large body of work dedicated to the question of how to implement more efficiently. The point of this exercise is to use and as a lens to examine data covariance in sharding block operations; and a naive treatment of is useful to these needs.
In a fully developed tensor expression sharding environment, it could be useful to hoist some operations, such as to the level that the compiler were directly aware of them; and could more aggressively use the existing research in those spaces; but it is not necessary to develop the foundations of such an environment.

Returning to , we can rewrite as a composition of and :

Applying this re-write would restructure our expression graph from this:

To this:

A block operation sharding solution for on should translate to a solution for on .

We can decompose by distinguishing between the matrix multiplication operator () and the cell-wise product operation (); and generate an intermediate product with shape .

To do this, we need to extend and broadcast and to the combined shape , to produce an intermediate result :

And we need to introduce a new operator which sums along and removes one dim of .

We can now define in terms of this intermediate result, and

This decomposition yields the following expression graph:

In this decomposition, is a well-behaved block operation; but is represented differently, it is not a block operation as we’ve represented them before, but a reduction operation.

Sharding

Consider ; a simple cell-wise multiplication. We expect the output to have the same shape and dimensions as the input:

To achieve this in tensor operations over inputs where the shapes are not initially the same, but can be manipulated to be the same; it’s common to use broadcasting; to treat any dimension which is for one input, but non for another input as though it were broadcast or spread to cover the size of the other:

Unknown environment 'eqnarry*'\begin{eqnarry*} Prod(A_{[1,n,o]}, B_{[m,1,o]})_{[m,n,o]} := \left( \begin{split} (a_{1,n,o} \cdot b_{m,1,o}) &\qquad& … \\ … &\qquad& … \end{split} \right) \end{eqnarray*}

It is also common in tensor operations to perform various permutations, transpositions, and reversals to achieve appropriate alignment for broadcasting operations; all tensor libraries have a host of features, some more convenient than others.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
>>> import torch
>>> batch = 10
>>> input = 2
>>> output = 3

>>> x = torch.rand((batch, input)
>>> x.shape
torch.Size([10, 2]))
>>> x.unsqueeze(-1).shape
torch.Size([10, 2, 1])

>>> w = torch.rand((input, output))
>>> w.shape
torch.Size([2, 3]))
>>> w.unsqueeze(0).shape
torch.Size([1, 2, 3])

>>> (x.unsqueeze(-1) * w.unsqueeze(0)).shape
torch.Size([10, 2, 3])

Index projection functions permit working directly in the dimensions of the input and output tensors; provided there is enough space in the dimensionality of the index space to count all points in the block; so we can directly describe the above operation used by the with a simple index space that covers the full shape of the output.

📝 Note: careful readers may note that this involves the same input data being read by multiple output cells.

Reduction Operations

Reduction operations require information between cells, on the face they don’t appear shardable. Consider the index projections for a operation over two dimensions:

, as a block operation, cannot be sharded along the dimension.

Additional information about , and about rewrites to which are semantics-preserving; beyond what can be expressed about Block Operations, would permit us to break it apart.

In modeling tensor expression graphs, we’re interested in recurrent classes of operations; a solution specific to might be useful, but a larger class of answers would hold more value.

Suppose we notice that the summation reduction follows the monadic laws (it is associative and commutative); such that we can re-order and regroup it as we see fit:

Any operation with this property, no matter what the implementation is doing, permits us to mechanically rewrite evaluation order.

If we can attest that is a reduction operation along the reduction dimension; then we know we can split the operation into intermediate results.

Suppose we introduced a index dimension, to model partial reductions over blocks of the reduction dimension, producing an intermediate result with an additional dimension; and then and then applied a second stage to complete the reduction:

When an operation is known to be a monoidal reduction along a given dimension of the input, a broad family of equivalent rewrite schedules become possible; but it complicates representation of the index space, as is no longer a simple countable dimension.

Rewriting

Returning to the definition of ,

We can now construct from the combination of a block operation and a reduce operation:

Sharding over

Putting this together with the definition of ,

We can now express as a form of high-level reduction operation, over the , , and dimensions:

When sharding is desired over the dimension, expands to the following sub-graph of , , and operations:

And when sharding is desired over the dimension; expands to a graph over the one-step operation, which behaves the way our previous description of behaved:

Being able to express this re-write option, when the dimension is not sharded, will require us to develop high-order meta-operator representation above the index projection function formalism.

Next

The full decomposition of provides a pathway to sharding potentially large operations, at the cost of decomposing operations which can be represented by highly space and time efficient kernel implementations when they are not decomposed.

Were we able to select between this decomposition, when was large enough to require sharding, and the block representation of , when fit within our execution boundaries; we’d have a flexible mechanism to handle both large and small cases.

Decorating operators with re-write production rules will be developed in future work in this series.