LeanVM Development Progress #1 - Poseidon2 in MLIR and WHIR-p3 Python Implementation

17 min read
zkpposeidon2mlirwhir-p3pythonjax

Poseidon2 in MLIR

In hash-based STARK systems, the hash function dominates performance — and Poseidon2 is the most commonly used hash. Optimizing Poseidon2 directly translates to faster STARKs, while also stress-testing and maturing our MLIR infrastructure.

For this, we implemented Poseidon2 in MLIR, which outperformed Plonky3 (with SIMD disabled) on both AMD and Apple Silicon:

Plonky3ZKIRRatio (Plonky3 / ZKIR)
AMD 9950X3D6.12 ms5.70 ms1.07
Mac M4 Pro7.76 ms5.80 ms1.34

(The table shows the total time for 10,000 hash operations. The benchmark used BabyBear, but similar results are expected for KoalaBear.)

In practice, Packed PrimeField performance matters more for production workloads. Unfortunately, ZKIR is currently slower here:

Plonky3ZKIRRatio (Plonky3 / ZKIR)
AMD 9950X3D7.44 ms9.30 ms0.8

This slowdown is due to engineering optimizations in Plonky3 that we haven't yet implemented. For example, consider computing (xy)2(x - y)^2 for 0x,y<P0 \le x, y < P:

The standard approach performs:

  1. Compute v=P+xyv = P + x - y
  2. Apply reduce1(v)\mathsf{reduce_1}(v)
  3. Compute v2v^2
  4. Apply reduce2(v2)\mathsf{reduce_2}(v^2)

Where the reduction function is defined as:

reduce1(a)b\mathsf{reduce_1}(a) \to b

where 0a<2P0 \le a < 2P and 0b<P0 \le b < P

reduce2(a)b\mathsf{reduce_2}(a) \to b

where 0a<P20 \le a < P^2 and 0b<P0 \le b < P

Plonky3 skips step (2) and instead performs:

  1. Compute v=xyv = x - y
  2. Compute v2v^2
  3. Apply reduce2(v2)\mathsf{reduce_2}(v^2)

Since vv after step (1) lies within P<v<P-P < v < P, omitting the first reduction does not affect correctness — the result v2v^2 still satisfies 0v2<P20 \le v^2 < P^2.

We're currently implementing similar optimizations in ZKIR and will share the results soon.

WHIR-p3 in Python

XLA is a runtime compiler used in deep learning. It defines HLO and allows compilation from Python → Jaxpr → HLO → MLIR → LLVM IR → binary via the JAX library.

Our team has been developing a ZK-specific runtime compiler called ZKX by forking XLA. Using ZKX, we already achieved SOTA Groth16 performance on CPU.

The next goal is to express WHIR-p3 in Python. To that end, we are forking and adapting JAX for ZKX compatibility. In parallel, we're porting WHIR-p3 (including Poseidon2) to Python using JAX. For example, the Rust implementation of poly/dense.rs can be expressed in JAX as follows.

MLIR code

// Poseidon2 utility functions for BabyBear field
// Based on Plonky3 implementation: https://github.com/Plonky3/Plonky3

!pf = !field.pf<2013265921 : i32, true>
!pf_std = !field.pf<2013265921 : i32>
!state = memref<16x!pf>
!state_std = memref<16x!pf_std>

func.func @add_rc_and_sbox(%var: !pf, %c: !pf) -> !pf {
  %c7 = arith.constant 7 : i32
  %sum = field.add %var, %c : !pf
  %sum_sq = field.square %sum : !pf
  %sum_sq_sq = field.square %sum_sq : !pf
  %sum_cu = field.mul %sum, %sum_sq : !pf
  %sum_exp7 = field.mul %sum_sq_sq, %sum_cu : !pf
  return %sum_exp7 : !pf
}

// In-place version of apply_mat4 using memref
// Optimally, we just want to do matmul which then lowers to the following
// sequence but at this moment, it seems hard to achieve. Therefore, we just use field addition instead of matrix multiplication.
func.func @apply_mat4(%state: memref<4x!pf, strided<[1], offset: ?>>) {
  // Load the 4x4 MDS matrix (no changes here)
  %matrix = arith.constant dense<[
    [2, 3, 1, 1],
    [1, 2, 3, 1],
    [1, 1, 2, 3],
    [3, 1, 1, 2]
  ]> : tensor<4x4xi32>

  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %c2 = arith.constant 2 : index
  %c3 = arith.constant 3 : index

  // Allocate a temporary output buffer
  %output = memref.alloca() : memref<4x!pf>

  // Compute the sum of all 4 elements
  %x0 = memref.load %state[%c0] : memref<4x!pf, strided<[1], offset: ?>>
  %x1 = memref.load %state[%c1] : memref<4x!pf, strided<[1], offset: ?>>
  %x2 = memref.load %state[%c2] : memref<4x!pf, strided<[1], offset: ?>>
  %x3 = memref.load %state[%c3] : memref<4x!pf, strided<[1], offset: ?>>

  %x01 = field.add %x0, %x1 : !pf
  %x23 = field.add %x2, %x3 : !pf
  %x0123 = field.add %x01, %x23 : !pf
  %x01123 = field.add %x0123, %x1 : !pf
  %x01233 = field.add %x0123, %x3 : !pf

  %x00 = field.double %x0 : !pf
  %x22 = field.double %x2 : !pf


  // x[0] = x01123 + x01
  %x0_new = field.add %x01123, %x01 : !pf
  // x[1] = x01123 + 2*x[2]
  %x1_new = field.add %x01123, %x22 : !pf
  // x[2] = x01233 + x23
  %x2_new = field.add %x01233, %x23 : !pf
  // x[3] = x01233 + 2*x[0]
  %x3_new = field.add %x01233, %x00 : !pf

  // Store the sum in all output positions
  memref.store %x0_new, %state[%c0] : memref<4x!pf, strided<[1], offset: ?>>
  memref.store %x1_new, %state[%c1] : memref<4x!pf, strided<[1], offset: ?>>
  memref.store %x2_new, %state[%c2] : memref<4x!pf, strided<[1], offset: ?>>
  memref.store %x3_new, %state[%c3] : memref<4x!pf, strided<[1], offset: ?>>
  return
}

func.func @mds_light_permutation(%state: !state) {
  // First, apply M_4 to each consecutive four elements of the state
  // This replaces each x_i with x_i'
  affine.for %chunk_idx = 0 to 4 {
    // Calculate offset for this chunk
    %x0 = affine.load %state[%chunk_idx * 4] : !state
    %x1 = affine.load %state[%chunk_idx * 4 + 1] : !state
    %x01 = field.add %x0, %x1 : !pf
    %x2 = affine.load %state[%chunk_idx * 4 + 2] : !state
    %x3 = affine.load %state[%chunk_idx * 4 + 3] : !state
    %x23 = field.add %x2, %x3 : !pf
    %x0123 = field.add %x01, %x23 : !pf
    %x01123 = field.add %x0123, %x1 : !pf
    %x01233 = field.add %x0123, %x3 : !pf

    %x00 = field.double %x0 : !pf
    %x22 = field.double %x2 : !pf

    // x[0] = x01123 + x01
    %x0_new = field.add %x01123, %x01 : !pf
    // x[1] = x01123 + 2*x[2]
    %x1_new = field.add %x01123, %x22 : !pf
    // x[2] = x01233 + x23
    %x2_new = field.add %x01233, %x23 : !pf
    // x[3] = x01233 + 2*x[0]
    %x3_new = field.add %x01233, %x00 : !pf

    // Store the sum in all output positions
    affine.store %x0_new, %state[%chunk_idx * 4] : !state
    affine.store %x1_new, %state[%chunk_idx * 4 + 1] : !state
    affine.store %x2_new, %state[%chunk_idx * 4 + 2] : !state
    affine.store %x3_new, %state[%chunk_idx * 4 + 3] : !state
  }

  // Now apply the outer circulant matrix
  // Precompute the four sums of every four elements
  // Compute sums: sums[k] = sum of state[j + k] for j = 0, 4, 8, 12
  %sums = memref.alloca() : memref<4x!pf>
  affine.for %k = 0 to 4 {
    %val0 = affine.load %state[%k] : !state
    %val1 = affine.load %state[%k + 4] : !state
    %val2 = affine.load %state[%k + 8] : !state
    %val3 = affine.load %state[%k + 12] : !state
    %sum01 = field.add %val0, %val1 : !pf
    %sum23 = field.add %val2, %val3 : !pf
    %new_sum = field.add %sum01, %sum23 : !pf
    affine.store %new_sum, %sums[%k] : memref<4x!pf>
  }

  // Apply the formula: y_i = x_i' + sums[i % 4]
  affine.for %i = 0 to 4 {
    %val0 = affine.load %state[%i] : !state
    %val1 = affine.load %state[%i + 4] : !state
    %val2 = affine.load %state[%i + 8] : !state
    %val3 = affine.load %state[%i + 12] : !state
    %sum = affine.load %sums[%i] : memref<4x!pf>
    %sum0 = field.add %val0, %sum : !pf
    %sum1 = field.add %val1, %sum : !pf
    %sum2 = field.add %val2, %sum : !pf
    %sum3 = field.add %val3, %sum : !pf
    affine.store %sum0, %state[%i] : !state
    affine.store %sum1, %state[%i + 4] : !state
    affine.store %sum2, %state[%i + 8] : !state
    affine.store %sum3, %state[%i + 12] : !state
  }
  return
}

// Internal layer matrix multiplication
func.func @internal_layer_mat_mul(%state: !state, %sum: !pf) {
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %c2 = arith.constant 2 : index
  %c3 = arith.constant 3 : index
  %c4 = arith.constant 4 : index
  %c5 = arith.constant 5 : index
  %c6 = arith.constant 6 : index
  %c7 = arith.constant 7 : index
  %c8 = arith.constant 8 : index
  %c9 = arith.constant 9 : index
  %c10 = arith.constant 10 : index
  %c11 = arith.constant 11 : index
  %c12 = arith.constant 12 : index
  %c13 = arith.constant 13 : index
  %c14 = arith.constant 14 : index
  %c15 = arith.constant 15 : index

  // Precompute powers of 2 inverses using powui
  // [-2, 1, 2, 1/2, 3, 4, -1/2, -3, -4, 1/2^8, 1/4, 1/8, 1/2^27, -1/2^8, -1/16, -1/2^27]
  %inv_two = field.constant 134217727 : !pf
  %inv_four = field.constant 1073741824 : !pf
  %inv_eight = field.constant 536870912 : !pf
  %inv_sixteen = field.constant 268435456 : !pf
  %inv_256 = field.constant 16777216 : !pf
  %inv_2_27 = field.constant 32 : !pf

  // state[1] += sum
  %s1 = memref.load %state[%c1] : !state
  %new_s1 = field.add %s1, %sum : !pf
  memref.store %new_s1, %state[%c1] : !state

  // state[2] = state[2].double() + sum
  %s2 = memref.load %state[%c2] : !state
  %s2_double = field.double %s2 : !pf
  %new_s2 = field.add %s2_double, %sum : !pf
  memref.store %new_s2, %state[%c2] : !state

  // state[3] = state[3].halve() + sum
  %s3 = memref.load %state[%c3] : !state
  %s3_halve = field.mul %s3, %inv_two : !pf
  %new_s3 = field.add %s3_halve, %sum : !pf
  memref.store %new_s3, %state[%c3] : !state

  // state[4] = sum + state[4].double() + state[4]
  %s4 = memref.load %state[%c4] : !state
  %s4_double = field.double %s4 : !pf
  %s4_sum = field.add %s4_double, %s4 : !pf
  %new_s4 = field.add %sum, %s4_sum : !pf
  memref.store %new_s4, %state[%c4] : !state

  // state[5] = sum + state[5].double().double()
  %s5 = memref.load %state[%c5] : !state
  %s5_double = field.double %s5 : !pf
  %s5_double_double = field.double %s5_double : !pf
  %new_s5 = field.add %sum, %s5_double_double : !pf
  memref.store %new_s5, %state[%c5] : !state

  // state[6] = sum - state[6].halve()
  %s6 = memref.load %state[%c6] : !state
  %s6_halve = field.mul %s6, %inv_two : !pf
  %new_s6 = field.sub %sum, %s6_halve : !pf
  memref.store %new_s6, %state[%c6] : !state

  // state[7] = sum - (state[7].double() + state[7])
  %s7 = memref.load %state[%c7] : !state
  %s7_double = field.double %s7 : !pf
  %s7_sum = field.add %s7_double, %s7 : !pf
  %new_s7 = field.sub %sum, %s7_sum : !pf
  memref.store %new_s7, %state[%c7] : !state

  // state[8] = sum - state[8].double().double()
  %s8 = memref.load %state[%c8] : !state
  %s8_double = field.double %s8 : !pf
  %s8_double_double = field.double %s8_double : !pf
  %new_s8 = field.sub %sum, %s8_double_double : !pf
  memref.store %new_s8, %state[%c8] : !state

  // state[9] = state[9] * inv_256 + sum
  %s9 = memref.load %state[%c9] : !state
  %s9_div_256 = field.mul %s9, %inv_256 : !pf
  %new_s9 = field.add %s9_div_256, %sum : !pf
  memref.store %new_s9, %state[%c9] : !state

  // state[10] = state[10] * inv_four + sum
  %s10 = memref.load %state[%c10] : !state
  %s10_div_4 = field.mul %s10, %inv_four : !pf
  %new_s10 = field.add %s10_div_4, %sum : !pf
  memref.store %new_s10, %state[%c10] : !state

  // state[11] = state[11] * inv_eight + sum
  %s11 = memref.load %state[%c11] : !state
  %s11_div_8 = field.mul %s11, %inv_eight : !pf
  %new_s11 = field.add %s11_div_8, %sum : !pf
  memref.store %new_s11, %state[%c11] : !state

  // state[12] = state[12] * inv_2_27 + sum
  %s12 = memref.load %state[%c12] : !state
  %s12_div_27 = field.mul %s12, %inv_2_27 : !pf
  %new_s12 = field.add %s12_div_27, %sum : !pf
  memref.store %new_s12, %state[%c12] : !state

  // state[13] = sum - state[13] * inv_256
  %s13 = memref.load %state[%c13] : !state
  %s13_div_256 = field.mul %s13, %inv_256 : !pf
  %new_s13 = field.sub %sum, %s13_div_256 : !pf
  memref.store %new_s13, %state[%c13] : !state

  // state[14] = sum - state[14] * inv_sixteen
  %s14 = memref.load %state[%c14] : !state
  %s14_div_16 = field.mul %s14, %inv_sixteen : !pf
  %new_s14 = field.sub %sum, %s14_div_16 : !pf
  memref.store %new_s14, %state[%c14] : !state

  // state[15] = sum - state[15] * inv_2_27
  %s15 = memref.load %state[%c15] : !state
  %s15_div_27 = field.mul %s15, %inv_2_27 : !pf
  %new_s15 = field.sub %sum, %s15_div_27 : !pf
  memref.store %new_s15, %state[%c15] : !state

  return
}

// Internal layer: permutation (add RC to first element, S-box first, internal diffusion)
func.func @permute_state(%state: !state) {
  // Convert to memref for in-place operations
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %c2 = arith.constant 2 : index
  %c3 = arith.constant 3 : index
  %c4 = arith.constant 4 : index
  %c5 = arith.constant 5 : index
  %c6 = arith.constant 6 : index
  %c7 = arith.constant 7 : index
  %c8 = arith.constant 8 : index
  %c9 = arith.constant 9 : index
  %c10 = arith.constant 10 : index
  %c11 = arith.constant 11 : index
  %c12 = arith.constant 12 : index
  %c13 = arith.constant 13 : index
  %c14 = arith.constant 14 : index
  %c15 = arith.constant 15 : index

  // BABYBEAR_RC16_INTERNAL (13 scalar constants)
  %rc_internal = arith.constant dense<[250494022, 528496384, 1472966118, 977089650, 1885890237, 1094557811, 147492661, 664163003, 398852570, 336233633, 1628648315, 888594966, 586791090]> : tensor<13xi32>
  %rc_internal_mont = field.bitcast %rc_internal : tensor<13xi32> -> tensor<13x!pf>

  // For each internal constant: add RC and S-box to first element, then apply matrix multiplication
  affine.for %round = 0 to 13 {
    // Get current round constant via tensor.extract
    %rc = tensor.extract %rc_internal_mont[%round] : tensor<13x!pf>

    // Add RC and apply S-box to first element
    %s0 = memref.load %state[%c0] : !state
    %elem0 = func.call @add_rc_and_sbox(%s0, %rc) : (!pf, !pf) -> !pf

    // Compute sum of all elements using affine.for
    // NOTE: this is extremely slow, so we manually add them.
    // %zero = field.constant 0 : !pf
    // %sum = affine.for %i = 0 to 16 iter_args(%acc = %zero) -> (!pf) {
    //   %elem = tensor.extract %t[%i] : tensor<16x!pf>
    //   %new_acc = field.add %acc, %elem : !pf
    //   affine.yield %new_acc : !pf
    // }
    %elem1 = memref.load %state[%c1] : memref<16x!pf>
    %elem2 = memref.load %state[%c2] : memref<16x!pf>
    %elem3 = memref.load %state[%c3] : memref<16x!pf>
    %elem4 = memref.load %state[%c4] : memref<16x!pf>
    %elem5 = memref.load %state[%c5] : memref<16x!pf>
    %elem6 = memref.load %state[%c6] : memref<16x!pf>
    %elem7 = memref.load %state[%c7] : memref<16x!pf>
    %elem8 = memref.load %state[%c8] : memref<16x!pf>
    %elem9 = memref.load %state[%c9] : memref<16x!pf>
    %elem10 = memref.load %state[%c10] : memref<16x!pf>
    %elem11 = memref.load %state[%c11] : memref<16x!pf>
    %elem12 = memref.load %state[%c12] : memref<16x!pf>
    %elem13 = memref.load %state[%c13] : memref<16x!pf>
    %elem14 = memref.load %state[%c14] : memref<16x!pf>
    %elem15 = memref.load %state[%c15] : memref<16x!pf>

    // --- Step 2: Sum the elements using a reduction tree ---
    // This structure allows for maximum parallel execution by the CPU.

    // Level 1 (8 parallel additions)
    %sum2_3 = field.add %elem2, %elem3 : !pf
    %sum4_5 = field.add %elem4, %elem5 : !pf
    %sum6_7 = field.add %elem6, %elem7 : !pf
    %sum8_9 = field.add %elem8, %elem9 : !pf
    %sum10_11 = field.add %elem10, %elem11 : !pf
    %sum12_13 = field.add %elem12, %elem13 : !pf
    %sum14_15 = field.add %elem14, %elem15 : !pf

    // Level 2 (4 parallel additions)
    %sum1_3 = field.add %elem1, %sum2_3 : !pf
    %sum4_7 = field.add %sum4_5, %sum6_7 : !pf
    %sum8_11 = field.add %sum8_9, %sum10_11 : !pf
    %sum12_15 = field.add %sum12_13, %sum14_15 : !pf

    // Level 3 (2 parallel additions)
    %sum1_7 = field.add %sum1_3, %sum4_7 : !pf
    %sum8_15 = field.add %sum8_11, %sum12_15 : !pf

    // Level 4 (Partial sum)
    %partial_sum = field.add %sum1_7, %sum8_15 : !pf

    %total_sum = field.add %partial_sum, %elem0 : !pf
    %new_s0 = field.sub %partial_sum, %elem0 : !pf
    memref.store %new_s0, %state[%c0] : !state

    // Apply internal layer matrix multiplication
    func.call @internal_layer_mat_mul(%state, %total_sum) : (!state, !pf) -> ()
  }
  return
}

// External layer: terminal permutation (4 rounds: add RC, S-box, MDS)
func.func @permute_state_terminal(%state: !state) {
  %c0 = arith.constant 0 : index
  %c4 = arith.constant 4 : index
  %c1 = arith.constant 1 : index

  // BABYBEAR_RC16_EXTERNAL_FINAL (4 rounds x 16 constants)
  %rc_external_const = arith.constant dense<[
    [999830298, 304461056, 552699684, 450698925, 667466464, 1736509752, 1327760865, 1153241151, 816675655, 1076172858, 1914832527, 1668723429, 1365579850, 975704528, 1031625628, 1393317533],
    [1554700828, 1023828605, 1610378860, 347744760, 1909572073, 739227895, 428565985, 633143046, 121797685, 94048546, 1369350241, 1250010422, 114268841, 515033604, 49052844, 1962329907],
    [1380892638, 1860017417, 64711457, 9758460, 1681838395, 710850601, 1020228997, 1414164790, 1531515535, 36158805, 713604525, 89935127, 1870801994, 395985906, 1122769045, 1760811055],
    [819787042, 134654834, 1755145179, 18433016, 1701878989, 1782339297, 1483861396, 962480061, 1857590724, 222440409, 63223417, 515206622, 1348364213, 973414686, 1591066884, 705852913]
  ]> : tensor<4x16xi32>

  %rc_external_final = field.bitcast %rc_external_const : tensor<4x16xi32> -> tensor<4x16x!pf>
  %state_tensor = bufferization.to_tensor %state restrict : memref<16x!pf> to tensor<16x!pf>

  // Loop through 4 rounds of external terminal permutation
  affine.for %round = 0 to 4 {
    affine.for %i = 0 to 16 {
      %s = tensor.extract %state_tensor[%i] : tensor<16x!pf>
      %c = tensor.extract %rc_external_final[%round, %i] : tensor<4x16x!pf>
      %sbox = func.call @add_rc_and_sbox(%s, %c) : (!pf, !pf) -> !pf
      affine.store %sbox, %state[%i] : !state
    }

    // Apply MDS light permutation (in-place)
    func.call @mds_light_permutation(%state) : (!state) -> ()
  }

  return
}

// External layer: initial permutation (MDS light + terminal permutation)
func.func @permute_state_initial(%state: !state) {
  // First apply MDS light permutation
  func.call @mds_light_permutation(%state) : (!state) -> ()

  // Round constants for 16-width Poseidon2 on BabyBear
  // BABYBEAR_RC16_EXTERNAL_INITIAL (4 rounds x 16 constants)
  %rc_external_const = arith.constant dense<[
    [1582131512, 1899519471, 1641921850, 462688640, 1293997949, 1380417575, 1932416963, 283521298, 1016708647, 35751290, 1270782647, 851730739, 795004022, 929571430, 523703523, 1593957757],
    [895976710, 1742343460, 917700746, 1516725708, 1170237629, 785693164, 613651155, 352999196, 678775274, 1005433272, 1704854670, 1174551920, 508930349, 530338447, 1327158816, 1417652352],
    [1153538870, 583201050, 397833841, 1440603828, 454600685, 174490638, 171758601, 1998476616, 1403697810, 1807736944, 450348306, 1458895865, 787037868, 1063762964, 1987002214, 481645916],
    [1231767638, 1323639433, 238360103, 2012412459, 1024945356, 1108359895, 1284135849, 606928406, 1021455954, 719347978, 659671051, 769588663, 805534062, 592213995, 1752728055, 663410947]
  ]> : tensor<4x16xi32>

  %rc_external_final = field.bitcast %rc_external_const : tensor<4x16xi32> -> tensor<4x16x!pf>
  %state_tensor = bufferization.to_tensor %state restrict : memref<16x!pf> to tensor<16x!pf>

  // Then apply terminal permutation with initial external constants
  // Loop through 4 rounds of external terminal permutation
  affine.for %round = 0 to 4 {
    affine.for %i = 0 to 16 {
      %s = tensor.extract %state_tensor[%i] : tensor<16x!pf>
      %c = tensor.extract %rc_external_final[%round, %i] : tensor<4x16x!pf>
      %sbox = func.call @add_rc_and_sbox(%s, %c) : (!pf, !pf) -> !pf
      affine.store %sbox, %state[%i] : !state
    }

    // Apply MDS light permutation (in-place)
    func.call @mds_light_permutation(%state) : (!state) -> ()
  }
  return
}

// Complete Poseidon2 permutation
func.func @poseidon2_permute(%state: !state) {
  func.call @permute_state_initial(%state) : (!state) -> ()
  func.call @permute_state(%state) : (!state) -> ()
  func.call @permute_state_terminal(%state) : (!state) -> ()
  return
}

func.func @permute_10000(%state : !state) attributes { llvm.emit_c_interface } {
  affine.for %i = 0 to 10000 {
    func.call @poseidon2_permute(%state) : (!state) -> ()
  }
  return
}

Python code

This is a Python implementation of poly/dense.rs from WHIR-p3:

import jax
import jax.random as rnd
import jax.numpy as jnp
import jax.lax as lax
import numpy.random as nprnd


### is_zero
@jax.jit
def is_zero(poly):
  # A zero polynomial is all zeros, or an empty array
  poly = jnp.array(poly)
  return (poly.size == 0) | jnp.all(poly == 0)


### evaluate - use as is
@jax.jit
def evaluate(poly, x):
  poly = jnp.array(poly)
  return jnp.polyval(poly, x)


### random
@jax.jit
def random(key, poly):
  # Generate random real numbers from normal, then round to nearest integer for "discrete real numbers"
  # (Could also use randint but that produces ints, not real/float dtype)
  poly = jnp.array(poly)
  rand = rnd.uniform(key, poly.shape, minval=0., maxval=999)
  return jnp.round(rand)


### lagrange_interpolation
@jax.jit
def has_dup_x_jit(xs):
  xs_sorted = jnp.sort(xs)
  return jnp.any(xs_sorted[1:] == xs_sorted[:-1])


@jax.jit
def lagrange_interpolation(values):
  size = len(values)
  if size == 0:
    return jnp.zeros(0)

  # Unzip the list of (x, y) pairs to separate arrays
  xs, ys = zip(*values)
  xs = jnp.array(xs)
  ys = jnp.array(ys)

  # Check for duplicate x-coordinates using JAX ops and return None if found
  def dup_case(_):
    return jnp.zeros(size) # or whatever you want to represent None

  def no_dup_case(_):
    # Normal interpolation computation here
    # (return polynomial coefficients or whatever you compute)
    monomial_init = jnp.zeros(size)

    def body(i, acc):
      result_poly, basis_poly = acc
      current_y = jnp.polyval(result_poly, xs[i])
      delta = ys[i] - current_y
      basis_eval = jnp.polyval(basis_poly, xs[i])
      c_i = delta / basis_eval
      # element-wise multiplication
      term = basis_poly * c_i
      result_poly = result_poly + term
      monomial = monomial_init.at[size - 1].set(-xs[i])
      monomial_all = monomial.at[size - 2].set(1)
      # After i steps, B(x) = (x - x_0)(x - x_1)...(x - x_{i-1})
      basis_poly = jnp.polymul(basis_poly, monomial_all)
      basis_poly = basis_poly[-size:]
      return (result_poly, basis_poly)

    # The result polynomial P(x) starts at zero and is updated iteratively.
    zero_poly = jnp.zeros(size)
    # The basis polynomial B(x) starts at 1.
    basis_poly = zero_poly.at[size - 1].set(1)
    result_poly, basis_poly = jax.lax.fori_loop(0, size, body, (zero_poly, basis_poly))
    return result_poly

  return lax.cond(has_dup_x_jit(xs), dup_case, no_dup_case, operand=None)


### add - use as is
@jax.jit
def add(a, b):
  jax_array_a = jnp.array(a)
  jax_array_b = jnp.array(b)
  return jnp.polyadd(jax_array_a, jax_array_b)


### mul - use as is
@jax.jit
def mul(a, b):
  jax_array_a = jnp.array(a)
  jax_array_b = jnp.array(b)
  return jnp.polymul(jax_array_a, jax_array_b)