Files

151 lines
5.4 KiB
Go

package crypto
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"errors"
"testing"
)
// encryptAESGCM is a test helper that produces a GCM ciphertext with caller-supplied AAD.
func encryptAESGCM(t *testing.T, key, nonce, plaintext, aad []byte) []byte {
t.Helper()
block, err := aes.NewCipher(key)
if err != nil {
t.Fatalf("aes.NewCipher: %v", err)
}
aead, err := cipher.NewGCM(block)
if err != nil {
t.Fatalf("cipher.NewGCM: %v", err)
}
return aead.Seal(nil, nonce, plaintext, aad)
}
// testPlaintextPayloadLen: plaintext size before AES-GCM seal inside meta.local_encryptor_data.
// 96 (blob) - 12 (nonce) - 16 (tag) = 68 bytes.
const testPlaintextPayloadLen = yandexIntKeyBlobLen - gcmNonceSize - 16
func buildLocalEncryptorBlob(t *testing.T, masterKey, dataKey []byte) []byte {
t.Helper()
nonce := bytes.Repeat([]byte{0xAB}, gcmNonceSize)
plaintext := append([]byte{}, yandexSignature...)
plaintext = append(plaintext, dataKey...)
plaintext = append(plaintext, make([]byte, testPlaintextPayloadLen-len(plaintext))...)
ciphertext := encryptAESGCM(t, masterKey, nonce, plaintext, nil)
if len(ciphertext) != yandexIntKeyBlobLen-gcmNonceSize {
t.Fatalf("unexpected ciphertext len: got %d want %d", len(ciphertext), yandexIntKeyBlobLen-gcmNonceSize)
}
blob := []byte{0x01, 0x02, 0x03, 0x04} // arbitrary protobuf preamble
blob = append(blob, localEncryptorPrefix...)
blob = append(blob, nonce...)
blob = append(blob, ciphertext...)
blob = append(blob, 0xFF, 0xFE) // trailing junk should be ignored
return blob
}
func TestDecryptYandexIntermediateKey_RoundTrip(t *testing.T) {
masterKey := bytes.Repeat([]byte{0x11}, 32)
dataKey := bytes.Repeat([]byte{0x22}, yandexDataKeyLen)
blob := buildLocalEncryptorBlob(t, masterKey, dataKey)
got, err := DecryptYandexIntermediateKey(masterKey, blob)
if err != nil {
t.Fatalf("DecryptYandexIntermediateKey: %v", err)
}
if !bytes.Equal(got, dataKey) {
t.Errorf("key mismatch: got %x want %x", got, dataKey)
}
}
func TestDecryptYandexIntermediateKey_MissingMarker(t *testing.T) {
_, err := DecryptYandexIntermediateKey(bytes.Repeat([]byte{0x11}, 32), []byte("no marker here"))
if !errors.Is(err, errYandexMarkerNotFound) {
t.Fatalf("expected errYandexMarkerNotFound, got %v", err)
}
}
func TestDecryptYandexIntermediateKey_Truncated(t *testing.T) {
blob := append([]byte{0x00, 0x00}, localEncryptorPrefix...)
blob = append(blob, bytes.Repeat([]byte{0x55}, yandexIntKeyBlobLen-1)...)
_, err := DecryptYandexIntermediateKey(bytes.Repeat([]byte{0x11}, 32), blob)
if !errors.Is(err, errYandexBlobShort) {
t.Fatalf("expected errYandexBlobShort, got %v", err)
}
}
func TestDecryptYandexIntermediateKey_BadSignature(t *testing.T) {
masterKey := bytes.Repeat([]byte{0x11}, 32)
nonce := bytes.Repeat([]byte{0xAB}, gcmNonceSize)
plaintext := append([]byte{0xDE, 0xAD, 0xBE, 0xEF}, bytes.Repeat([]byte{0x22}, yandexDataKeyLen)...)
plaintext = append(plaintext, make([]byte, testPlaintextPayloadLen-len(plaintext))...)
ciphertext := encryptAESGCM(t, masterKey, nonce, plaintext, nil)
blob := append([]byte{}, localEncryptorPrefix...)
blob = append(blob, nonce...)
blob = append(blob, ciphertext...)
_, err := DecryptYandexIntermediateKey(masterKey, blob)
if !errors.Is(err, errYandexBadSignature) {
t.Fatalf("expected errYandexBadSignature, got %v", err)
}
}
// TestDecryptYandexIntermediateKey_TrailingDataIgnored verifies that trailing bytes past
// signature+32 are discarded.
func TestDecryptYandexIntermediateKey_TrailingDataIgnored(t *testing.T) {
masterKey := bytes.Repeat([]byte{0x11}, 32)
nonce := bytes.Repeat([]byte{0xAB}, gcmNonceSize)
plaintext := append([]byte{}, yandexSignature...)
plaintext = append(plaintext, bytes.Repeat([]byte{0x22}, 16)...)
plaintext = append(plaintext, make([]byte, testPlaintextPayloadLen-len(plaintext))...)
ciphertext := encryptAESGCM(t, masterKey, nonce, plaintext, nil)
blob := append([]byte{}, localEncryptorPrefix...)
blob = append(blob, nonce...)
blob = append(blob, ciphertext...)
got, err := DecryptYandexIntermediateKey(masterKey, blob)
if err != nil {
t.Fatalf("DecryptYandexIntermediateKey: %v", err)
}
want := bytes.Repeat([]byte{0x22}, 16)
want = append(want, make([]byte, 16)...)
if !bytes.Equal(got, want) {
t.Errorf("key mismatch: got %x want %x", got, want)
}
}
func TestAESGCMDecryptBlob_RoundTrip(t *testing.T) {
key := bytes.Repeat([]byte{0x55}, 32)
nonce := bytes.Repeat([]byte{0x66}, gcmNonceSize)
aad := []byte("row-aad")
plaintext := []byte("row-plaintext")
blob := append([]byte{}, nonce...)
blob = append(blob, encryptAESGCM(t, key, nonce, plaintext, aad)...)
got, err := AESGCMDecryptBlob(key, blob, aad)
if err != nil {
t.Fatalf("AESGCMDecryptBlob: %v", err)
}
if !bytes.Equal(got, plaintext) {
t.Errorf("plaintext mismatch: got %q want %q", got, plaintext)
}
}
func TestAESGCMDecryptBlob_BadAAD(t *testing.T) {
key := bytes.Repeat([]byte{0x55}, 32)
nonce := bytes.Repeat([]byte{0x66}, gcmNonceSize)
blob := append([]byte{}, nonce...)
blob = append(blob, encryptAESGCM(t, key, nonce, []byte("x"), []byte("aad-A"))...)
if _, err := AESGCMDecryptBlob(key, blob, []byte("aad-B")); err == nil {
t.Fatal("expected authentication failure with mismatched AAD")
}
}
func TestAESGCMDecryptBlob_TooShort(t *testing.T) {
_, err := AESGCMDecryptBlob(bytes.Repeat([]byte{0x55}, 32), []byte{0x01, 0x02}, nil)
if !errors.Is(err, errShortCiphertext) {
t.Fatalf("expected errShortCiphertext, got %v", err)
}
}