The goal is to solve two types of probability problems about sampling strings/multi-sets without replacement. The main differences between the two is permutation invariance. One example of each:
- Problem type # 1: What is the probability of obtaining the sequence 'abb' when picking 3 letters without replacement from {a: 2, b: 3, z: 1}
- Problem type # 2: What is the probability of picking 1 'd', 1 's' and 3 'g' when picking five letters without replacement from 'gggggggsggdgggg'?
In the naive setting, a transformer is used to map the question to the answer, e.g.
INPUT = "What is the probability of picking 1 'd', 1 's' and 3 'g' when picking five letters without replacement from 'gggggggsggdgggg'?"
OUTPUT = "2/21"
Baseline performance for this approach was
- ~75% for interpolation, i.e. problems involving strings of length no greater than those in the training set
- Only ~5% for extrapolation, i.e. problems with strings longer than those in the training set
The new approach introduces two main ideas:
- Problem decomposition: the output string is not just the final answer but a series of concatenated strings, each representing a single step in the solution, e.g. (Problem type # 1 is the same but without steps 4 and 5)
INPUT = "What is the probability of picking 1 'd', 1 's' and 3 'g' when picking five letters without replacement from 'gggggggsggdgggg'?"
# times each letter in target strings appears in source
STEP1 = "d:1 s:1 g:13"
# how many letters in total in source string
STEP2 = "1+1+13=15"
# probability of one target string
STEP3 = "(1/15)*(1/14)*(13/13)*(12/12)*(11/11)=1/210"
# number of target strings - from permutations
STEP4 = "5!/(3!)=20"
# probability of any target string
STEP5 = "20*1/210=2/21"
# solution
STEP6 = "2/21"
STEPS = STEP1, STEP2, STEP3, STEP4, STEP5, STEP6
OUTPUT = ' '.join(STEPS)
-
Transformers translate sequences one character at a time. This allows us to check if an equal sign has been written to the output tape. When this happens, the next character is not computed using the transformer but calling an external calculator that evaluates the expression immediately before the equal sign. Sequence translation then resumes as usual.
- During training, characters obtained using the external calculator are masked so that they don't contribute to the transformer loss
>99% for interpolation; Just ~5% for extrapolation. Failure modes are different in each case:
- When interpolating, the network learns the decomposition exactly; the very few errors come from single-digit mistakes when counting letters or copying numbers, e.g.
d: 5
instead ofd: 6
or writing1001/9690
in step 3 but then writing101/9690
in step 5 - When extrapolating, the network simply fails to generalize in step 4, where all probabilities are multiplied together; this expression will only contain as many factors as the length of the target sequences it was trained on, so e.g. if the test target sequence has length 5 but the network was trained on sequences of length
<=3
, it will count all the 5 letters correctly in step 1 but then multiply only the (correct) probabilities for the first 3 letters in step 4, completely ignoring the remaining 2- The few cases it gets right are degenerate because the conditional probabilities it fails to compute are all 1
I would do a longer decomposition that includes intermediate symbolic steps. My intuition is that this way the network would learn the general rule mapping target strings to a product of probabilities. For the simpler case (problem type # 1), we write:
then, evaluating each conditional probability
where
Concrete example:
INPUT = "Prob of 'abb' sampling w/o repl from 'abbcab'?"
# counting
STEP1 = "a:2 b:3 c:1"
# total
STEP2 = "2+3+1=5"
# calculate C(a,source) for each a in target
STEP3 = "count_a_:0 count_b_a:0 count_b_ab:1"
# completely general symbolic target
STEP4 = "(a|'')*(ab|a)*(abb|ab)"
# substituting character counts
STEP5 = "(2-0)/(5-0)*(3-0)/(5-1)*(3-1)/(5-2)=1/5"
# solution
STEP6 = "1/5"
STEPS = STEP1, STEP2, STEP3, STEP4, STEP5, STEP6
OUTPUT = ' '.join(STEPS)
I feel like this approach is still somewhat sketchy. Ideally, steps in the computation shouldn't be substrings of a larget OUTPUT
string, but something else (nodes in a tree?), and we'd have one transformer solving each step