Enhancing the DSL for Type Safety

In "A mini DSL" we have used the Int type to represent a millimeter in our DSL of measurements. That was nice to begin with but also somewhat limiting since we cannot use the Int type again for other purposes and still expect type-safe expressions.

Let’s see how we can improve the situation.

The New Goal

We would still like to use the DSL to work for units of distance as before but also for units of duration. The combination of both naturally leads to units of velocity.

What we try to achieve with our DSL of units
10.m   -- distance in meters
10.sec -- duration in seconds
1000.m `per` 60.min == 1.kmh  -- velocity

The type system shall disallow improper operations like adding a duration to a distance.

The implementation strategy is to first (intrusively) refactor the old solution to a point where it still works exactly like before but allows adding durations as a non-intrusive increment.

Step 1: Refactoring

In the old code, we made Int an instance of the Millimeter type class. We will now generalize and make Millimeter an instance of the Num type class such that we can do arithmetic operations on it.

The compiler tells us what functions to implement to satisfy Num
data Millimeter = MM Int

instance Num Millimeter where
    zero = MM 0
    one  = MM 1
    (MM a) + (MM b) = MM (a+b)
    (MM a) - (MM b) = MM (a-b)
    (MM a) * (MM b) = MM (a*b)
    MM a  <=> MM b  = a <=> b
    hashCode (MM a) = hashCode a
    fromInt     a   = MM a
    fromInteger a   = MM a.fromIntegral

This means that we can now calculate in terms of millimeters:

Calculating in millimeters
MM 10 + MM 10 == MM 20

For making 10.mm the value MM 10, we just make our Metric type class using Millimeter as the return type instead of Int.

Update return types from Int to Millimeter
class (Integral a) => Metric a  where
    mm :: a -> Millimeter
    cm :: a -> Millimeter
    m  :: a -> Millimeter

instance Metric Int where
    mm i = MM i
    cm i = i.mm * 10
    m  i = i.cm * 100

And voilà, after the refactoring all old code still works like before.

Proceed. Here is nothing to see. It’s all like before…​
main args = do
    println $ 10.m - 20.cm + 10.mm - 3.cm  == 9780.mm

So why the effort?

  • We get a new level of type safety.

  • We can now add new units of measurement.

Step 2: Adding Durations

The next step is to apply a non-intrusive increment where we do not touch exiting code at all but add the new feature of calculating with durations. We follow the same strategy as before.

The full solution for durations in one go
data Seconds = SEC Int

instance Num Seconds where
    zero = SEC 0
    one  = SEC 1
    (SEC a) + (SEC b) = SEC (a+b)
    (SEC a) - (SEC b) = SEC (a-b)
    (SEC a) * (SEC b) = SEC (a*b)
    SEC a  <=> SEC b  = a <=> b
    hashCode (SEC a)  = hashCode a
    fromInt     a     = SEC a
    fromInteger a     = SEC a.fromIntegral

class (Integral a) => Time a where
    sec     :: a -> Seconds
    minutes :: a -> Seconds
    h       :: a -> Seconds

instance Time Int where
    sec     i = SEC i
    minutes i = i.sec * 60
    h       i = i.minutes * 60

main args = do
    println $ 3.sec + 1.h - 30.minutes == 1803.sec

Let’s see what this gives us.

Putting it all together

With this new structure, we have some serious benefits.

  • We can no longer accidentally write 10.m + 10.sec since the compiler rejects that with Type error in expression "sec 10". Type is Seconds used as Millimeter. Now, that is a nice and telling error message! Is Frege reading my mind?

  • We can still mix Int values into the calculation and they will be safely promoted to the right type! 10.mm + 3 == 13.mm is just as correct and type-safe as 10.sec + 3 == 13.sec. This is controlled by the fromInt implementation of the Num type class.

And now we can go even further and mix distances and duration to describe velocity.

Describing velocity as the quotient of distance over duration
data Velocity = KMH Double
derive Eq Velocity

class (Real a)  => Unit a where
    kmh :: a -> Velocity

instance Unit Double where
    kmh val = KMH val

per (MM mm) (SEC sec) =
    KMH (( mm.fromIntegral / sec.fromIntegral) * 0.0036)

main args = do
    println $ (500.m * 3) `per` (3000.sec + 600)  == 1.5.kmh

Our DSL is really shaping up nicely. It supports us with

  • full type safety,

  • full type inference, and

  • excellent error messages.

It is still easy enough to read and write.

The main workhorses for creating the DSL were type classes in combination with the dot-syntax. We mainly did non-intrusive increments with one intrusive refactoring in-between.

All modifications were robust:

  • non-intrusive increments are robust by definition and

  • the type system supported us to do a robust refactoring when needed.

results matching ""

    No results matching ""