Custom JIT compilation with Runtime Multi-Stage Programming


For a while now I've been fiddling with a library called Slinc, a way to link C code in Scala. Since my original version, I've been using context values, macros, and other tools to generate code for writing and reading representations of Structs into and out of native memory.

If you're not familiar with how macros work in Scala, the most simple way of writing them involves using quotes and splices.

val additionExpr = '{4 + 3}
val function = '{() => $additionExpr} //produces () => 4+3

Compile time information can be used in combination with creative usage of quotes and splices to generate all kinds of different code:

val exprs = for 
  i <- 0 until 5
yield '{(j: Int) => println(j + ${Expr(i)})} //${Expr(i)} smuggles the compile time value i into our generated code

'{
	val k = 15
	
	${
		Expr.block(
			//would normally produce ((j: Int) => println(j + 0))(k) as an example
			//but Expr.betaReduce simplifies it to println(k + 0)
			exprs.map(expr => Expr.betaReduce('{$expr('k)})),
			'{}
		)
	}
}
//resulting code:
//{
//  val k = 15
//  println(k + 0)
//  println(k + 1)
//  println(k + 2)
//  println(k + 3)
//  println(k + 4)
//}

Compile-time reflection and terms are also a thing in Scala 3 macros, but the usage of those is a story for another time.

Anyway one of my first attempts at generating code to write data to native memory would take case classes like these:

case class Y(a: Int, b: Int) derives Struct
case class X(a: Int, y: Y, b: Int) derives Struct

And produce code like below to write it into native memory:
{
  Encoder.given_Encoder_Int.into(
	  a.a,
	  memoryAddress,
	  l.layout.byteOffset(PathElement.groupElement("a")) + offset
  )
  Encoder.given_Encoder_Int.into(
	  a.y.a,
	  memoryAddress,
	  l.layout.byteOffset(
		  PathElement.groupElement("y"),
		  PathElement.groupElement("a")
	  ) + offset
  )
  Encoder.given_Encoder_Int.into(
	  a.y.b,
	  memoryAddress,
	  l.layout.byteOffset(
		  PathElement.groupElement("y"),
		  PathElement.groupElement("b")
	  ) + offset
  )
  Encoder.given_Encoder_Int.into(
	  a.b,
	  memoryAddress,
	  l.layout.byteOffset(PathElement.groupElement("b")) + offset
  )
}

This code basically takes the Int Encoder, as I was calling it back then, and calls its "into" method to write a.a, a.y.a, a.y.b, and a.b into native memory. It provides offsets into the native memory by analyzing the layout of a C struct mirroring the layout of X. This code is fairly readable, and though I've removed the usage of full path names, it's something a human might write to record this case class into native memory.

So how fast was this code?

Benchmark                    Mode  Cnt    Score   Error   Units
AssignBench.writeCaseClass  thrpt    5    4.191 ± 0.061  ops/us
AssignBench.writeInt        thrpt    5  289.714 ± 7.512  ops/us

Not good. Despite only writing 4 integer values, the Encode for X is something like 70x slower. This is due to the querying the structure for the offsets with which to insert the Int.

In my latest code, I'm using a much simpler approach using plain inline methods instead of macros. Inline methods can't do everything a macro can, but in general they tend to be easier to work with, and simpler to support.

  inline def calcSender[A <: Product](
      offsets: Vector[Bytes]
  )(using m: Mirror.ProductOf[A]): Send[Product] =
    (rawMem: Mem, offset: Bytes, a: Product) =>
      calcTupSender[m.MirroredElemTypes](
        Tuple.fromProduct(a).toArray,
        rawMem,
        offset,
        offsets,
        0
      )
  inline def calcTupSender[A <: Tuple](
      a: Array[Object],
      rawMem: Mem,
      offset: Bytes,
      offsets: Vector[Bytes],
      position: Int
  ): Unit =
    inline erasedValue[A] match
      case _: (h *: t) =>
        summonInline[Send[h]].to(
          rawMem,
          offsets(position) + offset,
          a(position).asInstanceOf[h]
        )
        calcTupSender[t](a, rawMem, offset, offsets, position + 1)
      case _: EmptyTuple => ()

Here, "Send" is my modern equivalent to Encode. I decided that I wasn't actually producing native data, but putting data that was on the jvm into the native space, "Send" is a much better term for the operation. The above code has two parts, the main inline method that takes the Struct representation and offsets for the data placement, and renders a Tuple representation from it, and a recursive inline method that summons the "Send" instance for each element of the tuple, and writes the data to the appropriate place in native memory. For X, these inlines produce the following code:

generated code for: X
{
  val a$proxy1: Array[Object] = 
    {
      val Tuple_this: Tuple = Tuple.fromProduct(a)
      Tuples.toArray(Tuple_this):Array[Object]
    }
  {
    fr.hammons.sffi.given_Send_Int.to(rawMem, 
      Bytes.+(this.layout.offsets.apply(0))(offset)
    , a$proxy1.apply(0).asInstanceOf[Int])
    {
      AssignBenches.this.Y.derived$Struct.to(rawMem, 
        fr.hammons.sffi.Bytes.+(this.layout.offsets.apply(1))(offset)
      , a$proxy1.apply(1).asInstanceOf[AssignBenches.this.Y])
      {
        fr.hammons.sffi.given_Send_Int.to(rawMem, 
          fr.hammons.sffi.Bytes.+(this.layout.offsets.apply(2))(offset)
        , a$proxy1.apply(2).asInstanceOf[Int])
        ():Unit
      }:Unit
    }:Unit
  }:Unit
}
generated code for: Y
{
  val a$proxy3: Array[Object] = 
    {
      val Tuple_this: Tuple = Tuple.fromProduct(a)
      runtime.Tuples.toArray(Tuple_this):Array[Object]
    }
  {
    fr.hammons.sffi.given_Send_Int.to(rawMem, 
      fr.hammons.sffi.Bytes.+(this.layout.offsets.apply(0))(offset)
    , a$proxy3.apply(0).asInstanceOf[Int])
    {
      fr.hammons.sffi.given_Send_Int.to(rawMem, 
        fr.hammons.sffi.Bytes.+(this.layout.offsets.apply(1))(offset)
      , a$proxy3.apply(1).asInstanceOf[Int])
      ():Unit
    }:Unit
  }:Unit
}

I haven't removed the full names from this code-snippet, but I think they are clear enough. This code is much easier to generate than the original macro code did (you don't want to see that code, trust me). But how fast does it run?

AssignBenches.assignCaseClass2       thrpt    5    21.146 ±  0.790  ops/us
AssignBenches.assignInt2             thrpt    5   586.971 ± 54.264  ops/us
AssignBenches.assignCaseClass2          ss       5289.117            us/op
AssignBenches.assignInt2                ss       4339.710            us/op

So the new code is around 5x faster than before, but the write speed for Int has doubled, so this is more of a 2.5x speedup comparatively. Still, 27x slower than writing an Int is not a good situation, considering a theoretical best should be around 4x slower for this data structure. Can we do better?

Some of my prototypes using other forms of inline methods and macros got up to 50ops/us, but is it possible to go faster? Well, not really at compile time. One of the reasons writing Int has been so fast is that we know it's size at compile time, and knowing where to write it is relatively simple. On the other hand, knowing where to write the elements of X is really very platform specific. Alignment, and how data for structs is padded depends on the host platform. X itself might be a simple case, but we want a general purpose way to write the data for any Struct, and we don't want to encode alignment and padding rules at compile time that may not apply at runtime.

So we're stuck right? Wrong. Scala 3 has a new feature that will save us: Runtime multi-stage programming. Runtime multi-stage programming in Scala 3 works a lot like macros. The big difference is that it's more limited with regards to what code it can produce, and it can include runtime data inside of code it generates.

Let me repeat that last part one more time: it can include runtime data inside of code it generates. To put it simply, the slowness of our Send implementations were related to the fact that they were reliant on data that was only available at runtime. To write the Send instance perfectly for a case class, we would have to know what platform it would run on in advance, and it would only be available for that platform. Using runtime multi-stage programming, we can avoid this limitation by embedding parsed runtime information into the code we want to generate, allowing us to write something much closer to the perfect implementation for any platform. If this promise sounds a lot like the promises that just-in-time compilation were supposed to bring, you're not far off the mark. The code generation of runtime multi-stage programming lets us write something like our own just-in-time compilation. But enough prattling, lets see the implementation:

  private def sendGenHelper(
      layout: DataLayout,
      rawMem: Expr[Mem],
      offset: Expr[Bytes],
      value: Expr[Any]
  )(using Quotes): Expr[Unit] =
    layout match
      case IntLayout(_, _) =>
        '{ $rawMem.write($value.asInstanceOf[Int], $offset) }
      case StructLayout(_, _, children) =>
        val fns = children.zipWithIndex.map {
          case (StructMember(childLayout, _, subOffset), idx) =>
            sendGenHelper(
              childLayout,
              rawMem,
              '{ $offset + ${ Expr(subOffset) } },
              '{ $value.asInstanceOf[Product].productElement(${ Expr(idx) }) }
            )
        }.toList
        Expr.block(fns, '{})

  def sendStaged(layout: DataLayout)(using Quotes): Expr[Send[Product]] =
    '{ (mem: Mem, offset: Bytes, a: Product) =>
      ${
        sendGenHelper(layout, 'mem, 'offset, 'a)
      }
    }

So, if this code reminds you of the quotes and splices examples I showed you before, it should. Runtime multi-stage programming uses the same quotes and splices as a normal macro, but there are some limitations. You cannot use generic types in the code, because it's not supported. Therefore, to write a generic "Send" generator with runtime multi-stage programming, we have to use the "Product" type instead of something like A. Let's go over what this code does. The "sendStaged" is what is used to generate the code, and it takes a DataLayout (which describes the native memory layout we want to write to). It then hands off the work to the "sendGenHelper" method, that matches on the DataLayout. Right now we only have support for Int and StructLayouts, but adding more is trivial. The IntLayout case produces a quote that splices in the offset passed into sendGenHelper, and passes the value sent into sendGenHelper as an Int (since it's an Any... No generics can sometimes suck). The write method on rawMem is overloaded to handle int, float, byte, etc. In the StructLayout case, the children are extracted from the layout. The layout of each structmember, as well as its offset is then extracted. The layout is sent to a recursive call of sendGenHelper, and the offset is spliced into the expression of the original offset, as well as the element from the product being selected via "productElement(${Expr(idx)})". These are then fed into Expr.block to produce a block of write expressions.

What kind of code do these methods produce for X?

(mem: Mem, offset: Bytes, a: Product) => {
  mem.write(
    a.asInstanceOf[Product].productElement(0).asInstanceOf[Int], 
    offset + 0L
  )
  mem.write(
    a.asInstanceOf[Product].productElement(1)
	    .asInstanceOf[Product].productElement(0).asInstanceOf[Int], 
    offset + 4L + 0L
  )
  mem.write(
    a.asInstanceOf[Product].productElement(1)
	  	 .asInstanceOf[Product].productElement(1).asInstanceOf[Int],
    offset + 4L + 4L
  )
  mem.write(
    a.asInstanceOf[Product].productElement(2).asInstanceOf[Int],
    offset + 12L
  )
}

Almost completely perfect. Instead of having to access an array, or search through a data structure like MemoryLayout, the offsets are embedded right in the code, in a format that can easily be reduced and optimized by the JVM's JITC. This code is almost completely perfect, and would be what someone would perhaps write by hand if they knew in advance the target platform and how padding and byte alignment worked for C on it. And what's best is that this code will change as it needs to, depending on the host that runs it.

Before we decide we've achieved nirvana, we should benchmark how long it takes to write X with this...

Benchmark                             Mode  Cnt       Score       Error   Units
AssignBenches.assignCaseClass2       thrpt   25     301.874 ±     4.852  ops/us
AssignBenches.assignInt2             thrpt   25     593.422 ±    12.234  ops/us
AssignBenches.assignCaseClass2          ss    5  815860.330 ± 40784.477   us/op
AssignBenches.assignInt2                ss    5    4391.598 ±   279.541   us/op

The throughput of the code generated by runtime multi-stage programming is amazingly high. 148ops/microsecond would be good considering a single integer write can take place 593.4 times a microsecond. I surmise that the code is being optimized by the jit to write two of the ints at once at least, which is probably not done for the assignInt2 bench since each iteration is a single call of the int write method. In any case, this seems at first glance to be a wonderful result, but there's a reason I've been including single-shot benches so far...

Single shot mode in jmh can be useful to measure the cold-run time of code we're benching. For the int assignment, it takes 4391 microseconds to do the assignment cold. For the writer generated by runtime multi-stage programming, it takes 815860 microseconds!! Nearly a full second! The compile-time generated code didn't run much slower than the int writer when cold, so what's going on.

Well, runtime multi-stage programming involves embedding a Scala compiler in your code. And the Scala compiler is not well known for its speed, especially when running cold. In sampling mode with no warmup, the new code sees one sample that takes this long, with all other samples taking way less time. Likewise, the first run of the scala compiler we've embedded is this slow, speeding up to taking at max 50000 microseconds per compilation of a "Send" the complexity of X's.

This can be quite expensive. Adding a second of bootup (or random second long pause) to our program can be nasty, and in the worst case, each 20 compilations by this built in compiler can add an additional 1 second of pause time to a program. What can we do?

Well, we can do like any good JIT does! We have a slow and a fast version of our code!

      private lazy val sender: AtomicReference[Send[Product]] = 
        jit()
        AtomicReference(
          StructI.calcSender[A](layout.offsets)
        )

      def jit() = if useJit then
        given ExecutionContext = exec
        Future {
          val fn = run {
            val code = Send.sendStaged(layout)
            println(code.show)
            code
          }
          sender.lazySet(fn)
        }

      // jit()

      def to(mem: Mem, offset: Bytes, a: A): Unit =
        import scala.language.unsafeNulls
        sender.get().to(mem, offset, a)

When we first attempt to use "Send" for X, the method in use is the compile-time one that has 25ops/us speed. This one is slow, but doesn't take long to call when cold. In the meantime, if we've enabled JIT, we compile the way faster, staged version on a dedicated compiler thread, and swap out the implementation when its ready. With this setup, our benchmarks now look like:

Benchmark                             Mode  Cnt      Score      Error   Units
AssignBenches.assignCaseClass2       thrpt   25    261.051 ±    4.875  ops/us
AssignBenches.assignCaseClass2NoJIT  thrpt   25     21.213 ±    1.297  ops/us
AssignBenches.assignInt2             thrpt   25    587.855 ±   12.594  ops/us
AssignBenches.assignCaseClass2          ss    5  11494.275 ± 1899.774   us/op
AssignBenches.assignCaseClass2NoJIT     ss    5   5842.439 ±  572.122   us/op
AssignBenches.assignInt2                ss    5   4536.208 ±  314.503   us/op

Not quite as fast as before, but at the same time, not quite as slow on cold runs. This is a happy medium, with greatly improved performance, while still having decent cold-start performance.

A quick note: I tested this with GraalVM 22-17, because GraalVM is known to optimize the Scala compiler well. It didn't work well for this code though. Maybe the Graal compiler doesn't optimize panama's constructs as well as the default hotspot jitc?

In any case, I hope you found this article enlightening, and have come to see the worth of one of the most obscure Scala 3 features.

Happy Scala hacking!!