Skip to content

Instantly share code, notes, and snippets.

@cpylua
Created April 14, 2011 09:12
Show Gist options
  • Save cpylua/919163 to your computer and use it in GitHub Desktop.
Save cpylua/919163 to your computer and use it in GitHub Desktop.
binary search on a rotated sorted list
/*
Problem:
An element in a sorted array can be found in O(log n) time via binary
search. But suppose I rotate the sorted array at some pivot unknown to
you beforehand. So for instance, 1 2 3 4 5 might become 3 4 5 1 2.
Now devise a way to find an element in the rotated array in O(log n) time.
Solution:
Binary search can not be applied to unsorted lists, we must somehow find
a way to transform the unsorted list to a "sorted" list.
Take this list as an example, "4, 5, 1, 2, 3". We must find the "pivot"
that is the element at which we can split the unsorted list into two
sorted sub-lists. The key point here is to use binary search to find the
"pivot". We observe that if a sub-list x contains the "pivot", we have
x[l] > x[h](assume the list is sorted in ascending order before rotating).
bspivot(l, h)
m = (l+h)/2
if (x[l] < x[m])
return bspivot(l, m)
if (x[m+1] < x[h])
return bspivot(m+1, h)
return m;
Now we have the "pivot", we can split the list into two sub-lists and do
binary search on one of the sub-lists.
*/
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#define countof(x) (sizeof(x)/sizeof(x[0]))
int bspivot(int *seq, int l, int h) {
int m;
m = l + (h-l)/2;
if (seq[m] < seq[l])
return bspivot(seq, l, m);
if (seq[m+1] > seq[h])
return bspivot(seq, m+1, h);
return m;
}
int _bsearch(int *x, int l, int h, int t) {
int m;
m = l + (h-l)/2;
if (t > x[m])
return _bsearch(x, m+1, h, t);
if (t < x[m])
return _bsearch(x, l, m-1, t);
return m;
}
int rbsearch(int *x, int l, int h, int t) {
int pivot;
pivot = bspivot(x, l, h);
if (t >= x[l] && t <= x[pivot])
return _bsearch(x, l, pivot, t);
if (t >= x[pivot+1] && t <= x[h])
return _bsearch(x, pivot+1, h, t);
return -1;
}
void rotate(int *seq, size_t len, int n) {
int *x = NULL;
int p, i;
x = malloc(len * sizeof(int));
memcpy(x, seq, len*sizeof(int));
for (i = 0; i < len; i++){
p = (i+n) % len;
seq[p] = x[i];
}
free(x);
}
/*
void printseq(int *x, int len) {
int i;
for (i = 0; i < len-1; i++)
printf("%d,", x[i]);
printf("%d", x[i]);
}
*/
void test() {
int x[10] = {0};
int r[12] = {-1,1,2,3,4,5,6,7,8,9,0,-1};
int i, p, v;
for (i = 0; i < countof(x); i++)
x[i] = i+1;
for (i = 0; i < countof(x)-1; i++) {
rotate(x, countof(x), 1);
/*
p = bspivot(x, 0, countof(x)-1);
printseq(x, countof(x));
printf("\tpivot: %d\n", p);
*/
rotate(r+1, countof(x), countof(x)-i);
for (p = 0; p < countof(x)+2; p++) {
v = rbsearch(x, 0, countof(x)-1, p);
/*
printf("\t[%d] found at %d\n", p, v);
*/
if (v != r[p]) {
printf("Test failed\nExpected %d, but got %d\n", r[p], v);
return;
}
}
rotate(r+1, countof(x), i);
}
printf("%s\n", "Test OK!");
}
int main(int argc, char **argv){
test();
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment