18 Jan
2022
18 Jan
'22
10:09 p.m.
a jax contributor kindly shared this with me. you can store tpu models precompiled, which significantly speeds launch time, by using a compilation cache folder. from jax.experimental.compilation_cache import compilation_cache as cc cc.initialize_cache("/path/name/here", max_cache_size_bytes=32 * 2**30) not presently relevant as i'm on gpu, but should try to use this with tpus.