This repository has been archived by the owner on Nov 21, 2019. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6
/
curvepoint.go
252 lines (205 loc) · 6.27 KB
/
curvepoint.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
// Copyright (c) 2017 Clearmatics Technologies Ltd
// SPDX-License-Identifier: LGPL-3.0+
package main
import (
"bytes"
"crypto/rand"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"github.com/clearmatics/bn256"
"math/big"
)
// CurvePoint represents a point on an elliptic curve
type CurvePoint struct {
z *bn256.G1
}
// MarshalJSON converts a CurvePoint to a JSON representation
func (c *CurvePoint) MarshalJSON() ([]byte, error) {
x, y := c.GetXY()
return json.Marshal(&struct {
X *hexBig `json:"x"`
Y *hexBig `json:"y"`
}{
X: (*hexBig)(x),
Y: (*hexBig)(y),
})
}
// UnmarshalJSON converts a JSON representation to a CurvePoint struct
func (c *CurvePoint) UnmarshalJSON(data []byte) error {
var aux struct {
X *hexBig `json:"x"`
Y *hexBig `json:"y"`
}
err := json.Unmarshal(data, &aux)
if err != nil {
return err
}
if aux.X == nil || aux.Y == nil {
return errors.New("Invalid Point, no X or Y specified")
}
if c.SetFromXY((*big.Int)(aux.X), (*big.Int)(aux.Y)) == nil {
return errors.New("Failed to deserialize CurvePoint")
}
return nil
}
// Equals returns true if X and Y of both curve points are equal
func (c CurvePoint) Equals(d *CurvePoint) bool {
return bytes.Compare(c.Marshal(), d.Marshal()) == 0
}
// Prime returns the prime component of the BN256 curve
func (c CurvePoint) Prime() *big.Int {
return bn256.P
}
// Order returns the order component of the BN256 curve
func (c CurvePoint) Order() *big.Int {
return bn256.Order
}
// isBetween checks number is within range of (lower,upper)
// e.g. number > lower && number < upper
func isBetween(number *big.Int, lower *big.Int, upper *big.Int) bool {
return false == (number.Cmp(lower) <= 0 || number.Cmp(upper) >= 0)
}
// randomPositiveBelow generates a uniformly random number between 1 and `below`
func randomPositiveBelow(below *big.Int) *big.Int {
for {
number, err := rand.Int(rand.Reader, bn256.Order)
if err != nil {
return nil
}
// x > 0 && x < below
if isBetween(number, bigZero, below) {
return number
}
}
}
// RandomN returns a uniformly random integer between 1 and P-1
func (c CurvePoint) RandomN() *big.Int {
return randomPositiveBelow(c.Order())
}
// RandomP returns a uniformly random integer between 1 and P-1
func (c CurvePoint) RandomP() *big.Int {
return randomPositiveBelow(c.Prime())
}
// GetXY returns the X and Y coordinates for a given CurvePoint
func (c CurvePoint) GetXY() (*big.Int, *big.Int) {
// Each value is a 256-bit number.
const numBytes = 256 / 8
if c.z != nil {
m := c.z.Marshal()
x := new(big.Int).SetBytes(m[0*numBytes : 1*numBytes])
y := new(big.Int).SetBytes(m[1*numBytes : 2*numBytes])
return x, y
}
return nil, nil
}
// SetFromXY returns a CurvePoint based on the provided x and Y coordinates
func (c *CurvePoint) SetFromXY(x *big.Int, y *big.Int) *CurvePoint {
const numBytes = 256 / 8
// XXX: there's no equivalent to SetCurvePoints, other than Unmarshal
xBytes := new(big.Int).Mod(x, bn256.P).Bytes()
yBytes := new(big.Int).Mod(y, bn256.P).Bytes()
m := make([]byte, numBytes*2)
copy(m[1*numBytes-len(xBytes):], xBytes)
copy(m[2*numBytes-len(yBytes):], yBytes)
z, isOk := new(bn256.G1).Unmarshal(m)
if isOk {
c.z = z
return c
}
return nil
}
// Marshal converts a CurvePoint to a JSON representation
func (c CurvePoint) Marshal() []byte {
return c.z.Marshal()
}
// Unmarshal converts a JSON representation to a CurvePoint struct
func (c CurvePoint) Unmarshal(m []byte) bool {
_, ret := c.z.Unmarshal(m)
return ret
}
// IsOnCurve returns true if point is on curve
func (c CurvePoint) IsOnCurve() bool {
return c.z.IsOnCurve()
}
func (c CurvePoint) String() string {
return fmt.Sprintf("CurvePoint(%v)", c.z)
}
// NewCurvePointFromString create a CurvePoint from a string representation
func NewCurvePointFromString(s []byte) *CurvePoint {
return NewCurvePointFromHash(sha256.Sum256(s))
}
// NewCurvePointFromHash implements the 'try-and-increment' method of
// hashing into a curve which preserves random oracle proofs of security
//
func NewCurvePointFromHash(h [sha256.Size]byte) *CurvePoint {
P := CurvePoint{}.Prime()
N := CurvePoint{}.Order()
// (p+1) / 1
A, _ := new(big.Int).SetString("c19139cb84c680a6e14116da060561765e05aa45a1c72a34f082305b61f3f52", 16)
x := new(big.Int).SetBytes(h[:])
x.Mod(x, N)
// TODO: limit number of iterations?
// y² = x³ + B
for {
xx := new(big.Int).Mul(x, x) // x²
xx.Mod(xx, P)
xxx := xx.Mul(xx, x) // x³
xxx.Mod(xxx, P)
beta := new(big.Int).Add(xxx, curveB) // x³ + B
beta.Mod(beta, P)
//y := new(big.Int).ModSqrt(t, P) // y = √(x³+B)
y := new(big.Int).Exp(beta, A, P)
if y != nil {
// Then verify (√(x³+B)%P)² == (x³+B)%P
z := new(big.Int).Mul(y, y)
z.Mod(z, P)
if z.Cmp(beta) == 0 {
curveout := new(CurvePoint).SetFromXY(x, y)
if curveout != nil {
return curveout
}
}
}
x.Add(x, bigOne)
}
}
// ScalarBaseMult returns the product x where the result and base are the x coordinates of group points, base is the standard generator
func (c CurvePoint) ScalarBaseMult(x *big.Int) CurvePoint {
return CurvePoint{new(bn256.G1).ScalarBaseMult(x)}
}
// ScalarMult returns the product c*x where the result and base are the x coordinates of group points
func (c CurvePoint) ScalarMult(x *big.Int) CurvePoint {
return CurvePoint{new(bn256.G1).ScalarMult(c.z, x)}
}
// Add performs an addition of two elliptic curve points
func (c CurvePoint) Add(y CurvePoint) CurvePoint {
return CurvePoint{new(bn256.G1).Add(c.z, y.z)}
}
// ParameterPointAdd returns the addition of c scaled by cj and tj as a curve point
func (c CurvePoint) ParameterPointAdd(tj *big.Int, cj *big.Int) CurvePoint {
a := CurvePoint{}.ScalarBaseMult(tj)
pk := c.ScalarMult(cj)
return a.Add(pk)
}
// HashPointAdd returns the addition of hashSP scaled by cj and c scaled by tj
func (c CurvePoint) HashPointAdd(hashSP CurvePoint, tj *big.Int, cj *big.Int) CurvePoint {
b := c.ScalarMult(tj)
bj := hashSP.ScalarMult(cj)
return b.Add(bj)
}
// ParseCurvePoint parses string representations of X and Y points
// these can be hex or base10 encoded
func ParseCurvePoint(pointX string, pointY string) *CurvePoint {
x, errX := ParseBigInt(pointX)
y, errY := ParseBigInt(pointY)
if nil != errX || nil != errY {
return nil
}
c := CurvePoint{}
if c.SetFromXY(x, y) != nil {
return &c
}
return nil
}