Practical Einops: Tensor Operations Based on Indices
People are familiar with vectors and matrices operations but are less familiar with tensor operations. In machine learning, tensors often refer to batched vectors or batched matrices and are represented by an array-like object with multiple indices. Due to this reason, tensors operations in most Python packages, including NumPy, PyTorch and TensorFlow, are typically named after vectors and matrices operations. However, tensors themselves have a particular useful operation, called contraction, which uses index-based notations and can cover most vectors and matrices operations. This index-based notations intuitively and verbosely describe the relationship between the components of input and output tensors. Today's topic, the Python's einops package, extends these notations and provides an elegant API for flexible and powerful tensor operations.
Notations. In this post, we use letters to denote tensors, for example, \(x, y\), and use brackets to denote their components, such as \(x[i,j,k], y[i,j]\). With a slight abuse of notation, we may also use \(x[i,j,k]\) directly to denote a tensor with three indices.
A quick example for summations
Consider the batched bilinear form \(x^\intercal Q y\) for batched vectors \(x\),
\(y\) and batched matrices \(Q\). These calculations often arise when
dealing with stochastic sequential data, e.g., \(x\), \(y\) and \(Q\) are
elements of stochastic processes and \(x[i,j,k]\) stands for the value
of the \(k\)-th component at time step \(i\) for the \(j\)-th sample
path. With the index-based notations, the output tensor can be written
as \[ \mathsf{BatchedBilinearForm}(x, Q, y)= \sum_{k,l} x[i,j,k]\,
Q[i,j,k,l]\, y[i,j,l], \] which is a tensor with two indices \(i\) and \(j\)
as the summation is performed over \(k\) and \(l\). The following code
demonstrates how to do this with and without einops
in for PyTorch
tensors.
# without einops output = (x * (Q @ y.unsqueeze(-1)).squeeze(-1)).sum(dim=-1) # with einops output = einops.einsum(x, Q, y, "i j k, i j k l, i j l -> i j")
From this example, we can perceive some benefits and drawbacks of
einops
already.
- It gives the input and output shapes explicitly. The indices where
the summations occur are deduced from these shapes. Therefore,
while using and reading operations described by
einops
we can have a crystal-clear understanding of the involved tensors shapes. - It offers a unified way to perform summation operations, including
but not limited to
torch.dot
,torch.bmm
,torch.sum
, and so on. - It describes the operations by indices directly, which is not so transparent if we want to translate it back to vector and matrix operations.
- It describes the operations by a string following specific patterns, which might be very confusing for people who are unfamiliar with it.
- It might be inefficient and slow since it needs to parse a string. The decreased performance might not be significant if the same pattern is used multiple times with caching techniques
Despite the last performance issue, einops
provides an alternative way
to describe tensor operations based on indices. The pattern string
gives the input and output shapes but requires additional effort to
learn its usage.
Another example for permutations
The index-based notation is also intuitive for axis
permutations. Consider a tensor \(x[t,b,c,i]\) with the shape of (time,
batch, channel, feature)
. If we want to permute the axes and expect
the output tensor \(y\) to have a shape of (channel, feature, batch,
time)
, then we effectively mean \[ x[t,b,c,i] = y[c,i,b,t]. \] The
following code demonstrates how to do this with and without einops
in
for PyTorch tensors.
# without einops y = x.permute(2, 3, 1, 0) # with einops y = einops.rearrange(x, "t b c i -> c i b t")
Once again, einops
is more intuitive as it expilcitly specifies the
input and output shapes.
More features provided by einops
In fact, both NumPy and PyTorch provide a routine function einsum
,
which is actually the motivation behind einops.einsum
. The two
examples given above can also be achieved by torch.einsum
directly. However, the einops package extends the idea further, and
provide more advanced features. Below are some usage scenarios that I
believe might be useful. There are also official tutorials and
examples on github.
Use ellipsis. For the first example on the batched bilinear form, the demonstrated code with
einops
is slightly restrictive than the pure PyTorch approach. Indeed, the pattern given there explicitly specifies that the input tensor \(x\) has three indices. This level of specification may be desired based on requirements. But if we want to remove this restriction, then the code can be modified tooutput = einops.einsum(x, Q, y, "... k, ... k l, ... l -> ...")
Reshape and check axis sizes. For the second example on axis permutations, we can also explicitly specify axis sizes and let
einops
to check, say,y = einops.rearrange(x, "t b c i -> c i b t", t=10, c=30) # this is equivalent to t, b, c, i = x.size() assert t == 10 and c == 30 y = einops.rearrange(x, "t b c i -> c i b t")
Split axes. Sometimes, we may want to split each image into 4 pieces. For example, the following code demonstrates how to take an input tensor with the shape of
(b, c, h, w)
and return a tensor with the shape of(b, c, 2, 2, h2, w2)
def split_patches_without_einops(x: torch.Tensor) -> torch.Tensor: b, c, h, w = x.size() y = x.view(b, c, h // 2, 2, w // 2, 2) return y.permute(0, 1, 3, 5, 2, 4) def split_patches_with_einops(x: torch.Tensor) -> torch.Tensor: return einops.rearrange(x, "b c (h s1) (w s2) -> b c s1 s2 h w", s1=2, s2=2)
Join axes. Sometimes, we may want to flatten a tensor by joining multiple axes. For example, the following code demonstrates how to take an input tensor with the shape of
(b, c, h, w)
and return a tensor with the shape of(b, c*h*w)
def join_axes_without_einops(x: torch.Tensor) -> torch.Tensor: b, c, h, w = x.size() return x.view(b, c * h * w) def join_axes_with_einops(x: torch.Tensor) -> torch.Tensor: return einops.rearrange(x, "b c h w -> b (c h w)")
Layer. It is possible to create an
torch.nn.Module
instance for aneinops.rearrange
operation and put it into thetorch.nn.Sequential
container. For example, the following code demonstrates how to build a simple image classifier. Note that the first layer is included to check axis sizes and can be skipped.from einops.layers.torch import Rearrange model = torch.nn.Sequential( Rearrange("b c h w -> b c h w", c=3, h=8, w=8), torch.nn.Conv2d(3, 16, 3, stride=2, padding=1), Rearrange("b c h w -> b (c h w)", c=16, h=4, w=4), torch.nn.Linear(16 * 4 * 4, 120), torch.nn.ReLU(), torch.nn.Linear(120, 10), )
References
- Rogozhnikov A. (2018). Einops. GitHub. https://github.com/arogozhnikov/einops
- Duran-Martin. G. (2021). Einsums in the wild. Notion. https://grrddm.notion.site/Einsums-in-the-wild-bd773f01ba4c463ca9e4c1b5a6d90f5f#3cc76f8130ac4a348888f531069f7c8a
- Noobbodyjourney. (2021). [Discussion] Why are Einstein Sum Notations not popular in ML? They changed my life. [Reddit Post]. R/MachineLearning. https://www.reddit.com/r/MachineLearning/comments/r8tsv6/discussion_why_are_einstein_sum_notations_not/