28 Nov 2024

Practical Einops: Tensor Operations Based on Indices

People are familiar with vectors and matrices operations but are less familiar with tensor operations. In machine learning, tensors often refer to batched vectors or batched matrices and are represented by an array-like object with multiple indices. Due to this reason, tensors operations in most Python packages, including NumPy, PyTorch and TensorFlow, are typically named after vectors and matrices operations. However, tensors themselves have a particular useful operation, called contraction, which uses index-based notations and can cover most vectors and matrices operations. This index-based notations intuitively and verbosely describe the relationship between the components of input and output tensors. Today's topic, the Python's einops package, extends these notations and provides an elegant API for flexible and powerful tensor operations.

Notations. In this post, we use letters to denote tensors, for example, \(x, y\), and use brackets to denote their components, such as \(x[i,j,k], y[i,j]\). With a slight abuse of notation, we may also use \(x[i,j,k]\) directly to denote a tensor with three indices.

A quick example for summations

Consider the batched bilinear form \(x^\intercal Q y\) for batched vectors \(x\), \(y\) and batched matrices \(Q\). These calculations often arise when dealing with stochastic sequential data, e.g., \(x\), \(y\) and \(Q\) are elements of stochastic processes and \(x[i,j,k]\) stands for the value of the \(k\)-th component at time step \(i\) for the \(j\)-th sample path. With the index-based notations, the output tensor can be written as \[ \mathsf{BatchedBilinearForm}(x, Q, y)= \sum_{k,l} x[i,j,k]\, Q[i,j,k,l]\, y[i,j,l], \] which is a tensor with two indices \(i\) and \(j\) as the summation is performed over \(k\) and \(l\). The following code demonstrates how to do this with and without einops in for PyTorch tensors.

# without einops
output = (x * (Q @ y.unsqueeze(-1)).squeeze(-1)).sum(dim=-1)
# with einops
output = einops.einsum(x, Q, y, "i j k, i j k l, i j l -> i j")

From this example, we can perceive some benefits and drawbacks of einops already.

  1. It gives the input and output shapes explicitly. The indices where the summations occur are deduced from these shapes. Therefore, while using and reading operations described by einops we can have a crystal-clear understanding of the involved tensors shapes.
  2. It offers a unified way to perform summation operations, including but not limited to torch.dot, torch.bmm, torch.sum, and so on.
  3. It describes the operations by indices directly, which is not so transparent if we want to translate it back to vector and matrix operations.
  4. It describes the operations by a string following specific patterns, which might be very confusing for people who are unfamiliar with it.
  5. It might be inefficient and slow since it needs to parse a string. The decreased performance might not be significant if the same pattern is used multiple times with caching techniques

Despite the last performance issue, einops provides an alternative way to describe tensor operations based on indices. The pattern string gives the input and output shapes but requires additional effort to learn its usage.

Another example for permutations

The index-based notation is also intuitive for axis permutations. Consider a tensor \(x[t,b,c,i]\) with the shape of (time, batch, channel, feature). If we want to permute the axes and expect the output tensor \(y\) to have a shape of (channel, feature, batch, time), then we effectively mean \[ x[t,b,c,i] = y[c,i,b,t]. \] The following code demonstrates how to do this with and without einops in for PyTorch tensors.

# without einops
y = x.permute(2, 3, 1, 0)
# with einops
y = einops.rearrange(x, "t b c i -> c i b t")

Once again, einops is more intuitive as it expilcitly specifies the input and output shapes.

More features provided by einops

In fact, both NumPy and PyTorch provide a routine function einsum, which is actually the motivation behind einops.einsum. The two examples given above can also be achieved by torch.einsum directly. However, the einops package extends the idea further, and provide more advanced features. Below are some usage scenarios that I believe might be useful. There are also official tutorials and examples on github.

  • Use ellipsis. For the first example on the batched bilinear form, the demonstrated code with einops is slightly restrictive than the pure PyTorch approach. Indeed, the pattern given there explicitly specifies that the input tensor \(x\) has three indices. This level of specification may be desired based on requirements. But if we want to remove this restriction, then the code can be modified to

    output = einops.einsum(x, Q, y, "... k, ... k l, ... l -> ...")
    
  • Reshape and check axis sizes. For the second example on axis permutations, we can also explicitly specify axis sizes and let einops to check, say,

    y = einops.rearrange(x, "t b c i -> c i b t", t=10, c=30)
    # this is equivalent to
    t, b, c, i = x.size()
    assert t == 10 and c == 30
    y = einops.rearrange(x, "t b c i -> c i b t")
    
  • Split axes. Sometimes, we may want to split each image into 4 pieces. For example, the following code demonstrates how to take an input tensor with the shape of (b, c, h, w) and return a tensor with the shape of (b, c, 2, 2, h2, w2)

    def split_patches_without_einops(x: torch.Tensor) -> torch.Tensor:
        b, c, h, w = x.size()
        y = x.view(b, c, h // 2, 2, w // 2, 2)
        return y.permute(0, 1, 3, 5, 2, 4)
    
    def split_patches_with_einops(x: torch.Tensor) -> torch.Tensor:
        return einops.rearrange(x, "b c (h s1) (w s2) -> b c s1 s2 h w", s1=2, s2=2)
    
  • Join axes. Sometimes, we may want to flatten a tensor by joining multiple axes. For example, the following code demonstrates how to take an input tensor with the shape of (b, c, h, w) and return a tensor with the shape of (b, c*h*w)

    def join_axes_without_einops(x: torch.Tensor) -> torch.Tensor:
        b, c, h, w = x.size()
        return x.view(b, c * h * w)
    
    def join_axes_with_einops(x: torch.Tensor) -> torch.Tensor:
        return einops.rearrange(x, "b c h w -> b (c h w)")
    
  • Layer. It is possible to create an torch.nn.Module instance for an einops.rearrange operation and put it into the torch.nn.Sequential container. For example, the following code demonstrates how to build a simple image classifier. Note that the first layer is included to check axis sizes and can be skipped.

    from einops.layers.torch import Rearrange
    
    model = torch.nn.Sequential(
        Rearrange("b c h w -> b c h w", c=3, h=8, w=8),
        torch.nn.Conv2d(3, 16, 3, stride=2, padding=1),
        Rearrange("b c h w -> b (c h w)", c=16, h=4, w=4),
        torch.nn.Linear(16 * 4 * 4, 120),
        torch.nn.ReLU(),
        torch.nn.Linear(120, 10),
    )
    

References

Tags: ai
Created by Org Static Blog