Skip to content

Traversing a Circuit

Adam Izraelevitz edited this page Sep 25, 2016 · 3 revisions

Understanding IR node children

Writing a Firrtl pass usually requires writing functions which walk the Firrtl datastructure to either collection information or replace IR nodes with new IR nodes.

The IR datastructure is a tree, where each IR node can have some number of children nodes (which in turn can have more children nodes, etc.). IR nodes without children are called leaves.

Different IR nodes can have different children types. The following table shows the possible children type for each IR node type:

+------------+-----------------------------+
|    Node    |          Children           |
+------------+-----------------------------+
| Circuit    | DefModule                   |
| DefModule  | Port, Statement             |
| Port       | Type, Direction             |
| Statement  | Statement, Expression, Type |
| Expression | Expression, Type            |
| Type       | Type, Width                 |
| Width      |                             |
| Direction  |                             |
+------------+-----------------------------+

The map function

To write a function that traverses a Circuit, we need to first understand the functional programming concept map.

Understanding Seq.map

A Scala sequence of strings, can be represented as a tree with a root node Seq and children nodes "a", "b", and "c":

val s = Seq("a", "b", "c")
    Seq
 /   |   \
"a" "b" "c"

Suppose we define a function f that, given a String argument x, concatenates x with itself:

def f(x: String): String = x + x

We can call s.map to return a new Seq[String] whose children are the result of applying f to every child of s:

val s = Seq("a", "b", "c")
def f(x: String): String = x + x  // repeated declaration for clarity
val t = s.map(f)
println(t) // Seq("aa", "bb", "cc")
     Seq
 /    |    \
"aa" "bb" "cc"

Understanding Firrtl's map

We use this "mapping" idea to create our own, custom map methods on IR nodes. Suppose we have a DoPrim expression representing 1 + 1; this can be depicted as a tree of expressions with a root node DoPrim:

        DoPrim
     /          \
UIntValue    UIntValue

If we have a function f that takes an Expression argument and returns a new Expression, we can "map" it onto all children Expression of a given IR node, like our DoPrim. This would return the following new DoPrim, whose children are the result of applying f to every Expression child of DoPrim:

        DoPrim
     /          \
f(UIntValue)    f(UIntValue)

Sometimes IR nodes have children of multiple types. For example, Conditionally has both Expression and Statement children. In this case, the map will only apply its function to the children whose type matches the function's argument type (and return value type):

val c = Conditionally(info, e, s1, s2) // e: Expression, s1, s2: Statement, info: FileInfo
def fExp(e: Expression): Expression = ...
def fStmt(s: Statement): Statement = ...
c.map(fExp)  // Conditionally(fExp(e), s1, s2)
c.map(fStmt) // Conditionally(e, fStmt(s1), fStmt(s2))

Side comment: Scala has "infix notation", which allows you to drop the . and parenthesis when calling a function which has one argument. Often, we write these map functions with infix notation:

c map fExp  // equivalent to c.map(fExp)
c map fStmt // equivalent to c.map(fStmt)

Pre-order traversal

To traverse a Firrtl tree, we use map to write recursive functions which visit every child of every node we care about.

Suppose we want to collect the names of every register declared in the design; we know this requires visiting every Statement. However, some Statement nodes can have children Statement. Thus, we need to write a function that will both check if its input argument is a DefRegister and, if not, will recursively apply f to all Statement children of its input argument:

The following function, f, is similar to our described function yet it takes two arguments: a mutable hashset of register names, and a Statement. Using function currying, we can pass only the first argument to return a new function with the desired type signature (Statement=>Statement):

def f(regNames: mutable.HashSet[String]())(s: Statement): Statement = s match {
  // If register, add name to regNames
  case r: DefRegister =>
    regNames += r.name
    r // Return argument unchanged (ok because DefRegister has no Statement children)
  // If not, apply f(regNames) to all children Statement
  case _ => s map f(regNames) // Note that f(regNames) is of type Statement=>Statement
}

This pattern is very common in Firrtl, and is called "pre-order traversal" because the recursive function matches on the original IR node before recursively applying to its children nodes.

Post-order traversal

We can write the previous example in a "post-order traversal" as follows:

def f(regNames: mutable.HashSet[String]())(s: Statement): Statement = 
  // Not we immediately recurse to the children nodes, then match
  s map f(regName) match {
    // If register, add name to regNames
    case r: DefRegister =>
      regNames += r.name
      r // Return argument unchanged (ok because DefRegister has no Statement children)
    // If not, return s
    case _ => s // Note that all Statement children of s have had f(regNames) already applied
  }

While the traversal ordering is different between these two examples, it makes no difference for this use case (and many others). However, it is an important tool to keep in your back pocket for when the traversal ordering matters.