art with code

2024-01-19

Second round of thoughts on LLMs

 LLMs are systems that compress a lot of text in a lossy fashion and pull out the most plausible and popular continuation or infill for the input text.

It's a lot like your autopilot mode. You as the mind are consulted by the brain to do predictions and give high-level feedback on what kind of next actions to take, but most of the execution happens subconsciously with the brain pulling up memories and playing them back. Often your brain doesn't even consult you on what to do, since running a mind is slow and expensive, and it's faster and cheaper to do memory playback instead - i.e. run on autopilot.

If you have enough memories, you can do almost everything on autopilot.

Until you can't, which is where you run into one of the LLM capability limits. Structured thinking and search. To solve a more complex problem, you string memories together and search for an answer. That requires exploration, backtracking and avoiding re-exploring deadends. Think of solving a math problem: you start off by matching heuristics (the lemmas you've memorized) to the equation, transforming it this way and that, sometimes falling back all the way to the basic axioms of the algebra, going on wild goose chases, abandoning unpromising tracks, until you find the right sequence of transformations that leads you to the answer.

Note that you do need LLM-style memory use in that, you need to know the axioms to use them in the first place. Otherwise you need to go off and search for the axioms themselves and the definition of truth, etc. which is going to add a good chunk of extra work on top of it all. (What is the minimum thought, the minimal memory, that we use? A small random adjustment and its observation? From an LLM perspective, as long as you have a scoring function, the minimum change is changing the output by one token. Brute-force enumeration over all token sequences.)

If you add a search system to the LLM that can backtrack the generation and keeps track of different explored avenues, perhaps this system can solve problems that require structured thinking.


LLMs as universal optimizers. You can use an LLM to rank its input ("Score the following 0-100: ...") You can also use an LLM to improve its input ("Make this better: ..."). Combine the two and you get the optimizer:

while (true) {
  program = llm(improve + best_program)
  score = llm(score + program)
  if (score > best_score) {
    best_score = score
    best_program = program
  }
}


LLMs as universal functions. An LLM takes as its input a sequence of tokens and outputs a sequence of tokens. LLMs are trained using sequences of tokens as the input. The training program for an LLM is a sequence of tokens.

llm2 = train(llm, data)

can become

llm2 = llm(train)(llm, llm(data))

And of course, you can recursively apply an LLM to its own output: output' = llm(llm(llm(llm(...)))). You can ask the LLM to rank its inputs and try to improve them, validating the outputs with something else: optimize = input => ([input] * 10).map(x => llm(improve + x)).filter(ix => isValid(ix)).map(ix => ({score: llm(score + ix), value: ix})).maxBy('score').value

This gives you the self-optimizer:

while(true) {
  train = optimize(train)
  training_data = optimize(training_data)
  llm = train(llm, training_data)
}

If you had Large Model Models - LMMs - you could call optimize directly on the model. You can also optimize the optimization function, scoring function and improver function as you go, for a fully self-optimizing optimizer.

while (true) {
  lmm = optimize(lmm, lmm, scoring_model, improver_model)
  optimize = optimize(lmm, optimize, scoring_model, improver_model)
  scoring_model = optimize(lmm, scoring_model, scoring_model, improver_model)
  improver_model = optimize(lmm, improver_model, scoring_model, improver_model)
}

The laws of numerical integration likely apply here, you'll halve the noise by taking 4x the samples. Who knows!


LLMs generate text at a few hundred bytes per second. An LLM takes a second to do a simple arithmetic calculation (and gets it wrong, because the path generated for math is many tokens long and the temperature plus lossy compression make it pull the wrong numbers.) The hardware is capable of doing I/O at tens or hundreds of gigabytes per second. Ancient CPUs do a billion calculations in a second. I guess you could improve on token-based math by encoding all 16-bit numbers as tokens and having some magic in the tokenizer.. but still, you're trying to memorize the multiplication table or addition table or what have you. Ain't gonna work. Use a computer. They're really good at arithmetic.

We'll probably get something like RAG ("inject search results into the input prompt") but on the output size ("inject 400 bytes at offset 489 from training file x003.txt") to get to megabytes / second LLM output rates. Or diffusers... SDXL img2img at 1024x1024 resolution takes a 3MB context and outputs 3MB in a second. If you think about the structure of LLM, the slow bitrate of the output is a bit funny: Llama2's intermediate layers pass through 32 megabytes of data, and the final output layers up that to 260 MB, which gets combined to 32000 token scores, which are then sampled to determine the final output token. Gigabytes of I/O to produce 2 bytes at the output end.


SuperHuman benchmark for tool-using models. Feats like "multiply these two 4096x4096 matrices, you've got 50 ms, go!", grepping large files at 20 GB/s, using SAT solvers and TSP solvers, proof assistants, and so on. Combining problem solving with known-good algorithms and optimal hardware utilization. The problems would require creatively combining optimized inner loops. Try to find a Hamiltonian path through a number of locations and do heavy computation at each visited node, that kind of thing.


Diffusers and transformers. A diffuser starts off from a random field of tokens and denoises it into a more plausible arrangement of tokens. A transformer starts off from a string of tokens and outputs a plausible continuation.

SD-style diffusers are coupled with an autoencoder to convert input tokens into latent space, and latents to output tokens. In the classic Stable Diffusion model, the autoencoder converts an 8x8 patch of pixels into a single latent, and a latent into an 8x8 patch of pixels. These conversions consider the entire image (more or less), so it's not quite like JPEG's 8x8 DCT/iDCT.

What if you used an autoencoder to turn a single latent space LLM token into 64 output tokens? 64x faster generation with this one trick?

A diffuser starts off from a random graph and tweaks it until it resolves into a plausible path. A transformer generates a path one node at a time. 


A transformer keeps track of an attention score for each pair of input tokens, which allows it to consider all the relations between the tokens in the input string. This also makes it O(n^2) in time and space. For short inputs and outputs, this is not much of a problem. At longer input lengths you definitely start to feel it, and this is the reason for the tiny context sizes of TF-based LLMs. If the "large input" to your hundred gigabyte program is 100kB in size, there's probably some work left to be done.

Or maybe there's something there like there was with sorting algorithms. You'd think that to establish the ordering, you have to compare each element with every other element (selection sort, O(n^2)). But you can take advantage of the transitivity of the comparison operation to recursively split the sort into smaller sub-sorts (merge sort, quicksort, O(n log2 n)), or the limited element alphabet size to do it in one pass (radix sort, counting sort, O(n)-ish).

What could be the transitive operation in a transformer? At an output token, the previous tokens have been produced without taking the output token into account, so you get the triangle matrix shape. That's still O(n^2). Is there some kind of transitive property to attention? Like, we'd only need to pay attention to the tokens that contributed to high-weight tokens? Some parts of the token output are grammatical, so they weigh the immediately preceding tokens highly, but don't really care about anything else. In that case, can we do an early exit? Can we combine token sequences into compressed higher-order tokens and linearly reduce the token count of the content? Maybe you could apply compression to the attention matrix to reduce each input token's attention to top-n highest values, which would scale linearly. What if you took some lessons from path tracing like importance sampling, lookup tree, reducing variance until you get to an error threshold. Some tokens would get resolved in a couple of tree lookups, others might take thousands. 


No comments:

Blog Archive