LeanVM Development Progress #2 - Packed Poseidon2 and JAX Code Generation
In our previous report, we showed that our MLIR-based Poseidon2 outperformed Plonky3 on regular PrimeField operations, but lagged behind on Packed PrimeField. This report covers improvements on four major areas:
- Closing the Packed PrimeField performance gap
- JAX integration for code generation
- ZKX opcode expansion for additional provers
- Pure-JAX Poseidon2 implementation
Packed Poseidon2 in MLIR
Benchmark Results
Following the optimizations described below, we achieved performance parity with Plonky3 on Packed PrimeField operations:
| Hardware | Width | Plonky3 | ZKIR | Ratio (Plonky3 / ZKIR) |
|---|---|---|---|---|
| AMD 9950X3D | 16 | 7.38 ms | 7.38 ms | 1.00 |
| Mac M4 Pro | 4 | 5.51 ms | 5.58 ms | 0.98 |
The table shows the total time for Width × 10,000 hash operations. The benchmark used BabyBear, but similar results are expected for KoalaBear.
Compared to our previous benchmark where we showed a 0.8 ratio on Packed PrimeField, we have now closed the performance gap and achieved competitive performance with Plonky3.
Optimizations Implemented
To achieve this performance improvement, we implemented several algebraic rewrites and architecture-specific optimizations in the ZKIR compiler pass:
For field exponents, we had a square-and-mul algorithm but it was still not fast enough. Unrolling the while loop when the exponent is constant gave us a nice boost. Now, we can just do field.powui %a, %arbitrary_constant and have unrolled square-and-mul.
This optimization is particularly important in Poseidon2's S-box operation, which computes .
2. Rewriting Power-of-Two Division
When doing modular multiplication over inverse power of two, we can do better than standard montgomery multiplication. We added a logic to detect such constant operand and rewrite this operation to a more efficient algorithm.
3. Removing Redundant Reduction
When computing modular , we don't actually need to reduce 's range from since the next operation is squaring.
4. Vector/Tensor Constant Folding
We enabled constant folding of tensor and vector types specifically for the mod_arith dialect, allowing compile-time evaluation of modular arithmetic operations on constants.
For x86 platforms with AVX-512 support (like the AMD 9950X3D), we rewrite standard arithmetics to custom dual-lane SIMD instructions for peak CPU utilization. This is particularly beneficial when there is a lot of chained operations since the lane splitting at each operation border cancels out thanks to MLIR Pattern Rewriting.
a_odd * b_odd
/ \
a * b ab_high, ab_low
\ /
a_even * b_even
For ARM platforms (Mac M4 Pro, AWS Graviton), we implemented a separate lowering logic that leverages NEON SIMD instructions. This optimization is currently under code review and pending merge.
JAX Integration
For the first time, we have successfully implemented a code generation pipeline from JAX to ZKX. This is a significant milestone as it demonstrates our compiler infrastructure can target high-level Python frameworks while generating optimized ZK prover code.
Current Support
The pipeline currently supports:
- Boolean types
- Integer types
We plan to add prime-field type support within the next two weeks, which will enable direct ZK circuit compilation from JAX code.
Code Example
Here's a simple example demonstrating the pipeline:
import jax
# Define a simple function
fast_f = jax.jit(lambda x: x + 1)
# Execute it
print(fast_f(1)) # Output: 2
This trivial example illustrates the compilation flow. Let's examine how this code is transformed through our pipeline.
Compilation Pipeline
Step 1: StableHLO Generation
JAX first compiles the Python code to StableHLO (Stable High-Level Operations), which is a portable intermediate representation:
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<i32>) -> (tensor<i32> {jax.result_info = "result"}) {
%c = stablehlo.constant dense<1> : tensor<i32>
%0 = stablehlo.add %arg0, %c : tensor<i32>
return %0 : tensor<i32>
}
}
Step 2: HLO Lowering
StableHLO is then lowered to HLO (High-Level Operations), which is closer to the computational representation:
ENTRY %main.4 (Arg_0.1: s32[]) -> s32[] {
%Arg_0.1 = s32[] parameter(0), metadata={op_name="x"}
%constant.2 = s32[] constant(1)
ROOT %add.3 = s32[] add(s32[] %Arg_0.1, s32[] %constant.2)
}
Step 3: MLIR Translation
Finally, HLO is converted to MLIR with explicit memory operations using the memref dialect:
module {
func.func @add.3(%arg0: memref<i32> {bufferization.writable = false},
%arg1: memref<i32> {bufferization.writable = false})
-> memref<i32> attributes {llvm.emit_c_interface} {
// Load the first operand
%0 = memref.load %arg0[] : memref<i32>
// Load the second operand
%1 = memref.load %arg1[] : memref<i32>
// Perform addition
%2 = arith.addi %0, %1 : i32
// Allocate output memory
%alloc = memref.alloc() : memref<i32>
// Store result
memref.store %2, %alloc[] : memref<i32>
return %alloc : memref<i32>
}
}
This MLIR representation explicitly shows:
- Memory loads from input buffers
- The arithmetic operation
- Memory allocation for the output
- Storing the result
From this MLIR representation, we can apply ZKX optimization passes and generate efficient CPU or GPU code, just as we do for Groth16 and other ZK proving schemes.
Future Roadmap
Once we add prime-field type support, developers will be able to:
- Write ZK proving scheme in pure Python using JAX
- Generate production-optimized prover code via ZKX
This will dramatically improve the developer experience for ZK application development.
ZKX Opcode Expansion
To support a wider range of ZK proving systems, we significantly expanded the opcodes available in ZKX.
Groth16 Opcodes (12 total)
Our initial implementation focused on Groth16, which required these core operations:
- Arithmetic:
add,multiply,subtract - Linear algebra:
dot(dot product) - Conversions:
bitcast,convert - FFT operations:
fft(forward and inverse) - Cryptographic:
msm(multi-scalar multiplication) - Structural:
const,parameter,slice,tuple
These 12 opcodes are sufficient to implement the complete Groth16 proving algorithm as demonstrated in RabbitSNARK.
WHIR-p3 Extensions (41+ additional opcodes)
To support WHIR-p3 and other modern proving systems, we added 41 additional opcodes across several categories:
Comparison and Conditional Logic
compare- element-wise comparison operationsconditional- if-then-else semanticsselect- conditional value selectionclamp- value clamping to range
Array Manipulation
concatenate- array concatenationdynamic-slice- runtime-determined slicingdynamic-update-slice- runtime updatesiota- generate sequential valuespad- array paddingreshape- dimension transformationreverse- array reversalslice- static slicingtranspose- dimension permutation
Reduction and Aggregation
reduce- generic reduction operationmaximum- maximum reductionminimum- minimum reduction
Bitwise Operations
and- bitwise ANDor- bitwise ORxor- bitwise XORnot- bitwise NOTshift-left- left shiftshift-right-arithmetic- arithmetic right shiftshift-right-logical- logical right shiftcount-leading-zeros- count leading zero bitspopulation-count- count set bits
Advanced Operations
abs- absolute valuesign- sign extractionremainder- modulo operationsort- sorting operationmap- element-wise function applicationwhile- while-loop semanticscall- function invocation
This expanded opcode set enables ZKX to "cover most single-machine provers" available today, making it a general-purpose ZK compilation target rather than being limited to specific proving schemes.
Implementation Progress
The opcode expansion was completed through six sequential pull requests:
- PR #97: Initial opcode expansion
- PR #98: Shape inference enhancements for new opcodes
- PR #99-#106: CPU code generation implementation (6-part series)
Each PR went through rigorous code review and testing to ensure correctness and performance.
Poseidon2 in JAX
As a demonstration of the JAX integration and a practical use case, we implemented Poseidon2 hash function in pure JAX.
Implementation Details
This is an integer-based early version that will be migrated to proper prime-field types once that support is added. The implementation targets Poseidon2 with BabyBear field and width 16.
Round Constants
Poseidon2 requires three sets of round constants:
Initial Round Constants (IRC)
- 4 rounds × 16 constants
- Applied during initial external rounds
Internal Round Constants (IxRC)
- 13 scalar constants
- Applied during internal rounds
Terminal Round Constants (TRC)
- 4 rounds × 16 constants
- Applied during terminal external rounds
Core Operations
1. S-box Transformation
The S-box applies the power map with addition of round constants:
In code:
@jit
def sbox_with_rc(self, elem, rc):
"""Apply S-box with round constant"""
temp = elem + rc
return temp ** 7 # Power of 7 for Poseidon2
2. MDS Light Permutation
The MDS (Maximum Distance Separable) light permutation is a two-phase transformation:
Phase 1: Apply the matrix to each 4-element group
For a 16-element state, we split it into 4 groups of 4 elements each and apply the matrix to each group:
Phase 2: Apply outer circulant transformation using column sums
@jit
def mds_light_permutation(self, state):
"""
Applies MDS light permutation:
1. M_4 matrix to each 4-element chunk
2. Outer circulant transformation
"""
# Process each 4-element chunk through M_4
chunks = [state[i:i+4] for i in range(0, 16, 4)]
transformed_chunks = [self.m4_multiply(chunk) for chunk in chunks]
# Compute column sums
sums = compute_column_sums(transformed_chunks)
# Apply outer circulant: y_i = x_i' + sums[i % 4]
result = [transformed_chunks[i] + sums[i % 4] for i in range(16)]
return result
3. Internal Layer Matrix Multiplication
The internal layer applies a matrix of the form where is diagonal:
The diagonal elements vary by position, with some being inverted:
@jit
def internal_layer_mat_mul(self, state, sum_val):
"""
Applies (1 + diagonal_mat) multiplication
Different multipliers for different positions
"""
result = []
for i, elem in enumerate(state):
multiplier = self.get_diagonal_multiplier(i)
result.append(elem * multiplier + sum_val)
return result
Full Permutation
The complete Poseidon2 permutation combines all these operations:
@jit
def permute(self, state):
"""
Full Poseidon2 permutation:
- Initial external layer (4 rounds)
- Internal layer (13 rounds)
- Terminal external layer (4 rounds)
"""
# Initial external rounds
for round in range(4):
state = self.external_round(state, self.initial_rc[round])
# Internal rounds
for round in range(13):
state = self.internal_round(state, self.internal_rc[round])
# Terminal external rounds
for round in range(4):
state = self.external_round(state, self.terminal_rc[round])
return state
Next Steps
This integer-based implementation will be migrated to use native prime-field types once that support is added to the JAX integration (estimated within 2 weeks).
Once complete, this will demonstrate:
- Writing ZK hash functions in pure Python
- Automatic compilation to optimized MLIR
- Cross-platform code generation for CPU and GPU
- Performance competitive with hand-written implementations
Conclusion
This progress report demonstrates significant advances across multiple fronts:
-
Packed Poseidon2: Achieved performance parity with Plonky3's highly-optimized Rust implementation through algebraic rewrites and architecture-specific optimizations
-
JAX Integration: Successfully demonstrated end-to-end code generation from Python/JAX to optimized MLIR, paving the way for Python-first ZK development
-
Opcode Expansion: Grew from 12 Groth16-specific opcodes to 53+ opcodes covering most single-machine proving systems
-
Poseidon2 in JAX: Implemented a complete Poseidon2 hash function in pure JAX, demonstrating practical applications of the compilation pipeline
These developments bring us closer to our vision of making ZK development as accessible as deep learning, where developers can write high-level Python code and automatically get production-quality, cross-platform optimized implementations.
In the next progress report, we expect to share:
- Prime-field type support in JAX
- GPU code generation for Poseidon2
- Additional proving system support (PLONK, STARKs)
- Performance benchmarks on diverse hardware platforms
Stay tuned for more updates!