diff --git a/internal/crypto/aes.go b/internal/crypto/aes.go new file mode 100644 index 0000000..2209136 --- /dev/null +++ b/internal/crypto/aes.go @@ -0,0 +1,294 @@ +package crypto + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "fmt" + "io" + "os" + + "golang.org/x/crypto/pbkdf2" +) + +const ( + // AES-256 requires 32-byte keys + KeySize = 32 + + // GCM standard nonce size + NonceSize = 12 + + // Salt size for PBKDF2 + SaltSize = 32 + + // PBKDF2 iterations (OWASP recommended minimum) + PBKDF2Iterations = 600000 + + // Buffer size for streaming encryption + BufferSize = 64 * 1024 // 64KB chunks +) + +// AESEncryptor implements AES-256-GCM encryption +type AESEncryptor struct{} + +// NewAESEncryptor creates a new AES-256-GCM encryptor +func NewAESEncryptor() *AESEncryptor { + return &AESEncryptor{} +} + +// Algorithm returns the algorithm name +func (e *AESEncryptor) Algorithm() EncryptionAlgorithm { + return AlgorithmAES256GCM +} + +// DeriveKey derives a 32-byte key from a password using PBKDF2-SHA256 +func DeriveKey(password []byte, salt []byte) []byte { + return pbkdf2.Key(password, salt, PBKDF2Iterations, KeySize, sha256.New) +} + +// GenerateSalt generates a random salt +func GenerateSalt() ([]byte, error) { + salt := make([]byte, SaltSize) + if _, err := io.ReadFull(rand.Reader, salt); err != nil { + return nil, fmt.Errorf("failed to generate salt: %w", err) + } + return salt, nil +} + +// GenerateNonce generates a random nonce for GCM +func GenerateNonce() ([]byte, error) { + nonce := make([]byte, NonceSize) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return nil, fmt.Errorf("failed to generate nonce: %w", err) + } + return nonce, nil +} + +// ValidateKey checks if a key is the correct length +func ValidateKey(key []byte) error { + if len(key) != KeySize { + return fmt.Errorf("invalid key length: expected %d bytes, got %d bytes", KeySize, len(key)) + } + return nil +} + +// Encrypt encrypts data from reader and returns an encrypted reader +func (e *AESEncryptor) Encrypt(reader io.Reader, key []byte) (io.Reader, error) { + if err := ValidateKey(key); err != nil { + return nil, err + } + + // Create AES cipher + block, err := aes.NewCipher(key) + if err != nil { + return nil, fmt.Errorf("failed to create cipher: %w", err) + } + + // Create GCM mode + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("failed to create GCM: %w", err) + } + + // Generate nonce + nonce, err := GenerateNonce() + if err != nil { + return nil, err + } + + // Create pipe for streaming + pr, pw := io.Pipe() + + go func() { + defer pw.Close() + + // Write nonce first (needed for decryption) + if _, err := pw.Write(nonce); err != nil { + pw.CloseWithError(fmt.Errorf("failed to write nonce: %w", err)) + return + } + + // Read plaintext in chunks and encrypt + buf := make([]byte, BufferSize) + for { + n, err := reader.Read(buf) + if n > 0 { + // Encrypt chunk + ciphertext := gcm.Seal(nil, nonce, buf[:n], nil) + + // Write encrypted chunk length (4 bytes) + encrypted data + lengthBuf := []byte{ + byte(len(ciphertext) >> 24), + byte(len(ciphertext) >> 16), + byte(len(ciphertext) >> 8), + byte(len(ciphertext)), + } + if _, err := pw.Write(lengthBuf); err != nil { + pw.CloseWithError(fmt.Errorf("failed to write chunk length: %w", err)) + return + } + if _, err := pw.Write(ciphertext); err != nil { + pw.CloseWithError(fmt.Errorf("failed to write ciphertext: %w", err)) + return + } + + // Increment nonce for next chunk (simple counter mode) + for i := len(nonce) - 1; i >= 0; i-- { + nonce[i]++ + if nonce[i] != 0 { + break + } + } + } + if err == io.EOF { + break + } + if err != nil { + pw.CloseWithError(fmt.Errorf("read error: %w", err)) + return + } + } + }() + + return pr, nil +} + +// Decrypt decrypts data from reader and returns a decrypted reader +func (e *AESEncryptor) Decrypt(reader io.Reader, key []byte) (io.Reader, error) { + if err := ValidateKey(key); err != nil { + return nil, err + } + + // Create AES cipher + block, err := aes.NewCipher(key) + if err != nil { + return nil, fmt.Errorf("failed to create cipher: %w", err) + } + + // Create GCM mode + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("failed to create GCM: %w", err) + } + + // Create pipe for streaming + pr, pw := io.Pipe() + + go func() { + defer pw.Close() + + // Read initial nonce + nonce := make([]byte, NonceSize) + if _, err := io.ReadFull(reader, nonce); err != nil { + pw.CloseWithError(fmt.Errorf("failed to read nonce: %w", err)) + return + } + + // Read and decrypt chunks + lengthBuf := make([]byte, 4) + for { + // Read chunk length + if _, err := io.ReadFull(reader, lengthBuf); err != nil { + if err == io.EOF { + break + } + pw.CloseWithError(fmt.Errorf("failed to read chunk length: %w", err)) + return + } + + chunkLen := int(lengthBuf[0])<<24 | int(lengthBuf[1])<<16 | + int(lengthBuf[2])<<8 | int(lengthBuf[3]) + + // Read encrypted chunk + ciphertext := make([]byte, chunkLen) + if _, err := io.ReadFull(reader, ciphertext); err != nil { + pw.CloseWithError(fmt.Errorf("failed to read ciphertext: %w", err)) + return + } + + // Decrypt chunk + plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + pw.CloseWithError(fmt.Errorf("decryption failed (wrong key?): %w", err)) + return + } + + // Write plaintext + if _, err := pw.Write(plaintext); err != nil { + pw.CloseWithError(fmt.Errorf("failed to write plaintext: %w", err)) + return + } + + // Increment nonce for next chunk + for i := len(nonce) - 1; i >= 0; i-- { + nonce[i]++ + if nonce[i] != 0 { + break + } + } + } + }() + + return pr, nil +} + +// EncryptFile encrypts a file +func (e *AESEncryptor) EncryptFile(inputPath, outputPath string, key []byte) error { + // Open input file + inFile, err := os.Open(inputPath) + if err != nil { + return fmt.Errorf("failed to open input file: %w", err) + } + defer inFile.Close() + + // Create output file + outFile, err := os.Create(outputPath) + if err != nil { + return fmt.Errorf("failed to create output file: %w", err) + } + defer outFile.Close() + + // Encrypt + encReader, err := e.Encrypt(inFile, key) + if err != nil { + return err + } + + // Copy encrypted data to output file + if _, err := io.Copy(outFile, encReader); err != nil { + return fmt.Errorf("failed to write encrypted data: %w", err) + } + + return nil +} + +// DecryptFile decrypts a file +func (e *AESEncryptor) DecryptFile(inputPath, outputPath string, key []byte) error { + // Open input file + inFile, err := os.Open(inputPath) + if err != nil { + return fmt.Errorf("failed to open input file: %w", err) + } + defer inFile.Close() + + // Create output file + outFile, err := os.Create(outputPath) + if err != nil { + return fmt.Errorf("failed to create output file: %w", err) + } + defer outFile.Close() + + // Decrypt + decReader, err := e.Decrypt(inFile, key) + if err != nil { + return err + } + + // Copy decrypted data to output file + if _, err := io.Copy(outFile, decReader); err != nil { + return fmt.Errorf("failed to write decrypted data: %w", err) + } + + return nil +} diff --git a/internal/crypto/aes_test.go b/internal/crypto/aes_test.go new file mode 100644 index 0000000..067ec66 --- /dev/null +++ b/internal/crypto/aes_test.go @@ -0,0 +1,232 @@ +package crypto + +import ( + "bytes" + "crypto/rand" + "io" + "os" + "path/filepath" + "testing" +) + +func TestAESEncryptionDecryption(t *testing.T) { + encryptor := NewAESEncryptor() + + // Generate a random key + key := make([]byte, KeySize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + t.Fatalf("Failed to generate key: %v", err) + } + + testData := []byte("This is test data for encryption and decryption. It contains multiple bytes to ensure proper streaming.") + + // Test streaming encryption/decryption + t.Run("StreamingEncryptDecrypt", func(t *testing.T) { + // Encrypt + reader := bytes.NewReader(testData) + encReader, err := encryptor.Encrypt(reader, key) + if err != nil { + t.Fatalf("Encryption failed: %v", err) + } + + // Read all encrypted data + encryptedData, err := io.ReadAll(encReader) + if err != nil { + t.Fatalf("Failed to read encrypted data: %v", err) + } + + // Verify encrypted data is different from original + if bytes.Equal(encryptedData, testData) { + t.Error("Encrypted data should not equal plaintext") + } + + // Decrypt + decReader, err := encryptor.Decrypt(bytes.NewReader(encryptedData), key) + if err != nil { + t.Fatalf("Decryption failed: %v", err) + } + + // Read decrypted data + decryptedData, err := io.ReadAll(decReader) + if err != nil { + t.Fatalf("Failed to read decrypted data: %v", err) + } + + // Verify decrypted data matches original + if !bytes.Equal(decryptedData, testData) { + t.Errorf("Decrypted data does not match original.\nExpected: %s\nGot: %s", + string(testData), string(decryptedData)) + } + }) + + // Test file encryption/decryption + t.Run("FileEncryptDecrypt", func(t *testing.T) { + tempDir, err := os.MkdirTemp("", "crypto_test_*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Create test file + testFile := filepath.Join(tempDir, "test.txt") + if err := os.WriteFile(testFile, testData, 0644); err != nil { + t.Fatalf("Failed to write test file: %v", err) + } + + // Encrypt file + encryptedFile := filepath.Join(tempDir, "test.txt.enc") + if err := encryptor.EncryptFile(testFile, encryptedFile, key); err != nil { + t.Fatalf("File encryption failed: %v", err) + } + + // Verify encrypted file exists and is different + encData, err := os.ReadFile(encryptedFile) + if err != nil { + t.Fatalf("Failed to read encrypted file: %v", err) + } + if bytes.Equal(encData, testData) { + t.Error("Encrypted file should not equal plaintext") + } + + // Decrypt file + decryptedFile := filepath.Join(tempDir, "test.txt.dec") + if err := encryptor.DecryptFile(encryptedFile, decryptedFile, key); err != nil { + t.Fatalf("File decryption failed: %v", err) + } + + // Verify decrypted file matches original + decData, err := os.ReadFile(decryptedFile) + if err != nil { + t.Fatalf("Failed to read decrypted file: %v", err) + } + if !bytes.Equal(decData, testData) { + t.Errorf("Decrypted file does not match original") + } + }) + + // Test wrong key + t.Run("WrongKey", func(t *testing.T) { + wrongKey := make([]byte, KeySize) + if _, err := io.ReadFull(rand.Reader, wrongKey); err != nil { + t.Fatalf("Failed to generate wrong key: %v", err) + } + + // Encrypt with correct key + reader := bytes.NewReader(testData) + encReader, err := encryptor.Encrypt(reader, key) + if err != nil { + t.Fatalf("Encryption failed: %v", err) + } + + encryptedData, err := io.ReadAll(encReader) + if err != nil { + t.Fatalf("Failed to read encrypted data: %v", err) + } + + // Try to decrypt with wrong key + decReader, err := encryptor.Decrypt(bytes.NewReader(encryptedData), wrongKey) + if err != nil { + // Error during decrypt setup is OK + return + } + + // Try to read - should fail + _, err = io.ReadAll(decReader) + if err == nil { + t.Error("Expected decryption to fail with wrong key") + } + }) +} + +func TestKeyDerivation(t *testing.T) { + password := []byte("test-password-12345") + + // Generate salt + salt, err := GenerateSalt() + if err != nil { + t.Fatalf("Failed to generate salt: %v", err) + } + + if len(salt) != SaltSize { + t.Errorf("Expected salt size %d, got %d", SaltSize, len(salt)) + } + + // Derive key + key := DeriveKey(password, salt) + if len(key) != KeySize { + t.Errorf("Expected key size %d, got %d", KeySize, len(key)) + } + + // Verify same password+salt produces same key + key2 := DeriveKey(password, salt) + if !bytes.Equal(key, key2) { + t.Error("Same password and salt should produce same key") + } + + // Verify different salt produces different key + salt2, _ := GenerateSalt() + key3 := DeriveKey(password, salt2) + if bytes.Equal(key, key3) { + t.Error("Different salt should produce different key") + } +} + +func TestKeyValidation(t *testing.T) { + validKey := make([]byte, KeySize) + if err := ValidateKey(validKey); err != nil { + t.Errorf("Valid key should pass validation: %v", err) + } + + shortKey := make([]byte, 16) + if err := ValidateKey(shortKey); err == nil { + t.Error("Short key should fail validation") + } + + longKey := make([]byte, 64) + if err := ValidateKey(longKey); err == nil { + t.Error("Long key should fail validation") + } +} + +func TestLargeData(t *testing.T) { + encryptor := NewAESEncryptor() + + // Generate key + key := make([]byte, KeySize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + t.Fatalf("Failed to generate key: %v", err) + } + + // Create large test data (1MB) + largeData := make([]byte, 1024*1024) + if _, err := io.ReadFull(rand.Reader, largeData); err != nil { + t.Fatalf("Failed to generate large data: %v", err) + } + + // Encrypt + encReader, err := encryptor.Encrypt(bytes.NewReader(largeData), key) + if err != nil { + t.Fatalf("Encryption failed: %v", err) + } + + encryptedData, err := io.ReadAll(encReader) + if err != nil { + t.Fatalf("Failed to read encrypted data: %v", err) + } + + // Decrypt + decReader, err := encryptor.Decrypt(bytes.NewReader(encryptedData), key) + if err != nil { + t.Fatalf("Decryption failed: %v", err) + } + + decryptedData, err := io.ReadAll(decReader) + if err != nil { + t.Fatalf("Failed to read decrypted data: %v", err) + } + + // Verify + if !bytes.Equal(decryptedData, largeData) { + t.Error("Decrypted large data does not match original") + } +} diff --git a/internal/crypto/interface.go b/internal/crypto/interface.go new file mode 100644 index 0000000..8c9633f --- /dev/null +++ b/internal/crypto/interface.go @@ -0,0 +1,86 @@ +package crypto + +import ( + "io" +) + +// EncryptionAlgorithm represents the encryption algorithm used +type EncryptionAlgorithm string + +const ( + AlgorithmAES256GCM EncryptionAlgorithm = "aes-256-gcm" +) + +// EncryptionConfig holds encryption configuration +type EncryptionConfig struct { + // Enabled indicates whether encryption is enabled + Enabled bool + + // KeyFile is the path to a file containing the encryption key + KeyFile string + + // KeyEnvVar is the name of an environment variable containing the key + KeyEnvVar string + + // Algorithm specifies the encryption algorithm to use + Algorithm EncryptionAlgorithm + + // Key is the actual encryption key (derived from KeyFile or KeyEnvVar) + Key []byte +} + +// Encryptor provides encryption and decryption capabilities +type Encryptor interface { + // Encrypt encrypts data from reader and returns an encrypted reader + // The returned reader streams encrypted data without loading everything into memory + Encrypt(reader io.Reader, key []byte) (io.Reader, error) + + // Decrypt decrypts data from reader and returns a decrypted reader + // The returned reader streams decrypted data without loading everything into memory + Decrypt(reader io.Reader, key []byte) (io.Reader, error) + + // EncryptFile encrypts a file in-place or to a new file + EncryptFile(inputPath, outputPath string, key []byte) error + + // DecryptFile decrypts a file in-place or to a new file + DecryptFile(inputPath, outputPath string, key []byte) error + + // Algorithm returns the encryption algorithm used by this encryptor + Algorithm() EncryptionAlgorithm +} + +// KeyDeriver derives encryption keys from passwords/passphrases +type KeyDeriver interface { + // DeriveKey derives a key from a password using PBKDF2 or similar + DeriveKey(password []byte, salt []byte, keyLength int) ([]byte, error) + + // GenerateSalt generates a random salt for key derivation + GenerateSalt() ([]byte, error) +} + +// EncryptionMetadata contains metadata about encrypted backups +type EncryptionMetadata struct { + // Algorithm used for encryption + Algorithm string `json:"algorithm"` + + // KeyDerivation method used (e.g., "pbkdf2-sha256") + KeyDerivation string `json:"key_derivation,omitempty"` + + // Salt used for key derivation (base64 encoded) + Salt string `json:"salt,omitempty"` + + // Nonce/IV used for encryption (base64 encoded) + Nonce string `json:"nonce,omitempty"` + + // Version of encryption format + Version int `json:"version"` +} + +// DefaultConfig returns a default encryption configuration +func DefaultConfig() *EncryptionConfig { + return &EncryptionConfig{ + Enabled: false, + Algorithm: AlgorithmAES256GCM, + KeyEnvVar: "DBBACKUP_ENCRYPTION_KEY", + } +}