Skip to content

Instantly share code, notes, and snippets.

@Phylliida
Last active November 13, 2020 22:56
Show Gist options
  • Save Phylliida/6ce9f04fbde470df44266ca740c0fb2f to your computer and use it in GitHub Desktop.
Save Phylliida/6ce9f04fbde470df44266ca740c0fb2f to your computer and use it in GitHub Desktop.
Analysis of 1 layer 1 head transformers learning the 12 2-state and 2 1-state DFAs
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@Phylliida
Copy link
Author

Phylliida commented Nov 13, 2020

Note that n_embd=64, lr=5e-4. Smaller embedding sizes probably would work fine, this is a really easy task.

These are all possible 2 and 1 state DFAs on two letters. The Octagons represent start state. The double circle for a's is if you want to read that as a traditional "accepting DFA", but how they actually work is they initially output a symbol of their start state, then they receive a character, hop to that state and emit another symbol, receive another character and hop to a new state and emit a symbol, etc.

I created a dataset of all possible "binary strings" of a's and b's of length 8 as inputs, then held out a random 20% of them for a test set. This doesn't actually model a test set very well as 4 characters already probably enough to get full info about what the automata is doing, but that's okay, I mostly just wanted to see how hard it was for the network to learn various automata. It also doesn't seem like 5-10 gradient steps (which is all these need) would be enough to memorize.

I used a GPT type model (instead of a seq2seq), framing the task as "given the tuple (past symbol emitted, cur character to process), predict the next symbol the automata will output." To encode tuples, I just represented the tuple (i,j) as the character i+j*alphabetSize, but there are better ways to do this that I want to look at. Because this means it can output invalid characters, I take the mod of alphabetSize (mod 2) for the outputs.

My batch size was big enough that the entire dataset fit into a single batch, so one epoch was a single gradient descent step. For testing to see if they modeled the DFA properly, I decoded a single symbol y_1 given the pair of (initialStateOutput, firstCharacter), then appended (y_1, secondCharacter) to the sequence and decoded y_2, etc. I used argmax for decoding from probabilities. I stopped training runs once it had the correct output for every possible length 8 sequence, "epochs" refers to how many epochs/gradient steps it took for that to happen.

The "avg" values are over 100 runs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment