Friday, January 27, 2012

Mixins Simplify Composition in Scala

Scala has a nifty feature called “traits.”  A trait is a means of encapsulating methods and fields that behaves in some ways like an interface, and in other ways like an abstract class.  Like an interface or abstract class, a trait can declare methods and fields that must later be defined in some concrete class.  Like an abstract class, but unlike an interface, a trait can also supply default bodies and values for those methods and fields.  Unlike a class, however, a trait cannot have constructor parameters.  Here is a simple trait that specifies an iterator:

 

trait MyIter [A] {

  def next(): A

  def hasNext: Boolean

}

 

In the first of these roles, traits can be used just like traditional interfaces: as a way of specifying the external behavior of an object without saying anything about its implementation.  In the second of these roles, traits can be used like abstract base classes, supplying both declarations and definitions. 

 

However – and this is the really cool part – since traits are not classes, multiple traits can be specified for a single object without violating single inheritance constraints.  So, like interfaces, multiple traits can contribute to a class declaration, but unlike either interfaces or single base classes, multiple traits can also contribute to a class definition.  This is sometimes called “mixin” inheritance, because the capabilities are “mixed into” other types.

 

The example here develops a simple set of iterator traits and classes.  Iterators were chosen because they are a simple construct which can nevertheless illustrate some powerful mixin concepts.  Of course, the standard Scala library contains a much more sophisticated set of iterator traits and classes, and you’ll want to use those rather than mine for real-world programming.  In particular, to keep things simple, I ignore covariance and contravariance, which can be explicitly controlled in Scala types.

 

 (And as always, the code and information here are presented "as-is" and without warranty or implied fitness of any kind; use it at your own risk.)

 

(The code is presented in fragments; a complete listing is at the bottom.)

 

Recalling the simple iterator example from above:

 

trait MyIter [A] {

  def next(): A

  def hasNext: Boolean

}

 

This trait is parameterized on the iterated type, the “[]” brackets delineating type parameters in Scala.  Also, the methods follow the standard Scala practice for methods without arguments: if the method modifies the object, parentheses are used, if it does not modify the object, the parentheses are omitted.

 

In this case, the iterator supplies only declarations, and so is used in an interface role.  (In fact, it resembles the actual Iterator trait from Scala.)  We could supply functionality via another trait or concrete class, but we’ll actually do it using an abstract class.  The abstract class below defines iteration across a range of integers, and delegates the conversion from integer to iterated type to its derived classes using the abstract function “current.”  Inheritors can get basic iteration capability simply by defining this function.

 

abstract class MyCursorIter[A] (

  protected val first: Int,

  protected val last: Int) extends MyIter[A] {

 

  protected var cursor = first - 1

 

  protected def current: A

 

  def next() = {

    if (hasNext) {

      cursor += 1

      current

    }

    else throw new Exception("Iteration past end.")

  }

 

  def hasNext = cursor < last

}

 

Note that first trait or the base class is mixed in using the “extends” keyword; subsequent traits are mixed in using the “with” keyword.

 

Once you have an iterator, you can use it to do all sorts of wonderful things.  But perhaps the most wonderful of all is the higher order function: “fold.”  Fold is a good candidate for the second role of traits: defining methods and values.  Below is a definition of a fold trait as applied to the “MyIter” trait:

 

trait MyFold[A] {

 

  this: MyIter[A] =>

 

  @tailrec

  final def fold[B] (acc: B, f:(B,A) => B): B = {

    if (hasNext) {

      fold(f(acc,next()),f)

    } else acc

  }

}

 

The curious-looking line “this: MyIter[A] =>“ tells the compiler that this trait will only be applied to types that also support the “MyIter[A]” trait.  This is what allows you to use members “hasNext” and “next()” from that trait.  “@tailrec” on the fold function tells the Scala compiler that the function should be made tail-recursive, and to warn you at compile time if that is not possible.  What follows is a standard fold function that takes an initial result “acc” (for accumulator), and steps through the elements passing on the result of applying the function “f” to the previous result and each element, finally returning the last result.

 

Fold is like the Swiss Army Knife of higher order functions in that it can be used to implement a whole range of other higher order functions.  The fragment below shows a couple of other mixins which use fold. The first computes the length of an iterated sequence and the second converts an iterated sequence to a list.  Note the use of the “this: … =>” construct to inform the compiler that these traits extend types with the fold trait.  And note also the use of the “_” in “this: MyFold[_] =>” in MyLength which indicates that the type parameter is not important for this definition.  Also note that both functions consume the iterator, leaving it empty.

 

trait MyLength {

 

  this: MyFold[_] =>

 

  def length() = {

    fold(0, (b: Int, _) => b + 1)

  }

}

 

 

trait MyToList[A] {

 

  this: MyFold[A] =>

   

  def toList() =

    fold(Nil, (b: List[A], a: A) => a::b).reverse

}

 

And, for convenience, here is a trait that mixes a reset function into the “MyCursorIter” class.  If an iterator can be made resettable, functions can be written that consume it without rendering it useless for further computation.

 

trait MyReset {

 

  this: MyCursorIter[_] =>

   

  def reset() = cursor = first - 1

}

 

At last, time for the first concrete implementation.  And here it is, “MyStringIter,” a character iterator across strings:

 

class MyStringIter (s: String)

  extends MyCursorIter[Char](0,s.length - 1)

  with MyReset

  with MyFold[Char]

  with MyLength

  with MyToList[Char] {

 

  protected def current = s(cursor)

}

 

And the test:

 

object Test00 {

 

  def main(args: Array[String]): Unit = {

   

    // String example.

          

    val msi = new MyStringIter("This is a test.")

   

    println(msi.toList())

   

    msi.reset()

 

    println(msi.length())

 

 

    println("Your breakpoint here.")

  }

}

 

But wait!  There’s more!  In Scala you can use all these traits and classes to cobble together types on the fly.  The extended test below illustrates this by defining, with the test function itself, an ad-hoc iterator across integer ranges.

 

object Test00 {

 

  def main(args: Array[String]): Unit = {

   

    // String example.

          

    val msi = new MyStringIter("This is a test.")

   

    println(msi.toList())

   

    msi.reset()

  

    println(msi.length())

   

    // Range example.

          

    val msr =

      new MyCursorIter[Int](0,9)

        with MyReset

        with MyFold[Int]

        with MyToList[Int]

        { protected def current = cursor }

   

    println(msr.toList())

   

    msr.reset()

   

    // Error!  MyLength is not mixed-in to msr.

    // println(msr.length())

  

    println("Your breakpoint here.")

  }

}

 

Now isn’t that cool?

 

-Neil

 

Here is the full listing:

 

/*

 * Use of this software is governed by

 * the following license:

 * The MIT License (MIT)

 * http://www.opensource.org/licenses/mit-license.php

 *

 * Copyright (c) 2012 Neil P. Carrier

 *

 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT

 * WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,

 * INCLUDING BUT NOT LIMITED TO THE WARRANTIES

 * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR

 * PURPOSE AND NONINFRINGEMENT. IN NO EVENT

 * SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE

 * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER

 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT,

 * TORT OR OTHERWISE, ARISING FROM, OUT OF

 * OR IN CONNECTION WITH THE SOFTWARE OR THE USE

 * OR OTHER DEALINGS IN THE SOFTWARE.

 *

 */

 

package com.techneilogy.test00

 

import annotation.tailrec

 

 

trait MyIter [A] {

  def next(): A

  def hasNext: Boolean

}

 

 

abstract class MyCursorIter[A] (

  protected val first: Int,

  protected val last: Int) extends MyIter[A] {

 

  protected var cursor = first - 1

 

  protected def current: A

 

  def next() = {

    if (hasNext) {

      cursor += 1

      current

    }

    else throw new Exception("Iteration past end.")

  }

 

  def hasNext = cursor < last

}

 

 

trait MyFold[A] {

 

  this: MyIter[A] =>

 

  @tailrec

  final def fold[B] (acc: B, f:(B,A) => B): B = {

    if (hasNext) {

      fold(f(acc,next()),f)

    } else acc

  }

}

 

 

trait MyLength {

 

  this: MyFold[_] =>

 

  def length() = {

    fold(0, (b: Int, _) => b + 1)

  }

}

 

 

trait MyToList[A] {

 

  this: MyFold[A] =>

   

  def toList() =

    fold(Nil, (b: List[A], a: A) => a::b).reverse

}

 

 

 

trait MyReset {

 

  this: MyCursorIter[_] =>

   

  def reset() = cursor = first - 1

}

 

 

class MyStringIter (s: String)

  extends MyCursorIter[Char](0,s.length - 1)

  with MyReset

  with MyFold[Char]

  with MyLength

  with MyToList[Char] {

 

  protected def current = s(cursor)

}

 

 

object Test00 {

 

  def main(args: Array[String]): Unit = {

   

    // String example.

          

    val msi = new MyStringIter("This is a test.")

   

    println(msi.toList())

   

    msi.reset()

 

    println(msi.length())

   

    // Range example.

          

    val msr =

      new MyCursorIter[Int](0,9)

        with MyReset

        with MyFold[Int]

        with MyToList[Int]

        { protected def current = cursor }

   

    println(msr.toList())

   

    msr.reset()

   

    // Error!  MyLength is not mixed-in to msr.

    // println(msr.length())

  

    println("Your breakpoint here.")

  }

}