[ot][crazy][spam] Notes: Matmul Acceleration
Goal: Take implementation-oriented notes on https://arxiv.org/pdf/2106.10860.pdf . My understanding of the approach is that it lossily compresses the data into a smaller representation that contains most of the big ends of the information, combines in a way that involves no add-multiplies, then decompresses it and produces a batch matmul result. The goal is to understand the derivation and implementation enough to daydream about implementing a convolution operation rather than matrix multiplication. [multiple people are working on convolution; i don't know whether i will find their results] Planning to skip the abstract, then start through the intro. 1 Introduction - Approximate Matrix Multiplication (AMM) - assumptions: tall, dense, held in a single address space - A is large data, B is linear operator such as model weights - nonlinear preprocessing reduces problem to table lookups - no multiply-adds if B is constant - a family of quantizers used rather than a single expensive on - 3 parts: 1. high-speed quantization functions 2. high-speed low-bitwidth stable summation algorithm 3. matmul algorithm using 1&2 with theoretical quality guarantees 1.1 Problem Formulation A: NxD B: DxM N >> D >= M tau: computation time bound Problem formulation to minimize error epsilon for computation time bound tau: f(g(A), h(B)) * alpha + Beta - AB [interruption]
A~: original training set with statistical nearness to A 3 Background - Product Quantization - PQ is classic algorithm for approximating inner products and euclidean distances - commonly used for tasks like this a * b is approximated by a^ * b, where a^ has a special structure for quick computation, and is close to a. a^ is formed by "concatenating learned prototypes in disjoint subspaces" dot products are precomputed between b and the a^ prototypes, and reused for many a near a^. so a must be from a set of vectors that are all near some known vector a^; precomputations are made with a^ . a is A's rows and b is B's columns. a is transposed. 1. Prototype Learning K-means is used C times to cluster the rows of A~. The clusters become C sets of K prototypes. 2. Encoding Function g(a) The closest prototype to a is picked for each subspace. These are stored as integer indices using C log2(K) bits. 3. Table Construction h(B) Each subspace precomputes or caches the dot product between b and each prototype. This makes C lookup tables of size K. 4. Aggregation f(,) Above indices and tables are used to lookup a * b in each subspace: an estimated partial. The results are then summed across all C subspaces. --- The next section elaborates on the above 4 steps in detail with summation formulae. I'll probably take a bit to collect my thoughts. They basically described almost the entire algorithm in those 4 points, almost enough for implementation. Maybe I'll look for any missing bits, or information to help one think about generalising it to other operations. There's a concern that this biases the data toward a^ and A~. I wasn't expecting that. Very tuned toward having well-sampled a known space of data. Halting use and updating a^ and A~ and C and the prototypes with newly found outliers, and ensuring these outliers have prototypes, would be very important for keeping results accurate. I infer this relates to the theoretical quality guarantees.
plan: review source implementing this to understand PQ 1-4 better. comparing with sources reduces errors in exposure to new material. Encoding Function g(A) detail: the sequence of indices of the K-means centroids in a are called a's _encoding_. the K centroids are called a _codebook_. a is considered to be composed of _subvectors_ among C [i think]. Table Construction h(B) detail - K <= 16 and 8-bit quantized lookup tables commonly known as offering enormous speedups compared to other choices. 8-bit integers provide for more parallelism in SIMD. - 8-bit quantization done by subtracting minimums per table and linearly rescaling. maximum per-table entropy kept <= 255. invertible affine transform. see: appendix A Aggregation f(,) detail: Summation of selected encodings with b, rather than original multiplication. == equation makes it kind of look like they basically precalculate a bunch of dot products with near data, and select from among these precalculated dot products to form the result matrix. this may be all that is going on. the approach seems it would inform well the kind of information processing that is happening, to an algorithm dev. future work might imply that transformer models that use matrices can be replaced by lookup tables and bit encodings, or trees of branch conditions with names that are humanly meaningful. i think apple released some research recently replacing a large part of transformers with something simpler. ==
=== idea daydream between equal signs was likely not informed! as usual! === 4 Our Method - PQ is good for N,M >> D - but only N >> M,D is needed g(a) can be slow to process - new g(a) : [stepping away]
4.1 new encoding g(a) - new trainable hash functions - balanced binary regression trees, each tree a hash bucket leaf(x): traverse tree from root, select child based on where x_j is <= node-specific v. for simd, limited to 16 leaves, and j is depth-specific. vectorization details in Appendix B this is called Algorithm 1, the MaddnessHash input: vector x, split indices j^1..j^4, split threshold v^1...v^4 i <- 1 // node index within level of tree for t <- 1 to 4 do v <- v^t_i // lookup split threshold for node i at level t b <- x_(j^t) >= v ? 1 : 0 // above split threshold ? i <- 2i - 1 + b // assign to left or right child end for return i ------------------- i'm guessing i can learn more effectively what i might need from source code than most of this. skipping a little to see if theory is mentioned. - hash function parameters constructed in greedy manner - algorithm for forming in Algorithm 2 - experimentally found selecting more than 4 indices for best loss gave little return Algorithm 2 is 22 lines long. section 4.2 describes its parts. 4.3 new method for optimizing pq prototypes 4.4 fast 8-bit aggregation f(,) 4.5 theoretical guarantees proof and further discussion in Appendix F A~ is from probability distribution D, maximum value sigma_A . C is # of codebooks, lambda > 0 is a regulization parameter used in 4.3 . i'm not certain where the error is in the equation, but it contains a difference constraint of C * sigma_A * ||b||_2 / (2*sqrt(lambda)) * [1/256 + (8 + sqrt(rho(C,D,del))/sqrt(2n)] where v(C,D,del) =triangle C(4 * ceil(log2(d)) + 256)log2 - logdel not sure what a log2 after an expression means. i'm not sure what this means, not even sure if accuracy is implied by high values or low values in the overall expression. i think i'd get more return looking through the appendix, for quality guarantees. 5 Experimentes note: post-softmax, this algorithm gave the same results as matmul. - the algorithm is the fastest, but would be over 4x faster if hardware lookup-accumulates were optimized as much as multiply-accumulates. 6 Discussion, Conclusion - theoretical guarantees incomplete, quantization errors not yet described - tested only on cpus, not gpus. expecting different parameters would be optimal. - not tested a parallel implementation using threads - thinking next exploiting weight reuse in convolutional layers to further optimize - accelerating an entire network is being considered an engineering challenge due to boilerplate work - it is unknown when it is appropriate to use this acceleration - no work done to differentiate across the hash function at all - mostly promising on inference rather than training - strong reduction in electricity needs due to implementability with multiplexers only. very strong locality of data access. - speed increase over previous AMM methods is up to 10x - future methods similar to this also hold promise. this kind of work unbottlenecks many things.
a github containing maddness and its predecessor bolt: https://github.com/dblalock/bolt last PR was a python interfacing cleanup merged 12 hours ago
citations of this paper listed at https://scholar.google.com/scholar?cites=16672894839769153249 - memory-efficient backpropagation (2022) - current trends in data summaries (2022) - ARM 4-bit PQ, SIMD acceled approximate nearest neighbor search - High performance deep learning on resource constrained platforms (2021)
participants (1)
-
Undiscussed Horrific Abuse, One Victim of Many