[ot][spam][crazy][data] transformer model 'attention' improvement

k gmkarl at gmail.com
Tue Jan 25 03:11:32 PST 2022


30:  chunk_values, chunk_weights, chunk_max = jax.lax.map(
31:    chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size))

from help(jax.lax.map):

      def map(f, xs):
        return np.stack([f(x) for x in xs])

so it just passes each element of xs into chunk_scanner, and again
stacks the results.


More information about the cypherpunks mailing list