Last active
November 13, 2020 22:56
-
-
Save Phylliida/6ce9f04fbde470df44266ca740c0fb2f to your computer and use it in GitHub Desktop.
Analysis of 1 layer 1 head transformers learning the 12 2-state and 2 1-state DFAs
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Note that n_embd=64, lr=5e-4. Smaller embedding sizes probably would work fine, this is a really easy task.
These are all possible 2 and 1 state DFAs on two letters. The Octagons represent start state. The double circle for a's is if you want to read that as a traditional "accepting DFA", but how they actually work is they initially output a symbol of their start state, then they receive a character, hop to that state and emit another symbol, receive another character and hop to a new state and emit a symbol, etc.
I created a dataset of all possible "binary strings" of a's and b's of length 8 as inputs, then held out a random 20% of them for a test set. This doesn't actually model a test set very well as 4 characters already probably enough to get full info about what the automata is doing, but that's okay, I mostly just wanted to see how hard it was for the network to learn various automata. It also doesn't seem like 5-10 gradient steps (which is all these need) would be enough to memorize.
I used a GPT type model (instead of a seq2seq), framing the task as "given the tuple (past symbol emitted, cur character to process), predict the next symbol the automata will output." To encode tuples, I just represented the tuple (i,j) as the character i+j*alphabetSize, but there are better ways to do this that I want to look at. Because this means it can output invalid characters, I take the mod of alphabetSize (mod 2) for the outputs.
My batch size was big enough that the entire dataset fit into a single batch, so one epoch was a single gradient descent step. For testing to see if they modeled the DFA properly, I decoded a single symbol y_1 given the pair of (initialStateOutput, firstCharacter), then appended (y_1, secondCharacter) to the sequence and decoded y_2, etc. I used argmax for decoding from probabilities. I stopped training runs once it had the correct output for every possible length 8 sequence, "epochs" refers to how many epochs/gradient steps it took for that to happen.
The "avg" values are over 100 runs