Discovering faster matrix multiplication algorithms with reinforcement learning

TensorGame

TensorGame is played as follows. The start position ({{mathscr{S}}}_{0}) of the game corresponds to the tensor ({mathscr{T}}) representing the bilinear operation of interest, expressed in some basis. In each step t of the game, the player writes down three vectors (u(t), v(t), w(t)), which specify the rank-1 tensor u(t)v(t)w(t), and the state of the game is updated by subtracting the newly written down factor:

$${{mathscr{S}}}_{t}leftarrow {{mathscr{S}}}_{t-1}-{{bf{u}}}^{(t)}otimes {{bf{v}}}^{(t)}otimes {{bf{w}}}^{(t)}.$$

(2)

The game ends when the state reaches the zero tensor, ({{mathscr{S}}}_{R}={bf{0}}). This means that the factors written down throughout the game form a factorization of the start tensor ({{mathscr{S}}}_{0}), that is, ({{mathscr{S}}}_{0}={sum }_{t=1}^{R}{{bf{u}}}^{(t)}otimes {{bf{v}}}^{(t)}otimes {{bf{w}}}^{(t)}). This factorization is then scored. For example, when optimizing for asymptotic time complexity the score is −R, and when optimizing for practical runtime the algorithm corresponding to the factorization ({{({{bf{u}}}^{(t)},{{bf{v}}}^{(t)},{{bf{w}}}^{(t)})}}_{t=1}^{R}) is constructed (see Algorithm 1) and then benchmarked on the fly (see Supplementary Information).

In practice, we also impose a limit Rlimit on the maximum number of moves in the game, so that a weak player is not stuck in unnecessarily (or even infinitely) long games. When a game ends  because it has run out of moves, a penalty score is given so that it is never advantageous to deliberately exhaust the move limit. For example, when optimizing for asymptotic time complexity, this penalty is derived from an upper bound on the tensor rank of the final residual tensor ({{mathscr{S}}}_{{R}_{text{limit}}}). This upper bound on the tensor rank is obtained by summing the matrix ranks of the slices of the tensor.

TensorGame over rings

We say that the decomposition of ({{mathscr{T}}}_{n}) in equation (1) is in a ring ({mathcal{E}}) (defining the arithmetic operations) if each of the factors u(t), v(t) and w(t) has entries belonging to the set ({mathcal{E}}), and additions and multiplications are interpreted according to ({mathcal{E}}). The tensor rank depends, in general, on the ring. At each step of TensorGame, the additions and multiplications in equation (2) are interpreted in ({mathcal{E}}). For example, when working in ({{mathbb{Z}}}_{2}), (in this case, the factors u(t), v(t) and w(t) live in F = {0, 1}), a modulo 2 operation is applied after each state update (equation (2)).

We note that integer-valued decompositions u(t), v(t) and w(t) lead to decompositions in arbitrary rings ({mathcal{E}}). Hence, provided F only contains integers, algorithms we find in standard arithmetic apply more generally to any ring.

AlphaTensor

AlphaTensor builds on AlphaZero1 and its extension Sampled AlphaZero21, combining a deep neural network with a sample-based MCTS search algorithm.

The deep neural network, fθ(s) = (π, z) parameterized by θ, takes as input the current state s of the game and outputs a probability distribution π(s) over actions and z(s) over returns (sum of future rewards) G. The parameters θ of the deep neural network are trained by reinforcement learning from self-play games and synthetic demonstrations. Self-play games are played by actors, running a sample-based MCTS search at every state st encountered in the game. The MCTS search returns an improved probability distribution over moves from which an action at is selected and applied to the environment. The sub-tree under at is reused for the subsequent search at st+1. At the end of the game, a return G is obtained and the trajectory is sent to the learner to update the neural network parameters θ. The distribution over returns z(st) is learned through distributional reinforcement learning using the quantile regression distributional loss34, and the network policy π(st) is updated using a Kullback–Leibler divergence loss, to maximize its similarity to the search policy for self-play games or to the next action for synthetic demonstrations. We use the Adam optimizer35 with decoupled weight decay36 to optimize the parameters θ of the neural network.

Sample-based MCTS search

The sample-based MCTS search is very similar to the one described in Sampled AlphaZero. Specifically, the search consists of a series of simulated trajectories of TensorGame that are aggregated in a tree. The search tree therefore consists of nodes representing states and edges representing actions. Each state-action pair (s, a) stores a set of statistics (N(s,a),Q(s,a),hat{pi }(s,a)), where N(s, a) is the visit count, Q(s, a) is the action value and (hat{pi }(s,a)) is the empirical policy probability. Each simulation traverses the tree from the root state s0 until a leaf state sL is reached by recursively selecting in each state s an action a that has not been frequently explored, has high empirical policy probability and high value. Concretely, actions within the tree are selected by maximizing over the probabilistic upper confidence tree bound21,37

$$mathop{{rm{argmax}}}limits_{a}Q(s,a)+c(s)cdot hat{pi }(s,a)frac{sqrt{{sum }_{b}N(s,b)}}{1+N(s,a)},$$

where c(s) is an exploration factor controlling the influence of the empirical policy (hat{pi }(s,a)) relative to the values Q(s, a) as nodes are visited more often. In addition, a transposition table is used to recombine different action sequences if they reach the exact same tensor. This can happen particularly often in TensorGame as actions are commutative. Finally, when a leaf state sL is reached, it is evaluated by the neural network, which returns K actions {ai} sampled from π(asL), alongside the empirical distribution (hat{pi }(a| {s}_{{rm{L}}})=frac{1}{K}{sum }_{i}{delta }_{a,{a}_{i}}) and a value v(sL) constructed from z(sL). Differently from AlphaZero and Sampled AlphaZero, we chose v not to be the mean of the distribution of returns z(sL) as is usual in most reinforcement learning agents, but instead to be a risk-seeking value, leveraging the facts that TensorGame is a deterministic environment and that we are primarily interested in finding the best trajectory possible. The visit counts and values on the simulated trajectory are then updated in a backward pass as in Sampled AlphaZero.

Policy improvement

After simulating N(s) trajectories from state s using MCTS, the normalized visit counts of the actions at the root of the search tree N(s, a)/N(s) form a sample-based improved policy. Differently from AlphaZero and Sampled AlphaZero, we use an adaptive temperature scheme to smooth the normalized visit counts distribution as some states can accumulate an order of magnitude more visits than others because of sub-tree reuse and transposition table. Concretely, we define the improved policy as ({mathcal{I}}hat{pi }(s,a)={N}^{1/tau (s)}(s,a)/{sum }_{b}{N}^{1/tau (s)}(s,b)) where (tau (s)=log N(s)/log bar{N},{rm{if}},N > bar{N}) and 1 otherwise, with (bar{N}) being a hyperparameter. For training, we use ({mathcal{I}}hat{pi }) directly as a target for the network policy π. For acting, we additionally discard all actions that have a value lower than the value of the most visited action, and sample proportionally to ({mathcal{I}}hat{pi }) among those remaining high-value actions.

Learning one agent for multiple target tensors

We train a single agent to decompose the different tensors ({{mathscr{T}}}_{n,m,p}) in a given arithmetic (standard or modular). As the network works with fixed-size inputs, we pad all tensors (with zeros) to the size of the largest tensor we consider (({{mathscr{T}}}_{5}), of size 25 × 25 × 25). At the beginning of each game, we sample uniformly at random a target ({{mathscr{T}}}_{n,m,p}), and play TensorGame. Training a single agent on different targets leads to better results thanks to the transfer between targets. All our results reported in Fig. 3 are obtained using multiple runs of this multi-target setting. We also train a single agent to decompose tensors in both arithmetics. Owing to learned transfer between the two arithmetics, this agent discovers a different distribution of algorithms (of the same ranks) in standard arithmetic than the agent trained on standard arithmetic only, thereby increasing the overall diversity of discovered algorithms.

Synthetic demonstrations

The synthetic demonstrations buffer contains tensor-factorization pairs, where the factorizations ({{({{bf{u}}}^{(r)},{{bf{v}}}^{(r)},{{bf{w}}}^{(r)})}}_{r=1}^{R}) are first generated at random, after which the tensor ({mathscr{D}}={sum }_{r=1}^{R}{{bf{u}}}^{(r)}otimes {{bf{v}}}^{(r)}otimes {{bf{w}}}^{(r)}) is formed. We create a dataset containing 5 million such tensor-factorization pairs. Each element in the factors is sampled independently and identically distributed (i.i.d.) from a given categorical distribution over F (all possible values that can be taken). We discarded instances whose decompositions were clearly suboptimal (contained a factor with u = 0, v = 0, or w = 0).

In addition to these synthetic demonstrations, we further add to the demonstration buffer previous games that have achieved large scores to reinforce the good moves made by the agent in these games.

Change of basis

The rank of a bilinear operation does not depend on the basis in which the tensor representing it is expressed, and for any invertible matrices A, B and C we have ({rm{Rank}},({mathscr{T}})={rm{Rank}},({{mathscr{T}}}^{({bf{A}},{bf{B}},{bf{C}})})), where ({{mathscr{T}}}^{({bf{A}},{bf{B}},{bf{C}})}) is the tensor after change of basis given by

$${{mathscr{T}}}_{ijk}^{({bf{A}},{bf{B}},{bf{C}})}=mathop{sum }limits_{a=1}^{S}mathop{sum }limits_{b=1}^{S}mathop{sum }limits_{c=1}^{S}{{bf{A}}}_{ia}{{bf{B}}}_{jb}{{bf{C}}}_{kc}{{mathscr{T}}}_{abc}.$$

(3)

Hence, exhibiting a rank-R decomposition of the matrix multiplication tensor ({{mathscr{T}}}_{n}) expressed in any basis proves that the product of two n × n matrices can be computed using R scalar multiplications. Moreover, it is straightforward to convert such a rank-R decomposition into a rank-R decomposition in the canonical basis, thus yielding a practical algorithm of the form shown in Algorithm 1. We leverage this observation by expressing the matrix multiplication tensor ({{mathscr{T}}}_{n}) in a large number of randomly generated bases (typically 100,000) in addition to the canonical basis, and letting AlphaTensor play games in all bases in parallel.

This approach has three appealing properties: (1) it provides a natural exploration mechanism as playing games in different bases automatically injects diversity into the games played by the agent; (2) it exploits properties of the problem as the agent need not succeed in all bases—it is sufficient to find a low-rank decomposition in any of the bases; (3) it enlarges coverage of the algorithm space because a decomposition with entries in a finite set F = {−2, −1, 0, 1, 2} found in a different basis need not have entries in the same set when converted back into the canonical basis.

In full generality, a basis change for a 3D tensor of size S × S × S is specified by three invertible S × S matrices A, B and C. However, in our procedure, we sample bases at random and impose two restrictions: (1) A = B = C, as this performed better in early experiments, and (2) unimodularity ((det {bf{A}}in {-1,+1})), which ensures that after converting an integral factorization into the canonical basis it still contains integer entries only (this is for representational convenience and numerical stability of the resulting algorithm). See Supplementary Information for the exact algorithm.

Signed permutations

In addition to playing (and training on) games in different bases, we also utilize a data augmentation mechanism whenever the neural network is queried in a new MCTS node. At acting time, when the network is queried, we transform the input tensor by applying a change of basis—where the change of basis matrix is set to a random signed permutation. We then query the network on this transformed input tensor, and finally invert the transformation in the network’s policy predictions. Although this data augmentation procedure can be applied with any generic change of basis matrix (that is, it is not restricted to signed permutation matrices), we use signed permutations mainly for computational efficiency. At training time, whenever the neural network is trained on an (input, policy targets, value target) triplet (Fig. 2), we apply a randomly chosen signed permutation to both the input and the policy targets, and train the network on this transformed triplet. In practice, we sample 100 signed permutations at the beginning of an experiment, and use them thereafter.

Action canonicalization

For any λ1, λ2, λ3 {−1, +1} such that λ1λ2λ3 = 1, the actions (λ1u, λ2v, λ3w) and (u, v, w) are equivalent because they lead to the same rank-one tensor (λ1u)  (λ2v)  (λ3w) = uvw. To prevent the network from wasting capacity on predicting multiple equivalent actions, during training we always present targets (u, v, w) for the policy head in a canonical form, defined as having the first non-zero element of u and the first non-zero element of v strictly positive. This is well defined because u or v cannot be all zeros (if they are to be part of a minimal rank decomposition), and for any (u, v, w) there are unique λ1, λ2, λ3 {−1, +1} (with λ1λ2λ3 = 1) that transform it into canonical form. In case the network predicts multiple equivalent actions anyway, we merge them together (summing their empirical policy probabilities) before inserting them into the MCTS tree.

Training regime

We train AlphaTensor on a TPU v3, with a total batch size of 2,048. We use 64 TPU cores, and train for 600,000 iterations. On the actor side, the games are played on standalone TPU v4, and we use 1,600 actors. In practice, the procedure  takes a week to converge.

Neural network

The architecture is composed of a torso, followed by a policy head that predicts a distribution over actions, and a value head that predicts a distribution of the returns from the current state (see Extended Data Fig. 3).

Input

The input to the network contains all the relevant information of the current state and is composed of a list of tensors and a list of scalars. The most important piece of information is the current 3D tensor ({{mathscr{S}}}_{t}) of size S × S × S. (For simplicity, in the description here we assume that all the three dimensions of the tensor are equal in size. The generalization to different sizes is straightforward.) In addition, the model is given access to the last h actions (h being a hyperparameter usually set to 7), represented as h rank-1 tensors that are concatenated to the input. The list of scalars includes the time index t of the current action (where 0 ≤ t < Rlimit).

Torso

The torso of the network is in charge of mapping both scalars and tensors from the input to a representation that is useful to both policy and value heads. Its architecture is based on a modification of transformers23, and its main signature is that it operates over three S × S grids projected from the S × S × S input tensors. Each grid represents two out of the three modes of the tensor. Defining the modes of the tensor as ({mathcal{U}},{mathcal{V}},{mathcal{W}}), the rows and columns of the first grid are associated to ({mathcal{U}}) and ({mathcal{V}}), respectively, the rows and columns of the second grid are associated to ({mathcal{W}}) and ({mathcal{U}}), and the rows and columns of the third grid are associated to ({mathcal{V}}) and ({mathcal{W}}). Each element of each grid is a feature vector, and its initial value is given by the elements of the input tensors along the grid’s missing mode. These feature vectors are enriched by concatenating an S × S × 1 linear projection from the scalars. This is followed by a linear layer projecting these feature vectors into a 512-dimensional space.

The rest of the torso is a sequence of attention-based blocks with the objective of propagating information between the three grids. Each of those blocks has three stages, one for every pair of grids. In each stage, the grids involved are concatenated, and axial attention24 is performed over the columns. It is noted that in each stage we perform in parallel S self-attention operations of 2S elements in each. The representation sent to the policy head corresponds to the 3S512-dimensional feature vectors produced by the last layer of the torso. A detailed description of the structure of the torso is specified in Extended Data Fig. 4 (top) and Appendix A.1.1 in Supplementary Information.

Policy head

The policy head uses the transformer architecture23 to model an autoregressive policy. Factors are decomposed into k tokens of dimensionality d such that k × d = 3S. The transformer conditions on the tokens already generated and cross-attends to the features produced by the torso. At training time, we use teacher-forcing, that is, the ground truth actions are decomposed into tokens and taken as inputs into the causal transformer in such a way that the prediction of a token depends only on the previous tokens. At inference time, K actions are sampled from the head. The feature representation before the last linear layer of the initial step (that is, the only step that is not conditioned on the ground truth) is used as an input to the value head, described below. Details of the architecture are presented in Extended Data Fig. 4 (centre) and Appendix A.1.2 in Supplementary Information.

Value head

The value head is composed of a four-layer multilayer perceptron whose last layer produces q outputs corresponding to the (frac{1}{2q},frac{3}{2q},ldots frac{2q-1}{2q}) quantiles. In this way, the value head predicts the distribution of returns from this state in the form of values predicted for the aforementioned quantiles34. At inference time, we encourage the agent to be risk-seeking by using the average of the predicted values for quantiles over 75%. A detailed description of the value head is presented in Extended Data Fig. 4 (bottom) and Appendix A.1.3 in Supplementary Information.

Related work

The quest for efficient matrix multiplication algorithms started with Strassen’s breakthrough in ref. 2, which showed that one can multiply 2 × 2 matrices using 7 scalar multiplications, leading to an algorithm of complexity ({mathcal{O}}({n}^{2.81})). This led to the development of a very active field of mathematics attracting worldwide interest, which studies the asymptotic complexity of matrix multiplication (see refs. 3,4,5,6). So far, the best known complexity for matrix multiplication is ({mathcal{O}}({n}^{2.37286})) (ref. 12), which improves over ref. 11, and builds on top of fundamental results in the field8,9,10. However, this does not yield practical algorithms, as such approaches become advantageous only for astronomical matrix sizes. Hence, a significant body of work aims at exhibiting explicit factorizations of matrix multiplication tensors, as these factorizations provide practical algorithms. After Strassen’s breakthrough showing that (text{rank},({{mathscr{T}}}_{2})le 7), efficient algorithms for larger matrix sizes were found15,16,18,26,38. Most notably, Laderman showed in ref. 15 that 3 × 3 matrix multiplications can be performed with 23 scalar multiplications. In addition to providing individual low-rank factorizations, an important research direction aims at understanding the space of matrix multiplication algorithms—as opposed to exhibiting individual low-rank factorizations—by studying the symmetry groups and diversity of factorizations (see ref. 5 and references therein). For example, the symmetries of 2 × 2 matrix multiplication were studied in refs. 39,40,41,42, where Strassen’s algorithm was shown to be essentially unique. The case of 3 × 3 was studied in ref. 43, whereas a symmetric factorization for all n is provided in ref. 44.

On the computational front, continuous optimization has been the main workhorse for decomposing tensors17,45,46, and in particular matrix multiplication tensors. Such continuous optimization procedures (for example, alternating least squares), however, yield approximate solutions, which correspond to inexact matrix multiplication algorithms with floating point operations. To circumvent this issue, regularization procedures have been proposed, such as ref. 18, to extract exact decompositions. Unfortunately, such approaches often require  substantial human intervention and expertise to decompose large tensors. A different line of attack was explored in refs. 47,48, based on learning the continuous weights of a two-layer network that mimics the structure of the matrix multiplication operation. This method, which is trained through supervised learning of matrix multiplication examples, finds approximate solutions to 2 × 2 and 3 × 3 matrix multiplications. In ref. 48, a quantization procedure is further used to obtain an exact decomposition for 2 × 2. Unlike continuous optimization-based approaches, AlphaTensor directly produces algorithms from the desired set of valid algorithms, and is flexible in that it allows us to optimize a wide range of (even non-differentiable) objectives. This unlocks tackling broader settings (for example, optimization in finite fields, optimization of runtime), as well as larger problems (for example, ({{mathscr{T}}}_{4}) and ({{mathscr{T}}}_{5})) than those previously considered. Different from continuous optimization, a boolean satisfiability (SAT) based  formulation of the problem of decomposing 3 × 3 matrix multiplication was recently proposed in ref. 20, which adds thousands of new decompositions of rank 23 to the list of known 3 × 3 factorizations. The approach relies on a state-of-the-art SAT solving procedure, where several assumptions and simplifications are made on the factorizations to reduce the search space. As is, this approach is, however, unlikely to scale to larger tensors, as the search space grows very quickly with the size.

On the practical implementation front, ref. 31 proposed several ideas to speed up implementation of fast matrix multiplication algorithms on central processing units (CPUs). Different fast algorithms are then compared and benchmarked, and the potential speed-up of such algorithms is shown against standard multiplication. Other works focused on getting the maximal performance out of a particular fast matrix multiplication algorithm (Strassen’s algorithm with one or two levels of recursion) on a CPU32 or a GPU49. These works show that, despite popular belief, such algorithms are of practical value. We see writing a custom low-level implementation of a given algorithm to be distinct from the focus of this paper—developing new efficient algorithms—and we believe that the algorithms we discovered can further benefit from a more efficient implementation by experts.

Beyond matrix multiplication and bilinear operations, a growing amount of research studies the use of optimization and machine learning to improve the efficiency of computational operations. There are three levels of abstractions at which this can be done: (1) in the hardware design, for example, chip floor planning50, (2) at the hardware–software interface, for example, program super-optimization of a reference implementation for specific hardware51, and (3) on the algorithmic level, for example, program induction52, algorithm selection53 or meta-learning54. Our work focuses on the algorithmic level of abstraction, although AlphaTensor is also flexible to discover efficient algorithms for specific hardware. Different from previous works, we focus on discovering matrix multiplication algorithms that are provably correct, without requiring initial reference implementations. We conclude by relating our work broadly to existing reinforcement learning methods for scientific discovery. Within mathematics, reinforcement learning was applied, for example, to theorem proving55,56,57,58, and to finding counterexamples refuting conjectures in combinatorics and graph theory59. Reinforcement learning was further shown to be useful in many areas in science, such as molecular design60,61 and synthesis62 and optimizing quantum dynamics63.

Read original article here

Leave a Comment