Source file src/crypto/hpke/hpke_test.go

     1  // Copyright 2024 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package hpke
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/ecdh"
    10  	"crypto/mlkem"
    11  	"crypto/mlkem/mlkemtest"
    12  	"crypto/sha3"
    13  	"encoding/hex"
    14  	"encoding/json"
    15  	"fmt"
    16  	"io"
    17  	"os"
    18  	"testing"
    19  )
    20  
    21  func Example() {
    22  	// In this example, we use MLKEM768-X25519 as the KEM, HKDF-SHA256 as the
    23  	// KDF, and AES-256-GCM as the AEAD to encrypt a single message from a
    24  	// sender to a recipient using the one-shot API.
    25  
    26  	kem, kdf, aead := MLKEM768X25519(), HKDFSHA256(), AES256GCM()
    27  
    28  	// Recipient side
    29  	var (
    30  		recipientPrivateKey PrivateKey
    31  		publicKeyBytes      []byte
    32  	)
    33  	{
    34  		k, err := kem.GenerateKey()
    35  		if err != nil {
    36  			panic(err)
    37  		}
    38  		recipientPrivateKey = k
    39  		publicKeyBytes = k.PublicKey().Bytes()
    40  	}
    41  
    42  	// Sender side
    43  	var ciphertext []byte
    44  	{
    45  		publicKey, err := kem.NewPublicKey(publicKeyBytes)
    46  		if err != nil {
    47  			panic(err)
    48  		}
    49  
    50  		message := []byte("|-()-|")
    51  		ct, err := Seal(publicKey, kdf, aead, []byte("example"), message)
    52  		if err != nil {
    53  			panic(err)
    54  		}
    55  
    56  		ciphertext = ct
    57  	}
    58  
    59  	// Recipient side
    60  	{
    61  		plaintext, err := Open(recipientPrivateKey, kdf, aead, []byte("example"), ciphertext)
    62  		if err != nil {
    63  			panic(err)
    64  		}
    65  		fmt.Printf("Decrypted message: %s\n", plaintext)
    66  	}
    67  
    68  	// Output:
    69  	// Decrypted message: |-()-|
    70  }
    71  
    72  func TestRoundTrip(t *testing.T) {
    73  	kems := []KEM{
    74  		DHKEM(ecdh.P256()),
    75  		DHKEM(ecdh.P384()),
    76  		DHKEM(ecdh.P521()),
    77  		DHKEM(ecdh.X25519()),
    78  		MLKEM768(),
    79  		MLKEM1024(),
    80  		MLKEM768P256(),
    81  		MLKEM1024P384(),
    82  		MLKEM768X25519(),
    83  	}
    84  	kdfs := []KDF{
    85  		HKDFSHA256(),
    86  		HKDFSHA384(),
    87  		HKDFSHA512(),
    88  		SHAKE128(),
    89  		SHAKE256(),
    90  	}
    91  	aeads := []AEAD{
    92  		AES128GCM(),
    93  		AES256GCM(),
    94  		ChaCha20Poly1305(),
    95  	}
    96  
    97  	for _, kem := range kems {
    98  		t.Run(fmt.Sprintf("KEM_%04x", kem.ID()), func(t *testing.T) {
    99  			k, err := kem.GenerateKey()
   100  			if err != nil {
   101  				t.Fatal(err)
   102  			}
   103  			kb, err := k.Bytes()
   104  			if err != nil {
   105  				t.Fatal(err)
   106  			}
   107  			kk, err := kem.NewPrivateKey(kb)
   108  			if err != nil {
   109  				t.Fatal(err)
   110  			}
   111  			if got, err := kk.Bytes(); err != nil {
   112  				t.Fatal(err)
   113  			} else if !bytes.Equal(got, kb) {
   114  				t.Errorf("re-serialized key mismatch: got %x, want %x", got, kb)
   115  			}
   116  			pk, err := kem.NewPublicKey(k.PublicKey().Bytes())
   117  			if err != nil {
   118  				t.Fatal(err)
   119  			}
   120  			if got := pk.Bytes(); !bytes.Equal(got, k.PublicKey().Bytes()) {
   121  				t.Errorf("re-serialized public key mismatch: got %x, want %x", got, k.PublicKey().Bytes())
   122  			}
   123  
   124  			for _, kdf := range kdfs {
   125  				t.Run(fmt.Sprintf("KDF_%04x", kdf.ID()), func(t *testing.T) {
   126  					for _, aead := range aeads {
   127  						t.Run(fmt.Sprintf("AEAD_%04x", aead.ID()), func(t *testing.T) {
   128  							c, err := Seal(pk, kdf, aead, []byte("info"), []byte("plaintext"))
   129  							if err != nil {
   130  								t.Fatal(err)
   131  							}
   132  							p, err := Open(kk, kdf, aead, []byte("info"), c)
   133  							if err != nil {
   134  								t.Fatal(err)
   135  							}
   136  							if !bytes.Equal(p, []byte("plaintext")) {
   137  								t.Errorf("unexpected plaintext: got %x, want %x", p, []byte("plaintext"))
   138  							}
   139  
   140  							p, err = Open(kk, kdf, aead, []byte("wrong"), c)
   141  							if err == nil {
   142  								t.Errorf("expected error when opening with wrong info, got plaintext %x", p)
   143  							}
   144  							c[len(c)-1] ^= 0xFF
   145  							p, err = Open(kk, kdf, aead, []byte("info"), c)
   146  							if err == nil {
   147  								t.Errorf("expected error when opening with corrupted ciphertext, got plaintext %x", p)
   148  							}
   149  
   150  							c, err = Seal(k.PublicKey(), kdf, aead, nil, nil)
   151  							if err != nil {
   152  								t.Fatal(err)
   153  							}
   154  							p, err = Open(k, kdf, aead, nil, c)
   155  							if err != nil {
   156  								t.Fatal(err)
   157  							}
   158  							if len(p) != 0 {
   159  								t.Errorf("unexpected plaintext: got %x, want empty", p)
   160  							}
   161  
   162  							// Test that Seal and Open don't modify the excess capacity of input
   163  							// slices. This is a regression test for a bug where decap would
   164  							// append to the enc slice, corrupting the ciphertext if they shared
   165  							// a backing array.
   166  							padSlice := func(b []byte) []byte {
   167  								s := make([]byte, len(b), len(b)+2000)
   168  								copy(s, b)
   169  								for i := len(b); i < cap(s); i++ {
   170  									s[:cap(s)][i] = 0xAA
   171  								}
   172  								return s[:len(b)]
   173  							}
   174  							checkSlice := func(name string, s []byte) {
   175  								for i := len(s); i < cap(s); i++ {
   176  									if s[:cap(s)][i] != 0xAA {
   177  										t.Errorf("%s: modified byte at index %d beyond slice length", name, i)
   178  										return
   179  									}
   180  								}
   181  							}
   182  
   183  							infoS := padSlice([]byte("info"))
   184  							plaintextS := padSlice([]byte("plaintext"))
   185  							c, err = Seal(pk, kdf, aead, infoS, plaintextS)
   186  							if err != nil {
   187  								t.Fatal(err)
   188  							}
   189  							checkSlice("Seal info", infoS)
   190  							checkSlice("Seal plaintext", plaintextS)
   191  
   192  							infoO := padSlice([]byte("info"))
   193  							ciphertextO := padSlice(c)
   194  							p, err = Open(kk, kdf, aead, infoO, ciphertextO)
   195  							if err != nil {
   196  								t.Fatalf("Open with large capacity slices failed: %v", err)
   197  							}
   198  							if !bytes.Equal(p, []byte("plaintext")) {
   199  								t.Errorf("unexpected plaintext: got %x, want %x", p, []byte("plaintext"))
   200  							}
   201  							checkSlice("Open info", infoO)
   202  							checkSlice("Open ciphertext", ciphertextO)
   203  
   204  							// Also test the Sender.Seal and Recipient.Open methods.
   205  							infoSender := padSlice([]byte("info"))
   206  							enc, sender, err := NewSender(pk, kdf, aead, infoSender)
   207  							if err != nil {
   208  								t.Fatal(err)
   209  							}
   210  							checkSlice("NewSender info", infoSender)
   211  
   212  							aadSeal := padSlice([]byte("aad"))
   213  							plaintextSeal := padSlice([]byte("plaintext"))
   214  							ct, err := sender.Seal(aadSeal, plaintextSeal)
   215  							if err != nil {
   216  								t.Fatal(err)
   217  							}
   218  							checkSlice("Sender.Seal aad", aadSeal)
   219  							checkSlice("Sender.Seal plaintext", plaintextSeal)
   220  
   221  							infoRecipient := padSlice([]byte("info"))
   222  							encPadded := padSlice(enc)
   223  							recipient, err := NewRecipient(encPadded, kk, kdf, aead, infoRecipient)
   224  							if err != nil {
   225  								t.Fatal(err)
   226  							}
   227  							checkSlice("NewRecipient info", infoRecipient)
   228  							checkSlice("NewRecipient enc", encPadded)
   229  
   230  							aadOpen := padSlice([]byte("aad"))
   231  							ctPadded := padSlice(ct)
   232  							p, err = recipient.Open(aadOpen, ctPadded)
   233  							if err != nil {
   234  								t.Fatalf("Recipient.Open failed: %v", err)
   235  							}
   236  							if !bytes.Equal(p, []byte("plaintext")) {
   237  								t.Errorf("unexpected plaintext: got %x, want %x", p, []byte("plaintext"))
   238  							}
   239  							checkSlice("Recipient.Open aad", aadOpen)
   240  							checkSlice("Recipient.Open ciphertext", ctPadded)
   241  						})
   242  					}
   243  				})
   244  			}
   245  		})
   246  	}
   247  }
   248  
   249  func mustDecodeHex(t *testing.T, in string) []byte {
   250  	t.Helper()
   251  	b, err := hex.DecodeString(in)
   252  	if err != nil {
   253  		t.Fatal(err)
   254  	}
   255  	return b
   256  }
   257  
   258  func TestVectors(t *testing.T) {
   259  	t.Run("rfc9180", func(t *testing.T) {
   260  		testVectors(t, "rfc9180")
   261  	})
   262  	t.Run("hpke-pq", func(t *testing.T) {
   263  		testVectors(t, "hpke-pq")
   264  	})
   265  }
   266  
   267  func testVectors(t *testing.T, name string) {
   268  	vectorsJSON, err := os.ReadFile("testdata/" + name + ".json")
   269  	if err != nil {
   270  		t.Fatal(err)
   271  	}
   272  	var vectors []struct {
   273  		Mode        uint16 `json:"mode"`
   274  		KEM         uint16 `json:"kem_id"`
   275  		KDF         uint16 `json:"kdf_id"`
   276  		AEAD        uint16 `json:"aead_id"`
   277  		Info        string `json:"info"`
   278  		IkmE        string `json:"ikmE"`
   279  		IkmR        string `json:"ikmR"`
   280  		SkRm        string `json:"skRm"`
   281  		PkRm        string `json:"pkRm"`
   282  		Enc         string `json:"enc"`
   283  		Encryptions []struct {
   284  			Aad   string `json:"aad"`
   285  			Ct    string `json:"ct"`
   286  			Nonce string `json:"nonce"`
   287  			Pt    string `json:"pt"`
   288  		} `json:"encryptions"`
   289  		Exports []struct {
   290  			Context string `json:"exporter_context"`
   291  			L       int    `json:"L"`
   292  			Value   string `json:"exported_value"`
   293  		} `json:"exports"`
   294  
   295  		// Instead of checking in a very large rfc9180.json, we computed
   296  		// alternative accumulated values.
   297  		AccEncryptions string `json:"encryptions_accumulated"`
   298  		AccExports     string `json:"exports_accumulated"`
   299  	}
   300  	if err := json.Unmarshal(vectorsJSON, &vectors); err != nil {
   301  		t.Fatal(err)
   302  	}
   303  
   304  	for _, vector := range vectors {
   305  		name := fmt.Sprintf("mode %04x kem %04x kdf %04x aead %04x",
   306  			vector.Mode, vector.KEM, vector.KDF, vector.AEAD)
   307  		t.Run(name, func(t *testing.T) {
   308  			if vector.Mode != 0 {
   309  				t.Skip("only mode 0 (base) is supported")
   310  			}
   311  			if vector.KEM == 0x0021 {
   312  				t.Skip("KEM 0x0021 (DHKEM(X448)) not supported")
   313  			}
   314  			if vector.KEM == 0x0040 {
   315  				t.Skip("KEM 0x0040 (ML-KEM-512) not supported")
   316  			}
   317  			if vector.KDF == 0x0012 || vector.KDF == 0x0013 {
   318  				t.Skipf("TurboSHAKE KDF not supported")
   319  			}
   320  
   321  			kdf, err := NewKDF(vector.KDF)
   322  			if err != nil {
   323  				t.Fatal(err)
   324  			}
   325  			if kdf.ID() != vector.KDF {
   326  				t.Errorf("unexpected KDF ID: got %04x, want %04x", kdf.ID(), vector.KDF)
   327  			}
   328  
   329  			aead, err := NewAEAD(vector.AEAD)
   330  			if err != nil {
   331  				t.Fatal(err)
   332  			}
   333  			if aead.ID() != vector.AEAD {
   334  				t.Errorf("unexpected AEAD ID: got %04x, want %04x", aead.ID(), vector.AEAD)
   335  			}
   336  
   337  			kem, err := NewKEM(vector.KEM)
   338  			if err != nil {
   339  				t.Fatal(err)
   340  			}
   341  			if kem.ID() != vector.KEM {
   342  				t.Errorf("unexpected KEM ID: got %04x, want %04x", kem.ID(), vector.KEM)
   343  			}
   344  
   345  			pubKeyBytes := mustDecodeHex(t, vector.PkRm)
   346  			kemSender, err := kem.NewPublicKey(pubKeyBytes)
   347  			if err != nil {
   348  				t.Fatal(err)
   349  			}
   350  			if kemSender.KEM() != kem {
   351  				t.Errorf("unexpected KEM from sender: got %04x, want %04x", kemSender.KEM().ID(), kem.ID())
   352  			}
   353  			if !bytes.Equal(kemSender.Bytes(), pubKeyBytes) {
   354  				t.Errorf("unexpected KEM bytes: got %x, want %x", kemSender.Bytes(), pubKeyBytes)
   355  			}
   356  
   357  			ikmE := mustDecodeHex(t, vector.IkmE)
   358  			setupDerandomizedEncap(t, ikmE, kemSender)
   359  
   360  			info := mustDecodeHex(t, vector.Info)
   361  			encap, sender, err := NewSender(kemSender, kdf, aead, info)
   362  			if err != nil {
   363  				t.Fatal(err)
   364  			}
   365  			if len(encap) != kem.encSize() {
   366  				t.Errorf("unexpected encapsulated key size: got %d, want %d", len(encap), kem.encSize())
   367  			}
   368  
   369  			expectedEncap := mustDecodeHex(t, vector.Enc)
   370  			if !bytes.Equal(encap, expectedEncap) {
   371  				t.Errorf("unexpected encapsulated key, got: %x, want %x", encap, expectedEncap)
   372  			}
   373  
   374  			privKeyBytes := mustDecodeHex(t, vector.SkRm)
   375  			kemRecipient, err := kem.NewPrivateKey(privKeyBytes)
   376  			if err != nil {
   377  				t.Fatal(err)
   378  			}
   379  			if kemRecipient.KEM() != kem {
   380  				t.Errorf("unexpected KEM from recipient: got %04x, want %04x", kemRecipient.KEM().ID(), kem.ID())
   381  			}
   382  			kemRecipientBytes, err := kemRecipient.Bytes()
   383  			if err != nil {
   384  				t.Fatal(err)
   385  			}
   386  			// X25519 serialized keys must be clamped, so the bytes might not match.
   387  			if !bytes.Equal(kemRecipientBytes, privKeyBytes) && vector.KEM != DHKEM(ecdh.X25519()).ID() {
   388  				t.Errorf("unexpected KEM bytes: got %x, want %x", kemRecipientBytes, privKeyBytes)
   389  			}
   390  			if vector.KEM == DHKEM(ecdh.X25519()).ID() {
   391  				kem2, err := kem.NewPrivateKey(kemRecipientBytes)
   392  				if err != nil {
   393  					t.Fatal(err)
   394  				}
   395  				kemRecipientBytes2, err := kem2.Bytes()
   396  				if err != nil {
   397  					t.Fatal(err)
   398  				}
   399  				if !bytes.Equal(kemRecipientBytes2, kemRecipientBytes) {
   400  					t.Errorf("X25519 re-serialized key differs: got %x, want %x", kemRecipientBytes2, kemRecipientBytes)
   401  				}
   402  				if !bytes.Equal(kem2.PublicKey().Bytes(), pubKeyBytes) {
   403  					t.Errorf("X25519 re-derived public key differs: got %x, want %x", kem2.PublicKey().Bytes(), pubKeyBytes)
   404  				}
   405  			}
   406  			if !bytes.Equal(kemRecipient.PublicKey().Bytes(), pubKeyBytes) {
   407  				t.Errorf("unexpected KEM sender bytes: got %x, want %x", kemRecipient.PublicKey().Bytes(), pubKeyBytes)
   408  			}
   409  
   410  			ikm := mustDecodeHex(t, vector.IkmR)
   411  			derivRecipient, err := kem.DeriveKeyPair(ikm)
   412  			if err != nil {
   413  				t.Fatal(err)
   414  			}
   415  			derivRecipientBytes, err := derivRecipient.Bytes()
   416  			if err != nil {
   417  				t.Fatal(err)
   418  			}
   419  			if !bytes.Equal(derivRecipientBytes, privKeyBytes) && vector.KEM != DHKEM(ecdh.X25519()).ID() {
   420  				t.Errorf("unexpected KEM bytes from seed: got %x, want %x", derivRecipientBytes, privKeyBytes)
   421  			}
   422  			if !bytes.Equal(derivRecipient.PublicKey().Bytes(), pubKeyBytes) {
   423  				t.Errorf("unexpected KEM sender bytes from seed: got %x, want %x", derivRecipient.PublicKey().Bytes(), pubKeyBytes)
   424  			}
   425  
   426  			recipient, err := NewRecipient(encap, kemRecipient, kdf, aead, info)
   427  			if err != nil {
   428  				t.Fatal(err)
   429  			}
   430  
   431  			if aead != ExportOnly() && len(vector.AccEncryptions) != 0 {
   432  				source, sink := sha3.NewSHAKE128(), sha3.NewSHAKE128()
   433  				for range 1000 {
   434  					aad, plaintext := drawRandomInput(t, source), drawRandomInput(t, source)
   435  					ciphertext, err := sender.Seal(aad, plaintext)
   436  					if err != nil {
   437  						t.Fatal(err)
   438  					}
   439  					sink.Write(ciphertext)
   440  					got, err := recipient.Open(aad, ciphertext)
   441  					if err != nil {
   442  						t.Fatal(err)
   443  					}
   444  					if !bytes.Equal(got, plaintext) {
   445  						t.Errorf("unexpected plaintext: got %x want %x", got, plaintext)
   446  					}
   447  				}
   448  				encryptions := make([]byte, 16)
   449  				sink.Read(encryptions)
   450  				expectedEncryptions := mustDecodeHex(t, vector.AccEncryptions)
   451  				if !bytes.Equal(encryptions, expectedEncryptions) {
   452  					t.Errorf("unexpected accumulated encryptions, got: %x, want %x", encryptions, expectedEncryptions)
   453  				}
   454  			} else if aead != ExportOnly() {
   455  				for _, enc := range vector.Encryptions {
   456  					aad := mustDecodeHex(t, enc.Aad)
   457  					plaintext := mustDecodeHex(t, enc.Pt)
   458  					expectedCiphertext := mustDecodeHex(t, enc.Ct)
   459  
   460  					ciphertext, err := sender.Seal(aad, plaintext)
   461  					if err != nil {
   462  						t.Fatal(err)
   463  					}
   464  					if !bytes.Equal(ciphertext, expectedCiphertext) {
   465  						t.Errorf("unexpected ciphertext, got: %x, want %x", ciphertext, expectedCiphertext)
   466  					}
   467  
   468  					got, err := recipient.Open(aad, ciphertext)
   469  					if err != nil {
   470  						t.Fatal(err)
   471  					}
   472  					if !bytes.Equal(got, plaintext) {
   473  						t.Errorf("unexpected plaintext: got %x want %x", got, plaintext)
   474  					}
   475  				}
   476  			} else {
   477  				if _, err := sender.Seal(nil, nil); err == nil {
   478  					t.Error("expected error from Seal with export-only AEAD")
   479  				}
   480  				if _, err := recipient.Open(nil, nil); err == nil {
   481  					t.Error("expected error from Open with export-only AEAD")
   482  				}
   483  			}
   484  
   485  			if len(vector.AccExports) != 0 {
   486  				source, sink := sha3.NewSHAKE128(), sha3.NewSHAKE128()
   487  				for l := range 1000 {
   488  					context := string(drawRandomInput(t, source))
   489  					value, err := sender.Export(context, l)
   490  					if err != nil {
   491  						t.Fatal(err)
   492  					}
   493  					sink.Write(value)
   494  					got, err := recipient.Export(context, l)
   495  					if err != nil {
   496  						t.Fatal(err)
   497  					}
   498  					if !bytes.Equal(got, value) {
   499  						t.Errorf("recipient: unexpected exported secret: got %x want %x", got, value)
   500  					}
   501  				}
   502  				exports := make([]byte, 16)
   503  				sink.Read(exports)
   504  				expectedExports := mustDecodeHex(t, vector.AccExports)
   505  				if !bytes.Equal(exports, expectedExports) {
   506  					t.Errorf("unexpected accumulated exports, got: %x, want %x", exports, expectedExports)
   507  				}
   508  			} else {
   509  				for _, exp := range vector.Exports {
   510  					context := string(mustDecodeHex(t, exp.Context))
   511  					expectedValue := mustDecodeHex(t, exp.Value)
   512  
   513  					value, err := sender.Export(context, exp.L)
   514  					if err != nil {
   515  						t.Fatal(err)
   516  					}
   517  					if !bytes.Equal(value, expectedValue) {
   518  						t.Errorf("unexpected exported value, got: %x, want %x", value, expectedValue)
   519  					}
   520  
   521  					got, err := recipient.Export(context, exp.L)
   522  					if err != nil {
   523  						t.Fatal(err)
   524  					}
   525  					if !bytes.Equal(got, value) {
   526  						t.Errorf("recipient: unexpected exported secret: got %x want %x", got, value)
   527  					}
   528  				}
   529  			}
   530  		})
   531  	}
   532  }
   533  
   534  func drawRandomInput(t *testing.T, r io.Reader) []byte {
   535  	t.Helper()
   536  	l := make([]byte, 1)
   537  	if _, err := r.Read(l); err != nil {
   538  		t.Fatal(err)
   539  	}
   540  	n := int(l[0])
   541  	b := make([]byte, n)
   542  	if _, err := r.Read(b); err != nil {
   543  		t.Fatal(err)
   544  	}
   545  	return b
   546  }
   547  
   548  func setupDerandomizedEncap(t *testing.T, randBytes []byte, pk PublicKey) {
   549  	t.Cleanup(func() {
   550  		testingOnlyGenerateKey = nil
   551  		testingOnlyEncapsulate = nil
   552  	})
   553  	switch pk.KEM() {
   554  	case DHKEM(ecdh.P256()), DHKEM(ecdh.P384()), DHKEM(ecdh.P521()), DHKEM(ecdh.X25519()):
   555  		r, err := pk.KEM().DeriveKeyPair(randBytes)
   556  		if err != nil {
   557  			t.Fatal(err)
   558  		}
   559  		testingOnlyGenerateKey = func() *ecdh.PrivateKey {
   560  			return r.(*dhKEMPrivateKey).priv.(*ecdh.PrivateKey)
   561  		}
   562  	case mlkem768:
   563  		pq := pk.(*mlkemPublicKey).pq.(*mlkem.EncapsulationKey768)
   564  		testingOnlyEncapsulate = func() ([]byte, []byte) {
   565  			ss, ct, err := mlkemtest.Encapsulate768(pq, randBytes)
   566  			if err != nil {
   567  				t.Fatal(err)
   568  			}
   569  			return ss, ct
   570  		}
   571  	case mlkem1024:
   572  		pq := pk.(*mlkemPublicKey).pq.(*mlkem.EncapsulationKey1024)
   573  		testingOnlyEncapsulate = func() ([]byte, []byte) {
   574  			ss, ct, err := mlkemtest.Encapsulate1024(pq, randBytes)
   575  			if err != nil {
   576  				t.Fatal(err)
   577  			}
   578  			return ss, ct
   579  		}
   580  	case mlkem768X25519:
   581  		pqRand, tRand := randBytes[:32], randBytes[32:]
   582  		pq := pk.(*hybridPublicKey).pq.(*mlkem.EncapsulationKey768)
   583  		k, err := ecdh.X25519().NewPrivateKey(tRand)
   584  		if err != nil {
   585  			t.Fatal(err)
   586  		}
   587  		testingOnlyGenerateKey = func() *ecdh.PrivateKey {
   588  			return k
   589  		}
   590  		testingOnlyEncapsulate = func() ([]byte, []byte) {
   591  			ss, ct, err := mlkemtest.Encapsulate768(pq, pqRand)
   592  			if err != nil {
   593  				t.Fatal(err)
   594  			}
   595  			return ss, ct
   596  		}
   597  	case mlkem768P256:
   598  		// The rest of randBytes are the following candidates for rejection
   599  		// sampling, but they are never reached.
   600  		pqRand, tRand := randBytes[:32], randBytes[32:64]
   601  		pq := pk.(*hybridPublicKey).pq.(*mlkem.EncapsulationKey768)
   602  		k, err := ecdh.P256().NewPrivateKey(tRand)
   603  		if err != nil {
   604  			t.Fatal(err)
   605  		}
   606  		testingOnlyGenerateKey = func() *ecdh.PrivateKey {
   607  			return k
   608  		}
   609  		testingOnlyEncapsulate = func() ([]byte, []byte) {
   610  			ss, ct, err := mlkemtest.Encapsulate768(pq, pqRand)
   611  			if err != nil {
   612  				t.Fatal(err)
   613  			}
   614  			return ss, ct
   615  		}
   616  	case mlkem1024P384:
   617  		pqRand, tRand := randBytes[:32], randBytes[32:]
   618  		pq := pk.(*hybridPublicKey).pq.(*mlkem.EncapsulationKey1024)
   619  		k, err := ecdh.P384().NewPrivateKey(tRand)
   620  		if err != nil {
   621  			t.Fatal(err)
   622  		}
   623  		testingOnlyGenerateKey = func() *ecdh.PrivateKey {
   624  			return k
   625  		}
   626  		testingOnlyEncapsulate = func() ([]byte, []byte) {
   627  			ss, ct, err := mlkemtest.Encapsulate1024(pq, pqRand)
   628  			if err != nil {
   629  				t.Fatal(err)
   630  			}
   631  			return ss, ct
   632  		}
   633  	default:
   634  		t.Fatalf("unsupported KEM %04x", pk.KEM().ID())
   635  	}
   636  }
   637  
   638  func TestSingletons(t *testing.T) {
   639  	if HKDFSHA256() != HKDFSHA256() {
   640  		t.Error("HKDFSHA256() != HKDFSHA256()")
   641  	}
   642  	if HKDFSHA384() != HKDFSHA384() {
   643  		t.Error("HKDFSHA384() != HKDFSHA384()")
   644  	}
   645  	if HKDFSHA512() != HKDFSHA512() {
   646  		t.Error("HKDFSHA512() != HKDFSHA512()")
   647  	}
   648  	if AES128GCM() != AES128GCM() {
   649  		t.Error("AES128GCM() != AES128GCM()")
   650  	}
   651  	if AES256GCM() != AES256GCM() {
   652  		t.Error("AES256GCM() != AES256GCM()")
   653  	}
   654  	if ChaCha20Poly1305() != ChaCha20Poly1305() {
   655  		t.Error("ChaCha20Poly1305() != ChaCha20Poly1305()")
   656  	}
   657  	if ExportOnly() != ExportOnly() {
   658  		t.Error("ExportOnly() != ExportOnly()")
   659  	}
   660  	if DHKEM(ecdh.P256()) != DHKEM(ecdh.P256()) {
   661  		t.Error("DHKEM(P-256) != DHKEM(P-256)")
   662  	}
   663  	if DHKEM(ecdh.P384()) != DHKEM(ecdh.P384()) {
   664  		t.Error("DHKEM(P-384) != DHKEM(P-384)")
   665  	}
   666  	if DHKEM(ecdh.P521()) != DHKEM(ecdh.P521()) {
   667  		t.Error("DHKEM(P-521) != DHKEM(P-521)")
   668  	}
   669  	if DHKEM(ecdh.X25519()) != DHKEM(ecdh.X25519()) {
   670  		t.Error("DHKEM(X25519) != DHKEM(X25519)")
   671  	}
   672  	if MLKEM768() != MLKEM768() {
   673  		t.Error("MLKEM768() != MLKEM768()")
   674  	}
   675  	if MLKEM1024() != MLKEM1024() {
   676  		t.Error("MLKEM1024() != MLKEM1024()")
   677  	}
   678  	if MLKEM768X25519() != MLKEM768X25519() {
   679  		t.Error("MLKEM768X25519() != MLKEM768X25519()")
   680  	}
   681  	if MLKEM768P256() != MLKEM768P256() {
   682  		t.Error("MLKEM768P256() != MLKEM768P256()")
   683  	}
   684  	if MLKEM1024P384() != MLKEM1024P384() {
   685  		t.Error("MLKEM1024P384() != MLKEM1024P384()")
   686  	}
   687  }
   688  

View as plain text