Index Projection Functions
This post develops part of this document:
This post explores Index Projection Functions; as a pathway to developing a tensor expression evaluation environment:
- See Part 1: Sharding Tensor Expressions.
Restricting to Shardable Operators
Suppose we’ve got a toy tensor expression language
1 | X, W, b, Z: Tensor |
And we’re interested in mechanical sharding optimizations of the resultant expression graph:
Shardable Operators
Let
As discussed in the previous post, we’re attempting to find a family of
- Given the shapes of the parameters, what are the expected shapes of the results?
- Given the shapes of the parameters, what independent shards are possible which can be fused back into the same results?
- How do the shards share resources (which sharding choices are more or less expensive)?
Consider the abstract one-
We’re interested in families of
Operator Index Counting
Crucially, the goal is to be able to shard:
- With a strong ability to predict execution costs before evaluation; and
- Without examining anything about the implementation of
.
This can be reframed as a counting problem:
- Can we enumerate all simple sub-problems of a given call to
?
To make this concrete, let’s reconsider
- What is the shape of
? - How many dimensions does
have? - What are their sizes?
- How many dimensions does
- What relationship does the shape of
have to the inputs ( , , ) and outputs ( )? - What portions of the inputs and outputs are associated with each point in
?
Given a block
It is important to state that the top-down approach (starting with an
- Top-Down: Given this
, can I find projection functions ? - Bottom-Up: Given a menagerie of known projection functions
, what can I construct?
Affine Projection Functions
One design approach for solving the
Affine projection functions are an approach I explored in depth working at 3Scan, and an approach that’s also been incorporated into the MLIR project’s Polyhedral Types.
What components make up an affine projection function?:
- an affine expression mapping points in
space to starts in the coordinate space of input/output tensors; - a fixed
defining the shape of region selected relative to the mapped point.
The simplest representation of this is a simple affine transform + a shape:
Are affine expressions the right or best solution to te design of projection functions? We don’t know; affine expressions can only be compared to other proposals, not all possible families of functions; there may be better ideas yet to be surfaced. We do know that affine expressions make some common patterns easy to express and to compute the shards of; and make some performance critical patterns tractable to express and compute the shards of.
Affine projection function have an important non-obvious property; it is generally tractable to arrange them such that coherent range blocks in the index space map to coherent space blocks in the input or output tensors. This property falls out of the fact that affine projection functions have constant marginal delta strides (the incremental change resulting from changing an input by one step is constant). Coherent input/output blocks dramatically simplify processing expectations, particularly in the face of shared input (as with convolution operations).
As with many matrix transform operations, the basic definitions are simple; but some of the implications can be complex to unpack. We’ll explore a few here.
Linear Strides Over a Batch Dimension
Consider
We’d like to be able to describe a
A very simple linear projection is sufficient to describe the mapping from a point in index space
to a batch row of the input
We also cleanly get the property that coherent ranges in the index space correspond to coherent tensor ranges in the mappend coordinate space:
I’ll continue developing this theme in future posts. More can be read in the tapestry work: