This past semester, I took a graduate course, CS 207 - Systems Development in Computational Science. In the course, we talked about good software engineering practices in C++ (but the lessons span beyond C++), in particular representation invariants, abstraction functions, and writing solid code specifications so that one could even prove things about code. The professor made a couple of blog entries for some of the lectures, explaining cool tricks with iterators and bits.

Early in the semester, we discussed several implementations of binary search, starting from a simplistic version and incrementally building up to a production-ready version. I thought the binary search discussion was an extremely eye-opening exercise; it was my first time seeing invariants being used in proofs to prove properties about code.

Below is how I’ve written binary search since high school:

/** Returns the index of any occurrence of @a x in @a a
 * @pre @a a has length equal to n
 * @pre @a a is sorted in increasing order
 * @return -1 if not found 
 */
int binary_search(int *a, int n, int x) {
  int lo = 0;
  int hi = n - 1;
  while (lo <= hi) {
    int mid = (lo + hi) / 2;
    if (x == a[mid])
      return mid;
    if (x < a[mid])
      hi = mid - 1;
    else
      lo = mid + 1;
  }
  return -1;
}

Here, I am using Doxygen style comments for my specifications. In this version of binary search, I return the index of any occurence of item x in array a, or return -1 if there is no such occurrence. While this implementation is acceptable for an array of ints, it is not particularly useful for other data types.

Using C++ templates, we can generalize this implementation to make it polymorphic for any type T, provided we provide a suitable comparison function compare where compare(p,q) returns true if and only if p is less than q for some ordering of values of type T. Thus, here is our attempt #2 at binary search:

/** Returns the index of any occurrence of @a x in @a a
 * @param compare(p,q) returns true if p < q
 * @pre @a a has length equal to @a n
 * @pre @a a is sorted by @a compare
 * @return -1 if not found 
 */
template<typename T, typename CMP>
int binary_search2(T *a, int n, T x, CMP compare) {
  int l = 0;
  int r = n;
  while (l < r) {
    int m = l + (r - l) / 2; // fix overflow issues
    if (compare(a[m],x))
      l = m + 1;
    else if (compare(x,a[m]))
      r = m;
    else
      return l;
  }
  return - 1;
}

Now, in order to call binary search, we must provide a function object compare that defines how we compare two elements of type T. Below is an example of how we would invoke this version of binary search:

struct IntComp {
  bool operator()(int x, int y) {
    return x < y;
  }
};

int main(void) {
  int arr[12] = {2,3,4,5,7,8,9,11,13,15,16,17};
  std::cout << binary_search2(arr, 12, 15, IntComp()) << std::endl;
  return 0;
}

We overload operator() to allow IntComp objects to be invoked like functions, and we pass an instance of IntComp to binary_search2 whenever we perform a binary search on an array of ints.

Note one other difference between the two versions of binary search: in attempt #1, we had the line:

int mid = (lo + hi) / 2;

whereas in attempt #2, we replaced this line with:

int m = l + (r - l) / 2;

For all these years, I’ve been writing binary search incorrectly! In the first version, we may run into integer overflow if lo + hi happen to be greater than the maximum integer value for int! In the second version, we fix this subtle bug by first subtracting r and l, then halving the difference and add the result to l to calculate the new middle index m. By subtracting first, we are guaranteed that r - l will not overflow (by the implicit precondition that r and l are valid indices into the array and r > l), and thus m will also be a valid index into the array.

We have generalized our binary search to work on an array containing any type. But, we have actually done more than this. In C++, iterators overload pointer syntax to represent collections of items. Using iterators, we can represent an entire range of items in a collection with only two iterators–one pointing to the beginning of the collection, and one pointing to the “position” after the last element in the collection. See the CS 207 blog entries here for more information on C++ iterators. In our example, however, we represent the array collection with a pointer to the first position and the number of items in the list. Because binary search requires random access into our collection, any collection represented by a random access iterator will be able to use the second version of our binary search!

Can we still do better? In our specification for binary search, note that we allowed the index of any occurrence of our search item x to be returned. This ambiguity makes it difficult to make any real use of the return value of binary search (except simply to check whether the item is in the collection). Instead of returning any index, what if we returned a lower bound position of the element x in our collection? By lower bound, we mean the first index into the array at which we should insert x and still keep the elements in sorted order.

For example, with the array {0, 1, 2, 5, 5, 5, 7, 9}, the lower bound of 0 would be 0, because we can insert 0 into index 0 and still keep our array sorted. The lower bound of -1 is also 0 by a similar reasoning. The lower bound of 5 is 3 because 3 is the smallest index that we can insert 5 and keep the array sorted. Similarly, the lower bound of 6 is 6. Note that the lower bound of 10 is 8, which is not a valid index into the array. This is okay because the return value only indicates the index that one could insert an item and maintain the sorted property of the array.

To implement this, we can think of the array as a collection of boolean values where the entries are {false, false, ..., false, true, true, ... true} (all the falses occur together at the beginning of the array). The boolean values correspond to whether our target element x is less than or equal to the value in that array position. Our goal, then, is to find the first true in the array, or return the last position (indicating that placing x at the end of the array would maintain the sorted property of our array). Building on the polymorphism we introduced in attempt #2, here is attempt #3 using the lower bound idea:

/** Return the lower-bound position of @a x in @a a
 * @param compare(p,q) returns true if p < q
 * @pre @a a has length equal to @a n
 * @pre @a a is sorted by @a compare
 * @post return R where 0 <= R <= @a n and:
 *   For all 0 <= i < n, 
 *      i < R iff a[i] < x
 *      i >= R iff a[i] >= x 
 */
template<typename T, typename CMP>
int lower_bound(T *a, int n, T x, CMP compare) {
  int l = 0;
  int r = n;
  while (l < r) {
    int m = l + (r - l) / 2;
    if (compare(a[m],x))
      l = m + 1;
    else
      r = m;
  }
  return l;
}

Nice, clean, and simple!

Note that this version uses only one comparison instead of two (as we did in attempts #1 and #2)! This lower bound idea not only tells us whether our element x is the array, but where we should place it to keep the list sorted!

This code looks simple enough to verify the correctness by eyeballing it; but can we make this rigorous? Can we prove the correctness of this code? Yes! Here is the same piece of code but commented heavily with the proof of its own correctness.

/** Return the lower-bound position of @a x in @a a
 * @param compare(p,q) returns true if p < q
 * @pre @a a has length equal to @a n
 * @pre @a a is sorted by @a compare
 * @post return R where 0 <= R <= @a n and:
 *   For all 0 <= i < n, 
 *      i < R iff a[i] < x
 *      i >= R iff a[i] >= x 
 */
template<typename T, typename CMP>
int lower_bound_proof(T *a, int n, T x, CMP compare) {
  // pre: for all i,j with 0 <= i <= j < n, we have a[i] <= a[j]
  // post: let R be the return value. Then 0 <= R <= n, and
  //   for all 0 <= i < n,
  //     i < R iff a[i] < x    (1)
  //     i >= R iff a[i] >= x  (2) 

  int l = 0;
  int r = n;
  while (l < r) {
    // PRE LOOP
    // loop invariant: l <= R <= r (always true in the loop)
    // decrementing function: d = r - l

    int m = l + (r - l) / 2; // if r - l >= 2, then (r - l)/2 >= 1,
                             //                so l < m < r
                             // if r - l == 1, then l = m < r, 
                             //                so l <= m < r

    if (compare(a[m],x)) {
      // we have a[m] < x. Then by (1), a[m] < x ==> m < R
      // then for all 0 <= i <= m, a[i] < x (b/c sorted)
      l = m + 1; // so l < l_new == m + 1 <= R
                 // r_new == r >= R
                 // so l_new <= R <= r_new
                 // and r_new - l_new < r - l (d decrements)
    } else {
      // we have a[m] >= x. Then by (2), a[m] >= x ==> m >= R
      // then for all m <= i < n, a[i] >= x (b/c sorted)
      r = m; // so r > r_new == m >= R
             // l_new == l <= R
             // so l_new <= R <= r_new
             // and r_new - l_new < r - l (d decrements)
    }

    // POST LOOP
    // loop invariant: l_new <= R <= r_new
    // decrementing function: r_new - l_new < r - l
  }

  // by the decrementing function, d eventually reaches 0;
  //      thus the loop terminates
  // by the loop invariant, we have l <= R <= r
  return l;
}

To prove the correctness, we make heavy use of the post condition:

/**
 * @post return R where 0 <= R <= @a n and:
 *   For all 0 <= i < n, 
 *      i < R iff a[i] < x
 *      i >= R iff a[i] >= x 
 */

Thus, all elements at indices less than the return value R are less than x, and all other elements are greater than or equal to x. We use this both of these if and only if conditions in the two branches of the if conditional to guide us on how we should update l or r.

In both of the branches of the conditional, we have that the new values of l and r are maintained so that l <= R <= r and still satisfy the post condition of the function. Thus, the statement l <= R <= r is a loop invariant of the while loop: it is always true on entering and leaving the loop. To ensure that the loop terminates, we require a decrementing function, a function that decreases on each iteration of the loop and is equal to zero when the loop terminates. In this case, the obvious choice for the decrementing function would be d = r - l. We show in both branches that the new values of l and r are such that r_new - l_new < r - l, and so d decreases on each iteration. When d = 0, we have that l = r, which is indeed when the loop terminates. Thus, our final line return l; is proven correct by the combination of our post conditions, pre conditions (array is sorted), loop invariant, and decrementing function. By analyzing the invariants in the code, the code almost writes itself! Cool!

To view the code in its entirety (along with a couple of simple test harnesses for each version of binary search), check out the source here.