regarding the idea for saving state, that could work here. basically you take a fancy text generation model and finetune it to produce its own embeddings by feeding it one token at a time instead of a document, each time feeding back its generated state as embeddings. it then is possibly bound by state size and complexity rather than input size and output size, and can possibly generate sensical documents of arbitrary length. I have a git repo somewhere where I implemented this last year or so.

also, I learned more about jax jit compilation working on the memory efficient attention improvements. jax has an option to inline or not inline subfunctions, so the issue is likely bisectable and removable by jit()ing subparts of the jax compilation that fails on colab. theoretically they aren't inlined by default (inlining would prevent the approach).