1
2
3
4
5 package hpke
6
7 import (
8 "crypto/ecdh"
9 "crypto/internal/rand"
10 "errors"
11 "internal/byteorder"
12 "slices"
13 )
14
15
16
17 type KEM interface {
18
19 ID() uint16
20
21
22 GenerateKey() (PrivateKey, error)
23
24
25
26
27 NewPublicKey([]byte) (PublicKey, error)
28
29
30
31
32 NewPrivateKey([]byte) (PrivateKey, error)
33
34
35
36
37 DeriveKeyPair(ikm []byte) (PrivateKey, error)
38
39 encSize() int
40 }
41
42
43
44
45
46 func NewKEM(id uint16) (KEM, error) {
47 switch id {
48 case 0x0010:
49 return DHKEM(ecdh.P256()), nil
50 case 0x0011:
51 return DHKEM(ecdh.P384()), nil
52 case 0x0012:
53 return DHKEM(ecdh.P521()), nil
54 case 0x0020:
55 return DHKEM(ecdh.X25519()), nil
56 case 0x0041:
57 return MLKEM768(), nil
58 case 0x0042:
59 return MLKEM1024(), nil
60 case 0x647a:
61 return MLKEM768X25519(), nil
62 case 0x0050:
63 return MLKEM768P256(), nil
64 case 0x0051:
65 return MLKEM1024P384(), nil
66 default:
67 return nil, errors.New("unsupported KEM")
68 }
69 }
70
71
72
73
74
75
76 type PublicKey interface {
77
78 KEM() KEM
79
80
81 Bytes() []byte
82
83 encap() (sharedSecret, enc []byte, err error)
84 }
85
86
87
88
89
90
91 type PrivateKey interface {
92
93 KEM() KEM
94
95
96
97
98
99
100 Bytes() ([]byte, error)
101
102
103 PublicKey() PublicKey
104
105 decap(enc []byte) (sharedSecret []byte, err error)
106 }
107
108 type dhKEM struct {
109 kdf KDF
110 id uint16
111 curve ecdh.Curve
112 Nsecret uint16
113 Nsk uint16
114 Nenc int
115 }
116
117 func (kem *dhKEM) extractAndExpand(dhKey, kemContext []byte) ([]byte, error) {
118 suiteID := byteorder.BEAppendUint16([]byte("KEM"), kem.id)
119 eaePRK, err := kem.kdf.labeledExtract(suiteID, nil, "eae_prk", dhKey)
120 if err != nil {
121 return nil, err
122 }
123 return kem.kdf.labeledExpand(suiteID, eaePRK, "shared_secret", kemContext, kem.Nsecret)
124 }
125
126 func (kem *dhKEM) ID() uint16 {
127 return kem.id
128 }
129
130 func (kem *dhKEM) encSize() int {
131 return kem.Nenc
132 }
133
134 var dhKEMP256 = &dhKEM{HKDFSHA256(), 0x0010, ecdh.P256(), 32, 32, 65}
135 var dhKEMP384 = &dhKEM{HKDFSHA384(), 0x0011, ecdh.P384(), 48, 48, 97}
136 var dhKEMP521 = &dhKEM{HKDFSHA512(), 0x0012, ecdh.P521(), 64, 66, 133}
137 var dhKEMX25519 = &dhKEM{HKDFSHA256(), 0x0020, ecdh.X25519(), 32, 32, 32}
138
139
140
141
142
143
144
145
146
147 func DHKEM(curve ecdh.Curve) KEM {
148 switch curve {
149 case ecdh.P256():
150 return dhKEMP256
151 case ecdh.P384():
152 return dhKEMP384
153 case ecdh.P521():
154 return dhKEMP521
155 case ecdh.X25519():
156 return dhKEMX25519
157 default:
158
159
160
161 return unsupportedCurveKEM{}
162 }
163 }
164
165 type unsupportedCurveKEM struct{}
166
167 func (unsupportedCurveKEM) ID() uint16 {
168 return 0
169 }
170 func (unsupportedCurveKEM) GenerateKey() (PrivateKey, error) {
171 return nil, errors.New("unsupported curve")
172 }
173 func (unsupportedCurveKEM) NewPublicKey([]byte) (PublicKey, error) {
174 return nil, errors.New("unsupported curve")
175 }
176 func (unsupportedCurveKEM) NewPrivateKey([]byte) (PrivateKey, error) {
177 return nil, errors.New("unsupported curve")
178 }
179 func (unsupportedCurveKEM) DeriveKeyPair([]byte) (PrivateKey, error) {
180 return nil, errors.New("unsupported curve")
181 }
182 func (unsupportedCurveKEM) encSize() int {
183 return 0
184 }
185
186 type dhKEMPublicKey struct {
187 kem *dhKEM
188 pub *ecdh.PublicKey
189 }
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204 func NewDHKEMPublicKey(pub *ecdh.PublicKey) (PublicKey, error) {
205 kem, ok := DHKEM(pub.Curve()).(*dhKEM)
206 if !ok {
207 return nil, errors.New("unsupported curve")
208 }
209 return &dhKEMPublicKey{
210 kem: kem,
211 pub: pub,
212 }, nil
213 }
214
215 func (kem *dhKEM) NewPublicKey(data []byte) (PublicKey, error) {
216 pub, err := kem.curve.NewPublicKey(data)
217 if err != nil {
218 return nil, err
219 }
220 return NewDHKEMPublicKey(pub)
221 }
222
223 func (pk *dhKEMPublicKey) KEM() KEM {
224 return pk.kem
225 }
226
227 func (pk *dhKEMPublicKey) Bytes() []byte {
228 return pk.pub.Bytes()
229 }
230
231
232
233 var testingOnlyGenerateKey func() *ecdh.PrivateKey
234
235 func (pk *dhKEMPublicKey) encap() (sharedSecret []byte, encapPub []byte, err error) {
236 privEph, err := pk.pub.Curve().GenerateKey(rand.Reader)
237 if err != nil {
238 return nil, nil, err
239 }
240 if testingOnlyGenerateKey != nil {
241 privEph = testingOnlyGenerateKey()
242 }
243 dhVal, err := privEph.ECDH(pk.pub)
244 if err != nil {
245 return nil, nil, err
246 }
247 encPubEph := privEph.PublicKey().Bytes()
248
249 encPubRecip := pk.pub.Bytes()
250 kemContext := append(encPubEph, encPubRecip...)
251 sharedSecret, err = pk.kem.extractAndExpand(dhVal, kemContext)
252 if err != nil {
253 return nil, nil, err
254 }
255 return sharedSecret, encPubEph, nil
256 }
257
258 type dhKEMPrivateKey struct {
259 kem *dhKEM
260 priv ecdh.KeyExchanger
261 }
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277 func NewDHKEMPrivateKey(priv ecdh.KeyExchanger) (PrivateKey, error) {
278 kem, ok := DHKEM(priv.Curve()).(*dhKEM)
279 if !ok {
280 return nil, errors.New("unsupported curve")
281 }
282 return &dhKEMPrivateKey{
283 kem: kem,
284 priv: priv,
285 }, nil
286 }
287
288 func (kem *dhKEM) GenerateKey() (PrivateKey, error) {
289 priv, err := kem.curve.GenerateKey(rand.Reader)
290 if err != nil {
291 return nil, err
292 }
293 return NewDHKEMPrivateKey(priv)
294 }
295
296 func (kem *dhKEM) NewPrivateKey(ikm []byte) (PrivateKey, error) {
297 priv, err := kem.curve.NewPrivateKey(ikm)
298 if err != nil {
299 return nil, err
300 }
301 return NewDHKEMPrivateKey(priv)
302 }
303
304 func (kem *dhKEM) DeriveKeyPair(ikm []byte) (PrivateKey, error) {
305
306 suiteID := byteorder.BEAppendUint16([]byte("KEM"), kem.id)
307 prk, err := kem.kdf.labeledExtract(suiteID, nil, "dkp_prk", ikm)
308 if err != nil {
309 return nil, err
310 }
311 if kem == dhKEMX25519 {
312 s, err := kem.kdf.labeledExpand(suiteID, prk, "sk", nil, kem.Nsk)
313 if err != nil {
314 return nil, err
315 }
316 return kem.NewPrivateKey(s)
317 }
318 var counter uint8
319 for counter < 4 {
320 s, err := kem.kdf.labeledExpand(suiteID, prk, "candidate", []byte{counter}, kem.Nsk)
321 if err != nil {
322 return nil, err
323 }
324 if kem == dhKEMP521 {
325 s[0] &= 0x01
326 }
327 r, err := kem.NewPrivateKey(s)
328 if err != nil {
329 counter++
330 continue
331 }
332 return r, nil
333 }
334 panic("chance of four rejections is < 2^-128")
335 }
336
337 func (k *dhKEMPrivateKey) KEM() KEM {
338 return k.kem
339 }
340
341 func (k *dhKEMPrivateKey) Bytes() ([]byte, error) {
342
343
344
345
346
347
348
349
350
351 priv, ok := k.priv.(*ecdh.PrivateKey)
352 if !ok {
353 return nil, errors.New("ecdh: private key does not support Bytes")
354 }
355 if k.kem == dhKEMX25519 {
356 b := priv.Bytes()
357 b[0] &= 248
358 b[31] &= 127
359 b[31] |= 64
360 return b, nil
361 }
362 return priv.Bytes(), nil
363 }
364
365 func (k *dhKEMPrivateKey) PublicKey() PublicKey {
366 return &dhKEMPublicKey{
367 kem: k.kem,
368 pub: k.priv.PublicKey(),
369 }
370 }
371
372 func (k *dhKEMPrivateKey) decap(encPubEph []byte) ([]byte, error) {
373 pubEph, err := k.priv.Curve().NewPublicKey(encPubEph)
374 if err != nil {
375 return nil, err
376 }
377 dhVal, err := k.priv.ECDH(pubEph)
378 if err != nil {
379 return nil, err
380 }
381 kemContext := append(slices.Clip(encPubEph), k.priv.PublicKey().Bytes()...)
382 return k.kem.extractAndExpand(dhVal, kemContext)
383 }
384
View as plain text