An algebraic data type in Scala

2019-02-19

It’s time to learn Scala. I took Martin Odersky’s course a couple of years ago and thoroughly enjoyed it. Now that some time has passed and I’ve been deeply immersed in all things Haskell, I feel it’s a good time to revisit Scala.

I thought I’d take a look at algebraic data types which is, I believe, something that Haskell excels at.

sealed trait Expr
case class Val(x: Int) extends Expr
case class Add(x: Expr, y: Expr) extends Expr
case class Mul(x: Expr, y: Expr) extends Expr
object HelloWorld {
def eval(e: Expr) : Int =
e match {
case Val(x) => x
case Add(x, y) => eval(x) + eval(y)
case Mul(x, y) => eval(x) * eval(y)
}
def main(args: Array[String]): Unit = {
val e : Expr = Mul(Val(2), Add(Val(3), Val(4)))
println(eval(e))
}
}
view raw ADTs.scala hosted with ❤ by GitHub

That’s a lot of typing. “The world’s most verbose sum types”, as someone once said.

But, at least I can pattern-match on it!

Surprisingly, the Haskell version isn’t significantly shorter in terms of line count:

data Expr =
Val Int
| Add Expr Expr
| Mul Expr Expr
eval :: Expr -> Int
eval (Val x) = x
eval (Add x y) = eval x + eval y
eval (Mul x y) = eval x * eval y
main :: IO ()
main = do
let e = Mul (Val 2) (Add (Val 3) (Val 4))
print $ eval e
view raw ADTs.hs hosted with ❤ by GitHub

Though there is less line noise.

Update: 2019-02-20

Here’s another side-by-side comparison of an extension to my simple expression language embedded in an algebraic data type. This adds Option (Maybe) types to the evaluation function, making this a convenient vehicle for demonstrating how to handle applicative and monadic computations using Scalaz:

import scalaz._
import std.option._
sealed trait Expr
case class Val(x: Int) extends Expr
case class Add(x: Expr, y: Expr) extends Expr
case class Sub(x: Expr, y: Expr) extends Expr
case class Mul(x: Expr, y: Expr) extends Expr
case class Div(x: Expr, y: Expr) extends Expr
object Main {
private def eval(e: Expr): Option[Int] =
e match {
case Val(x) => some(x)
case Add(ex, ey) => binOp(ex, ey)(_ + _)
case Sub(ex, ey) => binOp(ex, ey)(_ - _)
case Mul(ex, ey) => binOp(ex, ey)(_ * _)
case Div(ex, ey) =>
eval(ey)
.flatMap(y =>
if (y == 0)
none
else
eval(ex).map(_ / y))
/*
// Equivalent for-comprehension
case Div(ex, ey) => for {
y <- eval(ey)
if (y != 0)
x <- eval(ex)
} yield (x / y)
*/
}
private def binOp(ex: Expr, ey: Expr): ((Int, Int) => Int) => Option[Int] = {
Apply[Option].apply2(eval(ex), eval(ey))
}
def main(args: Array[String]): Unit = {
// Divide by 2: displays "Some(7)"
println(eval(Div(Mul(Val(2), Add(Val(3), Val(4))), Val(2))))
// Divide by 0: displays "None"
println(eval(Div(Mul(Val(2), Add(Val(3), Val(4))), Val(0))))
}
}
view raw Expr.scala hosted with ❤ by GitHub

Haskell:

#!/usr/bin/env stack
-- stack --resolver=lts-12.6 script
module Main (main) where
data Expr =
Val Int
| Add Expr Expr
| Sub Expr Expr
| Mul Expr Expr
| Div Expr Expr
eval :: Expr -> Maybe Int
eval (Val x) = pure x
eval (Add ex ey) = (+) <$> eval ex <*> eval ey
eval (Sub ex ey) = (-) <$> eval ex <*> eval ey
eval (Mul ex ey) = (*) <$> eval ex <*> eval ey
eval (Div ex ey) = do
y <- eval ey
if y == 0
then Nothing
else do
x <- eval ex
pure (x `div` y)
main :: IO ()
main = do
-- Divide by 2: displays "Just 7"
print $ eval (Div (Mul (Val 2) (Add (Val 3) (Val 4))) (Val 2))
-- Divide by 0: displays "Nothing"
print $ eval (Div (Mul (Val 2) (Add (Val 3) (Val 4))) (Val 0))
view raw Expr.hs hosted with ❤ by GitHub

The two pieces of code look remarkably similar. Haskell has a much more compact (and, in my opinion, intuitive) way of defining the Expr sum type. Haskell’s Applicative (baked into Prelude) provides the elegance of <$> and <*> but, other than that, they’re clearly similar. The flatMap/for-comprehension vs. do-notation is eerily similar, in fact.

All in all, all this flatMapping etc. makes a lot more sense now that I have Haskell in my veins. Furthermore, Scala’s syntactic noise is tolerable. Given that I primarily write Java at work and Haskell at home, Scala feels like a nice midway point between the two. I am confident that Scala will feel like a nice warm fuzzy comfort blanket one day!

Update: 2019-02-20

Here’s a translation into Java:

/*
* High-level commentary:
* This example suggests that Java is, in many ways, an assembly
* language. The resulting Java code is the equivalent Scala or
* Haskell code under a desugaring transform.
*/
package org.rcook;
import java.util.Optional;
import java.util.function.BiFunction;
interface Visitor<T> {
T visit(Val e);
T visit(Add e);
T visit(Sub e);
T visit(Mul e);
T visit(Div e);
}
interface Expr {
<T> T accept(Visitor<T> visitor);
}
final class Val implements Expr {
private final int value;
public Val(int value) {
this.value = value;
}
@Override
public <T> T accept(Visitor<T> visitor) {
return visitor.visit(this);
}
public int getValue() {
return value;
}
}
final class Add implements Expr {
private final Expr left;
private final Expr right;
public Add(Expr left, Expr right) {
this.left = left;
this.right = right;
}
@Override
public <T> T accept(Visitor<T> visitor) {
return visitor.visit(this);
}
public Expr getLeft() {
return left;
}
public Expr getRight() {
return right;
}
}
final class Sub implements Expr {
private final Expr left;
private final Expr right;
public Sub(Expr left, Expr right) {
this.left = left;
this.right = right;
}
@Override
public <T> T accept(final Visitor<T> visitor) {
return visitor.visit(this);
}
public Expr getLeft() {
return left;
}
public Expr getRight() {
return right;
}
}
final class Mul implements Expr {
private final Expr left;
private final Expr right;
public Mul(Expr left, Expr right) {
this.left = left;
this.right = right;
}
@Override
public <T> T accept(Visitor<T> visitor) {
return visitor.visit(this);
}
public Expr getLeft() {
return left;
}
public Expr getRight() {
return right;
}
}
final class Div implements Expr {
private final Expr left;
private final Expr right;
public Div(Expr left, Expr right) {
this.left = left;
this.right = right;
}
@Override
public <T> T accept(Visitor<T> visitor) {
return visitor.visit(this);
}
public Expr getLeft() {
return left;
}
public Expr getRight() {
return right;
}
}
public final class Main {
private static final class EvalVisitor implements Visitor<Optional<Integer>> {
@Override
public Optional<Integer> visit(Val e) {
return Optional.of(e.getValue());
}
@Override
public Optional<Integer> visit(Add e) {
return binOp(e.getLeft(), e.getRight(), (l, r) -> l + r);
}
@Override
public Optional<Integer> visit(Sub e) {
return binOp(e.getLeft(), e.getRight(), (l, r) -> l - r);
}
@Override
public Optional<Integer> visit(Mul e) {
return binOp(e.getLeft(), e.getRight(), (l, r) -> l * r);
}
@Override
public Optional<Integer> visit(Div e) {
return e
.getRight()
.accept(this)
.filter(r -> r != 0)
.flatMap(r -> e
.getLeft()
.accept(this)
.flatMap(l -> Optional.of(l / r)));
}
private Optional<Integer> binOp(
Expr left,
Expr right,
BiFunction<Integer, Integer, Integer> op) {
return left
.accept(this)
.flatMap(l -> right
.accept(this)
.flatMap(r -> Optional.of(op.apply(l, r))));
}
}
// Use Optional<Integer> since OptionalInt doesn't support flatMap etc.
private static Optional<Integer> eval(Expr e) {
EvalVisitor v = new EvalVisitor();
return e.accept(v);
}
public static void main(String[] args) {
// Divide by 2: outputs "Optional[7]"
System.out.println(eval(
new Div(
new Mul(
new Val(2),
new Add(
new Val(3),
new Val(4))),
new Val(2))));
// Divide by 0: outputs "Optional.empty"
System.out.println(eval(
new Div(
new Mul(
new Val(2),
new Add(
new Val(3),
new Val(4))),
new Val(0))));
}
}
view raw Expr.java hosted with ❤ by GitHub

Update: 2019-02-23

And in OCaml:

open Core.Option
open Printf
type expr =
| Val of int
| Add of expr * expr
| Sub of expr * expr
| Mul of expr * expr
| Div of expr * expr
let rec to_sexpr = function
| Val x -> string_of_int x
| Add (left, right) ->
sprintf "(add %s %s)" (to_sexpr left) (to_sexpr right)
| Sub (left, right) ->
sprintf "(sub %s %s)" (to_sexpr left) (to_sexpr right)
| Mul (left, right) ->
sprintf "(mul %s %s)" (to_sexpr left) (to_sexpr right)
| Div (left, right) ->
sprintf "(div %s %s)" (to_sexpr left) (to_sexpr right)
let rec eval = function
| Val x -> Some x
| Add (left, right) -> binOp left right (+)
| Sub (left, right) -> binOp left right (-)
| Mul (left, right) -> binOp left right ( * )
| Div (left, right) -> binOpM left right (fun l r ->
if r = 0
then None
else Some (l / r))
and binOp left right f = map2 (eval left) (eval right) f
and binOpM left right f =
eval left >>= fun l ->
eval right >>= fun r ->
f l r
let string_of_option f o = match o with
| None -> "None"
| Some x -> sprintf "Some %s" (f x)
let () =
(* Divide by 2 *)
let e = Div (Mul (Val 2, Add (Val 3, Val 4)), Val 2) in
print_endline (to_sexpr e);
print_endline (string_of_option string_of_int (eval e));
(* Divide by 0 *)
let e = Div (Mul (Val 2, Add (Val 3, Val 4)), Val 0) in
print_endline (to_sexpr e);
print_endline (string_of_option string_of_int (eval e));
view raw expr.ml hosted with ❤ by GitHub

Related posts

My first Eta program
LambdaConf 2016 conference report
My first MOOC: “Functional Programming Principles in Scala”

Tags

Scala
Haskell
Java
Scalaz
OCaml

Content © 2025 Richard Cook. All rights reserved.