[ot][spam][crazy][data] transformer model 'attention' improvement

k gmkarl at gmail.com
Tue Jan 25 02:42:57 PST 2022


And here's their explanation of the inner loops:

In each iteration of the outer loop, we call _query_chunk_attention,
which itself processes the keys and values in chunks (lines 23 to 33).
The chunks are processed sequentially and each chunk is summarized
independently (lines 14 to 21). Assuming a chunk size of sqrt(n) for
the keys and values, we hence obtain sqrt(n) summaries, giving rise to
the O(sqrt(n)) memory complexity.

After the summaries are computed, they need to be rescaled (lines 35
to 38) along the lines of Section 3, before we return the values
divided by the weights (line 42). The result of each iteration of the
outer loop is directly written to the output tensor res (line 56), so
that no additional memory is consumed across iterations. (A
multi-stage summarization approach could achieve O(log n) but would
complicate the implementation.)

[... In Figure 1 we provide default values for the chunk sizes that
lead to minimal runtime impact (on TPUv2), while still providing
significant memory savings.]

01:import functools, jax, math
02:from jax import numpy as jnp
03:
04:def _query_chunk_attention(query, key, value, precision,
key_chunk_size=4096):
05:  """Multi-head dot product attention with a limited number of queries."""
06:  num_kv, num_heads, k_features = key.shape
07:  v_features = value.shape[-1]
08:  key_chunk_size = min(key_chunk_size, num_kv)
09:  query = query / jnp.sqrt(k_features)
10:
11:  @functools.partial(jax.checkpoint, prevent_cse=False)
12:  def summarize_chunk(query, key, value):
13:    attn_weights = jnp.einsum('qhd,khd->qhk', query, key,
precision=precision)
14:    max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
15:    max_score = jax.lax.stop_gradient(max_score)
16:    exp_weights = jnp.exp(attn_weights - max_score)
17:    exp_values = jnp.einsum('vhf,qhv->qhf', value, exp_weights,
precision=precision)
18:    return (exp_values, exp_weights.sum(axis=-1),
19:      max_score.reshape((query.shape[0], num_heads)))
20:
21:  def chunk_scanner(chunk_idx):
22:    key_chunk = jax.lax.dynamic_slice(
23:      key, (chunk_idx, 0, 0),
24:    slice_sizes=(key_chunk_size, num_heads, k_features))
25:    value_chunk = jax.lax.dynamic_slice(
26:      value, (chunk_idx, 0, 0),
27:      slice_sizes=(key_chunk_size, num_heads, v_features))
28:    return summarize_chunk(query, key_chunk, value_chunk)
29:
30:  chunk_values, chunk_weights, chunk_max = jax.lax.map(
31:    chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size))
32:
33:  global_max = jnp.max(chunk_max, axis=0, keepdims=True)
34:  max_diffs = jnp.exp(chunk_max - global_max)
35:  chunk_values *= jnp.expand_dims(max_diffs, axis=-1)
36:  chunk_weights *= max_diffs
37:
38:  all_values = chunk_values.sum(axis=0)
39:  all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0)
40:  return all_values / all_weights
41:

It looks like the line numbers got offset a little bit in printing.


More information about the cypherpunks mailing list