LeanVM Development Progress #2 - Packed Poseidon2 and JAX Code Generation

10 min read
zkpposeidon2jaxmlircompileropcodes

In our previous report, we showed that our MLIR-based Poseidon2 outperformed Plonky3 on regular PrimeField operations, but lagged behind on Packed PrimeField. This report covers improvements on four major areas:

Packed Poseidon2 in MLIR

Benchmark Results

Following the optimizations described below, we achieved performance parity with Plonky3 on Packed PrimeField operations:

HardwareWidthPlonky3ZKIRRatio (Plonky3 / ZKIR)
AMD 9950X3D167.38 ms7.38 ms1.00
Mac M4 Pro45.51 ms5.58 ms0.98

The table shows the total time for Width × 10,000 hash operations. The benchmark used BabyBear, but similar results are expected for KoalaBear.

Compared to our previous benchmark where we showed a 0.8 ratio on Packed PrimeField, we have now closed the performance gap and achieved competitive performance with Plonky3.

Optimizations Implemented

To achieve this performance improvement, we implemented several algebraic rewrites and architecture-specific optimizations in the ZKIR compiler pass:

1. Unrolling Square-and-Mul

For field exponents, we had a square-and-mul algorithm but it was still not fast enough. Unrolling the while loop when the exponent is constant gave us a nice boost. Now, we can just do field.powui %a, %arbitrary_constant and have unrolled square-and-mul.

a5(a2)2×aa^5 \to (a^2)^2 \times a

This optimization is particularly important in Poseidon2's S-box operation, which computes x7x^7.

2. Rewriting Power-of-Two Division

When doing modular multiplication over inverse power of two, we can do better than standard montgomery multiplication. We added a logic to detect such constant operand and rewrite this operation to a more efficient algorithm.

3. Removing Redundant Reduction

When computing modular (ab)2(a - b)^2, we don't actually need to reduce aba-b's range from (P,P)[0,P)(-P,P) \rightarrow [0, P) since the next operation is squaring.

4. Vector/Tensor Constant Folding

We enabled constant folding of tensor and vector types specifically for the mod_arith dialect, allowing compile-time evaluation of modular arithmetic operations on constants.

5. Specializing for AVX-512

For x86 platforms with AVX-512 support (like the AMD 9950X3D), we rewrite standard arithmetics to custom dual-lane SIMD instructions for peak CPU utilization. This is particularly beneficial when there is a lot of chained operations since the lane splitting at each operation border cancels out thanks to MLIR Pattern Rewriting.

       a_odd  *  b_odd 
      /                \
a * b                    ab_high, ab_low
      \                / 
       a_even * b_even

6. Specializing for ARM-NEON

For ARM platforms (Mac M4 Pro, AWS Graviton), we implemented a separate lowering logic that leverages NEON SIMD instructions. This optimization is currently under code review and pending merge.

JAX Integration

For the first time, we have successfully implemented a code generation pipeline from JAX to ZKX. This is a significant milestone as it demonstrates our compiler infrastructure can target high-level Python frameworks while generating optimized ZK prover code.

Current Support

The pipeline currently supports:

  • Boolean types
  • Integer types

We plan to add prime-field type support within the next two weeks, which will enable direct ZK circuit compilation from JAX code.

Code Example

Here's a simple example demonstrating the pipeline:

import jax

# Define a simple function
fast_f = jax.jit(lambda x: x + 1)

# Execute it
print(fast_f(1))  # Output: 2

This trivial example illustrates the compilation flow. Let's examine how this code is transformed through our pipeline.

Compilation Pipeline

Step 1: StableHLO Generation

JAX first compiles the Python code to StableHLO (Stable High-Level Operations), which is a portable intermediate representation:

module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32>) -> (tensor<i32> {jax.result_info = "result"}) {
    %c = stablehlo.constant dense<1> : tensor<i32>
    %0 = stablehlo.add %arg0, %c : tensor<i32>
    return %0 : tensor<i32>
  }
}

Step 2: HLO Lowering

StableHLO is then lowered to HLO (High-Level Operations), which is closer to the computational representation:

ENTRY %main.4 (Arg_0.1: s32[]) -> s32[] {
  %Arg_0.1 = s32[] parameter(0), metadata={op_name="x"}
  %constant.2 = s32[] constant(1)
  ROOT %add.3 = s32[] add(s32[] %Arg_0.1, s32[] %constant.2)
}

Step 3: MLIR Translation

Finally, HLO is converted to MLIR with explicit memory operations using the memref dialect:

module {
  func.func @add.3(%arg0: memref<i32> {bufferization.writable = false},
                   %arg1: memref<i32> {bufferization.writable = false})
      -> memref<i32> attributes {llvm.emit_c_interface} {
    // Load the first operand
    %0 = memref.load %arg0[] : memref<i32>

    // Load the second operand
    %1 = memref.load %arg1[] : memref<i32>

    // Perform addition
    %2 = arith.addi %0, %1 : i32

    // Allocate output memory
    %alloc = memref.alloc() : memref<i32>

    // Store result
    memref.store %2, %alloc[] : memref<i32>

    return %alloc : memref<i32>
  }
}

This MLIR representation explicitly shows:

  • Memory loads from input buffers
  • The arithmetic operation
  • Memory allocation for the output
  • Storing the result

From this MLIR representation, we can apply ZKX optimization passes and generate efficient CPU or GPU code, just as we do for Groth16 and other ZK proving schemes.

Future Roadmap

Once we add prime-field type support, developers will be able to:

  • Write ZK proving scheme in pure Python using JAX
  • Generate production-optimized prover code via ZKX

This will dramatically improve the developer experience for ZK application development.

ZKX Opcode Expansion

To support a wider range of ZK proving systems, we significantly expanded the opcodes available in ZKX.

Groth16 Opcodes (12 total)

Our initial implementation focused on Groth16, which required these core operations:

  • Arithmetic: add, multiply, subtract
  • Linear algebra: dot (dot product)
  • Conversions: bitcast, convert
  • FFT operations: fft (forward and inverse)
  • Cryptographic: msm (multi-scalar multiplication)
  • Structural: const, parameter, slice, tuple

These 12 opcodes are sufficient to implement the complete Groth16 proving algorithm as demonstrated in RabbitSNARK.

WHIR-p3 Extensions (41+ additional opcodes)

To support WHIR-p3 and other modern proving systems, we added 41 additional opcodes across several categories:

Comparison and Conditional Logic

  • compare - element-wise comparison operations
  • conditional - if-then-else semantics
  • select - conditional value selection
  • clamp - value clamping to range

Array Manipulation

  • concatenate - array concatenation
  • dynamic-slice - runtime-determined slicing
  • dynamic-update-slice - runtime updates
  • iota - generate sequential values
  • pad - array padding
  • reshape - dimension transformation
  • reverse - array reversal
  • slice - static slicing
  • transpose - dimension permutation

Reduction and Aggregation

  • reduce - generic reduction operation
  • maximum - maximum reduction
  • minimum - minimum reduction

Bitwise Operations

  • and - bitwise AND
  • or - bitwise OR
  • xor - bitwise XOR
  • not - bitwise NOT
  • shift-left - left shift
  • shift-right-arithmetic - arithmetic right shift
  • shift-right-logical - logical right shift
  • count-leading-zeros - count leading zero bits
  • population-count - count set bits

Advanced Operations

  • abs - absolute value
  • sign - sign extraction
  • remainder - modulo operation
  • sort - sorting operation
  • map - element-wise function application
  • while - while-loop semantics
  • call - function invocation

This expanded opcode set enables ZKX to "cover most single-machine provers" available today, making it a general-purpose ZK compilation target rather than being limited to specific proving schemes.

Implementation Progress

The opcode expansion was completed through six sequential pull requests:

  1. PR #97: Initial opcode expansion
  2. PR #98: Shape inference enhancements for new opcodes
  3. PR #99-#106: CPU code generation implementation (6-part series)

Each PR went through rigorous code review and testing to ensure correctness and performance.

Poseidon2 in JAX

As a demonstration of the JAX integration and a practical use case, we implemented Poseidon2 hash function in pure JAX.

Implementation Details

This is an integer-based early version that will be migrated to proper prime-field types once that support is added. The implementation targets Poseidon2 with BabyBear field and width 16.

Round Constants

Poseidon2 requires three sets of round constants:

Initial Round Constants (IRC)

  • 4 rounds × 16 constants
  • Applied during initial external rounds

Internal Round Constants (IxRC)

  • 13 scalar constants
  • Applied during internal rounds

Terminal Round Constants (TRC)

  • 4 rounds × 16 constants
  • Applied during terminal external rounds

Core Operations

1. S-box Transformation

The S-box applies the power map xx7x \mapsto x^7 with addition of round constants:

sbox(x,rc)=(x+rc)7\text{sbox}(x, rc) = (x + rc)^7

In code:

@jit
def sbox_with_rc(self, elem, rc):
    """Apply S-box with round constant"""
    temp = elem + rc
    return temp ** 7  # Power of 7 for Poseidon2

2. MDS Light Permutation

The MDS (Maximum Distance Separable) light permutation is a two-phase transformation:

Phase 1: Apply the M4M_4 matrix to each 4-element group

For a 16-element state, we split it into 4 groups of 4 elements each and apply the M4M_4 matrix to each group:

M4=[2111121111211112]M_4 = \begin{bmatrix} 2 & 1 & 1 & 1 \\ 1 & 2 & 1 & 1 \\ 1 & 1 & 2 & 1 \\ 1 & 1 & 1 & 2 \end{bmatrix}

Phase 2: Apply outer circulant transformation using column sums

@jit
def mds_light_permutation(self, state):
    """
    Applies MDS light permutation:
    1. M_4 matrix to each 4-element chunk
    2. Outer circulant transformation
    """
    # Process each 4-element chunk through M_4
    chunks = [state[i:i+4] for i in range(0, 16, 4)]
    transformed_chunks = [self.m4_multiply(chunk) for chunk in chunks]

    # Compute column sums
    sums = compute_column_sums(transformed_chunks)

    # Apply outer circulant: y_i = x_i' + sums[i % 4]
    result = [transformed_chunks[i] + sums[i % 4] for i in range(16)]

    return result

3. Internal Layer Matrix Multiplication

The internal layer applies a matrix of the form (1+D)(1 + D) where DD is diagonal:

InternalLayer(x)=(I+D)x=x+Dx\text{InternalLayer}(x) = (I + D) \cdot x = x + D \cdot x

The diagonal elements vary by position, with some being inverted:

@jit
def internal_layer_mat_mul(self, state, sum_val):
    """
    Applies (1 + diagonal_mat) multiplication
    Different multipliers for different positions
    """
    result = []
    for i, elem in enumerate(state):
        multiplier = self.get_diagonal_multiplier(i)
        result.append(elem * multiplier + sum_val)
    return result

Full Permutation

The complete Poseidon2 permutation combines all these operations:

@jit
def permute(self, state):
    """
    Full Poseidon2 permutation:
    - Initial external layer (4 rounds)
    - Internal layer (13 rounds)
    - Terminal external layer (4 rounds)
    """
    # Initial external rounds
    for round in range(4):
        state = self.external_round(state, self.initial_rc[round])

    # Internal rounds
    for round in range(13):
        state = self.internal_round(state, self.internal_rc[round])

    # Terminal external rounds
    for round in range(4):
        state = self.external_round(state, self.terminal_rc[round])

    return state

Next Steps

This integer-based implementation will be migrated to use native prime-field types once that support is added to the JAX integration (estimated within 2 weeks).

Once complete, this will demonstrate:

  • Writing ZK hash functions in pure Python
  • Automatic compilation to optimized MLIR
  • Cross-platform code generation for CPU and GPU
  • Performance competitive with hand-written implementations

Conclusion

This progress report demonstrates significant advances across multiple fronts:

  1. Packed Poseidon2: Achieved performance parity with Plonky3's highly-optimized Rust implementation through algebraic rewrites and architecture-specific optimizations

  2. JAX Integration: Successfully demonstrated end-to-end code generation from Python/JAX to optimized MLIR, paving the way for Python-first ZK development

  3. Opcode Expansion: Grew from 12 Groth16-specific opcodes to 53+ opcodes covering most single-machine proving systems

  4. Poseidon2 in JAX: Implemented a complete Poseidon2 hash function in pure JAX, demonstrating practical applications of the compilation pipeline

These developments bring us closer to our vision of making ZK development as accessible as deep learning, where developers can write high-level Python code and automatically get production-quality, cross-platform optimized implementations.

In the next progress report, we expect to share:

  • Prime-field type support in JAX
  • GPU code generation for Poseidon2
  • Additional proving system support (PLONK, STARKs)
  • Performance benchmarks on diverse hardware platforms

Stay tuned for more updates!