diff --git a/cmd/restore.go b/cmd/restore.go index 2cdb7da..dd96b99 100755 --- a/cmd/restore.go +++ b/cmd/restore.go @@ -10,6 +10,7 @@ import ( "syscall" "time" + "dbbackup/internal/backup" "dbbackup/internal/cloud" "dbbackup/internal/database" "dbbackup/internal/restore" @@ -28,6 +29,10 @@ var ( restoreTarget string restoreVerbose bool restoreNoProgress bool + + // Encryption flags + restoreEncryptionKeyFile string + restoreEncryptionKeyEnv string = "DBBACKUP_ENCRYPTION_KEY" ) // restoreCmd represents the restore command @@ -156,6 +161,8 @@ func init() { restoreSingleCmd.Flags().StringVar(&restoreTarget, "target", "", "Target database name (defaults to original)") restoreSingleCmd.Flags().BoolVar(&restoreVerbose, "verbose", false, "Show detailed restore progress") restoreSingleCmd.Flags().BoolVar(&restoreNoProgress, "no-progress", false, "Disable progress indicators") + restoreSingleCmd.Flags().StringVar(&restoreEncryptionKeyFile, "encryption-key-file", "", "Path to encryption key file (required for encrypted backups)") + restoreSingleCmd.Flags().StringVar(&restoreEncryptionKeyEnv, "encryption-key-env", "DBBACKUP_ENCRYPTION_KEY", "Environment variable containing encryption key") // Cluster restore flags restoreClusterCmd.Flags().BoolVar(&restoreConfirm, "confirm", false, "Confirm and execute restore (required)") @@ -164,6 +171,8 @@ func init() { restoreClusterCmd.Flags().IntVar(&restoreJobs, "jobs", 0, "Number of parallel decompression jobs (0 = auto)") restoreClusterCmd.Flags().BoolVar(&restoreVerbose, "verbose", false, "Show detailed restore progress") restoreClusterCmd.Flags().BoolVar(&restoreNoProgress, "no-progress", false, "Disable progress indicators") + restoreClusterCmd.Flags().StringVar(&restoreEncryptionKeyFile, "encryption-key-file", "", "Path to encryption key file (required for encrypted backups)") + restoreClusterCmd.Flags().StringVar(&restoreEncryptionKeyEnv, "encryption-key-env", "DBBACKUP_ENCRYPTION_KEY", "Environment variable containing encryption key") } // runRestoreSingle restores a single database @@ -214,6 +223,20 @@ func runRestoreSingle(cmd *cobra.Command, args []string) error { } } + // Check if backup is encrypted and decrypt if necessary + if backup.IsBackupEncrypted(archivePath) { + log.Info("Encrypted backup detected, decrypting...") + key, err := loadEncryptionKey(restoreEncryptionKeyFile, restoreEncryptionKeyEnv) + if err != nil { + return fmt.Errorf("encrypted backup requires encryption key: %w", err) + } + // Decrypt in-place (same path) + if err := backup.DecryptBackupFile(archivePath, archivePath, key, log); err != nil { + return fmt.Errorf("decryption failed: %w", err) + } + log.Info("Decryption completed successfully") + } + // Detect format format := restore.DetectArchiveFormat(archivePath) if format == restore.FormatUnknown { @@ -340,6 +363,20 @@ func runRestoreCluster(cmd *cobra.Command, args []string) error { return fmt.Errorf("archive not found: %s", archivePath) } + // Check if backup is encrypted and decrypt if necessary + if backup.IsBackupEncrypted(archivePath) { + log.Info("Encrypted cluster backup detected, decrypting...") + key, err := loadEncryptionKey(restoreEncryptionKeyFile, restoreEncryptionKeyEnv) + if err != nil { + return fmt.Errorf("encrypted backup requires encryption key: %w", err) + } + // Decrypt in-place (same path) + if err := backup.DecryptBackupFile(archivePath, archivePath, key, log); err != nil { + return fmt.Errorf("decryption failed: %w", err) + } + log.Info("Cluster decryption completed successfully") + } + // Verify it's a cluster backup format := restore.DetectArchiveFormat(archivePath) if !format.IsClusterBackup() {