Categories
General

JAGS: normal and not-normal performance

Previously, we’ve walked through a coin-tossing example in JAGS and looked at the runtime performance. In this episode, we’ll look at the cost of different distributions.

Previously, we’ve used a uniform prior distribution. Let’s baseline on 100,000 step chain, with 100 datapoints. For uniform prior, JAGS takes a mere 0.2secs. But change to a normal prior, such as dnorm(0.5,0.1), and JAGS takes 3.3sec – with __ieee754_log_avx called from DBern::logDensity taking up 80% of the CPU time, according to perf:


  70.90%  jags-terminal  libm-2.19.so         [.] __ieee754_log_avx
   9.27%  jags-terminal  libjags.so.4.0.2     [.] _ZNK4jags20ScalarStochasticNode10logDensityEjNS_7PDFTypeE
   5.09%  jags-terminal  bugs.so              [.] _ZNK4jags4bugs5DBern10logDensityEdNS_7PDFTypeERKSt6vectorIPKdSaIS5_EES5_S5_

If we go from bernoulli data with normal prior, to normal data with normal priors on mean/sd, it gets more expensive again – 4.8 seconds instead of 3.3 – as the conditional posterior gets more complex. But still it’s all about the logarithms.

Logarithms aren’t straightforward to calculate. Computers usually try to do a fast table lookup, falling back to a series-expansion approach if that doesn’t work. On linux, the implementation comes as part of the GNU libc, with the name suggesting that it uses AVX instructions if your CPU is modern enough (there’s not a “logarithm” machine code instruction, but you can use AVX/SSE/etc to help with your logarithm implementation).

Notably, JAGS is only using a single core throughout all of this. If we wanted to compute multiple chains (eg. to check convergence) then the simple approach of running each chain in a separate JAGS process works fine – which is what the jags.parfit R package does. But could you leverage the SIMD nature of SSE/AVX instruction to run several chains in parallel? To be honest, the last time I was at this low level, SSE was just replacing MMX! But since AVX seems to have registers which hold eight 32-bit floats perhaps there’s the potential to do 8 chains in blazing data-parallel fashion?

(Or alternatively, how many people in the world care both about bayesian statistics and assembly-level optimization?!)

Just for reference, on my laptop a simple “double x=0; for (int i=0; i<1e8; i++) x = log(x);” loop takes 6.5 seconds, with 75% being in __ieee754_log_avx – meaning each log() is taking 48ns.

To complete the cycle, let’s go back to JAGS with a simple uniform prior and bernoulli likelihood, only do ten updates, with one datapoints and see how many times ‘log’ is called. For this, we can use ‘ltrace’ to trace calls to shared objects like log():


$ ltrace -xlog   $JAGS/jags/libexec/jags-terminal  example.jags  2>&1 | grep -c log@libm.so

Rather surprisingly, the answer is not stable! I’ve seen anything from 20 to 32 calls to log() even though the model/data isn’t changing (but the random seed presumably is). Does that line up with the 3.4 seconds to do 10 million steps @ 10 data points, if log() takes 48ns? If we assume 2 calls to log(), then 10e6 * 10 * 2 * 48e-9 = 9.6secs. So, about 2.8x out but fairly close.

Next step is to read through the JAGS code to understand the Gibbs sampler in detail. I’ve already read through the two parsers and some of the Graph stuff, but want to complete my understanding of the performance characteristics.