LeanVM Development Progress #1 - Poseidon2 in MLIR and WHIR-p3 Python Implementation
Poseidon2 in MLIR
In hash-based STARK systems, the hash function dominates performance — and Poseidon2 is the most commonly used hash. Optimizing Poseidon2 directly translates to faster STARKs, while also stress-testing and maturing our MLIR infrastructure.
For this, we implemented Poseidon2 in MLIR, which outperformed Plonky3 (with SIMD disabled) on both AMD and Apple Silicon:
| Plonky3 | ZKIR | Ratio (Plonky3 / ZKIR) | |
|---|---|---|---|
| AMD 9950X3D | 6.12 ms | 5.70 ms | 1.07 |
| Mac M4 Pro | 7.76 ms | 5.80 ms | 1.34 |
(The table shows the total time for 10,000 hash operations. The benchmark used BabyBear, but similar results are expected for KoalaBear.)
In practice, Packed PrimeField performance matters more for production workloads. Unfortunately, ZKIR is currently slower here:
| Plonky3 | ZKIR | Ratio (Plonky3 / ZKIR) | |
|---|---|---|---|
| AMD 9950X3D | 7.44 ms | 9.30 ms | 0.8 |
This slowdown is due to engineering optimizations in Plonky3 that we haven't yet implemented. For example, consider computing for :
The standard approach performs:
- Compute
- Apply
- Compute
- Apply
Where the reduction function is defined as:
where and
where and
Plonky3 skips step (2) and instead performs:
- Compute
- Compute
- Apply
Since after step (1) lies within , omitting the first reduction does not affect correctness — the result still satisfies .
We're currently implementing similar optimizations in ZKIR and will share the results soon.
WHIR-p3 in Python
XLA is a runtime compiler used in deep learning. It defines HLO and allows compilation from Python → Jaxpr → HLO → MLIR → LLVM IR → binary via the JAX library.
Our team has been developing a ZK-specific runtime compiler called ZKX by forking XLA. Using ZKX, we already achieved SOTA Groth16 performance on CPU.
The next goal is to express WHIR-p3 in Python. To that end, we are forking and adapting JAX for ZKX compatibility. In parallel, we're porting WHIR-p3 (including Poseidon2) to Python using JAX. For example, the Rust implementation of poly/dense.rs can be expressed in JAX as follows.
MLIR code
// Poseidon2 utility functions for BabyBear field
// Based on Plonky3 implementation: https://github.com/Plonky3/Plonky3
!pf = !field.pf<2013265921 : i32, true>
!pf_std = !field.pf<2013265921 : i32>
!state = memref<16x!pf>
!state_std = memref<16x!pf_std>
func.func @add_rc_and_sbox(%var: !pf, %c: !pf) -> !pf {
%c7 = arith.constant 7 : i32
%sum = field.add %var, %c : !pf
%sum_sq = field.square %sum : !pf
%sum_sq_sq = field.square %sum_sq : !pf
%sum_cu = field.mul %sum, %sum_sq : !pf
%sum_exp7 = field.mul %sum_sq_sq, %sum_cu : !pf
return %sum_exp7 : !pf
}
// In-place version of apply_mat4 using memref
// Optimally, we just want to do matmul which then lowers to the following
// sequence but at this moment, it seems hard to achieve. Therefore, we just use field addition instead of matrix multiplication.
func.func @apply_mat4(%state: memref<4x!pf, strided<[1], offset: ?>>) {
// Load the 4x4 MDS matrix (no changes here)
%matrix = arith.constant dense<[
[2, 3, 1, 1],
[1, 2, 3, 1],
[1, 1, 2, 3],
[3, 1, 1, 2]
]> : tensor<4x4xi32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
// Allocate a temporary output buffer
%output = memref.alloca() : memref<4x!pf>
// Compute the sum of all 4 elements
%x0 = memref.load %state[%c0] : memref<4x!pf, strided<[1], offset: ?>>
%x1 = memref.load %state[%c1] : memref<4x!pf, strided<[1], offset: ?>>
%x2 = memref.load %state[%c2] : memref<4x!pf, strided<[1], offset: ?>>
%x3 = memref.load %state[%c3] : memref<4x!pf, strided<[1], offset: ?>>
%x01 = field.add %x0, %x1 : !pf
%x23 = field.add %x2, %x3 : !pf
%x0123 = field.add %x01, %x23 : !pf
%x01123 = field.add %x0123, %x1 : !pf
%x01233 = field.add %x0123, %x3 : !pf
%x00 = field.double %x0 : !pf
%x22 = field.double %x2 : !pf
// x[0] = x01123 + x01
%x0_new = field.add %x01123, %x01 : !pf
// x[1] = x01123 + 2*x[2]
%x1_new = field.add %x01123, %x22 : !pf
// x[2] = x01233 + x23
%x2_new = field.add %x01233, %x23 : !pf
// x[3] = x01233 + 2*x[0]
%x3_new = field.add %x01233, %x00 : !pf
// Store the sum in all output positions
memref.store %x0_new, %state[%c0] : memref<4x!pf, strided<[1], offset: ?>>
memref.store %x1_new, %state[%c1] : memref<4x!pf, strided<[1], offset: ?>>
memref.store %x2_new, %state[%c2] : memref<4x!pf, strided<[1], offset: ?>>
memref.store %x3_new, %state[%c3] : memref<4x!pf, strided<[1], offset: ?>>
return
}
func.func @mds_light_permutation(%state: !state) {
// First, apply M_4 to each consecutive four elements of the state
// This replaces each x_i with x_i'
affine.for %chunk_idx = 0 to 4 {
// Calculate offset for this chunk
%x0 = affine.load %state[%chunk_idx * 4] : !state
%x1 = affine.load %state[%chunk_idx * 4 + 1] : !state
%x01 = field.add %x0, %x1 : !pf
%x2 = affine.load %state[%chunk_idx * 4 + 2] : !state
%x3 = affine.load %state[%chunk_idx * 4 + 3] : !state
%x23 = field.add %x2, %x3 : !pf
%x0123 = field.add %x01, %x23 : !pf
%x01123 = field.add %x0123, %x1 : !pf
%x01233 = field.add %x0123, %x3 : !pf
%x00 = field.double %x0 : !pf
%x22 = field.double %x2 : !pf
// x[0] = x01123 + x01
%x0_new = field.add %x01123, %x01 : !pf
// x[1] = x01123 + 2*x[2]
%x1_new = field.add %x01123, %x22 : !pf
// x[2] = x01233 + x23
%x2_new = field.add %x01233, %x23 : !pf
// x[3] = x01233 + 2*x[0]
%x3_new = field.add %x01233, %x00 : !pf
// Store the sum in all output positions
affine.store %x0_new, %state[%chunk_idx * 4] : !state
affine.store %x1_new, %state[%chunk_idx * 4 + 1] : !state
affine.store %x2_new, %state[%chunk_idx * 4 + 2] : !state
affine.store %x3_new, %state[%chunk_idx * 4 + 3] : !state
}
// Now apply the outer circulant matrix
// Precompute the four sums of every four elements
// Compute sums: sums[k] = sum of state[j + k] for j = 0, 4, 8, 12
%sums = memref.alloca() : memref<4x!pf>
affine.for %k = 0 to 4 {
%val0 = affine.load %state[%k] : !state
%val1 = affine.load %state[%k + 4] : !state
%val2 = affine.load %state[%k + 8] : !state
%val3 = affine.load %state[%k + 12] : !state
%sum01 = field.add %val0, %val1 : !pf
%sum23 = field.add %val2, %val3 : !pf
%new_sum = field.add %sum01, %sum23 : !pf
affine.store %new_sum, %sums[%k] : memref<4x!pf>
}
// Apply the formula: y_i = x_i' + sums[i % 4]
affine.for %i = 0 to 4 {
%val0 = affine.load %state[%i] : !state
%val1 = affine.load %state[%i + 4] : !state
%val2 = affine.load %state[%i + 8] : !state
%val3 = affine.load %state[%i + 12] : !state
%sum = affine.load %sums[%i] : memref<4x!pf>
%sum0 = field.add %val0, %sum : !pf
%sum1 = field.add %val1, %sum : !pf
%sum2 = field.add %val2, %sum : !pf
%sum3 = field.add %val3, %sum : !pf
affine.store %sum0, %state[%i] : !state
affine.store %sum1, %state[%i + 4] : !state
affine.store %sum2, %state[%i + 8] : !state
affine.store %sum3, %state[%i + 12] : !state
}
return
}
// Internal layer matrix multiplication
func.func @internal_layer_mat_mul(%state: !state, %sum: !pf) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c4 = arith.constant 4 : index
%c5 = arith.constant 5 : index
%c6 = arith.constant 6 : index
%c7 = arith.constant 7 : index
%c8 = arith.constant 8 : index
%c9 = arith.constant 9 : index
%c10 = arith.constant 10 : index
%c11 = arith.constant 11 : index
%c12 = arith.constant 12 : index
%c13 = arith.constant 13 : index
%c14 = arith.constant 14 : index
%c15 = arith.constant 15 : index
// Precompute powers of 2 inverses using powui
// [-2, 1, 2, 1/2, 3, 4, -1/2, -3, -4, 1/2^8, 1/4, 1/8, 1/2^27, -1/2^8, -1/16, -1/2^27]
%inv_two = field.constant 134217727 : !pf
%inv_four = field.constant 1073741824 : !pf
%inv_eight = field.constant 536870912 : !pf
%inv_sixteen = field.constant 268435456 : !pf
%inv_256 = field.constant 16777216 : !pf
%inv_2_27 = field.constant 32 : !pf
// state[1] += sum
%s1 = memref.load %state[%c1] : !state
%new_s1 = field.add %s1, %sum : !pf
memref.store %new_s1, %state[%c1] : !state
// state[2] = state[2].double() + sum
%s2 = memref.load %state[%c2] : !state
%s2_double = field.double %s2 : !pf
%new_s2 = field.add %s2_double, %sum : !pf
memref.store %new_s2, %state[%c2] : !state
// state[3] = state[3].halve() + sum
%s3 = memref.load %state[%c3] : !state
%s3_halve = field.mul %s3, %inv_two : !pf
%new_s3 = field.add %s3_halve, %sum : !pf
memref.store %new_s3, %state[%c3] : !state
// state[4] = sum + state[4].double() + state[4]
%s4 = memref.load %state[%c4] : !state
%s4_double = field.double %s4 : !pf
%s4_sum = field.add %s4_double, %s4 : !pf
%new_s4 = field.add %sum, %s4_sum : !pf
memref.store %new_s4, %state[%c4] : !state
// state[5] = sum + state[5].double().double()
%s5 = memref.load %state[%c5] : !state
%s5_double = field.double %s5 : !pf
%s5_double_double = field.double %s5_double : !pf
%new_s5 = field.add %sum, %s5_double_double : !pf
memref.store %new_s5, %state[%c5] : !state
// state[6] = sum - state[6].halve()
%s6 = memref.load %state[%c6] : !state
%s6_halve = field.mul %s6, %inv_two : !pf
%new_s6 = field.sub %sum, %s6_halve : !pf
memref.store %new_s6, %state[%c6] : !state
// state[7] = sum - (state[7].double() + state[7])
%s7 = memref.load %state[%c7] : !state
%s7_double = field.double %s7 : !pf
%s7_sum = field.add %s7_double, %s7 : !pf
%new_s7 = field.sub %sum, %s7_sum : !pf
memref.store %new_s7, %state[%c7] : !state
// state[8] = sum - state[8].double().double()
%s8 = memref.load %state[%c8] : !state
%s8_double = field.double %s8 : !pf
%s8_double_double = field.double %s8_double : !pf
%new_s8 = field.sub %sum, %s8_double_double : !pf
memref.store %new_s8, %state[%c8] : !state
// state[9] = state[9] * inv_256 + sum
%s9 = memref.load %state[%c9] : !state
%s9_div_256 = field.mul %s9, %inv_256 : !pf
%new_s9 = field.add %s9_div_256, %sum : !pf
memref.store %new_s9, %state[%c9] : !state
// state[10] = state[10] * inv_four + sum
%s10 = memref.load %state[%c10] : !state
%s10_div_4 = field.mul %s10, %inv_four : !pf
%new_s10 = field.add %s10_div_4, %sum : !pf
memref.store %new_s10, %state[%c10] : !state
// state[11] = state[11] * inv_eight + sum
%s11 = memref.load %state[%c11] : !state
%s11_div_8 = field.mul %s11, %inv_eight : !pf
%new_s11 = field.add %s11_div_8, %sum : !pf
memref.store %new_s11, %state[%c11] : !state
// state[12] = state[12] * inv_2_27 + sum
%s12 = memref.load %state[%c12] : !state
%s12_div_27 = field.mul %s12, %inv_2_27 : !pf
%new_s12 = field.add %s12_div_27, %sum : !pf
memref.store %new_s12, %state[%c12] : !state
// state[13] = sum - state[13] * inv_256
%s13 = memref.load %state[%c13] : !state
%s13_div_256 = field.mul %s13, %inv_256 : !pf
%new_s13 = field.sub %sum, %s13_div_256 : !pf
memref.store %new_s13, %state[%c13] : !state
// state[14] = sum - state[14] * inv_sixteen
%s14 = memref.load %state[%c14] : !state
%s14_div_16 = field.mul %s14, %inv_sixteen : !pf
%new_s14 = field.sub %sum, %s14_div_16 : !pf
memref.store %new_s14, %state[%c14] : !state
// state[15] = sum - state[15] * inv_2_27
%s15 = memref.load %state[%c15] : !state
%s15_div_27 = field.mul %s15, %inv_2_27 : !pf
%new_s15 = field.sub %sum, %s15_div_27 : !pf
memref.store %new_s15, %state[%c15] : !state
return
}
// Internal layer: permutation (add RC to first element, S-box first, internal diffusion)
func.func @permute_state(%state: !state) {
// Convert to memref for in-place operations
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c4 = arith.constant 4 : index
%c5 = arith.constant 5 : index
%c6 = arith.constant 6 : index
%c7 = arith.constant 7 : index
%c8 = arith.constant 8 : index
%c9 = arith.constant 9 : index
%c10 = arith.constant 10 : index
%c11 = arith.constant 11 : index
%c12 = arith.constant 12 : index
%c13 = arith.constant 13 : index
%c14 = arith.constant 14 : index
%c15 = arith.constant 15 : index
// BABYBEAR_RC16_INTERNAL (13 scalar constants)
%rc_internal = arith.constant dense<[250494022, 528496384, 1472966118, 977089650, 1885890237, 1094557811, 147492661, 664163003, 398852570, 336233633, 1628648315, 888594966, 586791090]> : tensor<13xi32>
%rc_internal_mont = field.bitcast %rc_internal : tensor<13xi32> -> tensor<13x!pf>
// For each internal constant: add RC and S-box to first element, then apply matrix multiplication
affine.for %round = 0 to 13 {
// Get current round constant via tensor.extract
%rc = tensor.extract %rc_internal_mont[%round] : tensor<13x!pf>
// Add RC and apply S-box to first element
%s0 = memref.load %state[%c0] : !state
%elem0 = func.call @add_rc_and_sbox(%s0, %rc) : (!pf, !pf) -> !pf
// Compute sum of all elements using affine.for
// NOTE: this is extremely slow, so we manually add them.
// %zero = field.constant 0 : !pf
// %sum = affine.for %i = 0 to 16 iter_args(%acc = %zero) -> (!pf) {
// %elem = tensor.extract %t[%i] : tensor<16x!pf>
// %new_acc = field.add %acc, %elem : !pf
// affine.yield %new_acc : !pf
// }
%elem1 = memref.load %state[%c1] : memref<16x!pf>
%elem2 = memref.load %state[%c2] : memref<16x!pf>
%elem3 = memref.load %state[%c3] : memref<16x!pf>
%elem4 = memref.load %state[%c4] : memref<16x!pf>
%elem5 = memref.load %state[%c5] : memref<16x!pf>
%elem6 = memref.load %state[%c6] : memref<16x!pf>
%elem7 = memref.load %state[%c7] : memref<16x!pf>
%elem8 = memref.load %state[%c8] : memref<16x!pf>
%elem9 = memref.load %state[%c9] : memref<16x!pf>
%elem10 = memref.load %state[%c10] : memref<16x!pf>
%elem11 = memref.load %state[%c11] : memref<16x!pf>
%elem12 = memref.load %state[%c12] : memref<16x!pf>
%elem13 = memref.load %state[%c13] : memref<16x!pf>
%elem14 = memref.load %state[%c14] : memref<16x!pf>
%elem15 = memref.load %state[%c15] : memref<16x!pf>
// --- Step 2: Sum the elements using a reduction tree ---
// This structure allows for maximum parallel execution by the CPU.
// Level 1 (8 parallel additions)
%sum2_3 = field.add %elem2, %elem3 : !pf
%sum4_5 = field.add %elem4, %elem5 : !pf
%sum6_7 = field.add %elem6, %elem7 : !pf
%sum8_9 = field.add %elem8, %elem9 : !pf
%sum10_11 = field.add %elem10, %elem11 : !pf
%sum12_13 = field.add %elem12, %elem13 : !pf
%sum14_15 = field.add %elem14, %elem15 : !pf
// Level 2 (4 parallel additions)
%sum1_3 = field.add %elem1, %sum2_3 : !pf
%sum4_7 = field.add %sum4_5, %sum6_7 : !pf
%sum8_11 = field.add %sum8_9, %sum10_11 : !pf
%sum12_15 = field.add %sum12_13, %sum14_15 : !pf
// Level 3 (2 parallel additions)
%sum1_7 = field.add %sum1_3, %sum4_7 : !pf
%sum8_15 = field.add %sum8_11, %sum12_15 : !pf
// Level 4 (Partial sum)
%partial_sum = field.add %sum1_7, %sum8_15 : !pf
%total_sum = field.add %partial_sum, %elem0 : !pf
%new_s0 = field.sub %partial_sum, %elem0 : !pf
memref.store %new_s0, %state[%c0] : !state
// Apply internal layer matrix multiplication
func.call @internal_layer_mat_mul(%state, %total_sum) : (!state, !pf) -> ()
}
return
}
// External layer: terminal permutation (4 rounds: add RC, S-box, MDS)
func.func @permute_state_terminal(%state: !state) {
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
%c1 = arith.constant 1 : index
// BABYBEAR_RC16_EXTERNAL_FINAL (4 rounds x 16 constants)
%rc_external_const = arith.constant dense<[
[999830298, 304461056, 552699684, 450698925, 667466464, 1736509752, 1327760865, 1153241151, 816675655, 1076172858, 1914832527, 1668723429, 1365579850, 975704528, 1031625628, 1393317533],
[1554700828, 1023828605, 1610378860, 347744760, 1909572073, 739227895, 428565985, 633143046, 121797685, 94048546, 1369350241, 1250010422, 114268841, 515033604, 49052844, 1962329907],
[1380892638, 1860017417, 64711457, 9758460, 1681838395, 710850601, 1020228997, 1414164790, 1531515535, 36158805, 713604525, 89935127, 1870801994, 395985906, 1122769045, 1760811055],
[819787042, 134654834, 1755145179, 18433016, 1701878989, 1782339297, 1483861396, 962480061, 1857590724, 222440409, 63223417, 515206622, 1348364213, 973414686, 1591066884, 705852913]
]> : tensor<4x16xi32>
%rc_external_final = field.bitcast %rc_external_const : tensor<4x16xi32> -> tensor<4x16x!pf>
%state_tensor = bufferization.to_tensor %state restrict : memref<16x!pf> to tensor<16x!pf>
// Loop through 4 rounds of external terminal permutation
affine.for %round = 0 to 4 {
affine.for %i = 0 to 16 {
%s = tensor.extract %state_tensor[%i] : tensor<16x!pf>
%c = tensor.extract %rc_external_final[%round, %i] : tensor<4x16x!pf>
%sbox = func.call @add_rc_and_sbox(%s, %c) : (!pf, !pf) -> !pf
affine.store %sbox, %state[%i] : !state
}
// Apply MDS light permutation (in-place)
func.call @mds_light_permutation(%state) : (!state) -> ()
}
return
}
// External layer: initial permutation (MDS light + terminal permutation)
func.func @permute_state_initial(%state: !state) {
// First apply MDS light permutation
func.call @mds_light_permutation(%state) : (!state) -> ()
// Round constants for 16-width Poseidon2 on BabyBear
// BABYBEAR_RC16_EXTERNAL_INITIAL (4 rounds x 16 constants)
%rc_external_const = arith.constant dense<[
[1582131512, 1899519471, 1641921850, 462688640, 1293997949, 1380417575, 1932416963, 283521298, 1016708647, 35751290, 1270782647, 851730739, 795004022, 929571430, 523703523, 1593957757],
[895976710, 1742343460, 917700746, 1516725708, 1170237629, 785693164, 613651155, 352999196, 678775274, 1005433272, 1704854670, 1174551920, 508930349, 530338447, 1327158816, 1417652352],
[1153538870, 583201050, 397833841, 1440603828, 454600685, 174490638, 171758601, 1998476616, 1403697810, 1807736944, 450348306, 1458895865, 787037868, 1063762964, 1987002214, 481645916],
[1231767638, 1323639433, 238360103, 2012412459, 1024945356, 1108359895, 1284135849, 606928406, 1021455954, 719347978, 659671051, 769588663, 805534062, 592213995, 1752728055, 663410947]
]> : tensor<4x16xi32>
%rc_external_final = field.bitcast %rc_external_const : tensor<4x16xi32> -> tensor<4x16x!pf>
%state_tensor = bufferization.to_tensor %state restrict : memref<16x!pf> to tensor<16x!pf>
// Then apply terminal permutation with initial external constants
// Loop through 4 rounds of external terminal permutation
affine.for %round = 0 to 4 {
affine.for %i = 0 to 16 {
%s = tensor.extract %state_tensor[%i] : tensor<16x!pf>
%c = tensor.extract %rc_external_final[%round, %i] : tensor<4x16x!pf>
%sbox = func.call @add_rc_and_sbox(%s, %c) : (!pf, !pf) -> !pf
affine.store %sbox, %state[%i] : !state
}
// Apply MDS light permutation (in-place)
func.call @mds_light_permutation(%state) : (!state) -> ()
}
return
}
// Complete Poseidon2 permutation
func.func @poseidon2_permute(%state: !state) {
func.call @permute_state_initial(%state) : (!state) -> ()
func.call @permute_state(%state) : (!state) -> ()
func.call @permute_state_terminal(%state) : (!state) -> ()
return
}
func.func @permute_10000(%state : !state) attributes { llvm.emit_c_interface } {
affine.for %i = 0 to 10000 {
func.call @poseidon2_permute(%state) : (!state) -> ()
}
return
}
Python code
This is a Python implementation of poly/dense.rs from WHIR-p3:
import jax
import jax.random as rnd
import jax.numpy as jnp
import jax.lax as lax
import numpy.random as nprnd
### is_zero
@jax.jit
def is_zero(poly):
# A zero polynomial is all zeros, or an empty array
poly = jnp.array(poly)
return (poly.size == 0) | jnp.all(poly == 0)
### evaluate - use as is
@jax.jit
def evaluate(poly, x):
poly = jnp.array(poly)
return jnp.polyval(poly, x)
### random
@jax.jit
def random(key, poly):
# Generate random real numbers from normal, then round to nearest integer for "discrete real numbers"
# (Could also use randint but that produces ints, not real/float dtype)
poly = jnp.array(poly)
rand = rnd.uniform(key, poly.shape, minval=0., maxval=999)
return jnp.round(rand)
### lagrange_interpolation
@jax.jit
def has_dup_x_jit(xs):
xs_sorted = jnp.sort(xs)
return jnp.any(xs_sorted[1:] == xs_sorted[:-1])
@jax.jit
def lagrange_interpolation(values):
size = len(values)
if size == 0:
return jnp.zeros(0)
# Unzip the list of (x, y) pairs to separate arrays
xs, ys = zip(*values)
xs = jnp.array(xs)
ys = jnp.array(ys)
# Check for duplicate x-coordinates using JAX ops and return None if found
def dup_case(_):
return jnp.zeros(size) # or whatever you want to represent None
def no_dup_case(_):
# Normal interpolation computation here
# (return polynomial coefficients or whatever you compute)
monomial_init = jnp.zeros(size)
def body(i, acc):
result_poly, basis_poly = acc
current_y = jnp.polyval(result_poly, xs[i])
delta = ys[i] - current_y
basis_eval = jnp.polyval(basis_poly, xs[i])
c_i = delta / basis_eval
# element-wise multiplication
term = basis_poly * c_i
result_poly = result_poly + term
monomial = monomial_init.at[size - 1].set(-xs[i])
monomial_all = monomial.at[size - 2].set(1)
# After i steps, B(x) = (x - x_0)(x - x_1)...(x - x_{i-1})
basis_poly = jnp.polymul(basis_poly, monomial_all)
basis_poly = basis_poly[-size:]
return (result_poly, basis_poly)
# The result polynomial P(x) starts at zero and is updated iteratively.
zero_poly = jnp.zeros(size)
# The basis polynomial B(x) starts at 1.
basis_poly = zero_poly.at[size - 1].set(1)
result_poly, basis_poly = jax.lax.fori_loop(0, size, body, (zero_poly, basis_poly))
return result_poly
return lax.cond(has_dup_x_jit(xs), dup_case, no_dup_case, operand=None)
### add - use as is
@jax.jit
def add(a, b):
jax_array_a = jnp.array(a)
jax_array_b = jnp.array(b)
return jnp.polyadd(jax_array_a, jax_array_b)
### mul - use as is
@jax.jit
def mul(a, b):
jax_array_a = jnp.array(a)
jax_array_b = jnp.array(b)
return jnp.polymul(jax_array_a, jax_array_b)