Understanding Interactive Proof Systems and Sum Check Protocol: Part 2
In Part 1, we covered the math, and the logic behind Sum-Check Protocol. In this part, we will be hacking up a quick Sum-Check protocol source code for demonstration.
First, let's write some utils functions inside a file named utils.go
utils.go:
First, we write a helper function that returns the number of arguments taken by a function.
type FuncType func(...int) int
func Arity(f interface{}) int {
// Get the type of f, which should be a function.
fType := reflect.TypeOf(f)
if fType.Kind() != reflect.Func {
// Optionally, handle the case where f is not a function.
return -1
}
// Return the number of input arguments.
return fType.NumIn()
}
Then we write a helper function, which takes an input n, and returns its binary vector, front-padded to pad length ( taken as input again ).
func ToBits(n int, padToLen int) []int {
binStr := strconv.FormatInt(int64(n), 2)
if len(binStr) > padToLen {
padToLen = len(binStr)
}
v := make([]int, len(binStr))
for i, ch := range binStr {
if ch == '1' {
v[i] = 1
} else {
v[i] = 0
}
}
diff := padToLen - len(v)
paddedV := make([]int, diff)
return append(paddedV, v...)
}
- Binary Representation:
- The line
binStr := strconv.FormatInt(int64(n), 2)
converts the integern
into a binary string representation. For instance, ifn
is6
,binStr
will be"110"
.
2. Creating a Slice of Integers:
- The subsequent loop creates a slice
v
of integers where each element corresponds to a digit in the binary string. '1's in the binary string are represented as1
in the slice, and '0's as0
.
3. Padding:
- The function then calculates
diff
, the difference between the desired lengthpadToLen
and the actual length of the binary representation. If the binary representation is shorter thanpadToLen
, it needs to be padded with zeroes. paddedV := make([]int, diff)
creates a slice of zeroes of lengthdiff
.- Finally, the function returns a new slice that concatenates the padding and the binary representation slice (
append(paddedV, v...)
).
Lastly: We write a util function that returns the degree of j’th variable in g.
func DegJ(g FuncType, j int) int {
exp := 1
for {
args := make([]int, 1)
for i := range args {
if i == j {
args[i] = 100
} else {
args[i] = 1
}
}
out1 := g(args[0])
args[0] = 1000
out2 := g(args[0])
// Consider a function f(x) = x²
// To find the degree of x (assuming it's the second variable, so x = 1), the function would compare f(100) with f(1000).
// If x is cubed (x³), the output should scale by 1000^ 3 / 100 ^ 3
// when x changes from 100 to 1000. The function checks if this scaling holds to estimate the degree.
if math.Abs(float64(out1)/math.Pow(100, float64(exp))-float64(out2)/math.Pow(1000, float64(exp))) < 1 {
return exp
} else if exp > 10 {
panic("exp grew larger than 10")
} else {
exp++
}
}
}
verifier.go:
This is where all the logic related to the verifier takes place.
Firstly we define the verifier as follows:
type Verifier struct {
g FuncType
gArity int // represents the number of inputs to polynomial
H int // witness
randomChallenges []int
round int
polynomials []FuncType
}
Now, we need a function that accepts polynomial sent by the prover:
func (v *Verifier) RecievePolynomials(polynomial FuncType) {
v.polynomials = append(v.polynomials, polynomial)
}
Now we need a function, that can verify that the latest polynomial is a univariant polynomial of almost deg_j(g) and that the relationship mentioned:
For this, we write another function on the verifier struct:
func (v *Verifier) CheckLatestPolynomial() error {
latestPoly := v.polynomials[len(v.polynomials)-1]
degLatest := DegJ(latestPoly, 0)
degMax := DegJ(v.g, v.round-1)
if degLatest > degMax {
return fmt.Errorf("Prover sent polynomial of degree %d greater than expected : %d", degLatest, degMax)
}
newSum := latestPoly(0) + latestPoly(1)
var check int
if v.round == 1 {
check = v.H
} else {
check = v.polynomials[len(v.polynomials)-2](v.randomChallenges[len(v.randomChallenges)-1])
}
if check != newSum {
return fmt.Errorf("Prover sent incorrect polynomials: %d, expected %d", newSum, check)
}
return nil
}
If you read this, you will know that we also need to have a mechanism to generate a random number r and send it as a challenge to the prover. To achieve this:
func (v *Verifier) GetNewRandomValueAndSend(p *Prover) {
rand.Seed(time.Now().UnixNano())
v.randomChallenges = append(v.randomChallenges, rand.Intn(2))
p.ReceiveChallenge(v.randomChallenges[len(v.randomChallenges)-1])
v.round++
}
Now, lastly, there is the part where in the final step, the verifier has all the random challenges, and now it uses them to calculate the final value of g. This should match with the final function sₙ(Xn) sent by the prover. If the values match, the verifier accepts.
func (v *Verifier) EvaluateAndCheckGV() (bool, error) {
if len(v.randomChallenges) != v.gArity-1 {
return false, fmt.Errorf("Incorrect number of random challenges")
}
v.randomChallenges = append(v.randomChallenges, rand.Intn(2))
gFinal := v.g(v.randomChallenges...)
check := v.polynomials[len(v.polynomials)-1](v.randomChallenges[len(v.randomChallenges)-1])
if gFinal != check {
return false, fmt.Errorf("Prover sent incorrect final polynomials")
}
fmt.Println("VERIFIER ACCEPTS")
return true, nil
}
And that’s it. We have a basic code ready for a verifier in a Sum-Check Protocol!
prover.go:
This is where all the logic related to the prover goes.
Firstly, we define the prover struct as follows:
type Prover struct {
gArity int
randomChallenges []int
cachedPolynomials []FuncType
round int
H int
}
Now, remember in the Sum-Check protocol, a new polynomial is calculated at each interaction, where several variables = original arguments — round number.
The code iterates from 0
to 2^pad - 1
. In each iteration, ToBits
is used to create a binary representation of the loop counter i
, padded to pad
bits. This binary slice is prepended with the first argument (args[0]
), and then passed to the polynomial function poly
.
This function dynamically generates inputs for the polynomial poly
, combining a constant first argument with varying binary patterns, and aggregating the outputs, playing a key role in the sum-check protocol.
func (p *Prover) ComputeAndSendNextPolynomial(v *Verifier) {
round := p.round
poly := p.cachedPolynomials[len(p.cachedPolynomials)-1]
gJ := func(args ...int) int {
if len(args) == 0 {
// Handle the case where no arguments are passed
panic("gJ requires at least one argument")
}
pad := p.gArity - round
var sum int
for i := 0; i < (1 << pad); i++ {
args := append([]int{args[0]}, ToBits(i, pad)...)
sum += poly(args...)
}
return sum
}
v.RecievePolynomials(gJ)
p.round++
}
The iteration process in the gJ
the function is fundamental to the sum-check protocol, and it's designed to systematically evaluate a polynomial over all possible binary combinations of a certain length. Here's why we iterate this way:
- Exhaustive Evaluation: The goal is to evaluate the polynomial
poly
for every possible combination of binary inputs. This ensures a thorough and complete assessment of the polynomial's behavior across its entire domain. - Binary Combinations: By iterating from 0 to 2^(pad−1) and converting these numbers to binary, we generate all possible combinations of binary digits of length
pad
. This is a standard way to cover all cases in binary representation. - Reducing Complexity Per Round: In each round of the sum-check protocol, one less variable is considered, reducing the computational complexity step by step.
pad
decreases with each round, reflecting this reduction.
Lastly, we need a function to receive challenges from the verifier:
func (p *Prover) ReceiveChallenge(challenge int) {
p.randomChallenges = append(p.randomChallenges, challenge)
p.CacheNext(challenge)
fmt.Printf("Received challenge %d, initiating round %d\n", challenge, p.round)
}
where CacheNext is simply:
func (p *Prover) CacheNext(challenge int) {
poly := p.cachedPolynomials[len(p.cachedPolynomials)-1]
nextPoly := func(args ...int) int {
return poly(append([]int{challenge}, args...)...)
}
p.cachedPolynomials = append(p.cachedPolynomials, nextPoly)
}
Now to run this protocol, we can define a simple main.go file with the following code to run the prover and the verifier:
type SumcheckProtocol struct {
gArity int
p *Prover
v *Verifier
round int
done bool
}
// do the initialization
....
// Advance protocol by 1 round
func (s *SumcheckProtocol) AdvanceRound() {
if s.done {
panic("Sumcheck protocol is finished")
}
s.p.ComputeAndSendNextPolynomial(s.v)
s.v.CheckLatestPolynomial()
if s.round == s.gArity {
// final round
s.done, _ = s.v.EvaluateAndCheckGV()
} else {
s.v.GetNewRandomValueAndSend(s.p)
s.round++
}
}
// Advance protocol to the end
func (s *SumcheckProtocol) AdvanceToEnd(verbose bool) {
for !s.done {
if verbose {
fmt.Println("Advance Output:", s)
}
s.AdvanceRound()
}
}
Now time to test the protocol !!!
To do this, we create a new file main_test.go and add the following lines:
func TestSumcheck(t *testing.T) {
g := func(args ...int) int {
a := args[0]
return a + a + a*a
}
protocol := NewSumcheckProtocol(g)
fmt.Print(protocol)
protocol.AdvanceToEnd(true)
f := func(args ...int) int {
a := args[0]
return a*a*a + a + a
}
protocol = NewSumcheckProtocol(f)
protocol.AdvanceToEnd(true)
ff := func(args ...int) int {
a := args[0]
return a*a*a + a + a + a*a
}
protocol = NewSumcheckProtocol(ff)
protocol.AdvanceToEnd(true)
}
And Hurray 🎉 🎉 🎉 🎉 !! You just wrote the Sum-Check protocol from scratch.