Skip to content

Instantly share code, notes, and snippets.

@javipus
Created December 16, 2019 18:18
Show Gist options
  • Save javipus/51222de8e590a3b7d9e1ad1add084ec6 to your computer and use it in GitHub Desktop.
Save javipus/51222de8e590a3b7d9e1ad1add084ec6 to your computer and use it in GitHub Desktop.

INTRODUCTION

The goal is to solve two types of probability problems about sampling strings/multi-sets without replacement. The main differences between the two is permutation invariance. One example of each:

  • Problem type # 1: What is the probability of obtaining the sequence 'abb' when picking 3 letters without replacement from {a: 2, b: 3, z: 1}
  • Problem type # 2: What is the probability of picking 1 'd', 1 's' and 3 'g' when picking five letters without replacement from 'gggggggsggdgggg'?

In the naive setting, a transformer is used to map the question to the answer, e.g.

INPUT = "What is the probability of picking 1 'd', 1 's' and 3 'g' when picking five letters without replacement from 'gggggggsggdgggg'?"
OUTPUT = "2/21"

Baseline performance for this approach was

  • ~75% for interpolation, i.e. problems involving strings of length no greater than those in the training set
  • Only ~5% for extrapolation, i.e. problems with strings longer than those in the training set

METHODS

The new approach introduces two main ideas:

  • Problem decomposition: the output string is not just the final answer but a series of concatenated strings, each representing a single step in the solution, e.g. (Problem type # 1 is the same but without steps 4 and 5)
INPUT = "What is the probability of picking 1 'd', 1 's' and 3 'g' when picking five letters without replacement from 'gggggggsggdgggg'?"

# times each letter in target strings appears in source
STEP1 = "d:1   s:1   g:13"
# how many letters in total in source string
STEP2 = "1+1+13=15"
# probability of one target string
STEP3 = "(1/15)*(1/14)*(13/13)*(12/12)*(11/11)=1/210"
# number of target strings - from permutations
STEP4 = "5!/(3!)=20"
# probability of any target string
STEP5 = "20*1/210=2/21"
# solution
STEP6 = "2/21"

STEPS = STEP1, STEP2, STEP3, STEP4, STEP5, STEP6
OUTPUT = ' '.join(STEPS)
  • Transformers translate sequences one character at a time. This allows us to check if an equal sign has been written to the output tape. When this happens, the next character is not computed using the transformer but calling an external calculator that evaluates the expression immediately before the equal sign. Sequence translation then resumes as usual.

    • During training, characters obtained using the external calculator are masked so that they don't contribute to the transformer loss

RESULTS

>99% for interpolation; Just ~5% for extrapolation. Failure modes are different in each case:

  • When interpolating, the network learns the decomposition exactly; the very few errors come from single-digit mistakes when counting letters or copying numbers, e.g. d: 5 instead of d: 6 or writing 1001/9690 in step 3 but then writing 101/9690 in step 5
  • When extrapolating, the network simply fails to generalize in step 4, where all probabilities are multiplied together; this expression will only contain as many factors as the length of the target sequences it was trained on, so e.g. if the test target sequence has length 5 but the network was trained on sequences of length <=3, it will count all the 5 letters correctly in step 1 but then multiply only the (correct) probabilities for the first 3 letters in step 4, completely ignoring the remaining 2
    • The few cases it gets right are degenerate because the conditional probabilities it fails to compute are all 1

COMMENTS

I would do a longer decomposition that includes intermediate symbolic steps. My intuition is that this way the network would learn the general rule mapping target strings to a product of probabilities. For the simpler case (problem type # 1), we write:

$$P(a_0...a_n) = \prod_{i=0}^{n-1}P(a_i|a_0\dots a_{i-1})$$

then, evaluating each conditional probability

$$\prod_{i=0}^{n-1} (C(a_i, a_0\dots a_{n-1})-C(a_i, a_0\dots a_{i-1}))/(L-i)$$

where $C(a, b)$ counts how many times character $a$ is contained in string $b$ and $L$ is the length of the source string.

Concrete example:

INPUT = "Prob of 'abb' sampling w/o repl from 'abbcab'?"

# counting
STEP1 = "a:2 b:3 c:1"
# total
STEP2 = "2+3+1=5"
# calculate C(a,source) for each a in target
STEP3 = "count_a_:0 count_b_a:0 count_b_ab:1"
# completely general symbolic target
STEP4 = "(a|'')*(ab|a)*(abb|ab)"
# substituting character counts
STEP5 = "(2-0)/(5-0)*(3-0)/(5-1)*(3-1)/(5-2)=1/5"
# solution
STEP6 = "1/5"

STEPS = STEP1, STEP2, STEP3, STEP4, STEP5, STEP6
OUTPUT = ' '.join(STEPS)

I feel like this approach is still somewhat sketchy. Ideally, steps in the computation shouldn't be substrings of a larget OUTPUT string, but something else (nodes in a tree?), and we'd have one transformer solving each step

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment