Sunday, August 8, 2010

Segment Tree in F#

Edit (2011.02.18): Oops! I found a bug in the code below, specifically, in the "findCollect" function. Please see this post (2011.02.18) for details. Correction is pending testing.

Today’s entry is a segment tree in F#. The details of segment trees are described elsewhere, so I won’t repeat them here. Suffice to say that a segment tree stores a group of intervals on an ordered set, in such a way that all the intervals containing a given value can be efficiently located. Such a query is sometimes called a “stabbing query,” because it’s like sticking a spear through the data and seeing which entries get stabbed.

For example, if one had a data set consisting of the professional starting and ending years of various baseball players, one could use it to efficiently answer a query such as: “What players were active in 1977?” In computer graphics, segment trees can be used for various types of hit testing, particularly when extended to two or three dimensions.

Segment trees are especially useful for large amounts of static data; that is, where all the segments are known before the tree is constructed. The implementation below is for the classic static segment tree. There are references one can find regarding ways to make segment trees dynamic, but if the data is very dynamic, or of lower quantity, you might also want to consider other algorithms.

In my case, I'll be using the segment tree for some fuzzy logic experiments I’m doing, so a static, one-dimensional segment tree will work. As a first step in fuzzification, I want to quickly locate those fuzzy sets with domains containing a given non-fuzzy (crisp) input value.

The code is below, along with an example. Let me say up front that I’m less than 100% satisfied with the current factoring. I think it could be more F#-like and use more of F#’s built-in functionality. But I wanted to get it posted rather than let it languish. I did manage to avoid classes other than the built-in class types of record and discriminated union (and, of course, tuples). I did this deliberately as part of my goal of seeing beyond “OOP-think.”

Also, I plan on using this internal to an F# project, so it doesn't quite meet general .NET guidelines. However, it should be easy to adapt the code to a general .NET libary.

As always, all the code here is presented "as-is" and without warranty or implied fitness of any kind; use at your own risk. I tested it by cross-checking against a (theoretically) good segment implementation, but you never know, I may have missed a case or they both might have the same bugs.

 
open System
 
// The next two functions assist in the 
// creation of a binary tree from a list.
 
/// Accumulate the results of applying f pairwise.
/// If the list has odd parity, the remaining
/// member is simply added.
// i.e. [f(0,1);f[2;3]...]
let pairWise (f:'a->'a->'a) l =
  let rec f0 acc = function
    | [] -> List.rev acc
    | h::[] -> List.rev (h::acc)
    | h0::h1::t -> f0 ((f h0 h1)::acc) t
  f0 [] l
 
/// Call pairwise on the list and recurse on
/// the result until only one item is left.
let rec pairWiseReduce (f:'a->'a->'a) = function
  | [] as l -> l
  | h::[] as l -> l
  | _ as l -> pairWiseReduce f (pairWise f l)
 
// These types encode the segment tree.
 
/// A leaf node.
// Eq - Leaf value.
// Hi - Exclusive leaf range.
// EqVal, HiVal - User data associated 
// with Eq, Hi.
type LeafNode<'a,'b> =
  {  
    Eq:'b;
    Hi:'b;
    mutable EqVal:'a list;
    mutable HiVal:'a list; 
  }
  static member Make eq hi =
    { Eq=eq; Hi=hi; EqVal=[]; HiVal=[]; }
 
/// An intermediate node.
// Lo - Inclusive lowest subtree value.
// Mid - Split point (inclusive to the right).
// Hi - Exclusive highest subtree value.
// EqVal, HiVal - User data associated 
// with Eq, Hi.
type TreeNode<'a,'b> =
  {
    Lo:'b;
    Mid:'b;
    Hi:'b;
    mutable RngVal:'a list;
  }
  static member Make lo mid hi =
    { Lo=lo; Mid=mid; Hi=hi; RngVal=[]; }
 
/// Typical F# tree with node specializations.
type Tree<'a,'b> =
  | Leaf of LeafNode<'a,'b>
  | Tree of TreeNode<'a,'b>*(Tree<'a,'b>*Tree<'a,'b>)
 
/// Join two trees with a new tree node.
let makeTree<'a,'b> (t0:Tree<'a,'b>) (t1:Tree<'a,'b>) =
  let lo = 
    match t0 with
    | Leaf(ln) -> ln.Eq 
    | Tree(tn,_) -> tn.Lo
  let mid,hi = 
    match t1 with
    | Leaf(ln) -> ln.Eq,ln.Hi 
    | Tree(tn,_) -> tn.Lo,tn.Hi
  Tree(TreeNode<'a,'b>.Make lo mid hi,(t0,t1))
 
/// Add a segment interval to a tree.
let rec addInterval (t:Tree<'a,'b>) l h a =
  match t with
  | Leaf(ln) -> 
    if (ln.Eq>=l)&&(ln.Eq<=h) then
      ln.EqVal <- a::ln.EqVal
    if (h>ln.Eq)&&(l<ln.Hi) then
      ln.HiVal <- a::ln.HiVal
  | Tree(tn,(t0,t1)) -> 
    match (tn.Lo>=l)&&(tn.Hi<=h) with 
    | true -> tn.RngVal <- a::tn.RngVal
    | false ->
      if (l<tn.Mid) then
        addInterval t0 l h a  
      if (h>=tn.Mid) then
        addInterval t1 l h a
 
// Note: either of the following functions 
// could be synthesized from the other.
// Many other aggregate functions are 
// also possible.
 
/// Find the intervals containing b, and apply f.
let rec findIter(t:Tree<'a,'b>) (b:'b) (f:'a list->'b->unit) =
  match t with 
  | Leaf(ln) -> 
    match b=ln.Eq with
    | true -> if not ln.EqVal.IsEmpty then f ln.EqVal b
    | false -> if not ln.HiVal.IsEmpty then f ln.HiVal b
  | Tree(tn,(t0,t1)) ->
    if not tn.RngVal.IsEmpty then f tn.RngVal b
    match b<tn.Mid with
    | true -> findIter t0 b f
    | false -> findIter t1 b f
 
/// Find the intervals containing b, and collect in a list.
let findCollect (t:Tree<'a,'b>) (b:'b) =
  let rec f acc = function
    | Leaf(ln) -> 
      match b=ln.Eq with
      | true -> List.append ln.EqVal acc
      | false -> List.append ln.HiVal acc
    | Tree(tn,(t0,t1)) ->
      match b<tn.Mid with
      | true -> f (List.append tn.RngVal acc) t0 
      | false -> f (List.append tn.RngVal acc) t1
  f [] t
 
/// Utility function to print a tree.
let printTree<'a,'b> (t:Tree<'a,'b>) =
  let rec f0 s = function
    | Leaf(ln) -> 
      printfn "%s(%A %A)" s (ln.Eq) (ln.Hi)
      printfn "%s  %A" s (ln.EqVal)
      printfn "%s  %A" s (ln.HiVal)
    | Tree(tn,(t0,t1)) ->
      printfn "%s(%A %A %A)" s (tn.Lo) (tn.Mid) (tn.Hi)
      printfn "%s  %A" s (tn.RngVal)
      let s0 = s+"  "
      f0 s0 t0
      f0 s0 t1
  f0 "" t 
 
 
// A test.  Uses tuples of the form int*int*string.
// The ints represent the range, and the string the 
// user data.
 
/// Generate a batch of 3-tuple ranges.
let genTests n min max rng =
  let r = System.Random()
  let makeTest () =
    let rl = r.Next(min,max)
    let rh = rl + r.Next(rng)
    (rl,rh,sprintf"%i..%i" rl rh)
  seq { for i in 1..n ->makeTest() }
 
let segTests = (genTests 10 0 10000 10000) |> Seq.toList
 
// Use the ranges to create a list of
// unique endpoints.  Note List.rev!
let segments = 
  segTests 
  |> List.collect (fun (l,h,_)->[l;h])
  |> Seq.distinct 
  |> Seq.toList 
  |> List.sort 
  |> List.rev
 
// Utility to make a leaf from a 
let makeLeaf<'a,'b> (ll:LeafNode<'a,'b> list,b1) (b0:'b) =
  (LeafNode<'a,'b>.Make b0 b1)::ll,b0
 
// Create a tree.  
// Note that the fold works pairwise with makeLeaf to
// create leaf nodes with the proper ranges.
// The list of leaf nodes is then reduced to a tree.
let tree =
  List.fold makeLeaf<string,int> ([],segments.Head) segments
  |> fst
  |> List.map (fun ln->Leaf(ln))
  |> pairWiseReduce makeTree
  |> List.head 
 
// Add the interval user data to the tree.  
segTests |> List.iter (fun (l,h,s)->addInterval tree  l h s)
 
// Make a one-dimensional stabbing query.
let stab = findCollect tree 12000
 
printfn "Your breakpoint here."

No comments: