Skip to content

Instantly share code, notes, and snippets.

@Drevanoorschot
Created January 18, 2020 20:44
Show Gist options
  • Save Drevanoorschot/9d0efb3dd78fa147c720b62ce0d491d7 to your computer and use it in GitHub Desktop.
Save Drevanoorschot/9d0efb3dd78fa147c720b62ce0d491d7 to your computer and use it in GitHub Desktop.
//imports?
class ParRadixSort {
context p!=none;
context ar != null;
context (\forall* int i; 0 <= i && i < ar.length ; Perm(ar[i],p));
ensures (\forall int i; 0 <= i && i < ar.length ; ar[i]==\old(ar[i]));
ensures |\result|==ar.length;
ensures (\forall int i; 0 <= i && i < ar.length ; ar[i]==\result[i]);
static seq<int> vals_method(frac p,int[] ar);
//ghost variables
given int k;
given seq<int> input_seq;
//function invariants
context_everywhere input != null && output != null && tempCounts != null && prefixsum != null && count != null;
context_everywhere partitionSize >= 0 && tCount >= 0 && radix > 1 && maxDigits >= 0;
context_everywhere tCount >= radix;
context_everywhere partitionSize * tCount == input.length;
context_everywhere input.length == output.length;
context_everywhere tempCounts.length == tCount;
context_everywhere (\forall int i; i >= 0 && i < tCount; tempCounts[i].length == radix);
context_everywhere k > 0;
context_everywhere prefixsum.length == ExpTwo(k);
context_everywhere count.length == radix;
context_everywhere count.length == prefixsum.length;
context_everywhere |input_seq| == count.length;
//permission specifications
context (\forall* int i; i >= 0 && i < input.length; Perm(input[i], read));
context (\forall* int i; i >= 0 && i < output.length; Perm(output[i], write));
context (\forall* int i; i >= 0 && i < tCount;
(\forall* int j; j >= 0 && j < radix; Perm(tempCounts[i][j], write)));
context (\forall* int i; i >= 0 && i < prefixsum.length; Perm(prefixsum[i], write));
context (\forall* int i; i >= 0 && i < count.length; Perm(count[i], write));
//pre-conditions
requires (\forall int i; i >= 0 && i < input.length; input[i] / ((maxDigits + 1) * radix) == 0);
//requires (\forall* int i; 0 <= i && i < count.length; count[i] == input_seq[i]);
//ensures (\forall* int i; 0 <= i && i < count.length; count[i] == input_seq[i]);
void parRadixSort(int[] input, int[] output, int radix, int partitionSize, int tCount, int maxDigits, int[][] tempCounts, int[] prefixsum, int[] count) {
//copy array
int[] inputCopy = new int[input.length];
//permission specifications
loop_invariant (\forall* int i; i >= 0 && i < input.length; Perm(input[i], read));
loop_invariant (\forall* int i; i >= 0 && i < inputCopy.length; Perm(inputCopy[i], write));
//helper invariants
loop_invariant l >= 0 && l <= input.length;
loop_invariant input.length == inputCopy.length;
for(int l = 0; l < input.length; l++) {
inputCopy[l] = input[l];
}
//count = new int[radix];
//tempCounts = new int[tCount][radix];
int[] partialInput = new int[inputCopy.length];
loop_invariant inputCopy != null && count != null && tempCounts != null && partialInput != null;
//helper invariants
loop_invariant i >= 0 && i <= maxDigits;
loop_invariant inputCopy.length == input.length;
//permission specifications
loop_invariant (\forall* int i; i >= 0 && i < inputCopy.length; Perm(inputCopy[i], write));
loop_invariant (\forall* int i; i >= 0 && i < input.length; Perm(input[i], read));
loop_invariant (\forall* int i; i >= 0 && i < output.length; Perm(output[i], write));
loop_invariant (\forall* int i; i >= 0 && i < tCount;
(\forall* int j; j >= 0 && j < radix; Perm(tempCounts[i][j], write)));
loop_invariant (\forall* int i; i >= 0 && i < prefixsum.length; Perm(prefixsum[i], write));
loop_invariant (\forall* int i; i >= 0 && i < count.length; Perm(count[i], write));
//more invariants
//loop_invariant (\forall int i; i >= 0 && i < count.length; count[i] == input_seq[i]);
for(int i = 0; i < maxDigits; i++) {
partialInput = new int[inputCopy.length];
//helper invariants
loop_invariant inputCopy != null && partialInput != null;
loop_invariant j >= 0 && j <= inputCopy.length;
loop_invariant partialInput.length == inputCopy.length;
//permission specifications
loop_invariant (\forall* int i; i >= 0 && i < inputCopy.length; Perm(inputCopy[i], write));
loop_invariant (\forall* int i; i >= 0 && i < partialInput.length; Perm(partialInput[i], write));
//more invariants
loop_invariant (\forall int i; i >= 0 && i < j; partialInput[i] >= 0 && partialInput[i] < radix);
for(int j = 0; j < inputCopy.length; j++) {
if (i == 0) {
partialInput[j] = inputCopy[j] % radix;
} else {
partialInput[j] = (inputCopy[j] / (i * radix)) % radix;
}
}
//helper invariant
loop_invariant z >= 0 && z <= count.length;
//permission specifications
loop_invariant (\forall* int i; i >= 0 && i < count.length; Perm(count[i], write));
for(int z = 0; z < count.length; z++) {
count[z] = 0;
}
//helper invariants
loop_invariant x >= 0 && x <= tCount;
//permission specifications
loop_invariant (\forall* int i; i >= 0 && i < tCount;
(\forall* int j; j >= 0 && j < radix; Perm(tempCounts[i][j], write)));
for(int x = 0; x < tCount; x++) {
//helper invariants
loop_invariant x >= 0 && x <= tCount;
loop_invariant y >= 0 && y <= radix;
//permission specifications
loop_invariant (\forall* int i; i >= 0 && i < tCount;
(\forall* int j; j >= 0 && j < radix; Perm(tempCounts[i][j], write)));
for(int y = 0; y < radix; y++) {
tempCounts[x][y] = 0;
}
}
// assert (\forall int i; i >= 0 && i < tempCounts.length; tempCounts[i].length == radix);
parCount(radix, partialInput, count, partitionSize, tCount, tempCounts);
int[] count_copy = new int[count.length];
//permission specifications
loop_invariant (\forall* int i; i >= 0 && i < prefixsum.length; Perm(prefixsum[i], write));
loop_invariant (\forall* int i; i >= 0 && i < prefixsum.length; Perm(count[i], write)); //TODO change to read
loop_invariant (\forall* int i; i >= 0 && i < prefixsum.length; Perm(count_copy[i], write));
//helper invariant
loop_invariant y >= 0 && y <= prefixsum.length;
loop_invariant (\forall int i; i >= 0 && i < y; prefixsum[i] == count[i]);
loop_invariant (\forall int i; i >= 0 && i < y; count_copy[i] == count[i]);
for(int y = 0; y < prefixsum.length; y++) {
prefixsum[y] = count[y];
count_copy[y] = count[y];
}
//assert(\forall int i; i >= 0 && i < prefixsum.length; prefixsum[i] == count[i]);
input_seq = vals_method(1/2, count_copy);
prefixSum(count_copy, prefixsum) with {
k = k;
input_seq = input_seq;
};
int[] tempOutput = new int[inputCopy.length];
int[] prefixsumExtended = new int[radix + 1];
//permission invariants
loop_invariant (\forall* int i; i >= 0 && i < prefixsum.length; Perm(prefixsum[i], write));
loop_invariant (\forall* int i; i >= 0 && i <= radix; Perm(prefixsumExtended[i], write));
//helper invariants
loop_invariant w >=0 && w <= prefixsum.length;
loop_invariant (\forall int i; i >= 0 && i < w; prefixsum[i] == prefixsumExtended[i]);
for(int w = 0; w < prefixsum.length; w++) {
prefixsumExtended[w] = prefixsum[w];
}
prefixsumExtended[radix] = input.length;//input[input.length - 1];
assert (\forall* int i; i >= 0 && i <= radix; Perm(prefixsumExtended[i], write));
assert prefixsumExtended[prefixsumExtended.length - 1] == input.length;
parReorder(radix, partialInput, input, tempOutput, prefixsumExtended) with {
prefixsumCopy = prefixsumExtended;
};
//assert false;
//inputCopy = tempOutput;
}
//output = inputCopy;
}
//function invariants
context_everywhere input != null && output != null && tempCounts != null;
context_everywhere partitionSize >= 0 && tCount >= 0 && radix > 1;
context_everywhere tCount >= radix;
context_everywhere partitionSize * tCount == input.length;
context_everywhere radix == output.length;
context_everywhere tempCounts.length == tCount;
context_everywhere (\forall int i; i >= 0 && i < tempCounts.length; tempCounts[i].length == radix);
context_everywhere output.length == radix;
//permission specifications
context (\forall* int i; i >= 0 && i < input.length; Perm(input[i], read));
context (\forall* int i; i >= 0 && i < output.length; Perm(output[i], write));
context (\forall* int i; i >= 0 && i < tCount;
(\forall* int j; j >= 0 && j < radix; Perm(tempCounts[i][j], write)));
//pre conditions
requires (\forall int i; i >= 0 && i < input.length; input[i] >= 0 && input[i] < radix);
void parCount(int radix, int[] input, int[] output, int partitionSize, int tCount, int[][] tempCounts) {
par count (int tid = 0..tCount)
//permission specifications
context (\forall* int i; i >= 0 && i < radix; Perm(tempCounts[tid][i], write));
requires (\forall* int i; i >= (partitionSize * tid) && i < (partitionSize * (tid + 1)); Perm(input[i], read));
//pre-conditions
requires (\forall int i; i >= (partitionSize * tid) && i < (partitionSize * (tid + 1)); input[i] >= 0 && input[i] < radix);
{
//permission invariants
loop_invariant (\forall* int i; i >= 0 && i < radix; Perm(tempCounts[tid][i], write));
loop_invariant (\forall* int i; i >= (partitionSize * tid) && i < (partitionSize * (tid + 1)); Perm(input[i], read));
//helper invariants
loop_invariant 0 <= i && i <= partitionSize;
loop_invariant (\forall int i; i >= (partitionSize * tid) && i < (partitionSize * (tid + 1)); input[i] >= 0 && input[i] < radix);
for (int i = 0; i < partitionSize; i++) {
int index = (tid * partitionSize) + i;
int number = input[index];
tempCounts[tid][number] = tempCounts[tid][number] + 1;
}
barrier(count) {
//permissions specifications
requires (\forall* int i; i >= 0 && i < radix; Perm(tempCounts[tid][i], write));
requires (\forall* int i; i >= (partitionSize * tid) && i < (partitionSize * (tid + 1)); Perm(input[i], read));
requires (\forall int i; i >= (partitionSize * tid) && i < (partitionSize * (tid + 1)); input[i] >= 0 && input[i] < radix);
ensures tid <= radix ==> (\forall* int i; i >= 0 && i < tCount; Perm(tempCounts[i][tid], write));
ensures tid <= radix ==> Perm(output[tid], write);
}
if (tid <= radix) {
//permission invariant
loop_invariant tid <= radix ==> (\forall* int i; i >= 0 && i < tCount; Perm(tempCounts[i][tid], write));
loop_invariant tid <= radix ==> Perm(output[tid], write);
//helper invariants
loop_invariant tid >= 0 && tid <= radix;
loop_invariant i >= 0 && i <= tCount;
for (int i = 0; i < tCount; i++) {
output[tid] = output[tid] + tempCounts[i][tid];
}
}
barrier(count) {
requires tid <= radix ==> (\forall* int i; i >= 0 && i < tCount; Perm(tempCounts[i][tid], write));
requires tid <= radix ==> Perm(output[tid], write);
ensures (\forall* int i; i >= 0 && i < radix; Perm(tempCounts[tid][i], write));
}
}
}
//ghost variables
given int[] prefixsumCopy;
//function invariants
context_everywhere input != null && output != null && partialInput != null && prefixsum != null;
context_everywhere radix > 1;
context_everywhere input.length == output.length && input.length == partialInput.length;
context_everywhere prefixsum.length == radix + 1;
//permission specifications
context (\forall* int i; i >= 0 && i < input.length; Perm(input[i], read));
context (\forall* int i; i >= 0 && i < partialInput.length; Perm(partialInput[i], read));
context (\forall* int i; i >= 0 && i < output.length; Perm(output[i], write));
context (\forall* int i; i >= 0 && i < prefixsum.length; Perm(prefixsum[i], write));
//preconditions
//requires prefixsum[prefixsum.length - 1] == input.length;
//requires (\forall int i; i >= 0 && i < prefixsum.length;
// (\forall int j; j >= 0 && j < prefixsum.length; i >= j ==> prefixsum[i] >= prefixsum[j]));
//requires (\forall int i; i >= 0 && i < prefixsum.length; prefixsum[i] >= 0 && prefixsum[i] < output.length);
void parReorder(int radix, int[] partialInput, int[] input, int[] output, int[] prefixsum) {
/*
par reorder(int tid = 0..radix)
//ghost variables
context prefixsum == prefixsumCopy;
//permission specifications
requires Perm(prefixsum[tid], write);
requires Perm(prefixsumCopy[tid], read);
requires Perm(prefixsumCopy[tid + 1], read);
requires (\forall* int i; i >= 0 && i < input.length; Perm(input[i], read));
requires (\forall* int i; i >= 0 && i < partialInput.length; Perm(partialInput[i], read));
requires (\forall* int i; i >= prefixsumCopy[tid] && i < prefixsumCopy[tid + 1]; Perm(output[i], write));
{
loop_invariant Perm(prefixsum[tid], write);
loop_invariant Perm(prefixsumCopy[tid], read);
loop_invariant Perm(prefixsumCopy[tid + 1], read);
loop_invariant (\forall* int i; i >= 0 && i < input.length; Perm(input[i], read));
loop_invariant (\forall* int i; i >= 0 && i < partialInput.length; Perm(partialInput[i], read));
loop_invariant (\forall* int i; i >= prefixsumCopy[tid] && i < prefixsumCopy[tid + 1]; Perm(output[i], write));
for(int i = 0; i < partialInput.length; i++) {
if(partialInput[i] == tid) {
output[prefixsum[tid]] = input[i];
prefixsum[tid] = prefixsum[tid] + 1;
}
}
}*/
}
//parallel prefix sum
requires 0 <= p;
ensures p < \result;
static pure int ExpTwo(int p) = 0 < p ? 2 * ExpTwo(p - 1) : 1;
ensures |xs| == 0 ==> \result == 0;
ensures |xs| == 1 ==> \result == head(xs);
static pure int intsum(seq<int> xs) =
0 < |xs| ? head(xs) + intsum(tail(xs)) : 0;
requires n <= |xs|;
ensures n < 0 ==> |Take(xs, n)| == 0;
ensures 0 <= n ==> |Take(xs, n)| == n;
ensures (\forall int i; 0 <= i && i < n; xs[i] == get(Take(xs, n), i));
static pure seq<int> Take(seq<int> xs, int n) =
0 < n ? seq<int> { head(xs) } + Take(tail(xs), n - 1) : seq<int> { };
requires 0 <= i && i <= |xs|;
ensures |\result| == |xs| - i;
ensures (\forall int j; 0 <= j && j < |\result|; \result[j] == intsum(Take(xs, i+j)));
static pure seq<int> psum(seq<int> xs, int i) =
i < |xs| ? seq<int> { intsum(Take(xs, i)) } + psum(xs, i + 1) : seq<int> { };
// TODO use this version instead of the above `psum` (the above version is just a helper definition).
ensures |\result| == |xs|;
ensures (\forall int j; 0 <= j && j < |\result|; \result[j] == intsum(Take(xs, j)));
static pure seq<int> psum2(seq<int> xs) = psum(xs, 0);
requires |xs| >= 0;
ensures |xs| == 0 ==> \result == xs;
ensures |xs| == 1 ==> \result == xs;
ensures |xs| == 2 ==> \result == seq<int> { head(xs) + head(tail(xs)) };
ensures |xs| % 2 == 0 ==> |\result| == |xs| / 2;
static pure seq<int> implode(seq<int> xs) =
1 < |xs| ? seq<int> { head(xs) + head(tail(xs)) } + implode(tail(tail(xs))) : xs;
requires 0 <= p;
static pure int exp(int n, int p) = 0 < p ? n * exp(n, p - 1) : 1;
requires 0 <= n;
requires n < |xs|;
static pure int get(seq<int> xs, int n) = xs[n];
requires k > 0;
requires |xs| == ExpTwo(k);
requires i >= 0 && i <= |xs|;
requires 1 <= lvl && lvl <= k;
requires stride == ExpTwo(lvl-1);
requires stride > 0 && stride < |xs|;
ensures |\result| == |xs| - i;
ensures (\forall int j; j >= 0 && j < |\result|; ((i < |xs|) && ((i+j) >= stride) && (((i+j) % (2*stride)) == (2*stride-1))) ==> \result[j] == xs[i+j] + xs[i+j - stride]);
ensures (\forall int j; j >= 0 && j < |\result|; ((i < |xs|) && (((i+j) < stride) || (((i+j) % (2*stride)) != (2*stride-1)))) ==> \result[j] == xs[i+j]);
static pure seq<int> up(seq<int> xs, int stride, int i, int k, int lvl) =
i < |xs| ? (
((i % (2*stride)) == (2*stride-1) && (i >= stride)?
seq<int> {xs[i] + xs[i-stride]} + up(xs, stride, i+1, k, lvl)
:
seq<int> {xs[i]} + up(xs, stride, i+1, k, lvl) ))
:
seq<int> {};
////////////////////////////////////////////////////////////////////////////////////////Lemmas
ensures intsum(seq<int> { }) == 0;
void lemma_intsum_zero() {
}
ensures psum2(seq<int> { }) == seq<int> { };
void lemma_psum_zero() {
}
ensures intsum(seq<int> { x }) == x;
void lemma_intsum_single(int x) {
assert tail(seq<int> { x }) == seq<int> { };
lemma_intsum_zero();
}
requires |xs| == 1;
ensures psum2(xs) == seq<int> {0};
void lemma_psum_single(seq<int> xs) {
assert tail(xs) == seq<int> { };
lemma_psum_zero();
}
ensures psum2(seq<int> { x, y }) == seq<int> { 0, x };
void lemma_psum_double(int x, int y) {
lemma_psum_single(tail(seq<int> { x, y }));
}
requires |xs| >= 0;
requires |ys| >= 0;
ensures |xs| == 0 ==> intsum(xs + ys) == intsum(ys);
ensures |ys| == 0 ==> intsum(xs + ys) == intsum(xs);
ensures |xs + ys| == |xs| + |ys|;
ensures intsum(tail(xs) + ys) == intsum(tail(xs)) + intsum(ys);
ensures intsum(xs + ys) == intsum(xs) + intsum(ys);
void lemma_intsum_app(seq<int> xs, seq<int> ys) {
if (0 < |xs|) {
lemma_intsum_app(tail(xs), ys);
assert tail(xs) + ys == tail(xs + ys);
}
}
requires |xs| <= 1;
ensures xs == implode(xs);
void lemma_implode_base(seq<int> xs) {
}
ensures implode(seq<int> { x, y }) == seq<int> { x + y };
void lemma_implode_single(int x, int y) {
lemma_implode_base(tail(tail(seq<int> { x, y })));
}
ensures intsum(xs) == intsum(implode(xs));
void lemma_implode_sum(seq<int> xs) {
if (1 < |xs|) {
lemma_implode_sum(tail(tail(xs)));
lemma_intsum_app(seq<int> { head(xs) + head(tail(xs)) }, implode(tail(tail(xs))));
lemma_intsum_single(head(xs) + head(tail(xs)));
assert intsum(xs) == head(xs) + intsum(tail(xs));
}
}
requires 0 < n;
ensures ExpTwo(n) == 2 * ExpTwo(n - 1);
void lemma_exp2_red_mult(int n) {
}
requires 0 < n;
ensures ExpTwo(n) / 2 == ExpTwo(n - 1);
void lemma_exp2_red_div(int n) {
}
requires 0 <= n;
ensures 0 < ExpTwo(n);
void lemma_exp2_positive(int n) {
if (0 < n) {
lemma_exp2_positive(n - 1);
}
}
requires 0 <= i;
requires i <= j;
ensures ExpTwo(i) <= ExpTwo(j);
void lemma_exp2_leq(int i, int j) {
if (0 < i) {
lemma_exp2_leq(i - 1, j - 1);
} else {
lemma_exp2_positive(j);
}
}
requires i >= 0 && j >= 0;
requires ExpTwo(i) == ExpTwo(j);
ensures i == j;
void power_two_lemma(int i, int j){
if (0 < i && j > 0) {
power_two_lemma(i - 1, j - 1);
} else {
if(i > j){
lemma_exp2_leq(j, i);
} else {
lemma_exp2_leq(i, j);
}
}
}
requires |xs| % 2 == 0;
ensures |implode(xs)| == |xs| / 2;
void lemma_implode_length_mod_two(seq<int> xs) {
if (1 < |xs|) {
lemma_implode_length_mod_two(tail(tail(xs)));
}
}
requires 0 < n && |xs| == ExpTwo(n);
ensures |implode(xs)| == ExpTwo(n - 1);
void lemma_implode_red_exp2(seq<int> xs, int n) {
if (1 < n) {
lemma_implode_length_mod_two(xs);
lemma_implode_red_exp2(implode(xs), n - 1);
} else {
lemma_implode_length_mod_two(xs);
}
}
requires 0 < i;
requires i < |xs|;
ensures get(tail(xs), i - 1) == xs[i];
void lemma_intseq_index_tail(seq<int> xs, int i) {
}
requires |xs| % 2 == 0;
requires 0 <= i && i < |implode(xs)|;
requires (2 * i) < |xs|;
requires (2 * i + 1) < |xs|;
ensures get(implode(xs), i) == xs[2 * i] + xs[2 * i + 1];
void lemma_implode_get(seq<int> xs, int i) {
if (1 < |xs|) {
if (0 < i) {
lemma_implode_get(tail(tail(xs)), i - 1);
lemma_implode_length_mod_two(xs);
}
}
}
requires |xs| % 2 == 0;
requires |implode(xs)| == |xs|/2;
ensures (\forall int i; 0 <= i && i < |implode(xs)|; get(implode(xs), i) == xs[2 * i] + xs[2 * i + 1]);
void lemma_implode_get_all(seq<int> xs) {
int j = 0;
loop_invariant 0 <= j && j <= |implode(xs)|;
loop_invariant (\forall int k; 0 <= k && k < j; get(implode(xs), k) == xs[2 * k] + xs[2 * k + 1]);
while (j < |implode(xs)|) {
lemma_implode_get(xs, j);
j = j + 1;
}
}
requires |xs| == 2 * |ys|;
requires 0 <= |ys|;
requires (\forall int i; 0 <= i && i < |ys|; ys[i] == xs[2*i] + xs[2*i+1]);
ensures ys == implode(xs);
void lemma_implode_rel(seq<int> xs, seq<int> ys) {
if (0 < |ys|) {
lemma_implode_rel(tail(tail(xs)), tail(ys));
}
}
requires 0 <= i && i < |xs|;
ensures get(psum2(xs), i) == intsum(Take(xs, i));
void lemma_psum_get(seq<int> xs, int i) {
if (0 < |xs|) {
if (0 < i) {
lemma_psum_get(tail(xs), i - 1);
}
}
}
ensures (\forall int i; 0 <= i && i < |xs|; get(psum2(xs), i) == intsum(Take(xs, i)));
void lemma_psum_get_all(seq<int> xs) {
int j = 0;
loop_invariant 0 <= j && j <= |xs|;
loop_invariant (\forall int k; 0 <= k && k < j; get(psum2(xs), k) == intsum(Take(xs, k)));
while (j < |xs|) {
lemma_psum_get(xs, j);
j = j + 1;
}
}
requires 0 < n && n <= |xs|;
ensures Take(xs, n) == Take(xs, n - 1) + seq<int> { xs[n - 1] };
void missing_lemma_2(seq<int> xs, int n) {
if (1 < n) {
missing_lemma_2(tail(xs), n - 1);
}
}
requires |xs| % 2 == 0;
requires |ys| % 2 == 0;
ensures implode(xs + ys) == implode(xs) + implode(ys);
void missing_lemma_3(seq<int> xs, seq<int> ys) {
if (0 < |xs|) {
missing_lemma_3(tail(tail(xs)), ys);
assert tail(tail(xs)) + ys == tail(tail(xs + ys));
}
}
ensures xs + (ys + zs) == (xs + ys) + zs;
void intseq_concat_assoc(seq<int> xs, seq<int> ys, seq<int> zs) { }
requires |xs| % 2 == 0;
requires 0 <= n && n < |implode(xs)|;
requires |implode(xs)| == |xs| / 2;
ensures Take(implode(xs), n) == implode(Take(xs, 2 * n));
void missing_lemma(seq<int> xs, int n) {
if (0 < n) {
missing_lemma(xs, n - 1);
assert Take(implode(xs), n - 1) == implode(Take(xs, 2 * n - 2)); // this is our induction hypothesis (IH)
missing_lemma_2(implode(xs), n);
assert Take(implode(xs), n) == Take(implode(xs), n - 1) + seq<int> { get(implode(xs), n - 1) };
assert Take(implode(xs), n) == implode(Take(xs, 2 * n - 2)) + seq<int> { get(implode(xs), n - 1) };
lemma_implode_get(xs, n - 1);
assert Take(implode(xs), n) == implode(Take(xs, 2 * n - 2)) + seq<int> { xs[2 * (n - 1)] + xs[2 * (n - 1) + 1] };
assert Take(implode(xs), n) == implode(Take(xs, 2 * n - 2)) + implode(seq<int> { xs[2 * n - 2], xs[2 * n - 1] });
missing_lemma_3(Take(xs, 2 * n - 2), seq<int> { xs[2 * n - 2] } + seq<int> { xs[2 * n - 1] });
assert Take(implode(xs), n) == implode(Take(xs, 2 * n - 2) + (seq<int> { xs[2 * n - 2] } + seq<int> { xs[2 * n - 1] }));
intseq_concat_assoc(Take(xs, 2 * n - 2), seq<int> { xs[2 * n - 2] }, seq<int> { xs[2 * n - 1] });
assert Take(implode(xs), n) == implode((Take(xs, 2 * n - 2) + seq<int> { xs[2 * n - 2] }) + seq<int> { xs[2 * n - 1] });
missing_lemma_2(xs, 2 * n - 1);
assert Take(implode(xs), n) == implode(Take(xs, 2 * n - 1) + seq<int> { xs[2 * n - 1] });
missing_lemma_2(xs, 2 * n);
assert Take(implode(xs), n) == implode(Take(xs, 2 * n));
}
else {
assert Take(implode(xs), n) == implode(Take(xs, 2 * n));
}
}
requires |xs| % 2 == 0;
requires |implode(xs)| == |xs|/2;
requires 0 <= i && i < |implode(xs)|;
requires 2 * i < |xs|;
ensures get(psum2(implode(xs)), i) == intsum(Take(xs, 2 * i));
void lemma_psum_Take2(seq<int> xs, int i) {
// we first have
assert get(psum2(implode(xs)), i) == intsum(Take(implode(xs), i));
// then
missing_lemma(xs, i);
assert intsum(Take(implode(xs), i)) == intsum(implode(Take(xs, 2 * i)));
// thus
lemma_implode_sum(Take(xs, 2 * i));
assert intsum(implode(Take(xs, 2 * i))) == intsum(Take(xs, 2 * i));
}
requires |xs| % 2 == 0;
requires |implode(xs)| == |xs|/2;
requires 0 <= i && i < |implode(xs)|;
requires 2 * i < |xs|;
ensures get(psum2(implode(xs)), i) == get(psum2(xs), 2 * i);
void lemma_get_psum_implode(seq<int> xs, int i) {
lemma_psum_Take2(xs, i);
}
requires 0 <= i;
requires 2 * i + 1 < |xs|;
ensures get(psum2(xs), 2 * i + 1) == get(psum2(xs), 2 * i) + get(xs, 2 * i);
void lemma_combine_psum(seq<int> xs, int i){
lemma_psum_get(xs, i);
assert get(psum2(xs), 2 * i) == intsum(Take(xs, 2 * i));
assert get(psum2(xs), 2 * i + 1) == intsum(Take(xs, 2 * i + 1));
missing_lemma_2(xs, 2 * i + 1);
assert Take(xs, 2 * i + 1) == Take(xs, 2 * i) + seq<int> { xs[2 * i] };
assert intsum(Take(xs, 2 * i + 1)) == intsum(Take(xs, 2 * i) + seq<int> { xs[2 * i] });
lemma_intsum_app(Take(xs, 2 * i), seq<int> { xs[2 * i] } );
assert intsum(Take(xs, 2 * i) + seq<int> { xs[2 * i] }) == intsum(Take(xs, 2 * i)) + intsum(seq<int> { xs[2 * i] });
assert intsum(Take(xs, 2 * i) + seq<int> { xs[2 * i] }) == intsum(Take(xs, 2 * i)) + get(xs, 2 * i);
assert intsum(Take(xs, 2 * i + 1)) == intsum(Take(xs, 2 * i)) + get(xs, 2 * i);
assert get(psum2(xs), 2 * i + 1) == get(psum2(xs), 2 * i) + get(xs, 2 * i);
}
///////////////////////////////////////////////////////////////////////////////////////////
given int k;
given seq<int> input_seq;
context_everywhere k > 0;
context_everywhere input != null;
context_everywhere output != null;
context_everywhere output.length == ExpTwo(k);
context_everywhere input.length == output.length;
context_everywhere |input_seq| == input.length;
requires (\forall* int i; 0<=i && i<input.length; Perm(input[i], 1/2)); //TODO ask how to retain read
requires (\forall* int i; 0<=i && i<output.length; Perm(output[i], write));
requires (\forall* int i; 0<=i && i<output.length; output[i] == input[i]);
requires (\forall* int i; 0<=i && i<output.length; input_seq[i] == input[i]);
ensures (\forall* int i; 0<=i && i<input.length; Perm(input[i], read)); //TODO ask how to retain read
ensures (\forall* int i; 0<=i && i<output.length; Perm(output[i], write));
ensures (\forall* int i; 0<=i && i<output.length; input_seq[i] == input[i]);
ensures (\forall* int i; 0<=i && i<output.length; output[i] == get(psum2(input_seq), i));
void prefixSum(int[] input, int[] output)
{
par Threads (int tid=0..output.length)
requires Perm(input[tid], read);
requires 2 * tid < output.length ==> Perm(output[2 * tid], write);
requires 2 * tid + 1 < output.length ==> Perm(output[2 * tid + 1], write);
requires tid==0 ==> (\forall* int i; 0 <= i && i < output.length && (i + 1) % 1 != 0; Perm(output[i], write));
requires 2 * tid < output.length ==> output[2 * tid] == input_seq[2 * tid];
requires 2 * tid + 1 < output.length ==> output[2 * tid + 1] == input_seq[2 * tid + 1];
requires input_seq[tid] == input[tid];
ensures Perm(input[tid], read);
ensures input_seq[tid] == input[tid];
ensures tid < output.length ==> Perm(output[tid], write);
ensures tid < output.length ==> get(psum2(input_seq), tid) == output[tid];
{
int indicator = 2 * tid + 1;
int stride = 1;
int lvl = 1;
seq<seq<int>> Matrix_UP = seq<seq<int>> { input_seq }; // ghost code
assert (\forall int i; 0 < i && i < lvl; Matrix_UP[i] == up(Matrix_UP[i - 1], stride/ExpTwo(lvl-i), 0, k, i));
seq<seq<int>> Matrix = seq<seq<int>> { input_seq };
loop_invariant k > 0;
loop_invariant output.length == ExpTwo(k);
loop_invariant tid >= 0 && tid < output.length;
loop_invariant stride > 0;
loop_invariant 1 <= lvl;
loop_invariant stride == ExpTwo(lvl-1);
loop_invariant lvl <= k+1;
loop_invariant indicator + 1 == ExpTwo(lvl)*(tid+1);
loop_invariant indicator + 1 == 2*stride*(tid+1);
loop_invariant indicator > 0;
loop_invariant stride <= output.length;
loop_invariant indicator < output.length ==> Perm(output[indicator], write);
loop_invariant indicator < output.length && indicator >= stride ==> Perm(output[indicator - stride], write);
loop_invariant tid==0 ==> (\forall* int i; 0 <= i && i < output.length && (i + 1) % stride != 0; Perm(output[i], write));
loop_invariant (tid==0 && (stride == output.length)) ==> (Perm(output[output.length - 1], write));
loop_invariant |Matrix_UP| == lvl;
loop_invariant (\forall int i; 0 <= i && i < lvl; |Matrix_UP[i]| == output.length);
loop_invariant lvl == 1 ==> Matrix_UP[lvl - 1] == input_seq;
loop_invariant lvl > 1 && lvl < |Matrix_UP| ==> Matrix_UP[lvl] == up(Matrix_UP[lvl - 1], (stride/2) - 1, 0, k, lvl - 1);
//loop_invariant (\forall int i; 0 < i && i < lvl; Matrix_UP[i] == up(Matrix_UP[i - 1], stride/ExpTwo(lvl-i), 0, k, i));
//loop_invariant (\forall int i; 0 < i && i < lvl; Matrix_UP[i] == up(Matrix_UP[i - 1], \old(stride)*ExpTwo(i), 0, k, i));
loop_invariant indicator < output.length ==> Matrix_UP[lvl - 1][indicator] == output[indicator];
loop_invariant indicator < output.length && indicator >= stride ==> Matrix_UP[lvl - 1][indicator - stride] == output[indicator - stride];
loop_invariant lvl == k+1 ==> Matrix_UP[lvl-1][output.length - 1] == intsum(input_seq);
loop_invariant lvl == k+1 ==> Matrix_UP[lvl-1][(output.length - 1)/2] == intsum(Take(input_seq, |input_seq|/2));
loop_invariant |Matrix| == lvl;
loop_invariant (\forall int i; 0 <= i && i < lvl; 0 <= |Matrix[i]| && |Matrix[i]| <= output.length);
loop_invariant (\forall int i; 0 <= i && i < lvl; |Matrix[i]| == ExpTwo(k - i));
loop_invariant (\forall int i; 0 < i && i < lvl; Matrix[i] == implode(Matrix[i - 1]));
loop_invariant (\forall int i; 0 <= i && i < lvl; intsum(Matrix[i]) == intsum(input_seq));
loop_invariant Matrix[0] == input_seq;
loop_invariant indicator < output.length && 2 * tid + 1 < |Matrix[lvl - 1]| ==> output[indicator] == Matrix[lvl - 1][2 * tid + 1];
loop_invariant indicator < output.length && indicator >= stride && 2 * tid < |Matrix[lvl - 1]| ==> output[indicator - stride] == Matrix[lvl - 1][2 * tid];
while(stride < output.length)
{
if(indicator < output.length && indicator >= stride)
{
assert 2 * tid + 1 < |Matrix[lvl - 1]| ==> output[indicator] == Matrix[lvl - 1][2 * tid + 1];
assert 2 * tid < |Matrix[lvl - 1]| ==> output[indicator - stride] == Matrix[lvl - 1][2 * tid];
output[indicator] = output[indicator] + output[indicator - stride];
assert 2 * tid + 1 < |Matrix[lvl - 1]| ==> output[indicator] == Matrix[lvl - 1][2 * tid + 1] + Matrix[lvl - 1][2 * tid];
}
lemma_implode_length_mod_two(Matrix[lvl - 1]);
lemma_implode_sum(Matrix[lvl - 1]);
lemma_implode_get_all(Matrix[lvl - 1]);
Matrix = Matrix + seq<seq<int>> { implode(Matrix[lvl - 1]) };
if(tid < |implode(Matrix[lvl - 1])|){
lemma_implode_get(Matrix[lvl - 1], tid);
assert 2 * tid + 1 < |Matrix[lvl - 1]| ==> get(implode(Matrix[lvl - 1]), tid) == Matrix[lvl - 1][2 * tid] + Matrix[lvl - 1][2 * tid + 1];
assert indicator < output.length && indicator >= stride ==> output[indicator] == Matrix[lvl - 1][2 * tid + 1] + Matrix[lvl - 1][2 * tid];
assert Matrix[lvl] == implode(Matrix[lvl - 1]);
assert indicator < output.length && indicator >= stride ==> output[indicator] == Matrix[lvl][tid];
}
barrier(Threads)
{
context_everywhere k > 0;
context_everywhere 1 <= lvl && lvl <= k;
context_everywhere output.length == ExpTwo(k);
context_everywhere |Matrix| == lvl + 1;
requires tid >= 0 && tid < output.length;
requires stride == ExpTwo(lvl-1);
requires stride > 0 && stride < output.length;
requires indicator + 1 == ExpTwo(lvl)*(tid+1);
requires indicator + 1 == 2*stride*(tid+1);
requires indicator > 0;
requires indicator < output.length ==> Perm(output[indicator], write);
requires indicator < output.length && indicator >= stride ==> Perm(output[indicator - stride], write);
requires tid==0 ==> (\forall* int i; 0 <= i && i < output.length && (i + 1) % stride != 0; Perm(output[i], write));
//requires (\forall int i; 0 <= i && i < lvl; 0 <= |Matrix[i]| && |Matrix[i]| <= output.length);
//requires (\forall int i; 0 <= i && i < lvl; |Matrix[i]| == ExpTwo(k - i));
//requires indicator < output.length && indicator >= stride && tid < |Matrix[lvl]| ==> output[indicator] == Matrix[lvl][tid];
ensures tid >= 0 && tid < output.length;
ensures 2 * stride == ExpTwo(lvl);
ensures 2 * stride > 0 && 2 * stride <= output.length;
ensures 2 * indicator + 2 == ExpTwo(lvl+1)*(tid+1);
ensures 2 * indicator + 2 == 2*stride*(tid+1);
ensures 2 * indicator + 1 > 0;
ensures 2 * indicator + 1 < output.length ==> Perm(output[2 * indicator + 1], write);
ensures 2 * indicator + 1 < output.length && 2 * indicator + 1 >= 2 * stride ==> Perm(output[2 * indicator + 1 - 2 * stride], write);
ensures tid==0 ==> (\forall* int i; 0 <= i && i < output.length && (i + 1) % (2 * stride) != 0; Perm(output[i], write));
ensures (tid==0 && (2 * stride == output.length)) ==> (Perm(output[output.length - 1], write));
//ensures (\forall int i; 0 <= i && i <= lvl; 0 <= |Matrix[i]| && |Matrix[i]| <= output.length);
//ensures (\forall int i; 0 <= i && i <= lvl; |Matrix[i]| == ExpTwo(k - i));
//ensures 2 * indicator + 1 < output.length && 2 * tid + 1 < |Matrix[lvl+1]| ==> output[2 * indicator + 1] == Matrix[lvl+1][2 * tid + 1];
//ensures 2 * indicator + 1 < output.length && 2 * indicator + 1 >= 2 * stride && 2 * tid < |Matrix[lvl+1]| ==> output[2 * indicator + 1 - 2 * stride] == Matrix[lvl+1][2 * tid];
}
Matrix_UP = Matrix_UP + seq<seq<int>> { up(Matrix_UP[lvl - 1], stride, 0, k, lvl) };
assert (indicator < output.length) && (indicator >= stride) ==> Matrix_UP[lvl][indicator] == Matrix_UP[lvl - 1][indicator] + Matrix_UP[lvl - 1][indicator-stride];
//assert (\forall int i; 0 < i && i <= lvl; Matrix_UP[i] == up(Matrix_UP[i - 1], stride/ExpTwo(lvl-i), 0, k, i));
indicator = 2 * indicator + 1;
stride = 2 * stride;
lvl = lvl + 1;
assert (\forall int i; 0 < i && i < lvl; Matrix_UP[i] == up(Matrix_UP[i - 1], stride/ExpTwo(lvl-i), 0, k, i));
assert stride == ExpTwo(lvl-1);
lemma_exp2_red_mult(lvl);
assert ExpTwo(lvl) == 2 * ExpTwo(lvl - 1);
assert 2*stride == ExpTwo(lvl);
assert indicator + 1 == ExpTwo(lvl)*(tid+1);
assert indicator + 1 == 2*stride*(tid+1);
}
assert stride == output.length;
assert stride == ExpTwo(lvl-1);
assert output.length == ExpTwo(k);
assert ExpTwo(lvl-1) == ExpTwo(k);
power_two_lemma(lvl-1, k);
assert lvl == k + 1;
assert indicator < output.length ==> Matrix_UP[lvl - 1][indicator] == output[indicator];
//assert (\forall int i; 0 < i && i < lvl; Matrix_UP[i] == up(Matrix_UP[i - 1], stride/ExpTwo(lvl-i), 0, k, i));
assert |Matrix| == lvl;
assert (\forall int i; 0 <= i && i < k + 1; |Matrix[i]| == ExpTwo(k - i));
assert (\forall int i; 0 < i && i < k + 1; Matrix[i] == implode(Matrix[i - 1]));
assert (\forall int i; 0 <= i && i < k + 1; intsum(Matrix[i]) == intsum(input_seq));
assert |Matrix[k]| == 1;
lemma_intsum_single(Matrix[k][0]);
assert intsum(Matrix[k]) == intsum(input_seq);
assert Matrix[k] == seq<int>{intsum(input_seq)};
assert Matrix[0] == input_seq;
assert (\forall int i; 0 <= i && i < k + 1; 0 < |Matrix[i]| && |Matrix[i]| <= output.length);
/////////////////////////////////////////////////////////////////////////////////
barrier(Threads)
{
context_everywhere k > 0;
context_everywhere output.length == ExpTwo(k);
context_everywhere |Matrix_UP| == k + 1;
context_everywhere lvl == k + 1;
requires stride == output.length;
requires indicator + 1 == ExpTwo(lvl)*(tid+1);
requires indicator + 1 == 2*stride*(tid+1);
requires indicator > 0;
requires stride > 0 ;
requires stride == output.length;
requires indicator < output.length ==> Perm(output[indicator], write);
requires indicator < output.length && indicator >= stride ==> Perm(output[indicator - stride], write);
requires tid==0 ==> (\forall* int i; 0 <= i && i < output.length && (i + 1) % stride != 0; Perm(output[i], write));
requires (tid==0 && (stride == output.length)) ==> (Perm(output[output.length - 1], write));
requires (\forall int i; 0 <= i && i <= k; |Matrix_UP[i]| == output.length);
requires indicator < output.length && indicator >= stride ==> Matrix_UP[lvl - 1][indicator] == output[indicator];
requires indicator < output.length && indicator >= stride ==> Matrix_UP[lvl - 1][indicator - stride] == output[indicator - stride];
context tid >= 0 && tid < output.length;
ensures stride == output.length / 2;
ensures indicator == output.length * tid + output.length - 1;
ensures stride > 0 ;
ensures indicator > 0;
ensures output.length * tid + output.length - 1 < output.length ==> Perm(output[output.length * tid + output.length - 1], write);
ensures output.length * tid + output.length - 1 < output.length && output.length * tid + output.length - 1 >= output.length / 2 ==> Perm(output[output.length * tid + output.length - 1 - output.length / 2], write);
ensures (\forall int i; 0 <= i && i <= k; |Matrix_UP[i]| == output.length);
ensures output.length * tid + output.length - 1 < output.length ==> Matrix_UP[lvl - 1][indicator] == output[indicator];
ensures output.length * tid + output.length - 1 < output.length && output.length * tid + output.length - 1 >= output.length / 2 ==> Matrix_UP[lvl - 1][indicator - stride] == output[indicator - stride];
ensures tid==0 ==> (\forall* int i; 0 <= i && i < output.length && (i + 1) % (output.length / 2) != 0; Perm(output[i], write));
}
/////////////////////////////////////////////////////////////////////////////////////// Down
/*if(output.length * tid + output.length - 1 < output.length)
{
output[output.length * tid + output.length - 1] = 0;
}*/
//assert (\forall int i; 0 <= i && i < k + 1; |Matrix_UP[i]| == ExpTwo(k - i));
//assert (\forall int i; 0 < i && i < k + 1; Matrix_UP[i] == implode(Matrix_UP[i - 1]));
//assert (\forall int i; 0 <= i && i < k + 1; intsum(Matrix_UP[i]) == intsum(input_seq));
//assert Matrix_UP[k] == seq<int>{intsum(input_seq)};
indicator = output.length * tid + output.length - 1; // output.length * tid + output.length - 1;
stride = output.length / 2; // output.length / 2;
lvl = k - 1; //lvl - 2;
int temp;
seq<int> temp_seq = seq<int> { 0 };
assert output.length * tid + output.length - 1 < output.length ==> Matrix_UP[lvl + 1][indicator] == output[indicator];
assert output.length * tid + output.length - 1 < output.length && output.length * tid + output.length - 1 >= output.length / 2 ==> Matrix_UP[lvl + 1][indicator - stride] == output[indicator - stride];
if(indicator < output.length)
{
output[indicator] = 0;
}
loop_invariant k > 0;
loop_invariant output.length == ExpTwo(k);
loop_invariant tid >= 0 && tid < output.length;
loop_invariant lvl <= k - 1;
loop_invariant lvl >= -1;
loop_invariant lvl >= 0 ==> stride == ExpTwo(lvl);
loop_invariant lvl == -1 ==> stride == 0;
loop_invariant stride >= 0;
loop_invariant indicator >= 0;
loop_invariant indicator+1 == ExpTwo(lvl+1)*(tid+1);
loop_invariant indicator < output.length ==> Perm(output[indicator], write);
loop_invariant indicator < output.length && indicator >= stride ==> Perm(output[indicator - stride], write);
loop_invariant |temp_seq| == ExpTwo(k - (lvl + 1));
loop_invariant 0 < |temp_seq| && |temp_seq| <= output.length;
loop_invariant temp_seq == psum2(Matrix[lvl + 1]);
loop_invariant (\forall int i; 0 <= i && i < k + 1; 0 < |Matrix[i]| && |Matrix[i]| <= output.length);
loop_invariant (\forall int i; 0 <= i && i < k + 1; |Matrix[i]| == ExpTwo(k - i));
loop_invariant (\forall int i; 0 <= i && i < k + 1; intsum(Matrix[i]) == intsum(input_seq));
loop_invariant (\forall int i; 0 < i && i < k + 1; Matrix[i] == implode(Matrix[i - 1]));
loop_invariant Matrix[0] == input_seq;
loop_invariant Matrix[k] == seq<int>{ intsum(input_seq) };
loop_invariant tid < |temp_seq| && indicator < output.length ==> temp_seq[tid] == output[indicator];
loop_invariant 2 * tid < |Matrix[lvl]| && indicator < output.length && indicator >= stride ==> output[indicator - stride] == get(Matrix[lvl], 2 * tid);
loop_invariant tid==0 ==> (\forall* int i; 0 <= i && i < output.length && (i + 1) % stride != 0; Perm(output[i], write));
while(stride >= 1)
{
if(indicator < output.length && indicator >= stride)
{
//assume 2 * tid < |Matrix[lvl]| ==> output[indicator - stride] == get(Matrix[lvl], 2 * tid);
assert tid < |temp_seq| ==> temp_seq[tid] == output[indicator];
temp = output[indicator];
assert tid < |temp_seq| ==> temp == temp_seq[tid];
output[indicator] = output[indicator] + output[indicator - stride];
assert tid < |temp_seq| ==> output[indicator] == temp_seq[tid] + output[indicator - stride];
assert 2 * tid < |Matrix[lvl]| ==> output[indicator - stride] == get(Matrix[lvl], 2 * tid);
assert 2 * tid < |Matrix[lvl]| && tid < |temp_seq| ==> output[indicator] == temp_seq[tid] + get(Matrix[lvl], 2 * tid);
assert tid < |Matrix[lvl + 1]| && tid < |temp_seq| ==> temp_seq[tid] == get(psum2(Matrix[lvl + 1]), tid);
assert tid < |Matrix[lvl + 1]| && 2 * tid < |Matrix[lvl]| ==> output[indicator] == get(psum2(Matrix[lvl + 1]), tid) + get(Matrix[lvl], 2 * tid);
assert Matrix[lvl + 1] == implode(Matrix[lvl]);
assert tid < |implode(Matrix[lvl])| && 2 * tid < |Matrix[lvl]| ==> output[indicator] == get(psum2(implode(Matrix[lvl])), tid) + get(Matrix[lvl], 2 * tid);
if(tid < |implode(Matrix[lvl])|){
lemma_get_psum_implode(Matrix[lvl], tid);
}
assert tid < |implode(Matrix[lvl])| && 2 * tid < |Matrix[lvl]| ==> get(psum2(implode(Matrix[lvl])), tid) == get(psum2(Matrix[lvl]), 2 * tid);
assert 2 * tid < |Matrix[lvl]| ==> output[indicator] == get(psum2(Matrix[lvl]), 2 * tid) + get(Matrix[lvl], 2 * tid);
if(2 * tid + 1 < |Matrix[lvl]|){
lemma_combine_psum(Matrix[lvl], tid);
}
assert 2 * tid + 1 < |Matrix[lvl]| ==> get(psum2(Matrix[lvl]), 2 * tid + 1) == get(psum2(Matrix[lvl]), 2 * tid) + get(Matrix[lvl], 2 * tid);
assert 2 * tid + 1 < |Matrix[lvl]| ==> output[indicator] == get(psum2(Matrix[lvl]), 2 * tid + 1);
assert tid < |temp_seq| ==> temp == temp_seq[tid];
output[indicator - stride] = temp;
assert tid < |temp_seq| ==> output[indicator - stride] == temp_seq[tid];
assert tid < |Matrix[lvl + 1]| && tid < |temp_seq| ==> temp_seq[tid] == get(psum2(Matrix[lvl + 1]), tid);
assert Matrix[lvl + 1] == implode(Matrix[lvl]);
assert tid < |implode(Matrix[lvl])| && tid < |temp_seq| ==> temp_seq[tid] == get(psum2(implode(Matrix[lvl])), tid);
if(tid < |implode(Matrix[lvl])|){
lemma_get_psum_implode(Matrix[lvl], tid);
}
assert tid < |implode(Matrix[lvl])| && 2 * tid < |Matrix[lvl]| ==> get(psum2(implode(Matrix[lvl])), tid) == get(psum2(Matrix[lvl]), 2 * tid);
assert 2 * tid < |Matrix[lvl]| && tid < |temp_seq| ==> temp_seq[tid] == get(psum2(Matrix[lvl]), 2 * tid);
assert 2 * tid < |Matrix[lvl]| ==> output[indicator - stride] == get(psum2(Matrix[lvl]), 2 * tid);
}
temp_seq = psum2(Matrix[lvl]);
assert 2 * tid < |temp_seq| && indicator < output.length && indicator >= stride ==> output[indicator - stride] == temp_seq[2 * tid];
assert 2 * tid + 1 < |temp_seq| && indicator < output.length && indicator >= stride ==> output[indicator] == temp_seq[2 * tid + 1];
barrier(Threads)
{
context_everywhere output.length == ExpTwo(k);
context_everywhere lvl >= 0 && lvl <= k - 1;
requires tid >= 0 && tid < output.length;
requires indicator >= 0;
requires stride >= 0 ;
requires stride == ExpTwo(lvl);
requires indicator+1 == ExpTwo(lvl+1)*(tid+1);
requires indicator < output.length ==> Perm(output[indicator], write);
requires indicator < output.length && indicator >= stride ==> Perm(output[indicator - stride], write);
requires tid==0 ==> (\forall* int i; 0 <= i && i < output.length && (i + 1) % stride != 0; Perm(output[i], write));
ensures tid >= 0 && tid < output.length;
ensures lvl-1 >= 0 ==> stride / 2 == ExpTwo(lvl - 1);
ensures lvl-1 == -1 ==> stride / 2 == 0;
ensures stride / 2 >= 0;
ensures (indicator - 1) / 2 >= 0;
ensures (indicator - 1) / 2+1 == ExpTwo(lvl)*(tid+1);
ensures (indicator - 1) / 2 < output.length ==> Perm(output[(indicator - 1) / 2], write);
ensures (indicator - 1) / 2 < output.length && (indicator - 1) / 2 >= stride / 2 ==> Perm(output[(indicator - 1) / 2 - stride / 2], write);
ensures (tid==0 && stride/2 > 0) ==> (\forall* int i; 0 <= i && i < output.length && (i + 1) % (stride/2) != 0; Perm(output[i], write));
}
indicator = (indicator - 1) / 2;
stride = stride / 2;
lvl = lvl - 1;
}
assert temp_seq == psum2(Matrix[0]);
assert Matrix[0] == input_seq;
assert temp_seq == psum2(input_seq);
assert tid < |temp_seq| && indicator < output.length ==> temp_seq[tid] == output[indicator];
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment