Turning megamorphic virtual calls into MethodHandles

This thread is about reducing the virtual call overhead (from my perspective at least). It’s understandable that inlining can’t fix megamorphic call sites in general. scala.concurrent.impl.Promise$Transformation.run is an interpreter of Futures, which in itself is a hard task for optimizer (as you’ve said) but additionally it crosses async barriers (i.e. executing threads can change during interpretation) so reliable inlining is pretty much impossible.

From my benchmark numbers on parasitic and trivial ExecutionContexts, going from regular2 (bimorphic CPU-predictable call) to regular8 (megamorphic CPU-predictable call) increases execution time by 20% or so (and that’s similar with pre-completed Futures and post-completed Futures). That’s non-negligible, especially given the fact that Futures do a nontrivial amount of work in addition to f.apply (they use atomic swaps in the end, regardless of ExecutionContext). With more lightweight types the megamorphic dispatch should have higher overhead.

What JVM lacks in my opinion is native support for free-floating closures. Invoking f.apply requires the same type of dispatch as invoking any other virtual method, so it’s costly. Had JVM had native reified closures then the call would reduce to one or two pointer dereferences plus jump. OTOH intoducing an incompatible type instead of lightweight SAM types would be harder to migrate to gradually. So in the end we’re left with MethodHandles to deal with optimizations on a low-level.

I’ve experimented with MethodHandles but it doesn’t seem that they could improve performance over normal virtual method dispatch. There’s MethodHandles.lookup().findSpecial, but:

  • it must be placed inside the subclass, so it rules out SAM types (lambdas can implement only one abstract method of an interface and nothing else), so I’ve experimented with full subclasses
  • you can’t prepare a simple MethodHandle invokable using invokeExact. You have to do MethodHandles.lookup().findSpecial.asType and only then you can do .invokeExact on it
  • alternatively you can do MethodHandles.lookup().findSpecial without .asType but then you can’t do .invokeExact and you’re left with .invoke
  • overall there’s rather a performance degradation than any gain

My benchmark:

package pl.tarsa.megamorphic_overhead.jmh;

import org.openjdk.jmh.annotations.*;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodType;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.Random;
import java.util.concurrent.TimeUnit;

import static java.lang.invoke.MethodHandles.lookup;

@SuppressWarnings({"WeakerAccess", "SwitchStatementWithTooFewBranches"})
@State(Scope.Benchmark)
@BenchmarkMode({Mode.AverageTime})
@OutputTimeUnit(TimeUnit.MICROSECONDS)
@Warmup(iterations = 1)
@Measurement(iterations = 1)
@Fork(value = 1,
        jvmArgsAppend = {"-Xmx1G", "-Xms1G", "-XX:+AlwaysPreTouch"})
@Threads(value = 1)
public class JavaRioBench {
    @Param({"true", "false"})
    private String useMhStr;

    private Rio<Integer> rio;

    @Setup
    public void setup() {
        rio = wrap(8);
        new Random(0).ints(100, 0, 8).forEachOrdered(index -> {
            boolean useMh;
            switch (useMhStr) {
                case "true":
                    useMh = true;
                    break;
                case "false":
                    useMh = false;
                    break;
                default:
                    throw new RuntimeException();
            }
            rio = rioSelect(index, rio, useMh);
        });
    }

    @Benchmark
    public Object xxx() {
        return interpret(rio);
    }

    class RioType {
        static final int SUCCEED = 0;
        static final int FLAT_MAP = 1;
    }

    @SuppressWarnings("unused")
    abstract class Rio<A> {
        final int rioType;

        Rio(int rioType) {
            this.rioType = rioType;
        }
    }

    abstract class RioFn<A, B> {
        final int rioType;
        final MethodHandle mHandle;
        final boolean useMh;

        RioFn(int rioType, MethodHandle mHandle, boolean useMh) {
            this.rioType = rioType;
            this.mHandle = mHandle;
            this.useMh = useMh;
        }

        abstract Rio<B> apply(A a);
    }

    final class Succeed<A> extends Rio<A> {
        final A value;

        Succeed(A value) {
            super(RioType.SUCCEED);
            this.value = value;
        }
    }

    final class FlatMap<A, B> extends Rio<B> {
        final Rio<A> rioInput;
        final FlatMapFn<A, B> rioFn;

        FlatMap(Rio<A> rioInput, FlatMapFn<A, B> rioFn) {
            super(RioType.FLAT_MAP);
            this.rioInput = rioInput;
            this.rioFn = rioFn;
        }
    }

    abstract class FlatMapFn<A, B> extends RioFn<A, B> {
        FlatMapFn(MethodHandle mHandle, boolean useMh) {
            super(RioType.FLAT_MAP, mHandle, useMh);
        }
    }

    static MethodType flatMapFnParamsType =
            MethodType.methodType(Rio.class, Object.class);

    static MethodType flatMapFnFullType = MethodType.methodType(
            Rio.class, FlatMapFn.class, Object.class);

    <A> Rio<A> wrap(A value) {
        return new Succeed<>(value);
    }

    Rio<Integer> rioSelect(int idx, Rio<Integer> rioInput, boolean useMh) {
        try {
            return rioSelect1(idx, rioInput, useMh);
        } catch (Exception ex) {
            throw new RuntimeException(ex);
        }
    }

    Rio<Integer> rioSelect1(int idx, Rio<Integer> rioInput, boolean useMh)
            throws Exception {
        switch (idx) {
            case 0:
                return new FlatMap<>(rioInput, new FlatMapFn0(useMh));
            case 1:
                return new FlatMap<>(rioInput, new FlatMapFn1(useMh));
            case 2:
                return new FlatMap<>(rioInput, new FlatMapFn2(useMh));
            case 3:
                return new FlatMap<>(rioInput, new FlatMapFn3(useMh));
            case 4:
                return new FlatMap<>(rioInput, new FlatMapFn4(useMh));
            case 5:
                return new FlatMap<>(rioInput, new FlatMapFn5(useMh));
            case 6:
                return new FlatMap<>(rioInput, new FlatMapFn6(useMh));
            case 7:
                return new FlatMap<>(rioInput, new FlatMapFn7(useMh));
        }
        throw new RuntimeException();
    }

    final class FlatMapFn0 extends FlatMapFn<Integer, Integer> {
        FlatMapFn0(boolean useMh) throws Exception {
            // NOTE couldn't figure out how to prepare a MethodHandle in single
            // findSpecial call without any adjustments done by asType,
            // invoke (without Exact), etc
            super(lookup().findSpecial(FlatMapFn0.class, "apply",
                    flatMapFnParamsType, FlatMapFn0.class)
                    .asType(flatMapFnFullType), useMh);
        }

        @Override
        Rio<Integer> apply(Integer a) {
            return wrap(a + 5);
        }
    }

    final class FlatMapFn1 extends FlatMapFn<Integer, Integer> {
        FlatMapFn1(boolean useMh) throws Exception {
            super(lookup().findSpecial(FlatMapFn1.class, "apply",
                    flatMapFnParamsType, FlatMapFn1.class)
                    .asType(flatMapFnFullType), useMh);
        }

        @Override
        Rio<Integer> apply(Integer a) {
            return wrap(a * 7);
        }
    }

    final class FlatMapFn2 extends FlatMapFn<Integer, Integer> {
        FlatMapFn2(boolean useMh) throws Exception {
            super(lookup().findSpecial(FlatMapFn2.class, "apply",
                    flatMapFnParamsType, FlatMapFn2.class)
                    .asType(flatMapFnFullType), useMh);
        }

        @Override
        Rio<Integer> apply(Integer a) {
            return wrap(a + (a >> 3));
        }
    }

    final class FlatMapFn3 extends FlatMapFn<Integer, Integer> {
        FlatMapFn3(boolean useMh) throws Exception {
            super(lookup().findSpecial(FlatMapFn3.class, "apply",
                    flatMapFnParamsType, FlatMapFn3.class)
                    .asType(flatMapFnFullType), useMh);
        }

        @Override
        Rio<Integer> apply(Integer a) {
            return wrap(a * 3 + 5);
        }
    }

    final class FlatMapFn4 extends FlatMapFn<Integer, Integer> {
        FlatMapFn4(boolean useMh) throws Exception {
            super(lookup().findSpecial(FlatMapFn4.class, "apply",
                    flatMapFnParamsType, FlatMapFn4.class)
                    .asType(flatMapFnFullType), useMh);
        }

        @Override
        Rio<Integer> apply(Integer a) {
            return wrap(a * 4 + 4);
        }
    }

    final class FlatMapFn5 extends FlatMapFn<Integer, Integer> {
        FlatMapFn5(boolean useMh) throws Exception {
            super(lookup().findSpecial(FlatMapFn5.class, "apply",
                    flatMapFnParamsType, FlatMapFn5.class)
                    .asType(flatMapFnFullType), useMh);
        }

        @Override
        Rio<Integer> apply(Integer a) {
            return wrap(a + 55);
        }
    }

    final class FlatMapFn6 extends FlatMapFn<Integer, Integer> {
        FlatMapFn6(boolean useMh) throws Exception {
            super(lookup().findSpecial(FlatMapFn6.class, "apply",
                    flatMapFnParamsType, FlatMapFn6.class)
                    .asType(flatMapFnFullType), useMh);
        }

        @Override
        Rio<Integer> apply(Integer a) {
            return wrap(a / 3);
        }
    }

    final class FlatMapFn7 extends FlatMapFn<Integer, Integer> {
        FlatMapFn7(boolean useMh) throws Exception {
            super(lookup().findSpecial(FlatMapFn7.class, "apply",
                    flatMapFnParamsType, FlatMapFn7.class)
                    .asType(flatMapFnFullType), useMh);
        }

        @Override
        Rio<Integer> apply(Integer a) {
            return wrap((a << 1) + (a >> 1));
        }
    }

    @SuppressWarnings("unchecked")
    <A> A interpret(Rio<A> originalRio) {
        Rio currentRio = originalRio;
        Deque<RioFn> stack = new ArrayDeque<>();
        while (currentRio.rioType != RioType.SUCCEED || !stack.isEmpty()) {
            switch (currentRio.rioType) {
                case RioType.SUCCEED:
                    Succeed rio1 = (Succeed) currentRio;
                    RioFn rioFn = stack.removeFirst();
                    switch (rioFn.rioType) {
                        case RioType.FLAT_MAP:
                            if (rioFn.useMh) {
                                try {
                                    currentRio = (Rio) rioFn.mHandle
                                            .invokeExact((FlatMapFn) rioFn,
                                                    rio1.value);
                                } catch (Throwable throwable) {
                                    throw new RuntimeException(throwable);
                                }
                            } else {
                                currentRio = rioFn.apply(rio1.value);
                            }
                            break;
                        default:
                            throw new RuntimeException();
                    }
                    break;
                case RioType.FLAT_MAP:
                    FlatMap rio2 = (FlatMap) currentRio;
                    stack.addFirst(rio2.rioFn);
                    currentRio = rio2.rioInput;
                    break;
                default:
                    throw new RuntimeException();
            }
        }
        return ((Succeed<A>) currentRio).value;
    }

    public static void main(String[] args) {
        JavaRioBench self = new JavaRioBench();
        Rio<Integer> rio = self.rioSelect(0, self.wrap(5), true);
        System.out.println(self.interpret(rio));
    }
}

Results:

[info] # Run complete. Total time: 00:00:40
[info] Benchmark         (useMhStr)  Mode  Cnt  Score   Error  Units
[info] JavaRioBench.xxx        true  avgt       1,757          us/op
[info] JavaRioBench.xxx       false  avgt       1,361          us/op

First result (worse one) is for MethodHandles enabled.

If you want JIT to inline MethodHandles you should make sure that they are always initialised into final static fields, which I’ve seen @DanHeidinga ‏ recommend at conferences.

To access a static field I need the class name at compile time. If I have that then I don’t need any MethodHandles (in this case at least) - I can do explicit downcast and invoke a final method.

I was hoping a MethodHandle can work like a pointer to a resolved virtual method, but it doesn’t seem so.

A fun example of this is the FastParse 2.0 parser combinator library, which in version 2.0 removes the interpretation layer in favor of “direct style” code, taking advantage of the fact that most grammars are in fact static and known up front. Thus it generates programs that are statically known at compile time, allowing optimizations to occur much more freely: whether via scalac macros, scalac’s optimizer, or the JVM’s optimizer which is designed for statically-known programs.

This in turn gives a ~4x speedup over the FastParse 1.0 design which relies on interpreting data structures. This speedup is over real world parsers for real grammars, despite the fact that FastParse 1.0 had already been optimized as much as possible given its design!

This is incorrect. Dispatching to virtual method calls, including megamorphic ones, is already one or two pointer dereferences plus jump. The JVM also already has native support for free-floating closures: They are called method handles, and they behave as well as they do: not significantly faster or slower than closures implemented via anonymous inner classes, as you have already seen.

This should not be surprising at all, because in the end, it goes through the same virtual method call machinery inside the JVM. “native support” means nothing when the underlying data structures and algorithms are unchanged, as they are here. There were many reasons to change lambdas encoding from anonymous inner classes to method handles, but a steady-state performance boost was not one of them.

1 Like