Skip to content

Commit a312b02

Browse files
committed
opt: evaluation with interpolation
1 parent 4de3622 commit a312b02

File tree

2 files changed

+127
-61
lines changed

2 files changed

+127
-61
lines changed

src/cac/mod.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ mod tests {
1111
use sha2::{Digest, Sha256};
1212

1313
use super::*;
14-
use crate::cac::{adaptor_sigs::AdaptorInfo, vsss::lagrange_interpolate_at_index};
14+
use crate::cac::{adaptor_sigs::AdaptorInfo, vsss::lagrange_interpolate_whole_polynomial};
1515

1616
#[test]
1717
fn test_full_flow() {
@@ -103,13 +103,16 @@ mod tests {
103103
&[(unused_share_commit.0, unused_share_secret)],
104104
]
105105
.concat();
106-
let missing_shares = (0..n)
106+
107+
let missing_points: Vec<usize> = (0..n)
107108
.filter(|&i| combined_shares.iter().all(|(j, _)| i != *j))
108-
.map(|i| (i, lagrange_interpolate_at_index(&combined_shares, i)))
109-
.collect::<Vec<_>>();
109+
.collect();
110+
111+
let missing_shares =
112+
lagrange_interpolate_whole_polynomial(&combined_shares, &missing_points);
110113

111-
for share in missing_shares {
112-
assert_eq!(share, all_shares[share.0]);
114+
for (x, y) in missing_points.into_iter().zip(missing_shares.into_iter()) {
115+
assert_eq!(all_shares[x].1, y)
113116
}
114117
}
115118
}

src/cac/vsss.rs

Lines changed: 118 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -97,42 +97,104 @@ impl ShareCommits {
9797
}
9898
}
9999

100-
// find the missing point by using the Lagrange interpolation, see https://en.wikipedia.org/wiki/Lagrange_polynomial
101-
// Input is a vec of (index, value), where index is 0-based
102-
pub fn lagrange_interpolate_at_index(points: &[(usize, Fr)], index: usize) -> Fr {
103-
lagrange_interpolate_at_x(points, Fr::from((index + 1) as u64))
104-
}
100+
/// Returns the values of the polynomial defined by known_points at missing_points, in the given order
101+
/// Assumes that points in the two sets are disjoint and their union is set of natural numbers smaller than < n (including 0) for n = len(known_points) + len(missing_points)
102+
/// Uses the fact that the number of missing points will be small compared to the known ones to evalute polynomials with factorials
103+
/// so, assuming field inversion and multiplication complexity are I and M, total complexity is O(I + len(missing_points) * n * O)
104+
pub fn lagrange_interpolate_whole_polynomial(
105+
known_points: &[(usize, Fr)],
106+
missing_points: &[usize],
107+
) -> Vec<Fr> {
108+
assert!(!known_points.is_empty() || !missing_points.is_empty());
109+
110+
let n = known_points.len() + missing_points.len();
111+
let factorial: Vec<Fr> = std::iter::once(Fr::one())
112+
.chain((1..n).scan(Fr::one(), |state, i| {
113+
*state *= Fr::from(i as u64);
114+
Some(*state)
115+
}))
116+
.collect();
117+
118+
// inv_fact[i] = 1 / factorial[i]
119+
let inv_factorial: Vec<Fr> = (0..n)
120+
.rev()
121+
.scan(
122+
factorial[n - 1]
123+
.inverse()
124+
.expect("This is guaranteed to be non-zero"),
125+
|cur_state, i| {
126+
let ith_value = *cur_state;
127+
*cur_state *= Fr::from(i as u64);
128+
Some(ith_value)
129+
},
130+
)
131+
.collect::<Vec<_>>()
132+
.into_iter()
133+
.rev()
134+
.collect();
135+
136+
let inv: Vec<Fr> = (0..n)
137+
.map(|i| {
138+
if i == 0 {
139+
Fr::zero() //This should never be used
140+
} else {
141+
inv_factorial[i] * factorial[i - 1]
142+
}
143+
})
144+
.collect();
145+
146+
// For x, calculates the multiplication of (x - i) for all i in known_points (known_points = 0..n \ missing_points)
147+
// returns the inverse of the multiplication result, based on the parameter
148+
let get_coeff = |x: usize, is_inverse: bool| {
149+
// corner case checks for 0 and n - 1 are not needed since inv_factorial[0] = factorial[0] = 1
150+
let mut result: Fr = if is_inverse {
151+
inv_factorial[x] * inv_factorial[n - 1 - x]
152+
} else {
153+
factorial[x] * factorial[n - 1 - x]
154+
};
155+
if (n - x).is_multiple_of(2) {
156+
result *= -Fr::one();
157+
}
158+
for i in missing_points {
159+
if *i == x {
160+
continue;
161+
};
162+
result *= if is_inverse {
163+
Fr::from(x as i64 - *i as i64)
164+
} else if *i < x {
165+
inv[x - *i]
166+
} else {
167+
-inv[*i - x]
168+
}
169+
}
170+
result
171+
};
172+
173+
let lagrange_basis_polynomial_coeffs: Vec<(usize, Fr)> = known_points
174+
.iter()
175+
.map(|(x, y)| (*x, get_coeff(*x, true) * y))
176+
.collect();
105177

106-
// internal function that allows also queying g(0)
107-
fn lagrange_interpolate_at_x(points: &[(usize, Fr)], x: Fr) -> Fr {
108-
let sc = |val: usize| Fr::from(val as u64);
109-
points
178+
missing_points
110179
.iter()
111-
.enumerate()
112-
.fold(Fr::zero(), |result, (i, (idx, y_i))| {
113-
let x_i = sc(*idx + 1); // share 0 corresponds to x=1
114-
// Compute L_i(x)
115-
let (num, denum) = points.iter().enumerate().filter(|(j, _)| *j != i).fold(
116-
(Fr::one(), Fr::one()),
117-
|(num, denum), (_, (idx, _))| {
118-
let x_j = sc(*idx + 1); // share 0 corresponds to x=1
119-
120-
(num * (x - x_j), denum * (x_i - x_j))
121-
},
122-
);
123-
124-
// calculate li = num / denum = num * denum^{-1}
125-
let denum_inv = denum.inverse().expect("x_i - x_j must be nonzero");
126-
let li = num * denum_inv;
127-
128-
result + *y_i * li
180+
.map(|x| {
181+
let all_differences = get_coeff(*x, false);
182+
lagrange_basis_polynomial_coeffs
183+
.iter()
184+
.fold(Fr::zero(), |result, (i, coeff_i)| {
185+
let ith_diff_inv: Fr = if i < x { inv[x - i] } else { -inv[i - x] };
186+
result + ith_diff_inv * all_differences * *coeff_i
187+
})
129188
})
189+
.collect()
130190
}
131191

132192
#[cfg(test)]
133193
mod tests {
134194
use super::*;
135-
195+
use rand::{SeedableRng, seq::index::sample};
196+
use rand_chacha::ChaCha20Rng;
197+
use std::collections::HashSet;
136198
#[test]
137199
fn test_polynomial_eval() {
138200
let polynomial = Polynomial::<Fr>::rand(rand::thread_rng(), 2);
@@ -146,33 +208,6 @@ mod tests {
146208
}
147209
}
148210

149-
#[test]
150-
fn test_interpolation_from_coefficients() {
151-
let polynomial_degree = 3;
152-
let polynomial = Polynomial::rand(rand::thread_rng(), polynomial_degree);
153-
154-
let num_shares = polynomial_degree + 1;
155-
let points = polynomial.shares(num_shares);
156-
157-
let secret = lagrange_interpolate_at_x(&points, Fr::zero());
158-
159-
assert_eq!(secret, polynomial.0[0]);
160-
}
161-
162-
#[test]
163-
fn test_interpolate_missing_shares() {
164-
let polynomial_degree = 3;
165-
let polynomial = Polynomial::rand(rand::thread_rng(), polynomial_degree);
166-
let points = polynomial.shares(6);
167-
let selected_points = &points[..polynomial_degree + 1];
168-
let missing_points = &points[polynomial_degree + 1..];
169-
170-
for (i, share) in missing_points.iter() {
171-
let reconstructed = lagrange_interpolate_at_index(selected_points, *i);
172-
assert_eq!(reconstructed, *share);
173-
}
174-
}
175-
176211
#[test]
177212
fn test_commit_verification() {
178213
let polynomial_degree = 3;
@@ -187,4 +222,32 @@ mod tests {
187222
let shares = polynomial.shares(num_shares);
188223
share_commits.verify_shares(&shares).unwrap();
189224
}
225+
226+
#[test]
227+
fn test_interpolation() {
228+
for (n_revealed, n_hidden) in vec![(5usize, 2usize), (100, 50), (300, 7)] {
229+
// Assumes one of the revealed ones is 0, as it will be in application, includes it in the n_revealed ones
230+
let n_total = n_revealed + n_hidden;
231+
let mut seed_rng = ChaCha20Rng::seed_from_u64(42);
232+
let hidden_points = sample(&mut seed_rng, n_total, n_hidden)
233+
.into_vec()
234+
.into_iter()
235+
.map(|x| x + 1)
236+
.collect::<Vec<_>>();
237+
let polynomial = Polynomial::rand(seed_rng, n_revealed - 1);
238+
let points = polynomial.shares(n_total); //points[i].0 = i
239+
240+
let aux_set: HashSet<_> = hidden_points.iter().copied().collect();
241+
let known_points: Vec<(usize, Fr)> = points
242+
.clone()
243+
.into_iter()
244+
.filter(|(x, _)| !aux_set.contains(x))
245+
.collect();
246+
let answer = lagrange_interpolate_whole_polynomial(&known_points, &hidden_points);
247+
248+
for (x, y) in hidden_points.into_iter().zip(answer.into_iter()) {
249+
assert_eq!(points[x].1, y);
250+
}
251+
}
252+
}
190253
}

0 commit comments

Comments
 (0)