25 Jan
2022
25 Jan
'22
11:11 a.m.
30: chunk_values, chunk_weights, chunk_max = jax.lax.map( 31: chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size)) from help(jax.lax.map): def map(f, xs): return np.stack([f(x) for x in xs]) so it just passes each element of xs into chunk_scanner, and again stacks the results.