Extreme JNI Performance

Extreme performance requires pushing the engineering limits
of the time. Using continuations in C++11 and Java 7 (via invoke)
is just that - taking the very latest stuff and using it to go very, very
fast indeed.

The image is Wikicommons,
JNI exists for two things: 1 to integrate legacy native code with Java. 2: To allow native code to do things which it does faster or better than Java. But calling Java (or any other JVM language) from native is SLOW. Using functional techniques and method handles we can unlock massive performance gains in this area!

Continuations are a way of not having to use the stack but maintaining a program with a logical forwards flow. What they do is say 'do this, when you have finished, do that'. In other words, they pass the chain of control to the thing being called rather than the traditional stack based approach of calling something then calling something else from a central point of control. 

In C++11 it is easy to write continuations using closures (lambdas which close around enclosing state). In Java 7 a similar effect can be achieved using method handles and partial application. But why would we want to do this? It is all driven by the poor performance of calling JVM languages from native.

Calling native code from the JVM is very fast. The JIT optimiser in the JVM can optimise all the code around the call and then just jump into the native code. Not only is it technically possible to make it fast, native methods are called quite a bit from the many implementations of the Java Standard Library and so there is a big incentive to get JVM->native fast.

Calling JVM code from native is slow. The JVM knows nothing about the native code and so cannot optimise its call into the JVM. The route taken is very similar to making a reflection call. This is very slow. Also, there is little incentive for JVM developers to improve performance of calls into the JVM from native.

The latter is a problem because using the services of the JVM from native code is a very powerful technique. Also, even when a JVM language is calling native, it is often useful to bubble back up from the native to the JVM code to perform work.

Note: From now on I will take about Java but this could apply to any JVM language.

Now, this is one of the most complex posts I have ever done. I am happy to answer any questions about it. I must also state that this proof of concept only. I have created code which proves the principle and it is included here at the bottom of the post. Before giving all of the code, here is a description of the pieces. Please note that is code compiles under Java 7u3 from Oracle and Visual C++11 beta from Microsoft. If you are not familiar with Java 7 or C++11 quite a bit of it might not make much sense as it uses the new features of both!

How This Works Via Images:

A traditional call system where C++ calls up to Java using the
slow Call*Method interface in JNI.

The fast system where C++ calls up to Java and Java
calls back to C++ using the fast native invoke system. C++
retains flow control by passing continuations to Java on the
initial call and in the return of all the native calls.

The same approach works where a complex C++ call
from Java would normally cause many calls up from C++
to Java. This works where, for example, C++ needs to access
or update Java objects based on logic choices in C++ which
cannot be anticipated at the start of the native call.

Also: The running sequence of the program (below) is described in detail in this video on Youtube. This shows me stepping through the program using debugging in VC++11:

On Youtube: http://www.youtube.com/watch?v=Ea4BfPHdepI&hd=1

Embedded (watch in HD to see clearly):


Creating and managing a JVM:
Here I am using the JVM as a service to C++. The technique could work just as well the other way around. However, because I am doing it this way around I want to manage the life-cycle of the JVM from C++. RAII is the technique I choose for this:


// An RAII object for managing JVM instances
  class JVM{
 
  private:
    // RAII
    JVM(const JVM &);
    JVM& operator=(const JVM &);

    // Real
    JNIEnv* create_vm(const char*);
    JNIEnv* env; // the env for creation, do not used in calls from JVM!
    JavaVM* jvm;
    vector<jobject> globals;

  public:
    JVM(const char* cp)  {create_vm(cp);}
    ~JVM();

    void        bindNativeMethod(char*,char*,char*,void*);
    jmethodInfo findJVMMethod   (char*, char*, char*,bool);
    jfieldInfo  findJVMField    (char*, char*, char*,bool);
    JNIEnv*     getEnv          (){return env;}
    void        makeJNICall     (function<void()>);
  };


The class will clean up the JVM and any registered global reference in it when it is shut down. It also offers services from the JVM to the rest of the code. Before looking at those services we should take a look as a really simple function which everything uses to access the JVM:


// Simple way to handle JNI calls safely
  void makeJNICall(JNIEnv* env,function<void()> func){
    env->ExceptionClear();
    func();
    if(env->ExceptionCheck()){
      env->ExceptionDescribe();
      throw env->ExceptionOccurred();
    }
  }


This is using the new lambda (closure) system in C++11 to create a stupid simple way of coupling the JVM's and C++'s exception handling system. My code does not do anything sensible with the thrown exceptions but it least it does the coupling.

The JVM class makes heavy use of this wrapper - for instance where it finds methods in Java which can be called from C++:


void JVM::bindNativeMethod(char* className,char* name,char* signature,void* functionPointer){
    jclass toBind;
    makeJNICall([&]{toBind=env->FindClass(className);});
    // TODO: a throw in here could leak local references
    makeJNICall([&]{
      JNINativeMethod ms[]={name,signature,functionPointer};
      env->RegisterNatives(toBind,ms,1);
    });
    makeJNICall([&]{
      env->DeleteLocalRef(toBind); 
    });
  }

The  Main Continuation Performing Loop:

C++

  void performContinue(JNIEnv *env,jobject self,jlongArray refs,jobject payload){
    cout << "Called from JVM to Native." <<  endl;
    jlong lptr[2];
    makeJNICall(env,[&]{
      env->GetLongArrayRegion(refs,0,1,lptr);
    }); 
    continuation* func = (continuation*)lptr[0];
    pair<jlong,jlong> pair=(*func)(env,payload);
    lptr[0]=pair.first;
    lptr[1]=pair.second;
    makeJNICall(env,[&]{
      env->SetLongArrayRegion(refs,0,2,lptr);
    });
  }


Java

  private native void performContinue(long[] continuation,ContinuationPayload payload);
  
  public void run(ContinuationPayload payload) throws Throwable{
    MethodHandle[] dispatch=javaDispatch.toArray(new MethodHandle[javaDispatch.size()]);
    performContinue(continuation,payload);
    while(continuation[0]!=0){
      dispatch[(int) continuation[1]].invokeExact(payload);
      performContinue(continuation,payload);
    }
  }


The Java loop keeps calling the C++ via the fast Java->native dispatch. The Java has no idea what the C++ is going to do. The Java passes down a pointer to a closure which the C++ uses as a continuation. Then the continuation finishes it sets in its return value a pointer to the next continuation. This process continues until a continuation returns a null pointer to indicate all the work is done.

In this example I am using method handles in the Java to allow the C++ to tell Java what to do. As well as a pointer to the next continuation, the C++ returns an index into a table of method handles. The Java loop uses this to invoke the action the C++ requires from Java. Other alternatives, like using Java objects all  implementing a simple interface would also work and might be somewhat faster.

And Here It Is:
Here is the code. This is not optimized for speed and is still rather clunky. The code here is just for proof of principle. As some point I would like to clean it up and do some performance tests

>=====================================< 
>============ C++  ===================< 
>=====================================<

#include "targetver.h"
#include <stdio.h>
#include <tchar.h>
#include "jni.h"
#include <iostream>
#include <functional>
#include <string>
#include <memory>
#include <utility>
#include <vector>
#include <algorithm>

using namespace std;

namespace nerds_central{

  typedef pair<jclass,jmethodID>               jmethodInfo;
  typedef pair<jclass,jfieldID>                jfieldInfo;
  typedef function<pair<jlong,jlong>(JNIEnv*,jobject)> continuation;

  // An RAII object for managing JVM instances
  class JVM{
 
  private:
    // RAII
    JVM(const JVM &);
    JVM& operator=(const JVM &);

    // Real
    JNIEnv* create_vm(const char*);
    JNIEnv* env; // the env for creation, do not used in calls from JVM!
    JavaVM* jvm;
    vector<jobject> globals;

  public:
    JVM(const char* cp)  {create_vm(cp);}
    ~JVM();

    void        bindNativeMethod(char*,char*,char*,void*);
    jmethodInfo findJVMMethod   (char*, char*, char*,bool);
    jfieldInfo  findJVMField    (char*, char*, char*,bool);
    JNIEnv*     getEnv          (){return env;}
    void        makeJNICall     (function<void()>);
  };

  typedef shared_ptr<JVM> JVM_ptr;

  JVM::~JVM(){
    for_each(globals.begin(),globals.end(),[&](jobject ref){
      env->DeleteGlobalRef(ref); // TODO handle exceptions - but cannot throw!
    });
    jvm->DestroyJavaVM();
  }

  JNIEnv* JVM::create_vm(const char* classPath) {
    JavaVMInitArgs args;
    JavaVMOption options[1];

    args.version = JNI_VERSION_1_6;
    args.nOptions = 1;

    //TODO: clunky
    // create the constant data incoming to something the JVM can use
    string classPathBuilder = "-Djava.class.path=";
    classPathBuilder+=classPath;
    const char* cp=classPathBuilder.c_str();
    // clean way of getting rid of const
    unique_ptr<char[]> cp_ptr(new char[classPathBuilder.size() +1]);
    memcpy(cp_ptr.get(),cp,classPathBuilder.size() +1);
    // I assume the string is copied in the JVM!
    options[0].optionString = cp_ptr.get(); 
    cout << "Launching with options '" << options[0].optionString << "'" << endl;

    // end of clunck - back to setting up JVM
    args.options = options;
    args.ignoreUnrecognized = JNI_FALSE;

    JNI_CreateJavaVM(&jvm, (void **)&env, &args);
    return env;
  }

  void JVM::bindNativeMethod(char* className,char* name,char* signature,void* functionPointer){
    jclass toBind;
    makeJNICall([&]{toBind=env->FindClass(className);});
    // TODO: a throw in here could leak local references
    makeJNICall([&]{
      JNINativeMethod ms[]={name,signature,functionPointer};
      env->RegisterNatives(toBind,ms,1);
    });
    makeJNICall([&]{
      env->DeleteLocalRef(toBind); 
    });
  }

  // Simple way to handle JNI calls safely
  void makeJNICall(JNIEnv* env,function<void()> func){
    env->ExceptionClear();
    func();
    if(env->ExceptionCheck()){
      env->ExceptionDescribe();
      throw env->ExceptionOccurred();
    }
  }

  void JVM::makeJNICall(function<void()> func){
    nerds_central::makeJNICall(env,func);
  }

  void performContinue(JNIEnv *env,jobject self,jlongArray refs,jobject payload){
    cout << "Called from JVM to Native." <<  endl;
    jlong lptr[2];
    makeJNICall(env,[&]{
      env->GetLongArrayRegion(refs,0,1,lptr);
    }); 
    continuation* func = (continuation*)lptr[0];
    pair<jlong,jlong> pair=(*func)(env,payload);
    lptr[0]=pair.first;
    lptr[1]=pair.second;
    makeJNICall(env,[&]{
      env->SetLongArrayRegion(refs,0,2,lptr);
    });
  }

  jmethodInfo JVM::findJVMMethod(char* clazz, char* name, char* signature,bool isStatic){
    jmethodInfo ret;
    makeJNICall([&]{
      ret.first=env->FindClass(clazz);
    });
    makeJNICall([&]{
      ret.first=(jclass)(env->NewGlobalRef(ret.first));
      globals.push_back(ret.first);
    });
    makeJNICall([&]{
      ret.second=isStatic
        ?env->GetStaticMethodID(ret.first,name,signature)
        :env->GetMethodID(ret.first,name,signature);
    });
    return ret;
  }

  
  jfieldInfo JVM::findJVMField(char* clazz, char* name, char* signature,bool isStatic){
    jfieldInfo ret;
    makeJNICall([&]{
      ret.first=env->FindClass(clazz);
    });
    makeJNICall([&]{
      ret.first=(jclass)(env->NewGlobalRef(ret.first));
      globals.push_back(ret.first);
    });
    makeJNICall([&]{
      ret.second=isStatic
        ?env->GetStaticFieldID(ret.first,name,signature)
        :env->GetFieldID(ret.first,name,signature);
    });
    return ret;
  }

  JVM_ptr setUpJVM(_TCHAR *cp){
    // TODO: quick and dirty conversion - should use locale
    string chars;
    for(int i=0;cp[i]!=L'\0';++i){
      chars.push_back((char)(cp[i]) & 255);
    }
    const char* cpc = chars.c_str();
    cout << "About to create JVM" << endl;
    shared_ptr<JVM> vm(new JVM(cpc));
    cout << "Created JVM" << endl;
    vm->bindNativeMethod(
      "com/nerdscentral/jni/ContinuationDriver",
      "performContinue",
      "([JLcom/nerdscentral/jni/ContinuationPayload;)V",
      &performContinue
    );
    cout << "performContinue bound" << endl;
    return vm;
  }

  void setPayloadString(JNIEnv* env,jobject payload,jfieldInfo strings,jint index,jstring data){
    makeJNICall(env,[&]{
      env->SetObjectArrayElement((jobjectArray)env->GetObjectField(payload,strings.second),0,data);
    });
  }

  const char* getPayloadString(JNIEnv* env,jobject payload,jfieldInfo strings,jint index){
    jstring js;
    makeJNICall(env,[&]{
      js=(jstring)env->GetObjectArrayElement((jobjectArray)env->GetObjectField(payload,strings.second),0);
    });
    const char* ret;
    makeJNICall(env,[&]{
      jboolean isCopy;
      ret=env->GetStringUTFChars(js,&isCopy);
    });
    return ret;
  }

  void releasePayloadString(JNIEnv* env,jobject payload,jfieldInfo strings,jint index,const char* chars){
    jstring js;
    makeJNICall(env,[&]{
      js=(jstring)env->GetObjectArrayElement((jobjectArray)env->GetObjectField(payload,strings.second),0);
    });
    makeJNICall(env,[&]{
      env->ReleaseStringUTFChars(js,chars);
    });
  }

  void runTest(_TCHAR *cp){
    JVM_ptr       vm = setUpJVM(cp);
    jmethodInfo test = vm->findJVMMethod(
      "com/nerdscentral/jni/ContinuationTester",
      "test",
      "(J)V",
      true
    );
    jfieldInfo strings=vm->findJVMField(
      "com/nerdscentral/jni/ContinuationPayload",
      "strings",
      "[Ljava/lang/String;",
      false
      );
    // Call this method with the initial continuation being C++ which 
    // runs the test - all in line here.
    continuation step1;
    continuation step2;
    continuation step3;
    string data="1234abcd4321";
    string* dataPtr=&data;
    cout << "Before first call string = '" << data << "'" << endl;

    step3=[dataPtr,step3,strings](JNIEnv* env,jobject payload)->pair<jlong,jlong>{
      const char* str = getPayloadString(env,payload,strings,1);
      cout << "In step3: '" << str << "'" << endl;
      releasePayloadString(env,payload,strings,1,str);
      // returning NULL as the lambda pointer ends the call back sequence
      return pair<jlong,jlong>((jlong)NULL,0);
    };

    step2=[dataPtr,step3,strings](JNIEnv* env,jobject payload)->pair<jlong,jlong>{
      const char* str = getPayloadString(env,payload,strings,1);
      cout << "In step2: '" << str << "'" << endl;
      releasePayloadString(env,payload,strings,1,str);
      return pair<jlong,jlong>((jlong)&step3,0);
    };

    // This is the code which will actually run.
    // TODO:
    // So far these need to be done in reverse order for the & operator for the continuation
    // to function correctly - this needs fixing up in a later version
    step1=[dataPtr,step2,strings](JNIEnv* env,jobject payload)->pair<jlong,jlong>{
      jstring jdata=env->NewStringUTF(dataPtr->c_str());
      setPayloadString(env,payload,strings,1,jdata);
      cout << "In step1: jstring has been created" << endl;
      return pair<jlong,jlong>((jlong)&step2,0);
    };

    // Start the call sequence - this will call step1
    // step1 will set up the string and return
    // the Java will then call step2
    // then step 3
    // then return and the C++ will end
    vm->makeJNICall([vm,test,&step1]{
      vm->getEnv()->CallStaticVoidMethod(test.first,test.second,&step1);
    });
    cout << "*** ALL DONE *** " << endl;
  }

};

int _tmain(int argc, _TCHAR* argv[]){
  if(argc!=2){
    cout << "One argument required which is the classpath" << endl;
    abort();
  }
}


>=====================================<
>============ JAVA ===================<
>=====================================<

*------------------------------------------------------------------------*

package com.nerdscentral.jni;

import java.lang.invoke.MethodHandle;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class ContinuationDriver {
  private List<MethodHandle> javaDispatch=new ArrayList<>();
  private long[] continuation; // stores a pointer to the lambda to perform next
  
  public ContinuationDriver(long initialContinuation) {
    continuation=new long[2];
    continuation[0]=initialContinuation;
  }

  private native void performContinue(long[] continuation,ContinuationPayload payload);
  
  public void run(ContinuationPayload payload) throws Throwable{
    MethodHandle[] dispatch=javaDispatch.toArray(new MethodHandle[javaDispatch.size()]);
    performContinue(continuation,payload);
    while(continuation[0]!=0){
      dispatch[(int) continuation[1]].invokeExact(payload);
      performContinue(continuation,payload);
    }
  }

  public MethodHandle addDispatch(int index,MethodHandle location){
    if(index+1>javaDispatch.size()){
      javaDispatch.addAll(Arrays.asList(new MethodHandle[1+index-javaDispatch.size()]));
    }
    javaDispatch.set(index, location);
    return location;
  } 
}

*------------------------------------------------------------------------*

package com.nerdscentral.jni;

import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;

public class ContinuationTester {
  
  /** C++ calls up to here */
  public static void test(long initialContinuation) throws Throwable{
    ContinuationDriver cd=new ContinuationDriver(initialContinuation);
    cd.addDispatch(
        0,
        MethodHandles.lookup().findStatic(
            ContinuationTester.class,
            "swapCase",
            MethodType.methodType(
                void.class,
                ContinuationPayload.class
            )
          )
        );
    cd.run(new ContinuationPayload(10));
  }
  
  /** Java calls this method to invert case of a string */
  public static void swapCase(ContinuationPayload payload){
    String s = payload.strings[0];
    s=s.toUpperCase().equals(s)?s.toLowerCase():s.toUpperCase();
    payload.strings[0]=s;
  }

}

*------------------------------------------------------------------------*

package com.nerdscentral.jni;

public class ContinuationPayload{
  long[]   longs;
  double[] doubles;
  String[] strings;
  public ContinuationPayload(int size){
    longs   = new long[size];
    doubles = new double[size];
    strings = new String[size];
  }
}


Here is the output:
  1. About to create JVM
  2. Launching with options '-Djava.class.path=E:\workspace\JNI-FreeChart\bin'
  3. Created JVM
  4. performContinue bound
  5. Before first call string = '1234abcd4321'
  6. Called from JVM to Native.
  7. In step1: jstring has been created
  8. Called from JVM to Native.
  9. In step2: '1234ABCD4321'
  10. Called from JVM to Native.
  11. In step3: '1234abcd4321'
  12. *** ALL DONE ***