ci: add golangci-lint config and fix formatting

- Add .golangci.yml with minimal linters (govet, ineffassign)
- Run gofmt -s and goimports on all files to fix formatting
- Disable fieldalignment and copylocks checks in govet
This commit is contained in:
2025-12-11 17:53:28 +01:00
parent 6b66ae5429
commit 914307ac8f
89 changed files with 1516 additions and 1618 deletions

View File

@@ -16,13 +16,13 @@ import (
type AuthMethod string
const (
AuthPeer AuthMethod = "peer"
AuthIdent AuthMethod = "ident"
AuthMD5 AuthMethod = "md5"
AuthScramSHA256 AuthMethod = "scram-sha-256"
AuthPassword AuthMethod = "password"
AuthTrust AuthMethod = "trust"
AuthUnknown AuthMethod = "unknown"
AuthPeer AuthMethod = "peer"
AuthIdent AuthMethod = "ident"
AuthMD5 AuthMethod = "md5"
AuthScramSHA256 AuthMethod = "scram-sha-256"
AuthPassword AuthMethod = "password"
AuthTrust AuthMethod = "trust"
AuthUnknown AuthMethod = "unknown"
)
// DetectPostgreSQLAuthMethod attempts to detect the authentication method
@@ -108,7 +108,7 @@ func parseHbaContent(content string, user string) AuthMethod {
for _, line := range lines {
line = strings.TrimSpace(line)
// Skip comments and empty lines
if line == "" || strings.HasPrefix(line, "#") {
continue
@@ -198,29 +198,29 @@ func buildAuthMismatchMessage(osUser, dbUser string, method AuthMethod) string {
msg.WriteString("\n⚠ Authentication Mismatch Detected\n")
msg.WriteString(strings.Repeat("=", 60) + "\n\n")
msg.WriteString(fmt.Sprintf(" PostgreSQL is using '%s' authentication\n", method))
msg.WriteString(fmt.Sprintf(" OS user '%s' cannot authenticate as DB user '%s'\n\n", osUser, dbUser))
msg.WriteString("💡 Solutions (choose one):\n\n")
msg.WriteString(fmt.Sprintf(" 1. Run as matching user:\n"))
msg.WriteString(fmt.Sprintf(" sudo -u %s %s\n\n", dbUser, getCommandLine()))
msg.WriteString(" 2. Configure ~/.pgpass file (recommended):\n")
msg.WriteString(fmt.Sprintf(" echo \"localhost:5432:*:%s:your_password\" > ~/.pgpass\n", dbUser))
msg.WriteString(" chmod 0600 ~/.pgpass\n\n")
msg.WriteString(" 3. Set PGPASSWORD environment variable:\n")
msg.WriteString(fmt.Sprintf(" export PGPASSWORD=your_password\n"))
msg.WriteString(fmt.Sprintf(" %s\n\n", getCommandLine()))
msg.WriteString(" 4. Provide password via flag:\n")
msg.WriteString(fmt.Sprintf(" %s --password your_password\n\n", getCommandLine()))
msg.WriteString("📝 Note: For production use, ~/.pgpass or PGPASSWORD are recommended\n")
msg.WriteString(" to avoid exposing passwords in command history.\n\n")
msg.WriteString(strings.Repeat("=", 60) + "\n")
return msg.String()
@@ -231,29 +231,29 @@ func getCommandLine() string {
if len(os.Args) == 0 {
return "./dbbackup"
}
// Build command without password if present
var parts []string
skipNext := false
for _, arg := range os.Args {
if skipNext {
skipNext = false
continue
}
if arg == "--password" || arg == "-p" {
skipNext = true
continue
}
if strings.HasPrefix(arg, "--password=") {
continue
}
parts = append(parts, arg)
}
return strings.Join(parts, " ")
}
@@ -298,7 +298,7 @@ func parsePgpass(path string, cfg *config.Config) string {
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
// Skip comments and empty lines
if line == "" || strings.HasPrefix(line, "#") {
continue

View File

@@ -14,7 +14,7 @@ import (
// The original file is replaced with the encrypted version
func EncryptBackupFile(backupPath string, key []byte, log logger.Logger) error {
log.Info("Encrypting backup file", "file", filepath.Base(backupPath))
// Validate key
if err := crypto.ValidateKey(key); err != nil {
return fmt.Errorf("invalid encryption key: %w", err)
@@ -81,25 +81,25 @@ func IsBackupEncrypted(backupPath string) bool {
// All databases are unencrypted
return false
}
// Try single database metadata
if meta, err := metadata.Load(backupPath); err == nil {
return meta.Encrypted
}
// Fallback: check if file starts with encryption nonce
file, err := os.Open(backupPath)
if err != nil {
return false
}
defer file.Close()
// Try to read nonce - if it succeeds, likely encrypted
nonce := make([]byte, crypto.NonceSize)
if n, err := file.Read(nonce); err != nil || n != crypto.NonceSize {
return false
}
return true
}

View File

@@ -20,11 +20,11 @@ import (
"dbbackup/internal/cloud"
"dbbackup/internal/config"
"dbbackup/internal/database"
"dbbackup/internal/security"
"dbbackup/internal/logger"
"dbbackup/internal/metadata"
"dbbackup/internal/metrics"
"dbbackup/internal/progress"
"dbbackup/internal/security"
"dbbackup/internal/swap"
)
@@ -42,7 +42,7 @@ type Engine struct {
func New(cfg *config.Config, log logger.Logger, db database.Database) *Engine {
progressIndicator := progress.NewIndicator(true, "line") // Use line-by-line indicator
detailedReporter := progress.NewDetailedReporter(progressIndicator, &loggerAdapter{logger: log})
return &Engine{
cfg: cfg,
log: log,
@@ -56,7 +56,7 @@ func New(cfg *config.Config, log logger.Logger, db database.Database) *Engine {
// NewWithProgress creates a new backup engine with a custom progress indicator
func NewWithProgress(cfg *config.Config, log logger.Logger, db database.Database, progressIndicator progress.Indicator) *Engine {
detailedReporter := progress.NewDetailedReporter(progressIndicator, &loggerAdapter{logger: log})
return &Engine{
cfg: cfg,
log: log,
@@ -73,9 +73,9 @@ func NewSilent(cfg *config.Config, log logger.Logger, db database.Database, prog
if progressIndicator == nil {
progressIndicator = progress.NewNullIndicator()
}
detailedReporter := progress.NewDetailedReporter(progressIndicator, &loggerAdapter{logger: log})
return &Engine{
cfg: cfg,
log: log,
@@ -126,16 +126,16 @@ func (e *Engine) BackupSingle(ctx context.Context, databaseName string) error {
// Start detailed operation tracking
operationID := generateOperationID()
tracker := e.detailedReporter.StartOperation(operationID, databaseName, "backup")
// Add operation details
tracker.SetDetails("database", databaseName)
tracker.SetDetails("type", "single")
tracker.SetDetails("compression", strconv.Itoa(e.cfg.CompressionLevel))
tracker.SetDetails("format", "custom")
// Start preparing backup directory
prepStep := tracker.AddStep("prepare", "Preparing backup directory")
// Validate and sanitize backup directory path
validBackupDir, err := security.ValidateBackupPath(e.cfg.BackupDir)
if err != nil {
@@ -144,7 +144,7 @@ func (e *Engine) BackupSingle(ctx context.Context, databaseName string) error {
return fmt.Errorf("invalid backup directory path: %w", err)
}
e.cfg.BackupDir = validBackupDir
if err := os.MkdirAll(e.cfg.BackupDir, 0755); err != nil {
err = fmt.Errorf("failed to create backup directory %s. Check write permissions or use --backup-dir to specify writable location: %w", e.cfg.BackupDir, err)
prepStep.Fail(err)
@@ -153,20 +153,20 @@ func (e *Engine) BackupSingle(ctx context.Context, databaseName string) error {
}
prepStep.Complete("Backup directory prepared")
tracker.UpdateProgress(10, "Backup directory prepared")
// Generate timestamp and filename
timestamp := time.Now().Format("20060102_150405")
var outputFile string
if e.cfg.IsPostgreSQL() {
outputFile = filepath.Join(e.cfg.BackupDir, fmt.Sprintf("db_%s_%s.dump", databaseName, timestamp))
} else {
outputFile = filepath.Join(e.cfg.BackupDir, fmt.Sprintf("db_%s_%s.sql.gz", databaseName, timestamp))
}
tracker.SetDetails("output_file", outputFile)
tracker.UpdateProgress(20, "Generated backup filename")
// Build backup command
cmdStep := tracker.AddStep("command", "Building backup command")
options := database.BackupOptions{
@@ -177,15 +177,15 @@ func (e *Engine) BackupSingle(ctx context.Context, databaseName string) error {
NoOwner: false,
NoPrivileges: false,
}
cmd := e.db.BuildBackupCommand(databaseName, outputFile, options)
cmdStep.Complete("Backup command prepared")
tracker.UpdateProgress(30, "Backup command prepared")
// Execute backup command with progress monitoring
execStep := tracker.AddStep("execute", "Executing database backup")
tracker.UpdateProgress(40, "Starting database backup...")
if err := e.executeCommandWithProgress(ctx, cmd, outputFile, tracker); err != nil {
err = fmt.Errorf("backup failed for %s: %w. Check database connectivity and disk space", databaseName, err)
execStep.Fail(err)
@@ -194,7 +194,7 @@ func (e *Engine) BackupSingle(ctx context.Context, databaseName string) error {
}
execStep.Complete("Database backup completed")
tracker.UpdateProgress(80, "Database backup completed")
// Verify backup file
verifyStep := tracker.AddStep("verify", "Verifying backup file")
if info, err := os.Stat(outputFile); err != nil {
@@ -209,7 +209,7 @@ func (e *Engine) BackupSingle(ctx context.Context, databaseName string) error {
verifyStep.Complete(fmt.Sprintf("Backup file verified: %s", size))
tracker.UpdateProgress(90, fmt.Sprintf("Backup verified: %s", size))
}
// Calculate and save checksum
checksumStep := tracker.AddStep("checksum", "Calculating SHA-256 checksum")
if checksum, err := security.ChecksumFile(outputFile); err != nil {
@@ -223,7 +223,7 @@ func (e *Engine) BackupSingle(ctx context.Context, databaseName string) error {
e.log.Info("Backup checksum", "sha256", checksum)
}
}
// Create metadata file
metaStep := tracker.AddStep("metadata", "Creating metadata file")
if err := e.createMetadata(outputFile, databaseName, "single", ""); err != nil {
@@ -232,12 +232,12 @@ func (e *Engine) BackupSingle(ctx context.Context, databaseName string) error {
} else {
metaStep.Complete("Metadata file created")
}
// Record metrics for observability
if info, err := os.Stat(outputFile); err == nil && metrics.GlobalMetrics != nil {
metrics.GlobalMetrics.RecordOperation("backup_single", databaseName, time.Now().Add(-time.Minute), info.Size(), true, 0)
}
// Cloud upload if enabled
if e.cfg.CloudEnabled && e.cfg.CloudAutoUpload {
if err := e.uploadToCloud(ctx, outputFile, tracker); err != nil {
@@ -245,39 +245,39 @@ func (e *Engine) BackupSingle(ctx context.Context, databaseName string) error {
// Don't fail the backup if cloud upload fails
}
}
// Complete operation
tracker.UpdateProgress(100, "Backup operation completed successfully")
tracker.Complete(fmt.Sprintf("Single database backup completed: %s", filepath.Base(outputFile)))
return nil
}
// BackupSample performs a sample database backup
func (e *Engine) BackupSample(ctx context.Context, databaseName string) error {
operation := e.log.StartOperation("Sample Database Backup")
// Ensure backup directory exists
if err := os.MkdirAll(e.cfg.BackupDir, 0755); err != nil {
operation.Fail("Failed to create backup directory")
return fmt.Errorf("failed to create backup directory: %w", err)
}
// Generate timestamp and filename
timestamp := time.Now().Format("20060102_150405")
outputFile := filepath.Join(e.cfg.BackupDir,
outputFile := filepath.Join(e.cfg.BackupDir,
fmt.Sprintf("sample_%s_%s%d_%s.sql", databaseName, e.cfg.SampleStrategy, e.cfg.SampleValue, timestamp))
operation.Update("Starting sample database backup")
e.progress.Start(fmt.Sprintf("Creating sample backup of '%s' (%s=%d)", databaseName, e.cfg.SampleStrategy, e.cfg.SampleValue))
// For sample backups, we need to get the schema first, then sample data
if err := e.createSampleBackup(ctx, databaseName, outputFile); err != nil {
e.progress.Fail(fmt.Sprintf("Sample backup failed: %v", err))
operation.Fail("Sample backup failed")
return fmt.Errorf("sample backup failed: %w", err)
}
// Check output file
if info, err := os.Stat(outputFile); err != nil {
e.progress.Fail("Sample backup file not created")
@@ -288,12 +288,12 @@ func (e *Engine) BackupSample(ctx context.Context, databaseName string) error {
e.progress.Complete(fmt.Sprintf("Sample backup completed: %s (%s)", filepath.Base(outputFile), size))
operation.Complete(fmt.Sprintf("Sample backup created: %s (%s)", outputFile, size))
}
// Create metadata file
if err := e.createMetadata(outputFile, databaseName, "sample", e.cfg.SampleStrategy); err != nil {
e.log.Warn("Failed to create metadata file", "error", err)
}
return nil
}
@@ -302,19 +302,19 @@ func (e *Engine) BackupCluster(ctx context.Context) error {
if !e.cfg.IsPostgreSQL() {
return fmt.Errorf("cluster backup is only supported for PostgreSQL")
}
operation := e.log.StartOperation("Cluster Backup")
// Setup swap file if configured
var swapMgr *swap.Manager
if e.cfg.AutoSwap && e.cfg.SwapFileSizeGB > 0 {
swapMgr = swap.NewManager(e.cfg.SwapFilePath, e.cfg.SwapFileSizeGB, e.log)
if swapMgr.IsSupported() {
e.log.Info("Setting up temporary swap file for large backup",
"path", e.cfg.SwapFilePath,
e.log.Info("Setting up temporary swap file for large backup",
"path", e.cfg.SwapFilePath,
"size_gb", e.cfg.SwapFileSizeGB)
if err := swapMgr.Setup(); err != nil {
e.log.Warn("Failed to setup swap file (continuing without it)", "error", err)
} else {
@@ -329,7 +329,7 @@ func (e *Engine) BackupCluster(ctx context.Context) error {
e.log.Warn("Swap file management not supported on this platform", "os", swapMgr)
}
}
// Use appropriate progress indicator based on silent mode
var quietProgress progress.Indicator
if e.silent {
@@ -340,42 +340,42 @@ func (e *Engine) BackupCluster(ctx context.Context) error {
quietProgress = progress.NewQuietLineByLine()
quietProgress.Start("Starting cluster backup (all databases)")
}
// Ensure backup directory exists
if err := os.MkdirAll(e.cfg.BackupDir, 0755); err != nil {
operation.Fail("Failed to create backup directory")
quietProgress.Fail("Failed to create backup directory")
return fmt.Errorf("failed to create backup directory: %w", err)
}
// Check disk space before starting backup (cached for performance)
e.log.Info("Checking disk space availability")
spaceCheck := checks.CheckDiskSpaceCached(e.cfg.BackupDir)
if !e.silent {
// Show disk space status in CLI mode
fmt.Println("\n" + checks.FormatDiskSpaceMessage(spaceCheck))
}
if spaceCheck.Critical {
operation.Fail("Insufficient disk space")
quietProgress.Fail("Insufficient disk space - free up space and try again")
return fmt.Errorf("insufficient disk space: %.1f%% used, operation blocked", spaceCheck.UsedPercent)
}
if spaceCheck.Warning {
e.log.Warn("Low disk space - backup may fail if database is large",
e.log.Warn("Low disk space - backup may fail if database is large",
"available_gb", float64(spaceCheck.AvailableBytes)/(1024*1024*1024),
"used_percent", spaceCheck.UsedPercent)
}
// Generate timestamp and filename
timestamp := time.Now().Format("20060102_150405")
outputFile := filepath.Join(e.cfg.BackupDir, fmt.Sprintf("cluster_%s.tar.gz", timestamp))
tempDir := filepath.Join(e.cfg.BackupDir, fmt.Sprintf(".cluster_%s", timestamp))
operation.Update("Starting cluster backup")
// Create temporary directory
if err := os.MkdirAll(filepath.Join(tempDir, "dumps"), 0755); err != nil {
operation.Fail("Failed to create temporary directory")
@@ -383,7 +383,7 @@ func (e *Engine) BackupCluster(ctx context.Context) error {
return fmt.Errorf("failed to create temp directory: %w", err)
}
defer os.RemoveAll(tempDir)
// Backup globals
e.printf(" Backing up global objects...\n")
if err := e.backupGlobals(ctx, tempDir); err != nil {
@@ -391,7 +391,7 @@ func (e *Engine) BackupCluster(ctx context.Context) error {
operation.Fail("Global backup failed")
return fmt.Errorf("failed to backup globals: %w", err)
}
// Get list of databases
e.printf(" Getting database list...\n")
databases, err := e.db.ListDatabases(ctx)
@@ -400,31 +400,31 @@ func (e *Engine) BackupCluster(ctx context.Context) error {
operation.Fail("Database listing failed")
return fmt.Errorf("failed to list databases: %w", err)
}
// Create ETA estimator for database backups
estimator := progress.NewETAEstimator("Backing up cluster", len(databases))
quietProgress.SetEstimator(estimator)
// Backup each database
parallelism := e.cfg.ClusterParallelism
if parallelism < 1 {
parallelism = 1 // Ensure at least sequential
}
if parallelism == 1 {
e.printf(" Backing up %d databases sequentially...\n", len(databases))
} else {
e.printf(" Backing up %d databases with %d parallel workers...\n", len(databases), parallelism)
}
// Use worker pool for parallel backup
var successCount, failCount int32
var mu sync.Mutex // Protect shared resources (printf, estimator)
// Create semaphore to limit concurrency
semaphore := make(chan struct{}, parallelism)
var wg sync.WaitGroup
for i, dbName := range databases {
// Check if context is cancelled before starting new backup
select {
@@ -435,14 +435,14 @@ func (e *Engine) BackupCluster(ctx context.Context) error {
return fmt.Errorf("backup cancelled: %w", ctx.Err())
default:
}
wg.Add(1)
semaphore <- struct{}{} // Acquire
go func(idx int, name string) {
defer wg.Done()
defer func() { <-semaphore }() // Release
// Check for cancellation at start of goroutine
select {
case <-ctx.Done():
@@ -451,14 +451,14 @@ func (e *Engine) BackupCluster(ctx context.Context) error {
return
default:
}
// Update estimator progress (thread-safe)
mu.Lock()
estimator.UpdateProgress(idx)
e.printf(" [%d/%d] Backing up database: %s\n", idx+1, len(databases), name)
quietProgress.Update(fmt.Sprintf("Backing up database %d/%d: %s", idx+1, len(databases), name))
mu.Unlock()
// Check database size and warn if very large
if size, err := e.db.GetDatabaseSize(ctx, name); err == nil {
sizeStr := formatBytes(size)
@@ -469,17 +469,17 @@ func (e *Engine) BackupCluster(ctx context.Context) error {
}
mu.Unlock()
}
dumpFile := filepath.Join(tempDir, "dumps", name+".dump")
compressionLevel := e.cfg.CompressionLevel
if compressionLevel > 6 {
compressionLevel = 6
}
format := "custom"
parallel := e.cfg.DumpJobs
if size, err := e.db.GetDatabaseSize(ctx, name); err == nil {
if size > 5*1024*1024*1024 {
format = "plain"
@@ -490,7 +490,7 @@ func (e *Engine) BackupCluster(ctx context.Context) error {
mu.Unlock()
}
}
options := database.BackupOptions{
Compression: compressionLevel,
Parallel: parallel,
@@ -499,14 +499,14 @@ func (e *Engine) BackupCluster(ctx context.Context) error {
NoOwner: false,
NoPrivileges: false,
}
cmd := e.db.BuildBackupCommand(name, dumpFile, options)
dbCtx, cancel := context.WithTimeout(ctx, 2*time.Hour)
defer cancel()
err := e.executeCommand(dbCtx, cmd, dumpFile)
cancel()
if err != nil {
e.log.Warn("Failed to backup database", "database", name, "error", err)
mu.Lock()
@@ -526,15 +526,15 @@ func (e *Engine) BackupCluster(ctx context.Context) error {
}
}(i, dbName)
}
// Wait for all backups to complete
wg.Wait()
successCountFinal := int(atomic.LoadInt32(&successCount))
failCountFinal := int(atomic.LoadInt32(&failCount))
e.printf(" Backup summary: %d succeeded, %d failed\n", successCountFinal, failCountFinal)
// Create archive
e.printf(" Creating compressed archive...\n")
if err := e.createArchive(ctx, tempDir, outputFile); err != nil {
@@ -542,7 +542,7 @@ func (e *Engine) BackupCluster(ctx context.Context) error {
operation.Fail("Archive creation failed")
return fmt.Errorf("failed to create archive: %w", err)
}
// Check output file
if info, err := os.Stat(outputFile); err != nil {
quietProgress.Fail("Cluster backup archive not created")
@@ -553,12 +553,12 @@ func (e *Engine) BackupCluster(ctx context.Context) error {
quietProgress.Complete(fmt.Sprintf("Cluster backup completed: %s (%s)", filepath.Base(outputFile), size))
operation.Complete(fmt.Sprintf("Cluster backup created: %s (%s)", outputFile, size))
}
// Create cluster metadata file
if err := e.createClusterMetadata(outputFile, databases, successCountFinal, failCountFinal); err != nil {
e.log.Warn("Failed to create cluster metadata file", "error", err)
}
return nil
}
@@ -567,11 +567,11 @@ func (e *Engine) executeCommandWithProgress(ctx context.Context, cmdArgs []strin
if len(cmdArgs) == 0 {
return fmt.Errorf("empty command")
}
e.log.Debug("Executing backup command with progress", "cmd", cmdArgs[0], "args", cmdArgs[1:])
cmd := exec.CommandContext(ctx, cmdArgs[0], cmdArgs[1:]...)
// Set environment variables for database tools
cmd.Env = os.Environ()
if e.cfg.Password != "" {
@@ -581,51 +581,51 @@ func (e *Engine) executeCommandWithProgress(ctx context.Context, cmdArgs []strin
cmd.Env = append(cmd.Env, "MYSQL_PWD="+e.cfg.Password)
}
}
// For MySQL, handle compression and redirection differently
if e.cfg.IsMySQL() && e.cfg.CompressionLevel > 0 {
return e.executeMySQLWithProgressAndCompression(ctx, cmdArgs, outputFile, tracker)
}
// Get stderr pipe for progress monitoring
stderr, err := cmd.StderrPipe()
if err != nil {
return fmt.Errorf("failed to get stderr pipe: %w", err)
}
// Start the command
if err := cmd.Start(); err != nil {
return fmt.Errorf("failed to start command: %w", err)
}
// Monitor progress via stderr
go e.monitorCommandProgress(stderr, tracker)
// Wait for command to complete
if err := cmd.Wait(); err != nil {
return fmt.Errorf("backup command failed: %w", err)
}
return nil
}
// monitorCommandProgress monitors command output for progress information
func (e *Engine) monitorCommandProgress(stderr io.ReadCloser, tracker *progress.OperationTracker) {
defer stderr.Close()
scanner := bufio.NewScanner(stderr)
scanner.Buffer(make([]byte, 64*1024), 1024*1024) // 64KB initial, 1MB max for performance
progressBase := 40 // Start from 40% since command preparation is done
progressBase := 40 // Start from 40% since command preparation is done
progressIncrement := 0
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
e.log.Debug("Command output", "line", line)
// Increment progress gradually based on output
if progressBase < 75 {
progressIncrement++
@@ -634,7 +634,7 @@ func (e *Engine) monitorCommandProgress(stderr io.ReadCloser, tracker *progress.
tracker.UpdateProgress(progressBase, "Processing data...")
}
}
// Look for specific progress indicators
if strings.Contains(line, "COPY") {
tracker.UpdateProgress(progressBase+5, "Copying table data...")
@@ -654,55 +654,55 @@ func (e *Engine) executeMySQLWithProgressAndCompression(ctx context.Context, cmd
if e.cfg.Password != "" {
dumpCmd.Env = append(dumpCmd.Env, "MYSQL_PWD="+e.cfg.Password)
}
// Create gzip command
gzipCmd := exec.CommandContext(ctx, "gzip", fmt.Sprintf("-%d", e.cfg.CompressionLevel))
// Create output file
outFile, err := os.Create(outputFile)
if err != nil {
return fmt.Errorf("failed to create output file: %w", err)
}
defer outFile.Close()
// Set up pipeline: mysqldump | gzip > outputfile
pipe, err := dumpCmd.StdoutPipe()
if err != nil {
return fmt.Errorf("failed to create pipe: %w", err)
}
gzipCmd.Stdin = pipe
gzipCmd.Stdout = outFile
// Get stderr for progress monitoring
stderr, err := dumpCmd.StderrPipe()
if err != nil {
return fmt.Errorf("failed to get stderr pipe: %w", err)
}
// Start monitoring progress
go e.monitorCommandProgress(stderr, tracker)
// Start both commands
if err := gzipCmd.Start(); err != nil {
return fmt.Errorf("failed to start gzip: %w", err)
}
if err := dumpCmd.Start(); err != nil {
return fmt.Errorf("failed to start mysqldump: %w", err)
}
// Wait for mysqldump to complete
if err := dumpCmd.Wait(); err != nil {
return fmt.Errorf("mysqldump failed: %w", err)
}
// Close pipe and wait for gzip
pipe.Close()
if err := gzipCmd.Wait(); err != nil {
return fmt.Errorf("gzip failed: %w", err)
}
return nil
}
@@ -714,17 +714,17 @@ func (e *Engine) executeMySQLWithCompression(ctx context.Context, cmdArgs []stri
if e.cfg.Password != "" {
dumpCmd.Env = append(dumpCmd.Env, "MYSQL_PWD="+e.cfg.Password)
}
// Create gzip command
gzipCmd := exec.CommandContext(ctx, "gzip", fmt.Sprintf("-%d", e.cfg.CompressionLevel))
// Create output file
outFile, err := os.Create(outputFile)
if err != nil {
return fmt.Errorf("failed to create output file: %w", err)
}
defer outFile.Close()
// Set up pipeline: mysqldump | gzip > outputfile
stdin, err := dumpCmd.StdoutPipe()
if err != nil {
@@ -732,20 +732,20 @@ func (e *Engine) executeMySQLWithCompression(ctx context.Context, cmdArgs []stri
}
gzipCmd.Stdin = stdin
gzipCmd.Stdout = outFile
// Start both commands
if err := gzipCmd.Start(); err != nil {
return fmt.Errorf("failed to start gzip: %w", err)
}
if err := dumpCmd.Run(); err != nil {
return fmt.Errorf("mysqldump failed: %w", err)
}
if err := gzipCmd.Wait(); err != nil {
return fmt.Errorf("gzip failed: %w", err)
}
return nil
}
@@ -757,23 +757,23 @@ func (e *Engine) createSampleBackup(ctx context.Context, databaseName, outputFil
// 2. Get list of tables
// 3. For each table, run sampling query
// 4. Combine into single SQL file
// For now, we'll use a simple approach with schema-only backup first
// Then add sample data
file, err := os.Create(outputFile)
if err != nil {
return fmt.Errorf("failed to create sample backup file: %w", err)
}
defer file.Close()
// Write header
fmt.Fprintf(file, "-- Sample Database Backup\n")
fmt.Fprintf(file, "-- Database: %s\n", databaseName)
fmt.Fprintf(file, "-- Strategy: %s = %d\n", e.cfg.SampleStrategy, e.cfg.SampleValue)
fmt.Fprintf(file, "-- Created: %s\n", time.Now().Format(time.RFC3339))
fmt.Fprintf(file, "-- WARNING: This backup may have referential integrity issues!\n\n")
// For PostgreSQL, we can use pg_dump --schema-only first
if e.cfg.IsPostgreSQL() {
// Get schema
@@ -781,61 +781,61 @@ func (e *Engine) createSampleBackup(ctx context.Context, databaseName, outputFil
SchemaOnly: true,
Format: "plain",
})
cmd := exec.CommandContext(ctx, schemaCmd[0], schemaCmd[1:]...)
cmd.Env = os.Environ()
if e.cfg.Password != "" {
cmd.Env = append(cmd.Env, "PGPASSWORD="+e.cfg.Password)
}
cmd.Stdout = file
if err := cmd.Run(); err != nil {
return fmt.Errorf("failed to export schema: %w", err)
}
fmt.Fprintf(file, "\n-- Sample data follows\n\n")
// Get tables and sample data
tables, err := e.db.ListTables(ctx, databaseName)
if err != nil {
return fmt.Errorf("failed to list tables: %w", err)
}
strategy := database.SampleStrategy{
Type: e.cfg.SampleStrategy,
Value: e.cfg.SampleValue,
}
for _, table := range tables {
fmt.Fprintf(file, "-- Data for table: %s\n", table)
sampleQuery := e.db.BuildSampleQuery(databaseName, table, strategy)
fmt.Fprintf(file, "\\copy (%s) TO STDOUT\n\n", sampleQuery)
}
}
return nil
}
// backupGlobals creates a backup of global PostgreSQL objects
func (e *Engine) backupGlobals(ctx context.Context, tempDir string) error {
globalsFile := filepath.Join(tempDir, "globals.sql")
cmd := exec.CommandContext(ctx, "pg_dumpall", "--globals-only")
if e.cfg.Host != "localhost" {
cmd.Args = append(cmd.Args, "-h", e.cfg.Host, "-p", fmt.Sprintf("%d", e.cfg.Port))
}
cmd.Args = append(cmd.Args, "-U", e.cfg.User)
cmd.Env = os.Environ()
if e.cfg.Password != "" {
cmd.Env = append(cmd.Env, "PGPASSWORD="+e.cfg.Password)
}
output, err := cmd.Output()
if err != nil {
return fmt.Errorf("pg_dumpall failed: %w", err)
}
return os.WriteFile(globalsFile, output, 0644)
}
@@ -844,13 +844,13 @@ func (e *Engine) createArchive(ctx context.Context, sourceDir, outputFile string
// Use pigz for faster parallel compression if available, otherwise use standard gzip
compressCmd := "tar"
compressArgs := []string{"-czf", outputFile, "-C", sourceDir, "."}
// Check if pigz is available for faster parallel compression
if _, err := exec.LookPath("pigz"); err == nil {
// Use pigz with number of cores for parallel compression
compressArgs = []string{"-cf", "-", "-C", sourceDir, "."}
cmd := exec.CommandContext(ctx, "tar", compressArgs...)
// Create output file
outFile, err := os.Create(outputFile)
if err != nil {
@@ -858,10 +858,10 @@ func (e *Engine) createArchive(ctx context.Context, sourceDir, outputFile string
goto regularTar
}
defer outFile.Close()
// Pipe to pigz for parallel compression
pigzCmd := exec.CommandContext(ctx, "pigz", "-p", strconv.Itoa(e.cfg.Jobs))
tarOut, err := cmd.StdoutPipe()
if err != nil {
outFile.Close()
@@ -870,7 +870,7 @@ func (e *Engine) createArchive(ctx context.Context, sourceDir, outputFile string
}
pigzCmd.Stdin = tarOut
pigzCmd.Stdout = outFile
// Start both commands
if err := pigzCmd.Start(); err != nil {
outFile.Close()
@@ -881,13 +881,13 @@ func (e *Engine) createArchive(ctx context.Context, sourceDir, outputFile string
outFile.Close()
goto regularTar
}
// Wait for tar to finish
if err := cmd.Wait(); err != nil {
pigzCmd.Process.Kill()
return fmt.Errorf("tar failed: %w", err)
}
// Wait for pigz to finish
if err := pigzCmd.Wait(); err != nil {
return fmt.Errorf("pigz compression failed: %w", err)
@@ -898,7 +898,7 @@ func (e *Engine) createArchive(ctx context.Context, sourceDir, outputFile string
regularTar:
// Standard tar with gzip (fallback)
cmd := exec.CommandContext(ctx, compressCmd, compressArgs...)
// Stream stderr to avoid memory issues
// Use io.Copy to ensure goroutine completes when pipe closes
stderr, err := cmd.StderrPipe()
@@ -914,7 +914,7 @@ regularTar:
// Scanner will exit when stderr pipe closes after cmd.Wait()
}()
}
if err := cmd.Run(); err != nil {
return fmt.Errorf("tar failed: %w", err)
}
@@ -925,26 +925,26 @@ regularTar:
// createMetadata creates a metadata file for the backup
func (e *Engine) createMetadata(backupFile, database, backupType, strategy string) error {
startTime := time.Now()
// Get backup file information
info, err := os.Stat(backupFile)
if err != nil {
return fmt.Errorf("failed to stat backup file: %w", err)
}
// Calculate SHA-256 checksum
sha256, err := metadata.CalculateSHA256(backupFile)
if err != nil {
return fmt.Errorf("failed to calculate checksum: %w", err)
}
// Get database version
ctx := context.Background()
dbVersion, _ := e.db.GetVersion(ctx)
if dbVersion == "" {
dbVersion = "unknown"
}
// Determine compression format
compressionFormat := "none"
if e.cfg.CompressionLevel > 0 {
@@ -954,7 +954,7 @@ func (e *Engine) createMetadata(backupFile, database, backupType, strategy strin
compressionFormat = fmt.Sprintf("gzip-%d", e.cfg.CompressionLevel)
}
}
// Create backup metadata
meta := &metadata.BackupMetadata{
Version: "2.0",
@@ -973,18 +973,18 @@ func (e *Engine) createMetadata(backupFile, database, backupType, strategy strin
Duration: time.Since(startTime).Seconds(),
ExtraInfo: make(map[string]string),
}
// Add strategy for sample backups
if strategy != "" {
meta.ExtraInfo["sample_strategy"] = strategy
meta.ExtraInfo["sample_value"] = fmt.Sprintf("%d", e.cfg.SampleValue)
}
// Save metadata
if err := meta.Save(); err != nil {
return fmt.Errorf("failed to save metadata: %w", err)
}
// Also save legacy .info file for backward compatibility
legacyMetaFile := backupFile + ".info"
legacyContent := fmt.Sprintf(`{
@@ -998,39 +998,39 @@ func (e *Engine) createMetadata(backupFile, database, backupType, strategy strin
"compression": %d,
"size_bytes": %d
}`, backupType, database, startTime.Format("20060102_150405"),
e.cfg.Host, e.cfg.Port, e.cfg.User, e.cfg.DatabaseType,
e.cfg.Host, e.cfg.Port, e.cfg.User, e.cfg.DatabaseType,
e.cfg.CompressionLevel, info.Size())
if err := os.WriteFile(legacyMetaFile, []byte(legacyContent), 0644); err != nil {
e.log.Warn("Failed to save legacy metadata file", "error", err)
}
return nil
}
// createClusterMetadata creates metadata for cluster backups
func (e *Engine) createClusterMetadata(backupFile string, databases []string, successCount, failCount int) error {
startTime := time.Now()
// Get backup file information
info, err := os.Stat(backupFile)
if err != nil {
return fmt.Errorf("failed to stat backup file: %w", err)
}
// Calculate SHA-256 checksum for archive
sha256, err := metadata.CalculateSHA256(backupFile)
if err != nil {
return fmt.Errorf("failed to calculate checksum: %w", err)
}
// Get database version
ctx := context.Background()
dbVersion, _ := e.db.GetVersion(ctx)
if dbVersion == "" {
dbVersion = "unknown"
}
// Create cluster metadata
clusterMeta := &metadata.ClusterMetadata{
Version: "2.0",
@@ -1050,7 +1050,7 @@ func (e *Engine) createClusterMetadata(backupFile string, databases []string, su
"database_version": dbVersion,
},
}
// Add database names to metadata
for _, dbName := range databases {
dbMeta := metadata.BackupMetadata{
@@ -1061,12 +1061,12 @@ func (e *Engine) createClusterMetadata(backupFile string, databases []string, su
}
clusterMeta.Databases = append(clusterMeta.Databases, dbMeta)
}
// Save cluster metadata
if err := clusterMeta.Save(backupFile); err != nil {
return fmt.Errorf("failed to save cluster metadata: %w", err)
}
// Also save legacy .info file for backward compatibility
legacyMetaFile := backupFile + ".info"
legacyContent := fmt.Sprintf(`{
@@ -1085,18 +1085,18 @@ func (e *Engine) createClusterMetadata(backupFile string, databases []string, su
}`, startTime.Format("20060102_150405"),
e.cfg.Host, e.cfg.Port, e.cfg.User, e.cfg.DatabaseType,
e.cfg.CompressionLevel, info.Size(), len(databases), successCount, failCount)
if err := os.WriteFile(legacyMetaFile, []byte(legacyContent), 0644); err != nil {
e.log.Warn("Failed to save legacy cluster metadata file", "error", err)
}
return nil
}
// uploadToCloud uploads a backup file to cloud storage
func (e *Engine) uploadToCloud(ctx context.Context, backupFile string, tracker *progress.OperationTracker) error {
uploadStep := tracker.AddStep("cloud_upload", "Uploading to cloud storage")
// Create cloud backend
cloudCfg := &cloud.Config{
Provider: e.cfg.CloudProvider,
@@ -1111,23 +1111,23 @@ func (e *Engine) uploadToCloud(ctx context.Context, backupFile string, tracker *
Timeout: 300,
MaxRetries: 3,
}
backend, err := cloud.NewBackend(cloudCfg)
if err != nil {
uploadStep.Fail(fmt.Errorf("failed to create cloud backend: %w", err))
return err
}
// Get file info
info, err := os.Stat(backupFile)
if err != nil {
uploadStep.Fail(fmt.Errorf("failed to stat backup file: %w", err))
return err
}
filename := filepath.Base(backupFile)
e.log.Info("Uploading backup to cloud", "file", filename, "size", cloud.FormatSize(info.Size()))
// Progress callback
var lastPercent int
progressCallback := func(transferred, total int64) {
@@ -1137,14 +1137,14 @@ func (e *Engine) uploadToCloud(ctx context.Context, backupFile string, tracker *
lastPercent = percent
}
}
// Upload to cloud
err = backend.Upload(ctx, backupFile, filename, progressCallback)
if err != nil {
uploadStep.Fail(fmt.Errorf("cloud upload failed: %w", err))
return err
}
// Also upload metadata file
metaFile := backupFile + ".meta.json"
if _, err := os.Stat(metaFile); err == nil {
@@ -1154,10 +1154,10 @@ func (e *Engine) uploadToCloud(ctx context.Context, backupFile string, tracker *
// Don't fail if metadata upload fails
}
}
uploadStep.Complete(fmt.Sprintf("Uploaded to %s/%s/%s", backend.Name(), e.cfg.CloudBucket, filename))
e.log.Info("Backup uploaded to cloud", "provider", backend.Name(), "bucket", e.cfg.CloudBucket, "file", filename)
return nil
}
@@ -1166,9 +1166,9 @@ func (e *Engine) executeCommand(ctx context.Context, cmdArgs []string, outputFil
if len(cmdArgs) == 0 {
return fmt.Errorf("empty command")
}
e.log.Debug("Executing backup command", "cmd", cmdArgs[0], "args", cmdArgs[1:])
// Check if pg_dump will write to stdout (which means we need to handle piping to compressor).
// BuildBackupCommand omits --file when format==plain AND compression==0, causing pg_dump
// to write to stdout. In that case we must pipe to external compressor.
@@ -1192,28 +1192,28 @@ func (e *Engine) executeCommand(ctx context.Context, cmdArgs []string, outputFil
if isPlainFormat && !hasFileFlag {
usesStdout = true
}
e.log.Debug("Backup command analysis",
"plain_format", isPlainFormat,
"has_file_flag", hasFileFlag,
e.log.Debug("Backup command analysis",
"plain_format", isPlainFormat,
"has_file_flag", hasFileFlag,
"uses_stdout", usesStdout,
"output_file", outputFile)
// For MySQL, handle compression differently
if e.cfg.IsMySQL() && e.cfg.CompressionLevel > 0 {
return e.executeMySQLWithCompression(ctx, cmdArgs, outputFile)
}
// For plain format writing to stdout, use streaming compression
if usesStdout {
e.log.Debug("Using streaming compression for large database")
return e.executeWithStreamingCompression(ctx, cmdArgs, outputFile)
}
// For custom format, pg_dump handles everything (writes directly to file)
// NO GO BUFFERING - pg_dump writes directly to disk
cmd := exec.CommandContext(ctx, cmdArgs[0], cmdArgs[1:]...)
// Set environment variables for database tools
cmd.Env = os.Environ()
if e.cfg.Password != "" {
@@ -1223,18 +1223,18 @@ func (e *Engine) executeCommand(ctx context.Context, cmdArgs []string, outputFil
cmd.Env = append(cmd.Env, "MYSQL_PWD="+e.cfg.Password)
}
}
// Stream stderr to avoid memory issues with large databases
stderr, err := cmd.StderrPipe()
if err != nil {
return fmt.Errorf("failed to create stderr pipe: %w", err)
}
// Start the command
if err := cmd.Start(); err != nil {
return fmt.Errorf("failed to start backup command: %w", err)
}
// Stream stderr output (don't buffer it all in memory)
go func() {
scanner := bufio.NewScanner(stderr)
@@ -1246,13 +1246,13 @@ func (e *Engine) executeCommand(ctx context.Context, cmdArgs []string, outputFil
}
}
}()
// Wait for command to complete
if err := cmd.Wait(); err != nil {
e.log.Error("Backup command failed", "error", err, "database", filepath.Base(outputFile))
return fmt.Errorf("backup command failed: %w", err)
}
return nil
}
@@ -1260,7 +1260,7 @@ func (e *Engine) executeCommand(ctx context.Context, cmdArgs []string, outputFil
// Uses: pg_dump | pigz > file.sql.gz (zero-copy streaming)
func (e *Engine) executeWithStreamingCompression(ctx context.Context, cmdArgs []string, outputFile string) error {
e.log.Debug("Using streaming compression for large database")
// Derive compressed output filename. If the output was named *.dump we replace that
// with *.sql.gz; otherwise append .gz to the provided output file so we don't
// accidentally create unwanted double extensions.
@@ -1273,43 +1273,43 @@ func (e *Engine) executeWithStreamingCompression(ctx context.Context, cmdArgs []
} else {
compressedFile = outputFile + ".gz"
}
// Create pg_dump command
dumpCmd := exec.CommandContext(ctx, cmdArgs[0], cmdArgs[1:]...)
dumpCmd.Env = os.Environ()
if e.cfg.Password != "" && e.cfg.IsPostgreSQL() {
dumpCmd.Env = append(dumpCmd.Env, "PGPASSWORD="+e.cfg.Password)
}
// Check for pigz (parallel gzip)
compressor := "gzip"
compressorArgs := []string{"-c"}
if _, err := exec.LookPath("pigz"); err == nil {
compressor = "pigz"
compressorArgs = []string{"-p", strconv.Itoa(e.cfg.Jobs), "-c"}
e.log.Debug("Using pigz for parallel compression", "threads", e.cfg.Jobs)
}
// Create compression command
compressCmd := exec.CommandContext(ctx, compressor, compressorArgs...)
// Create output file
outFile, err := os.Create(compressedFile)
if err != nil {
return fmt.Errorf("failed to create output file: %w", err)
}
defer outFile.Close()
// Set up pipeline: pg_dump | pigz > file.sql.gz
dumpStdout, err := dumpCmd.StdoutPipe()
if err != nil {
return fmt.Errorf("failed to create dump stdout pipe: %w", err)
}
compressCmd.Stdin = dumpStdout
compressCmd.Stdout = outFile
// Capture stderr from both commands
dumpStderr, err := dumpCmd.StderrPipe()
if err != nil {
@@ -1319,7 +1319,7 @@ func (e *Engine) executeWithStreamingCompression(ctx context.Context, cmdArgs []
if err != nil {
e.log.Warn("Failed to capture compress stderr", "error", err)
}
// Stream stderr output
if dumpStderr != nil {
go func() {
@@ -1332,7 +1332,7 @@ func (e *Engine) executeWithStreamingCompression(ctx context.Context, cmdArgs []
}
}()
}
if compressStderr != nil {
go func() {
scanner := bufio.NewScanner(compressStderr)
@@ -1344,30 +1344,30 @@ func (e *Engine) executeWithStreamingCompression(ctx context.Context, cmdArgs []
}
}()
}
// Start compression first
if err := compressCmd.Start(); err != nil {
return fmt.Errorf("failed to start compressor: %w", err)
}
// Then start pg_dump
if err := dumpCmd.Start(); err != nil {
return fmt.Errorf("failed to start pg_dump: %w", err)
}
// Wait for pg_dump to complete
if err := dumpCmd.Wait(); err != nil {
return fmt.Errorf("pg_dump failed: %w", err)
}
// Close stdout pipe to signal compressor we're done
dumpStdout.Close()
// Wait for compression to complete
if err := compressCmd.Wait(); err != nil {
return fmt.Errorf("compression failed: %w", err)
}
e.log.Debug("Streaming compression completed", "output", compressedFile)
return nil
}
@@ -1384,4 +1384,4 @@ func formatBytes(bytes int64) string {
exp++
}
return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(div), "KMGTPE"[exp])
}
}

View File

@@ -17,19 +17,19 @@ const (
type IncrementalMetadata struct {
// BaseBackupID is the SHA-256 checksum of the base backup this incremental depends on
BaseBackupID string `json:"base_backup_id"`
// BaseBackupPath is the filename of the base backup (e.g., "mydb_20250126_120000.tar.gz")
BaseBackupPath string `json:"base_backup_path"`
// BaseBackupTimestamp is when the base backup was created
BaseBackupTimestamp time.Time `json:"base_backup_timestamp"`
// IncrementalFiles is the number of changed files included in this backup
IncrementalFiles int `json:"incremental_files"`
// TotalSize is the total size of changed files (bytes)
TotalSize int64 `json:"total_size"`
// BackupChain is the list of all backups needed for restore (base + incrementals)
// Ordered from oldest to newest: [base, incr1, incr2, ...]
BackupChain []string `json:"backup_chain"`
@@ -39,16 +39,16 @@ type IncrementalMetadata struct {
type ChangedFile struct {
// RelativePath is the path relative to PostgreSQL data directory
RelativePath string
// AbsolutePath is the full filesystem path
AbsolutePath string
// Size is the file size in bytes
Size int64
// ModTime is the last modification time
ModTime time.Time
// Checksum is the SHA-256 hash of the file content (optional)
Checksum string
}
@@ -57,13 +57,13 @@ type ChangedFile struct {
type IncrementalBackupConfig struct {
// BaseBackupPath is the path to the base backup archive
BaseBackupPath string
// DataDirectory is the PostgreSQL data directory to scan
DataDirectory string
// IncludeWAL determines if WAL files should be included
IncludeWAL bool
// CompressionLevel for the incremental archive (0-9)
CompressionLevel int
}
@@ -72,11 +72,11 @@ type IncrementalBackupConfig struct {
type BackupChainResolver interface {
// FindBaseBackup locates the base backup for an incremental backup
FindBaseBackup(ctx context.Context, incrementalBackupID string) (*BackupInfo, error)
// ResolveChain returns the complete chain of backups needed for restore
// Returned in order: [base, incr1, incr2, ..., target]
ResolveChain(ctx context.Context, targetBackupID string) ([]*BackupInfo, error)
// ValidateChain verifies all backups in the chain exist and are valid
ValidateChain(ctx context.Context, chain []*BackupInfo) error
}
@@ -85,10 +85,10 @@ type BackupChainResolver interface {
type IncrementalBackupEngine interface {
// FindChangedFiles identifies files changed since the base backup
FindChangedFiles(ctx context.Context, config *IncrementalBackupConfig) ([]ChangedFile, error)
// CreateIncrementalBackup creates a new incremental backup
CreateIncrementalBackup(ctx context.Context, config *IncrementalBackupConfig, changedFiles []ChangedFile) error
// RestoreIncremental restores an incremental backup on top of a base backup
RestoreIncremental(ctx context.Context, baseBackupPath, incrementalPath, targetDir string) error
}
@@ -101,8 +101,8 @@ type BackupInfo struct {
Timestamp time.Time `json:"timestamp"`
Size int64 `json:"size"`
Checksum string `json:"checksum"`
// New fields for incremental support
BackupType BackupType `json:"backup_type"` // "full" or "incremental"
BackupType BackupType `json:"backup_type"` // "full" or "incremental"
Incremental *IncrementalMetadata `json:"incremental,omitempty"` // Only present for incremental backups
}

View File

@@ -42,7 +42,7 @@ func (e *MySQLIncrementalEngine) FindChangedFiles(ctx context.Context, config *I
return nil, fmt.Errorf("failed to load base backup info: %w", err)
}
// Validate base backup is full backup
// Validate base backup is full backup
if baseInfo.BackupType != "" && baseInfo.BackupType != "full" {
return nil, fmt.Errorf("base backup must be a full backup, got: %s", baseInfo.BackupType)
}
@@ -52,7 +52,7 @@ func (e *MySQLIncrementalEngine) FindChangedFiles(ctx context.Context, config *I
// Scan data directory for changed files
var changedFiles []ChangedFile
err = filepath.Walk(config.DataDirectory, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
@@ -199,7 +199,7 @@ func (e *MySQLIncrementalEngine) CreateIncrementalBackup(ctx context.Context, co
// Generate output filename: dbname_incr_TIMESTAMP.tar.gz
timestamp := time.Now().Format("20060102_150405")
outputFile := filepath.Join(filepath.Dir(config.BaseBackupPath),
outputFile := filepath.Join(filepath.Dir(config.BaseBackupPath),
fmt.Sprintf("%s_incr_%s.tar.gz", baseInfo.Database, timestamp))
e.log.Info("Creating incremental archive", "output", outputFile)
@@ -229,19 +229,19 @@ func (e *MySQLIncrementalEngine) CreateIncrementalBackup(ctx context.Context, co
// Create incremental metadata
metadata := &metadata.BackupMetadata{
Version: "2.3.0",
Timestamp: time.Now(),
Database: baseInfo.Database,
DatabaseType: baseInfo.DatabaseType,
Host: baseInfo.Host,
Port: baseInfo.Port,
User: baseInfo.User,
BackupFile: outputFile,
SizeBytes: stat.Size(),
SHA256: checksum,
Compression: "gzip",
BackupType: "incremental",
BaseBackup: filepath.Base(config.BaseBackupPath),
Version: "2.3.0",
Timestamp: time.Now(),
Database: baseInfo.Database,
DatabaseType: baseInfo.DatabaseType,
Host: baseInfo.Host,
Port: baseInfo.Port,
User: baseInfo.User,
BackupFile: outputFile,
SizeBytes: stat.Size(),
SHA256: checksum,
Compression: "gzip",
BackupType: "incremental",
BaseBackup: filepath.Base(config.BaseBackupPath),
Incremental: &metadata.IncrementalMetadata{
BaseBackupID: baseInfo.SHA256,
BaseBackupPath: filepath.Base(config.BaseBackupPath),

View File

@@ -40,7 +40,7 @@ func (e *PostgresIncrementalEngine) FindChangedFiles(ctx context.Context, config
return nil, fmt.Errorf("failed to load base backup info: %w", err)
}
// Validate base backup is full backup
// Validate base backup is full backup
if baseInfo.BackupType != "" && baseInfo.BackupType != "full" {
return nil, fmt.Errorf("base backup must be a full backup, got: %s", baseInfo.BackupType)
}
@@ -50,7 +50,7 @@ func (e *PostgresIncrementalEngine) FindChangedFiles(ctx context.Context, config
// Scan data directory for changed files
var changedFiles []ChangedFile
err = filepath.Walk(config.DataDirectory, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
@@ -160,7 +160,7 @@ func (e *PostgresIncrementalEngine) CreateIncrementalBackup(ctx context.Context,
// Generate output filename: dbname_incr_TIMESTAMP.tar.gz
timestamp := time.Now().Format("20060102_150405")
outputFile := filepath.Join(filepath.Dir(config.BaseBackupPath),
outputFile := filepath.Join(filepath.Dir(config.BaseBackupPath),
fmt.Sprintf("%s_incr_%s.tar.gz", baseInfo.Database, timestamp))
e.log.Info("Creating incremental archive", "output", outputFile)
@@ -190,19 +190,19 @@ func (e *PostgresIncrementalEngine) CreateIncrementalBackup(ctx context.Context,
// Create incremental metadata
metadata := &metadata.BackupMetadata{
Version: "2.2.0",
Timestamp: time.Now(),
Database: baseInfo.Database,
DatabaseType: baseInfo.DatabaseType,
Host: baseInfo.Host,
Port: baseInfo.Port,
User: baseInfo.User,
BackupFile: outputFile,
SizeBytes: stat.Size(),
SHA256: checksum,
Compression: "gzip",
BackupType: "incremental",
BaseBackup: filepath.Base(config.BaseBackupPath),
Version: "2.2.0",
Timestamp: time.Now(),
Database: baseInfo.Database,
DatabaseType: baseInfo.DatabaseType,
Host: baseInfo.Host,
Port: baseInfo.Port,
User: baseInfo.User,
BackupFile: outputFile,
SizeBytes: stat.Size(),
SHA256: checksum,
Compression: "gzip",
BackupType: "incremental",
BaseBackup: filepath.Base(config.BaseBackupPath),
Incremental: &metadata.IncrementalMetadata{
BaseBackupID: baseInfo.SHA256,
BaseBackupPath: filepath.Base(config.BaseBackupPath),
@@ -329,7 +329,7 @@ func (e *PostgresIncrementalEngine) CalculateFileChecksum(path string) (string,
// buildBackupChain constructs the backup chain from base backup to current incremental
func buildBackupChain(baseInfo *metadata.BackupMetadata, currentBackup string) []string {
chain := []string{}
// If base backup has a chain (is itself incremental), use that
if baseInfo.Incremental != nil && len(baseInfo.Incremental.BackupChain) > 0 {
chain = append(chain, baseInfo.Incremental.BackupChain...)
@@ -337,9 +337,9 @@ func buildBackupChain(baseInfo *metadata.BackupMetadata, currentBackup string) [
// Base is a full backup, start chain with it
chain = append(chain, filepath.Base(baseInfo.BackupFile))
}
// Add current incremental to chain
chain = append(chain, currentBackup)
return chain
}

View File

@@ -67,7 +67,7 @@ func TestIncrementalBackupRestore(t *testing.T) {
// Step 2: Create base (full) backup
t.Log("Step 2: Creating base backup...")
baseBackupPath := filepath.Join(backupDir, "testdb_base.tar.gz")
// Manually create base backup for testing
baseConfig := &IncrementalBackupConfig{
DataDirectory: dataDir,
@@ -192,7 +192,7 @@ func TestIncrementalBackupRestore(t *testing.T) {
var incrementalBackupPath string
for _, entry := range entries {
if !entry.IsDir() && filepath.Ext(entry.Name()) == ".gz" &&
if !entry.IsDir() && filepath.Ext(entry.Name()) == ".gz" &&
entry.Name() != filepath.Base(baseBackupPath) {
incrementalBackupPath = filepath.Join(backupDir, entry.Name())
break
@@ -209,7 +209,7 @@ func TestIncrementalBackupRestore(t *testing.T) {
incrStat, _ := os.Stat(incrementalBackupPath)
t.Logf("Base backup size: %d bytes", baseStat.Size())
t.Logf("Incremental backup size: %d bytes", incrStat.Size())
// Note: For tiny test files, incremental might be larger due to tar.gz overhead
// In real-world scenarios with larger files, incremental would be much smaller
t.Logf("Incremental contains %d changed files out of %d total",
@@ -273,7 +273,7 @@ func TestIncrementalBackupErrors(t *testing.T) {
// Create a dummy base backup
baseBackupPath := filepath.Join(tempDir, "base.tar.gz")
os.WriteFile(baseBackupPath, []byte("dummy"), 0644)
// Create metadata with current timestamp
baseMetadata := createTestMetadata("testdb", baseBackupPath, 100, "dummychecksum", "full", nil)
saveTestMetadata(baseBackupPath, baseMetadata)
@@ -333,7 +333,7 @@ func saveTestMetadata(backupPath string, metadata map[string]interface{}) error
metadata["timestamp"],
metadata["backup_type"],
)
_, err = file.WriteString(content)
return err
}

View File

@@ -23,7 +23,7 @@ func NewDiskSpaceCache(ttl time.Duration) *DiskSpaceCache {
if ttl <= 0 {
ttl = 30 * time.Second // Default 30 second cache
}
return &DiskSpaceCache{
cache: make(map[string]*cacheEntry),
cacheTTL: ttl,
@@ -40,17 +40,17 @@ func (c *DiskSpaceCache) Get(path string) *DiskSpaceCheck {
}
}
c.mu.RUnlock()
// Cache miss or expired - perform new check
check := CheckDiskSpace(path)
c.mu.Lock()
c.cache[path] = &cacheEntry{
check: check,
timestamp: time.Now(),
}
c.mu.Unlock()
return check
}
@@ -65,7 +65,7 @@ func (c *DiskSpaceCache) Clear() {
func (c *DiskSpaceCache) Cleanup() {
c.mu.Lock()
defer c.mu.Unlock()
now := time.Now()
for path, entry := range c.cache {
if now.Sub(entry.timestamp) >= c.cacheTTL {
@@ -80,4 +80,4 @@ var globalDiskCache = NewDiskSpaceCache(30 * time.Second)
// CheckDiskSpaceCached performs cached disk space check
func CheckDiskSpaceCached(path string) *DiskSpaceCheck {
return globalDiskCache.Get(path)
}
}

View File

@@ -54,7 +54,7 @@ func CheckDiskSpace(path string) *DiskSpaceCheck {
func CheckDiskSpaceForRestore(path string, archiveSize int64) *DiskSpaceCheck {
check := CheckDiskSpace(path)
requiredBytes := uint64(archiveSize) * 4 // Account for decompression
// Override status based on required space
if check.AvailableBytes < requiredBytes {
check.Critical = true
@@ -64,7 +64,7 @@ func CheckDiskSpaceForRestore(path string, archiveSize int64) *DiskSpaceCheck {
check.Warning = true
check.Sufficient = false
}
return check
}
@@ -134,7 +134,3 @@ func EstimateBackupSize(databaseSize uint64, compressionLevel int) uint64 {
// Add 10% buffer for metadata, indexes, etc.
return uint64(float64(estimated) * 1.1)
}

View File

@@ -54,7 +54,7 @@ func CheckDiskSpace(path string) *DiskSpaceCheck {
func CheckDiskSpaceForRestore(path string, archiveSize int64) *DiskSpaceCheck {
check := CheckDiskSpace(path)
requiredBytes := uint64(archiveSize) * 4 // Account for decompression
// Override status based on required space
if check.AvailableBytes < requiredBytes {
check.Critical = true
@@ -64,7 +64,7 @@ func CheckDiskSpaceForRestore(path string, archiveSize int64) *DiskSpaceCheck {
check.Warning = true
check.Sufficient = false
}
return check
}
@@ -108,4 +108,4 @@ func FormatDiskSpaceMessage(check *DiskSpaceCheck) string {
}
return msg
}
}

View File

@@ -37,7 +37,7 @@ func CheckDiskSpace(path string) *DiskSpaceCheck {
func CheckDiskSpaceForRestore(path string, archiveSize int64) *DiskSpaceCheck {
check := CheckDiskSpace(path)
requiredBytes := uint64(archiveSize) * 4 // Account for decompression
// Override status based on required space
if check.AvailableBytes < requiredBytes {
check.Critical = true
@@ -47,7 +47,7 @@ func CheckDiskSpaceForRestore(path string, archiveSize int64) *DiskSpaceCheck {
check.Warning = true
check.Sufficient = false
}
return check
}

View File

@@ -29,7 +29,7 @@ func CheckDiskSpace(path string) *DiskSpaceCheck {
// If no volume, try current directory
vol = "."
}
var freeBytesAvailable, totalNumberOfBytes, totalNumberOfFreeBytes uint64
// Call Windows API
@@ -73,7 +73,7 @@ func CheckDiskSpace(path string) *DiskSpaceCheck {
func CheckDiskSpaceForRestore(path string, archiveSize int64) *DiskSpaceCheck {
check := CheckDiskSpace(path)
requiredBytes := uint64(archiveSize) * 4 // Account for decompression
// Override status based on required space
if check.AvailableBytes < requiredBytes {
check.Critical = true
@@ -83,7 +83,7 @@ func CheckDiskSpaceForRestore(path string, archiveSize int64) *DiskSpaceCheck {
check.Warning = true
check.Sufficient = false
}
return check
}
@@ -128,4 +128,3 @@ func FormatDiskSpaceMessage(check *DiskSpaceCheck) string {
return msg
}

View File

@@ -8,10 +8,10 @@ import (
// Compiled regex patterns for robust error matching
var errorPatterns = map[string]*regexp.Regexp{
"already_exists": regexp.MustCompile(`(?i)(already exists|duplicate key|unique constraint|relation.*exists)`),
"disk_full": regexp.MustCompile(`(?i)(no space left|disk.*full|write.*failed.*space|insufficient.*space)`),
"lock_exhaustion": regexp.MustCompile(`(?i)(max_locks_per_transaction|out of shared memory|lock.*exhausted|could not open large object)`),
"syntax_error": regexp.MustCompile(`(?i)syntax error at.*line \d+`),
"already_exists": regexp.MustCompile(`(?i)(already exists|duplicate key|unique constraint|relation.*exists)`),
"disk_full": regexp.MustCompile(`(?i)(no space left|disk.*full|write.*failed.*space|insufficient.*space)`),
"lock_exhaustion": regexp.MustCompile(`(?i)(max_locks_per_transaction|out of shared memory|lock.*exhausted|could not open large object)`),
"syntax_error": regexp.MustCompile(`(?i)syntax error at.*line \d+`),
"permission_denied": regexp.MustCompile(`(?i)(permission denied|must be owner|access denied)`),
"connection_failed": regexp.MustCompile(`(?i)(connection refused|could not connect|no pg_hba\.conf entry)`),
"version_mismatch": regexp.MustCompile(`(?i)(version mismatch|incompatible|unsupported version)`),
@@ -135,9 +135,9 @@ func ClassifyError(errorMsg string) *ErrorClassification {
}
// Lock exhaustion errors
if strings.Contains(lowerMsg, "max_locks_per_transaction") ||
strings.Contains(lowerMsg, "out of shared memory") ||
strings.Contains(lowerMsg, "could not open large object") {
if strings.Contains(lowerMsg, "max_locks_per_transaction") ||
strings.Contains(lowerMsg, "out of shared memory") ||
strings.Contains(lowerMsg, "could not open large object") {
return &ErrorClassification{
Type: "critical",
Category: "locks",
@@ -173,9 +173,9 @@ func ClassifyError(errorMsg string) *ErrorClassification {
}
// Connection errors
if strings.Contains(lowerMsg, "connection refused") ||
strings.Contains(lowerMsg, "could not connect") ||
strings.Contains(lowerMsg, "no pg_hba.conf entry") {
if strings.Contains(lowerMsg, "connection refused") ||
strings.Contains(lowerMsg, "could not connect") ||
strings.Contains(lowerMsg, "no pg_hba.conf entry") {
return &ErrorClassification{
Type: "critical",
Category: "network",

View File

@@ -26,4 +26,4 @@ func formatBytes(bytes uint64) string {
exp++
}
return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp])
}
}

View File

@@ -41,7 +41,7 @@ func (pm *ProcessManager) Track(proc *os.Process) {
pm.mu.Lock()
defer pm.mu.Unlock()
pm.processes[proc.Pid] = proc
// Auto-cleanup when process exits
go func() {
proc.Wait()
@@ -59,14 +59,14 @@ func (pm *ProcessManager) KillAll() error {
procs = append(procs, proc)
}
pm.mu.RUnlock()
var errors []error
for _, proc := range procs {
if err := proc.Kill(); err != nil {
errors = append(errors, err)
}
}
if len(errors) > 0 {
return fmt.Errorf("failed to kill %d processes: %v", len(errors), errors)
}
@@ -82,18 +82,18 @@ func (pm *ProcessManager) Close() error {
// KillOrphanedProcesses finds and kills any orphaned pg_dump, pg_restore, gzip, or pigz processes
func KillOrphanedProcesses(log logger.Logger) error {
processNames := []string{"pg_dump", "pg_restore", "gzip", "pigz", "gunzip"}
myPID := os.Getpid()
var killed []string
var errors []error
for _, procName := range processNames {
pids, err := findProcessesByName(procName, myPID)
if err != nil {
log.Warn("Failed to search for processes", "process", procName, "error", err)
continue
}
for _, pid := range pids {
if err := killProcessGroup(pid); err != nil {
errors = append(errors, fmt.Errorf("failed to kill %s (PID %d): %w", procName, pid, err))
@@ -102,15 +102,15 @@ func KillOrphanedProcesses(log logger.Logger) error {
}
}
}
if len(killed) > 0 {
log.Info("Cleaned up orphaned processes", "count", len(killed), "processes", strings.Join(killed, ", "))
}
if len(errors) > 0 {
return fmt.Errorf("some processes could not be killed: %v", errors)
}
return nil
}
@@ -126,27 +126,27 @@ func findProcessesByName(name string, excludePID int) ([]int, error) {
}
return nil, err
}
var pids []int
lines := strings.Split(strings.TrimSpace(string(output)), "\n")
for _, line := range lines {
if line == "" {
continue
}
pid, err := strconv.Atoi(line)
if err != nil {
continue
}
// Don't kill our own process
if pid == excludePID {
continue
}
pids = append(pids, pid)
}
return pids, nil
}
@@ -158,17 +158,17 @@ func killProcessGroup(pid int) error {
// Process might already be gone
return nil
}
// Kill the entire process group (negative PID kills the group)
// This catches pipelines like "pg_dump | gzip"
if err := syscall.Kill(-pgid, syscall.SIGTERM); err != nil {
// If SIGTERM fails, try SIGKILL
syscall.Kill(-pgid, syscall.SIGKILL)
}
// Also kill the specific PID in case it's not in a group
syscall.Kill(pid, syscall.SIGTERM)
return nil
}
@@ -186,21 +186,21 @@ func KillCommandGroup(cmd *exec.Cmd) error {
if cmd.Process == nil {
return nil
}
pid := cmd.Process.Pid
// Get the process group ID
pgid, err := syscall.Getpgid(pid)
if err != nil {
// Process might already be gone
return nil
}
// Kill the entire process group
if err := syscall.Kill(-pgid, syscall.SIGTERM); err != nil {
// If SIGTERM fails, use SIGKILL
syscall.Kill(-pgid, syscall.SIGKILL)
}
return nil
}

View File

@@ -17,18 +17,18 @@ import (
// KillOrphanedProcesses finds and kills any orphaned pg_dump, pg_restore, gzip, or pigz processes (Windows implementation)
func KillOrphanedProcesses(log logger.Logger) error {
processNames := []string{"pg_dump.exe", "pg_restore.exe", "gzip.exe", "pigz.exe", "gunzip.exe"}
myPID := os.Getpid()
var killed []string
var errors []error
for _, procName := range processNames {
pids, err := findProcessesByNameWindows(procName, myPID)
if err != nil {
log.Warn("Failed to search for processes", "process", procName, "error", err)
continue
}
for _, pid := range pids {
if err := killProcessWindows(pid); err != nil {
errors = append(errors, fmt.Errorf("failed to kill %s (PID %d): %w", procName, pid, err))
@@ -37,15 +37,15 @@ func KillOrphanedProcesses(log logger.Logger) error {
}
}
}
if len(killed) > 0 {
log.Info("Cleaned up orphaned processes", "count", len(killed), "processes", strings.Join(killed, ", "))
}
if len(errors) > 0 {
return fmt.Errorf("some processes could not be killed: %v", errors)
}
return nil
}
@@ -58,35 +58,35 @@ func findProcessesByNameWindows(name string, excludePID int) ([]int, error) {
// No processes found or command failed
return []int{}, nil
}
var pids []int
lines := strings.Split(strings.TrimSpace(string(output)), "\n")
for _, line := range lines {
if line == "" {
continue
}
// Parse CSV output: "name","pid","session","mem"
fields := strings.Split(line, ",")
if len(fields) < 2 {
continue
}
// Remove quotes from PID field
pidStr := strings.Trim(fields[1], `"`)
pid, err := strconv.Atoi(pidStr)
if err != nil {
continue
}
// Don't kill our own process
if pid == excludePID {
continue
}
pids = append(pids, pid)
}
return pids, nil
}
@@ -111,7 +111,7 @@ func KillCommandGroup(cmd *exec.Cmd) error {
if cmd.Process == nil {
return nil
}
// On Windows, just kill the process directly
return cmd.Process.Kill()
}
}

View File

@@ -11,22 +11,22 @@ import (
type Backend interface {
// Upload uploads a file to cloud storage
Upload(ctx context.Context, localPath, remotePath string, progress ProgressCallback) error
// Download downloads a file from cloud storage
Download(ctx context.Context, remotePath, localPath string, progress ProgressCallback) error
// List lists all backup files in cloud storage
List(ctx context.Context, prefix string) ([]BackupInfo, error)
// Delete deletes a file from cloud storage
Delete(ctx context.Context, remotePath string) error
// Exists checks if a file exists in cloud storage
Exists(ctx context.Context, remotePath string) (bool, error)
// GetSize returns the size of a remote file
GetSize(ctx context.Context, remotePath string) (int64, error)
// Name returns the backend name (e.g., "s3", "azure", "gcs")
Name() string
}
@@ -137,10 +137,10 @@ func (c *Config) Validate() error {
// ProgressReader wraps an io.Reader to track progress
type ProgressReader struct {
reader io.Reader
total int64
read int64
callback ProgressCallback
reader io.Reader
total int64
read int64
callback ProgressCallback
lastReport time.Time
}
@@ -157,7 +157,7 @@ func NewProgressReader(r io.Reader, total int64, callback ProgressCallback) *Pro
func (pr *ProgressReader) Read(p []byte) (int, error) {
n, err := pr.reader.Read(p)
pr.read += int64(n)
// Report progress every 100ms or when complete
now := time.Now()
if now.Sub(pr.lastReport) > 100*time.Millisecond || err == io.EOF {
@@ -166,6 +166,6 @@ func (pr *ProgressReader) Read(p []byte) (int, error) {
}
pr.lastReport = now
}
return n, err
}

View File

@@ -30,11 +30,11 @@ func NewS3Backend(cfg *Config) (*S3Backend, error) {
}
ctx := context.Background()
// Build AWS config
var awsCfg aws.Config
var err error
if cfg.AccessKey != "" && cfg.SecretKey != "" {
// Use explicit credentials
credsProvider := credentials.NewStaticCredentialsProvider(
@@ -42,7 +42,7 @@ func NewS3Backend(cfg *Config) (*S3Backend, error) {
cfg.SecretKey,
"",
)
awsCfg, err = config.LoadDefaultConfig(ctx,
config.WithCredentialsProvider(credsProvider),
config.WithRegion(cfg.Region),
@@ -53,7 +53,7 @@ func NewS3Backend(cfg *Config) (*S3Backend, error) {
config.WithRegion(cfg.Region),
)
}
if err != nil {
return nil, fmt.Errorf("failed to load AWS config: %w", err)
}
@@ -69,7 +69,7 @@ func NewS3Backend(cfg *Config) (*S3Backend, error) {
}
},
}
client := s3.NewFromConfig(awsCfg, clientOptions...)
return &S3Backend{
@@ -114,7 +114,7 @@ func (s *S3Backend) Upload(ctx context.Context, localPath, remotePath string, pr
// Use multipart upload for files larger than 100MB
const multipartThreshold = 100 * 1024 * 1024 // 100 MB
if fileSize > multipartThreshold {
return s.uploadMultipart(ctx, file, key, fileSize, progress)
}
@@ -137,7 +137,7 @@ func (s *S3Backend) uploadSimple(ctx context.Context, file *os.File, key string,
Key: aws.String(key),
Body: reader,
})
if err != nil {
return fmt.Errorf("failed to upload to S3: %w", err)
}
@@ -151,10 +151,10 @@ func (s *S3Backend) uploadMultipart(ctx context.Context, file *os.File, key stri
uploader := manager.NewUploader(s.client, func(u *manager.Uploader) {
// Part size: 10MB
u.PartSize = 10 * 1024 * 1024
// Upload up to 10 parts concurrently
u.Concurrency = 10
// Leave parts on failure for debugging
u.LeavePartsOnError = false
})
@@ -245,10 +245,10 @@ func (s *S3Backend) List(ctx context.Context, prefix string) ([]BackupInfo, erro
if obj.Key == nil {
continue
}
key := *obj.Key
name := filepath.Base(key)
// Skip if it's just a directory marker
if strings.HasSuffix(key, "/") {
continue
@@ -260,11 +260,11 @@ func (s *S3Backend) List(ctx context.Context, prefix string) ([]BackupInfo, erro
Size: *obj.Size,
LastModified: *obj.LastModified,
}
if obj.ETag != nil {
info.ETag = *obj.ETag
}
if obj.StorageClass != "" {
info.StorageClass = string(obj.StorageClass)
} else {
@@ -285,7 +285,7 @@ func (s *S3Backend) Delete(ctx context.Context, remotePath string) error {
Bucket: aws.String(s.bucket),
Key: aws.String(key),
})
if err != nil {
return fmt.Errorf("failed to delete object: %w", err)
}
@@ -301,7 +301,7 @@ func (s *S3Backend) Exists(ctx context.Context, remotePath string) (bool, error)
Bucket: aws.String(s.bucket),
Key: aws.String(key),
})
if err != nil {
// Check if it's a "not found" error
if strings.Contains(err.Error(), "NotFound") || strings.Contains(err.Error(), "404") {
@@ -321,7 +321,7 @@ func (s *S3Backend) GetSize(ctx context.Context, remotePath string) (int64, erro
Bucket: aws.String(s.bucket),
Key: aws.String(key),
})
if err != nil {
return 0, fmt.Errorf("failed to get object metadata: %w", err)
}
@@ -338,7 +338,7 @@ func (s *S3Backend) BucketExists(ctx context.Context) (bool, error) {
_, err := s.client.HeadBucket(ctx, &s3.HeadBucketInput{
Bucket: aws.String(s.bucket),
})
if err != nil {
if strings.Contains(err.Error(), "NotFound") || strings.Contains(err.Error(), "404") {
return false, nil
@@ -355,7 +355,7 @@ func (s *S3Backend) CreateBucket(ctx context.Context) error {
if err != nil {
return err
}
if exists {
return nil
}
@@ -363,7 +363,7 @@ func (s *S3Backend) CreateBucket(ctx context.Context) error {
_, err = s.client.CreateBucket(ctx, &s3.CreateBucketInput{
Bucket: aws.String(s.bucket),
})
if err != nil {
return fmt.Errorf("failed to create bucket: %w", err)
}

View File

@@ -76,7 +76,7 @@ func ParseCloudURI(uri string) (*CloudURI, error) {
if len(parts) >= 3 {
// Extract bucket name (first part)
bucket = parts[0]
// Extract region if present
// bucket.s3.us-west-2.amazonaws.com -> us-west-2
// bucket.s3-us-west-2.amazonaws.com -> us-west-2

View File

@@ -45,11 +45,11 @@ type Config struct {
SampleValue int
// Output options
NoColor bool
Debug bool
LogLevel string
LogFormat string
NoColor bool
Debug bool
LogLevel string
LogFormat string
// Config persistence
NoSaveConfig bool
NoLoadConfig bool
@@ -194,11 +194,11 @@ func New() *Config {
AutoSwap: getEnvBool("AUTO_SWAP", false),
// Security defaults (MEDIUM priority)
RetentionDays: getEnvInt("RETENTION_DAYS", 30), // Keep backups for 30 days
MinBackups: getEnvInt("MIN_BACKUPS", 5), // Keep at least 5 backups
MaxRetries: getEnvInt("MAX_RETRIES", 3), // Maximum 3 retry attempts
AllowRoot: getEnvBool("ALLOW_ROOT", false), // Disallow root by default
CheckResources: getEnvBool("CHECK_RESOURCES", true), // Check resources by default
RetentionDays: getEnvInt("RETENTION_DAYS", 30), // Keep backups for 30 days
MinBackups: getEnvInt("MIN_BACKUPS", 5), // Keep at least 5 backups
MaxRetries: getEnvInt("MAX_RETRIES", 3), // Maximum 3 retry attempts
AllowRoot: getEnvBool("ALLOW_ROOT", false), // Disallow root by default
CheckResources: getEnvBool("CHECK_RESOURCES", true), // Check resources by default
// TUI automation defaults (for testing)
TUIAutoSelect: getEnvInt("TUI_AUTO_SELECT", -1), // -1 = disabled

View File

@@ -39,7 +39,7 @@ type LocalConfig struct {
// LoadLocalConfig loads configuration from .dbbackup.conf in current directory
func LoadLocalConfig() (*LocalConfig, error) {
configPath := filepath.Join(".", ConfigFileName)
data, err := os.ReadFile(configPath)
if err != nil {
if os.IsNotExist(err) {
@@ -54,7 +54,7 @@ func LoadLocalConfig() (*LocalConfig, error) {
for _, line := range lines {
line = strings.TrimSpace(line)
// Skip empty lines and comments
if line == "" || strings.HasPrefix(line, "#") {
continue
@@ -143,7 +143,7 @@ func LoadLocalConfig() (*LocalConfig, error) {
// SaveLocalConfig saves configuration to .dbbackup.conf in current directory
func SaveLocalConfig(cfg *LocalConfig) error {
var sb strings.Builder
sb.WriteString("# dbbackup configuration\n")
sb.WriteString("# This file is auto-generated. Edit with care.\n\n")

View File

@@ -1,24 +1,24 @@
package cpu
import (
"bufio"
"fmt"
"os"
"os/exec"
"runtime"
"strconv"
"strings"
"os"
"os/exec"
"bufio"
)
// CPUInfo holds information about the system CPU
type CPUInfo struct {
LogicalCores int `json:"logical_cores"`
PhysicalCores int `json:"physical_cores"`
Architecture string `json:"architecture"`
ModelName string `json:"model_name"`
MaxFrequency float64 `json:"max_frequency_mhz"`
CacheSize string `json:"cache_size"`
Vendor string `json:"vendor"`
LogicalCores int `json:"logical_cores"`
PhysicalCores int `json:"physical_cores"`
Architecture string `json:"architecture"`
ModelName string `json:"model_name"`
MaxFrequency float64 `json:"max_frequency_mhz"`
CacheSize string `json:"cache_size"`
Vendor string `json:"vendor"`
Features []string `json:"features"`
}
@@ -78,7 +78,7 @@ func (d *Detector) detectLinux(info *CPUInfo) error {
scanner := bufio.NewScanner(file)
physicalCoreCount := make(map[string]bool)
for scanner.Scan() {
line := scanner.Text()
if strings.TrimSpace(line) == "" {
@@ -324,11 +324,11 @@ func (d *Detector) GetCPUInfo() *CPUInfo {
// FormatCPUInfo returns a formatted string representation of CPU info
func (info *CPUInfo) FormatCPUInfo() string {
var sb strings.Builder
sb.WriteString(fmt.Sprintf("Architecture: %s\n", info.Architecture))
sb.WriteString(fmt.Sprintf("Logical Cores: %d\n", info.LogicalCores))
sb.WriteString(fmt.Sprintf("Physical Cores: %d\n", info.PhysicalCores))
if info.ModelName != "" {
sb.WriteString(fmt.Sprintf("Model: %s\n", info.ModelName))
}
@@ -341,6 +341,6 @@ func (info *CPUInfo) FormatCPUInfo() string {
if info.CacheSize != "" {
sb.WriteString(fmt.Sprintf("Cache Size: %s\n", info.CacheSize))
}
return sb.String()
}
}

View File

@@ -8,9 +8,9 @@ import (
"dbbackup/internal/config"
"dbbackup/internal/logger"
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver (pgx - high performance)
_ "github.com/go-sql-driver/mysql" // MySQL driver
_ "github.com/go-sql-driver/mysql" // MySQL driver
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver (pgx - high performance)
)
// Database represents a database connection and operations
@@ -19,43 +19,43 @@ type Database interface {
Connect(ctx context.Context) error
Close() error
Ping(ctx context.Context) error
// Database discovery
ListDatabases(ctx context.Context) ([]string, error)
ListTables(ctx context.Context, database string) ([]string, error)
// Database operations
CreateDatabase(ctx context.Context, name string) error
DropDatabase(ctx context.Context, name string) error
DatabaseExists(ctx context.Context, name string) (bool, error)
// Information
GetVersion(ctx context.Context) (string, error)
GetDatabaseSize(ctx context.Context, database string) (int64, error)
GetTableRowCount(ctx context.Context, database, table string) (int64, error)
// Backup/Restore command building
BuildBackupCommand(database, outputFile string, options BackupOptions) []string
BuildRestoreCommand(database, inputFile string, options RestoreOptions) []string
BuildSampleQuery(database, table string, strategy SampleStrategy) string
// Validation
ValidateBackupTools() error
}
// BackupOptions holds options for backup operations
type BackupOptions struct {
Compression int
Parallel int
Format string // "custom", "plain", "directory"
Blobs bool
SchemaOnly bool
DataOnly bool
NoOwner bool
NoPrivileges bool
Clean bool
IfExists bool
Role string
Compression int
Parallel int
Format string // "custom", "plain", "directory"
Blobs bool
SchemaOnly bool
DataOnly bool
NoOwner bool
NoPrivileges bool
Clean bool
IfExists bool
Role string
}
// RestoreOptions holds options for restore operations
@@ -77,12 +77,12 @@ type SampleStrategy struct {
// DatabaseInfo holds database metadata
type DatabaseInfo struct {
Name string
Size int64
Owner string
Encoding string
Collation string
Tables []TableInfo
Name string
Size int64
Owner string
Encoding string
Collation string
Tables []TableInfo
}
// TableInfo holds table metadata
@@ -105,10 +105,10 @@ func New(cfg *config.Config, log logger.Logger) (Database, error) {
// Common database implementation
type baseDatabase struct {
cfg *config.Config
log logger.Logger
db *sql.DB
dsn string
cfg *config.Config
log logger.Logger
db *sql.DB
dsn string
}
func (b *baseDatabase) Close() error {
@@ -131,4 +131,4 @@ func buildTimeout(ctx context.Context, timeout time.Duration) (context.Context,
timeout = 30 * time.Second
}
return context.WithTimeout(ctx, timeout)
}
}

View File

@@ -387,7 +387,7 @@ func (m *MySQL) buildDSN() string {
"/tmp/mysql.sock",
"/var/lib/mysql/mysql.sock",
}
// Use the first available socket path, fallback to TCP if none found
socketFound := false
for _, socketPath := range socketPaths {
@@ -397,7 +397,7 @@ func (m *MySQL) buildDSN() string {
break
}
}
// If no socket found, use TCP localhost
if !socketFound {
dsn += "tcp(localhost:" + strconv.Itoa(m.cfg.Port) + ")"

View File

@@ -12,7 +12,7 @@ import (
"dbbackup/internal/auth"
"dbbackup/internal/config"
"dbbackup/internal/logger"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/jackc/pgx/v5/stdlib"
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver (pgx)
@@ -43,51 +43,51 @@ func (p *PostgreSQL) Connect(ctx context.Context) error {
p.log.Debug("Loaded password from .pgpass file")
}
}
// Check for authentication mismatch before attempting connection
if mismatch, msg := auth.CheckAuthenticationMismatch(p.cfg); mismatch {
fmt.Println(msg)
return fmt.Errorf("authentication configuration required")
}
// Build PostgreSQL DSN (pgx format)
dsn := p.buildPgxDSN()
p.dsn = dsn
p.log.Debug("Connecting to PostgreSQL with pgx", "dsn", sanitizeDSN(dsn))
// Parse config with optimizations for large databases
config, err := pgxpool.ParseConfig(dsn)
if err != nil {
return fmt.Errorf("failed to parse pgx config: %w", err)
}
// Optimize connection pool for backup workloads
config.MaxConns = 10 // Max concurrent connections
config.MinConns = 2 // Keep minimum connections ready
config.MaxConnLifetime = 0 // No limit on connection lifetime
config.MaxConnIdleTime = 0 // No idle timeout
config.HealthCheckPeriod = 1 * time.Minute // Health check every minute
config.MaxConns = 10 // Max concurrent connections
config.MinConns = 2 // Keep minimum connections ready
config.MaxConnLifetime = 0 // No limit on connection lifetime
config.MaxConnIdleTime = 0 // No idle timeout
config.HealthCheckPeriod = 1 * time.Minute // Health check every minute
// Optimize for large query results (BLOB data)
config.ConnConfig.RuntimeParams["work_mem"] = "64MB"
config.ConnConfig.RuntimeParams["maintenance_work_mem"] = "256MB"
// Create connection pool
pool, err := pgxpool.NewWithConfig(ctx, config)
if err != nil {
return fmt.Errorf("failed to create pgx pool: %w", err)
}
// Test connection
if err := pool.Ping(ctx); err != nil {
pool.Close()
return fmt.Errorf("failed to ping PostgreSQL: %w", err)
}
// Also create stdlib connection for compatibility
db := stdlib.OpenDBFromPool(pool)
p.pool = pool
p.db = db
p.log.Info("Connected to PostgreSQL successfully", "driver", "pgx", "max_conns", config.MaxConns)
@@ -111,17 +111,17 @@ func (p *PostgreSQL) ListDatabases(ctx context.Context) ([]string, error) {
if p.db == nil {
return nil, fmt.Errorf("not connected to database")
}
query := `SELECT datname FROM pg_database
WHERE datistemplate = false
ORDER BY datname`
rows, err := p.db.QueryContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to query databases: %w", err)
}
defer rows.Close()
var databases []string
for rows.Next() {
var name string
@@ -130,7 +130,7 @@ func (p *PostgreSQL) ListDatabases(ctx context.Context) ([]string, error) {
}
databases = append(databases, name)
}
return databases, rows.Err()
}
@@ -139,18 +139,18 @@ func (p *PostgreSQL) ListTables(ctx context.Context, database string) ([]string,
if p.db == nil {
return nil, fmt.Errorf("not connected to database")
}
query := `SELECT schemaname||'.'||tablename as full_name
FROM pg_tables
WHERE schemaname NOT IN ('information_schema', 'pg_catalog', 'pg_toast')
ORDER BY schemaname, tablename`
rows, err := p.db.QueryContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to query tables: %w", err)
}
defer rows.Close()
var tables []string
for rows.Next() {
var name string
@@ -159,7 +159,7 @@ func (p *PostgreSQL) ListTables(ctx context.Context, database string) ([]string,
}
tables = append(tables, name)
}
return tables, rows.Err()
}
@@ -168,14 +168,14 @@ func (p *PostgreSQL) CreateDatabase(ctx context.Context, name string) error {
if p.db == nil {
return fmt.Errorf("not connected to database")
}
// PostgreSQL doesn't support CREATE DATABASE in transactions or prepared statements
query := fmt.Sprintf("CREATE DATABASE %s", name)
_, err := p.db.ExecContext(ctx, query)
if err != nil {
return fmt.Errorf("failed to create database %s: %w", name, err)
}
p.log.Info("Created database", "name", name)
return nil
}
@@ -185,14 +185,14 @@ func (p *PostgreSQL) DropDatabase(ctx context.Context, name string) error {
if p.db == nil {
return fmt.Errorf("not connected to database")
}
// Force drop connections and drop database
query := fmt.Sprintf("DROP DATABASE IF EXISTS %s", name)
_, err := p.db.ExecContext(ctx, query)
if err != nil {
return fmt.Errorf("failed to drop database %s: %w", name, err)
}
p.log.Info("Dropped database", "name", name)
return nil
}
@@ -202,14 +202,14 @@ func (p *PostgreSQL) DatabaseExists(ctx context.Context, name string) (bool, err
if p.db == nil {
return false, fmt.Errorf("not connected to database")
}
query := `SELECT EXISTS(SELECT 1 FROM pg_database WHERE datname = $1)`
var exists bool
err := p.db.QueryRowContext(ctx, query, name).Scan(&exists)
if err != nil {
return false, fmt.Errorf("failed to check database existence: %w", err)
}
return exists, nil
}
@@ -218,13 +218,13 @@ func (p *PostgreSQL) GetVersion(ctx context.Context) (string, error) {
if p.db == nil {
return "", fmt.Errorf("not connected to database")
}
var version string
err := p.db.QueryRowContext(ctx, "SELECT version()").Scan(&version)
if err != nil {
return "", fmt.Errorf("failed to get version: %w", err)
}
return version, nil
}
@@ -233,14 +233,14 @@ func (p *PostgreSQL) GetDatabaseSize(ctx context.Context, database string) (int6
if p.db == nil {
return 0, fmt.Errorf("not connected to database")
}
query := `SELECT pg_database_size($1)`
var size int64
err := p.db.QueryRowContext(ctx, query, database).Scan(&size)
if err != nil {
return 0, fmt.Errorf("failed to get database size: %w", err)
}
return size, nil
}
@@ -249,16 +249,16 @@ func (p *PostgreSQL) GetTableRowCount(ctx context.Context, database, table strin
if p.db == nil {
return 0, fmt.Errorf("not connected to database")
}
// Use pg_stat_user_tables for approximate count (faster)
parts := strings.Split(table, ".")
if len(parts) != 2 {
return 0, fmt.Errorf("table name must be in format schema.table")
}
query := `SELECT COALESCE(n_tup_ins, 0) FROM pg_stat_user_tables
WHERE schemaname = $1 AND relname = $2`
var count int64
err := p.db.QueryRowContext(ctx, query, parts[0], parts[1]).Scan(&count)
if err != nil {
@@ -269,14 +269,14 @@ func (p *PostgreSQL) GetTableRowCount(ctx context.Context, database, table strin
return 0, fmt.Errorf("failed to get table row count: %w", err)
}
}
return count, nil
}
// BuildBackupCommand builds pg_dump command
func (p *PostgreSQL) BuildBackupCommand(database, outputFile string, options BackupOptions) []string {
cmd := []string{"pg_dump"}
// Connection parameters
if p.cfg.Host != "localhost" {
cmd = append(cmd, "-h", p.cfg.Host)
@@ -284,27 +284,27 @@ func (p *PostgreSQL) BuildBackupCommand(database, outputFile string, options Bac
cmd = append(cmd, "--no-password")
}
cmd = append(cmd, "-U", p.cfg.User)
// Format and compression
if options.Format != "" {
cmd = append(cmd, "--format="+options.Format)
} else {
cmd = append(cmd, "--format=custom")
}
// For plain format with compression==0, we want to stream to stdout so external
// compression can be used. Set a marker flag so caller knows to pipe stdout.
usesStdout := (options.Format == "plain" && options.Compression == 0)
if options.Compression > 0 {
cmd = append(cmd, "--compress="+strconv.Itoa(options.Compression))
}
// Parallel jobs (only for directory format)
if options.Parallel > 1 && options.Format == "directory" {
cmd = append(cmd, "--jobs="+strconv.Itoa(options.Parallel))
}
// Options
if options.Blobs {
cmd = append(cmd, "--blobs")
@@ -324,23 +324,23 @@ func (p *PostgreSQL) BuildBackupCommand(database, outputFile string, options Bac
if options.Role != "" {
cmd = append(cmd, "--role="+options.Role)
}
// Database
cmd = append(cmd, "--dbname="+database)
// Output: For plain format with external compression, omit --file so pg_dump
// writes to stdout (caller will pipe to compressor). Otherwise specify output file.
if !usesStdout {
cmd = append(cmd, "--file="+outputFile)
}
return cmd
}
// BuildRestoreCommand builds pg_restore command
func (p *PostgreSQL) BuildRestoreCommand(database, inputFile string, options RestoreOptions) []string {
cmd := []string{"pg_restore"}
// Connection parameters
if p.cfg.Host != "localhost" {
cmd = append(cmd, "-h", p.cfg.Host)
@@ -348,12 +348,12 @@ func (p *PostgreSQL) BuildRestoreCommand(database, inputFile string, options Res
cmd = append(cmd, "--no-password")
}
cmd = append(cmd, "-U", p.cfg.User)
// Parallel jobs (incompatible with --single-transaction per PostgreSQL docs)
if options.Parallel > 1 && !options.SingleTransaction {
cmd = append(cmd, "--jobs="+strconv.Itoa(options.Parallel))
}
// Options
if options.Clean {
cmd = append(cmd, "--clean")
@@ -370,23 +370,23 @@ func (p *PostgreSQL) BuildRestoreCommand(database, inputFile string, options Res
if options.SingleTransaction {
cmd = append(cmd, "--single-transaction")
}
// NOTE: --exit-on-error removed because it causes entire restore to fail on
// "already exists" errors. PostgreSQL continues on ignorable errors by default
// and reports error count at the end, which is correct behavior for restores.
// Skip data restore if table creation fails (prevents duplicate data errors)
cmd = append(cmd, "--no-data-for-failed-tables")
// Add verbose flag ONLY if requested (WARNING: can cause OOM on large cluster restores)
if options.Verbose {
cmd = append(cmd, "--verbose")
}
// Database and input
cmd = append(cmd, "--dbname="+database)
cmd = append(cmd, inputFile)
return cmd
}
@@ -395,7 +395,7 @@ func (p *PostgreSQL) BuildSampleQuery(database, table string, strategy SampleStr
switch strategy.Type {
case "ratio":
// Every Nth record using row_number
return fmt.Sprintf("SELECT * FROM (SELECT *, row_number() OVER () as rn FROM %s) t WHERE rn %% %d = 1",
return fmt.Sprintf("SELECT * FROM (SELECT *, row_number() OVER () as rn FROM %s) t WHERE rn %% %d = 1",
table, strategy.Value)
case "percent":
// Percentage sampling using TABLESAMPLE (PostgreSQL 9.5+)
@@ -411,24 +411,24 @@ func (p *PostgreSQL) BuildSampleQuery(database, table string, strategy SampleStr
// ValidateBackupTools checks if required PostgreSQL tools are available
func (p *PostgreSQL) ValidateBackupTools() error {
tools := []string{"pg_dump", "pg_restore", "pg_dumpall", "psql"}
for _, tool := range tools {
if _, err := exec.LookPath(tool); err != nil {
return fmt.Errorf("required tool not found: %s", tool)
}
}
return nil
}
// buildDSN constructs PostgreSQL connection string
func (p *PostgreSQL) buildDSN() string {
dsn := fmt.Sprintf("user=%s dbname=%s", p.cfg.User, p.cfg.Database)
if p.cfg.Password != "" {
dsn += " password=" + p.cfg.Password
}
// For localhost connections, try socket first for peer auth
if p.cfg.Host == "localhost" && p.cfg.Password == "" {
// Try Unix socket connection for peer authentication
@@ -438,7 +438,7 @@ func (p *PostgreSQL) buildDSN() string {
"/tmp",
"/var/lib/pgsql",
}
for _, dir := range socketDirs {
socketPath := fmt.Sprintf("%s/.s.PGSQL.%d", dir, p.cfg.Port)
if _, err := os.Stat(socketPath); err == nil {
@@ -452,7 +452,7 @@ func (p *PostgreSQL) buildDSN() string {
dsn += " host=" + p.cfg.Host
dsn += " port=" + strconv.Itoa(p.cfg.Port)
}
if p.cfg.SSLMode != "" && !p.cfg.Insecure {
// Map SSL modes to supported values for lib/pq
switch strings.ToLower(p.cfg.SSLMode) {
@@ -472,7 +472,7 @@ func (p *PostgreSQL) buildDSN() string {
} else if p.cfg.Insecure {
dsn += " sslmode=disable"
}
return dsn
}
@@ -480,7 +480,7 @@ func (p *PostgreSQL) buildDSN() string {
func (p *PostgreSQL) buildPgxDSN() string {
// pgx supports both URL and keyword=value formats
// Use keyword format for Unix sockets, URL for TCP
// Try Unix socket first for localhost without password
if p.cfg.Host == "localhost" && p.cfg.Password == "" {
socketDirs := []string{
@@ -488,7 +488,7 @@ func (p *PostgreSQL) buildPgxDSN() string {
"/tmp",
"/var/lib/pgsql",
}
for _, dir := range socketDirs {
socketPath := fmt.Sprintf("%s/.s.PGSQL.%d", dir, p.cfg.Port)
if _, err := os.Stat(socketPath); err == nil {
@@ -500,34 +500,34 @@ func (p *PostgreSQL) buildPgxDSN() string {
}
}
}
// Use URL format for TCP connections
var dsn strings.Builder
dsn.WriteString("postgres://")
// User
dsn.WriteString(p.cfg.User)
// Password
if p.cfg.Password != "" {
dsn.WriteString(":")
dsn.WriteString(p.cfg.Password)
}
dsn.WriteString("@")
// Host and Port
dsn.WriteString(p.cfg.Host)
dsn.WriteString(":")
dsn.WriteString(strconv.Itoa(p.cfg.Port))
// Database
dsn.WriteString("/")
dsn.WriteString(p.cfg.Database)
// Parameters
params := make([]string, 0)
// SSL Mode
if p.cfg.Insecure {
params = append(params, "sslmode=disable")
@@ -550,21 +550,21 @@ func (p *PostgreSQL) buildPgxDSN() string {
} else {
params = append(params, "sslmode=prefer")
}
// Connection pool settings
params = append(params, "pool_max_conns=10")
params = append(params, "pool_min_conns=2")
// Performance tuning for large queries
params = append(params, "application_name=dbbackup")
params = append(params, "connect_timeout=30")
// Add parameters to DSN
if len(params) > 0 {
dsn.WriteString("?")
dsn.WriteString(strings.Join(params, "&"))
}
return dsn.String()
}
@@ -573,7 +573,7 @@ func sanitizeDSN(dsn string) string {
// Simple password removal for logging
parts := strings.Split(dsn, " ")
var sanitized []string
for _, part := range parts {
if strings.HasPrefix(part, "password=") {
sanitized = append(sanitized, "password=***")
@@ -581,6 +581,6 @@ func sanitizeDSN(dsn string) string {
sanitized = append(sanitized, part)
}
}
return strings.Join(sanitized, " ")
}
}

View File

@@ -14,38 +14,38 @@ import (
const (
// AES-256 requires 32-byte keys
KeySize = 32
// Nonce size for GCM
NonceSize = 12
// Salt size for key derivation
SaltSize = 32
// PBKDF2 iterations (100,000 is recommended minimum)
PBKDF2Iterations = 100000
// Magic header to identify encrypted files
EncryptedFileMagic = "DBBACKUP_ENCRYPTED_V1"
)
// EncryptionHeader stores metadata for encrypted files
type EncryptionHeader struct {
Magic [22]byte // "DBBACKUP_ENCRYPTED_V1" (21 bytes + null)
Version uint8 // Version number (1)
Algorithm uint8 // Algorithm ID (1 = AES-256-GCM)
Salt [32]byte // Salt for key derivation
Nonce [12]byte // GCM nonce
Reserved [32]byte // Reserved for future use
Magic [22]byte // "DBBACKUP_ENCRYPTED_V1" (21 bytes + null)
Version uint8 // Version number (1)
Algorithm uint8 // Algorithm ID (1 = AES-256-GCM)
Salt [32]byte // Salt for key derivation
Nonce [12]byte // GCM nonce
Reserved [32]byte // Reserved for future use
}
// EncryptionOptions configures encryption behavior
type EncryptionOptions struct {
// Key is the encryption key (32 bytes for AES-256)
Key []byte
// Passphrase for key derivation (alternative to direct key)
Passphrase string
// Salt for key derivation (if empty, will be generated)
Salt []byte
}
@@ -79,7 +79,7 @@ func NewEncryptionWriter(w io.Writer, opts EncryptionOptions) (*EncryptionWriter
// Derive or validate key
var key []byte
var salt []byte
if opts.Passphrase != "" {
// Derive key from passphrase
if len(opts.Salt) == 0 {
@@ -106,25 +106,25 @@ func NewEncryptionWriter(w io.Writer, opts EncryptionOptions) (*EncryptionWriter
} else {
return nil, fmt.Errorf("either Key or Passphrase must be provided")
}
// Create AES cipher
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
// Create GCM mode
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("failed to create GCM: %w", err)
}
// Generate nonce
nonce := make([]byte, NonceSize)
if _, err := rand.Read(nonce); err != nil {
return nil, fmt.Errorf("failed to generate nonce: %w", err)
}
// Write header
header := EncryptionHeader{
Version: 1,
@@ -133,11 +133,11 @@ func NewEncryptionWriter(w io.Writer, opts EncryptionOptions) (*EncryptionWriter
copy(header.Magic[:], []byte(EncryptedFileMagic))
copy(header.Salt[:], salt)
copy(header.Nonce[:], nonce)
if err := writeHeader(w, &header); err != nil {
return nil, fmt.Errorf("failed to write header: %w", err)
}
return &EncryptionWriter{
writer: w,
gcm: gcm,
@@ -160,16 +160,16 @@ func (ew *EncryptionWriter) Write(p []byte) (n int, err error) {
if ew.closed {
return 0, fmt.Errorf("writer is closed")
}
// Accumulate data in buffer
ew.buffer = append(ew.buffer, p...)
// If buffer is large enough, encrypt and write
const chunkSize = 64 * 1024 // 64KB chunks
for len(ew.buffer) >= chunkSize {
chunk := ew.buffer[:chunkSize]
encrypted := ew.gcm.Seal(nil, ew.nonce, chunk, nil)
// Write encrypted chunk size (4 bytes) then chunk
size := uint32(len(encrypted))
sizeBytes := []byte{
@@ -184,15 +184,15 @@ func (ew *EncryptionWriter) Write(p []byte) (n int, err error) {
if _, err := ew.writer.Write(encrypted); err != nil {
return n, err
}
// Move remaining data to start of buffer
ew.buffer = ew.buffer[chunkSize:]
n += chunkSize
// Increment nonce for next chunk
incrementNonce(ew.nonce)
}
return len(p), nil
}
@@ -202,11 +202,11 @@ func (ew *EncryptionWriter) Close() error {
return nil
}
ew.closed = true
// Encrypt and write remaining buffer
if len(ew.buffer) > 0 {
encrypted := ew.gcm.Seal(nil, ew.nonce, ew.buffer, nil)
size := uint32(len(encrypted))
sizeBytes := []byte{
byte(size >> 24),
@@ -221,12 +221,12 @@ func (ew *EncryptionWriter) Close() error {
return err
}
}
// Write final zero-length chunk to signal end
if _, err := ew.writer.Write([]byte{0, 0, 0, 0}); err != nil {
return err
}
return nil
}
@@ -237,22 +237,22 @@ func NewDecryptionReader(r io.Reader, opts EncryptionOptions) (*DecryptionReader
if err != nil {
return nil, fmt.Errorf("failed to read header: %w", err)
}
// Verify magic
if string(header.Magic[:len(EncryptedFileMagic)]) != EncryptedFileMagic {
return nil, fmt.Errorf("not an encrypted backup file")
}
// Verify version
if header.Version != 1 {
return nil, fmt.Errorf("unsupported encryption version: %d", header.Version)
}
// Verify algorithm
if header.Algorithm != 1 {
return nil, fmt.Errorf("unsupported encryption algorithm: %d", header.Algorithm)
}
// Derive or validate key
var key []byte
if opts.Passphrase != "" {
@@ -265,22 +265,22 @@ func NewDecryptionReader(r io.Reader, opts EncryptionOptions) (*DecryptionReader
} else {
return nil, fmt.Errorf("either Key or Passphrase must be provided")
}
// Create AES cipher
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
// Create GCM mode
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("failed to create GCM: %w", err)
}
nonce := make([]byte, NonceSize)
copy(nonce, header.Nonce[:])
return &DecryptionReader{
reader: r,
gcm: gcm,
@@ -306,12 +306,12 @@ func (dr *DecryptionReader) Read(p []byte) (n int, err error) {
dr.buffer = dr.buffer[n:]
return n, nil
}
// If EOF reached, return EOF
if dr.eof {
return 0, io.EOF
}
// Read next chunk size
sizeBytes := make([]byte, 4)
if _, err := io.ReadFull(dr.reader, sizeBytes); err != nil {
@@ -321,36 +321,36 @@ func (dr *DecryptionReader) Read(p []byte) (n int, err error) {
}
return 0, err
}
size := uint32(sizeBytes[0])<<24 | uint32(sizeBytes[1])<<16 | uint32(sizeBytes[2])<<8 | uint32(sizeBytes[3])
// Zero-length chunk signals end of stream
if size == 0 {
dr.eof = true
return 0, io.EOF
}
// Read encrypted chunk
encrypted := make([]byte, size)
if _, err := io.ReadFull(dr.reader, encrypted); err != nil {
return 0, err
}
// Decrypt chunk
decrypted, err := dr.gcm.Open(nil, dr.nonce, encrypted, nil)
if err != nil {
return 0, fmt.Errorf("decryption failed (wrong key?): %w", err)
}
// Increment nonce for next chunk
incrementNonce(dr.nonce)
// Return as much as fits in p, buffer the rest
n = copy(p, decrypted)
if n < len(decrypted) {
dr.buffer = decrypted[n:]
}
return n, nil
}
@@ -364,7 +364,7 @@ func writeHeader(w io.Writer, h *EncryptionHeader) error {
copy(data[24:56], h.Salt[:])
copy(data[56:68], h.Nonce[:])
copy(data[68:100], h.Reserved[:])
_, err := w.Write(data)
return err
}
@@ -374,7 +374,7 @@ func readHeader(r io.Reader) (*EncryptionHeader, error) {
if _, err := io.ReadFull(r, data); err != nil {
return nil, err
}
header := &EncryptionHeader{
Version: data[22],
Algorithm: data[23],
@@ -383,7 +383,7 @@ func readHeader(r io.Reader) (*EncryptionHeader, error) {
copy(header.Salt[:], data[24:56])
copy(header.Nonce[:], data[56:68])
copy(header.Reserved[:], data[68:100])
return header, nil
}

View File

@@ -9,11 +9,11 @@ import (
func TestEncryptDecrypt(t *testing.T) {
// Test data
original := []byte("This is a secret database backup that needs encryption! 🔒")
// Test with passphrase
t.Run("Passphrase", func(t *testing.T) {
var encrypted bytes.Buffer
// Encrypt
writer, err := NewEncryptionWriter(&encrypted, EncryptionOptions{
Passphrase: "super-secret-password",
@@ -21,23 +21,23 @@ func TestEncryptDecrypt(t *testing.T) {
if err != nil {
t.Fatalf("Failed to create encryption writer: %v", err)
}
if _, err := writer.Write(original); err != nil {
t.Fatalf("Failed to write data: %v", err)
}
if err := writer.Close(); err != nil {
t.Fatalf("Failed to close writer: %v", err)
}
t.Logf("Original size: %d bytes", len(original))
t.Logf("Encrypted size: %d bytes", encrypted.Len())
// Verify encrypted data is different from original
if bytes.Contains(encrypted.Bytes(), original) {
t.Error("Encrypted data contains plaintext - encryption failed!")
}
// Decrypt
reader, err := NewDecryptionReader(&encrypted, EncryptionOptions{
Passphrase: "super-secret-password",
@@ -45,30 +45,30 @@ func TestEncryptDecrypt(t *testing.T) {
if err != nil {
t.Fatalf("Failed to create decryption reader: %v", err)
}
decrypted, err := io.ReadAll(reader)
if err != nil {
t.Fatalf("Failed to read decrypted data: %v", err)
}
// Verify decrypted matches original
if !bytes.Equal(decrypted, original) {
t.Errorf("Decrypted data doesn't match original\nOriginal: %s\nDecrypted: %s",
string(original), string(decrypted))
}
t.Log("✅ Encryption/decryption successful")
})
// Test with direct key
t.Run("DirectKey", func(t *testing.T) {
key, err := GenerateKey()
if err != nil {
t.Fatalf("Failed to generate key: %v", err)
}
var encrypted bytes.Buffer
// Encrypt
writer, err := NewEncryptionWriter(&encrypted, EncryptionOptions{
Key: key,
@@ -76,15 +76,15 @@ func TestEncryptDecrypt(t *testing.T) {
if err != nil {
t.Fatalf("Failed to create encryption writer: %v", err)
}
if _, err := writer.Write(original); err != nil {
t.Fatalf("Failed to write data: %v", err)
}
if err := writer.Close(); err != nil {
t.Fatalf("Failed to close writer: %v", err)
}
// Decrypt
reader, err := NewDecryptionReader(&encrypted, EncryptionOptions{
Key: key,
@@ -92,23 +92,23 @@ func TestEncryptDecrypt(t *testing.T) {
if err != nil {
t.Fatalf("Failed to create decryption reader: %v", err)
}
decrypted, err := io.ReadAll(reader)
if err != nil {
t.Fatalf("Failed to read decrypted data: %v", err)
}
if !bytes.Equal(decrypted, original) {
t.Errorf("Decrypted data doesn't match original")
}
t.Log("✅ Direct key encryption/decryption successful")
})
// Test wrong password
t.Run("WrongPassword", func(t *testing.T) {
var encrypted bytes.Buffer
// Encrypt
writer, err := NewEncryptionWriter(&encrypted, EncryptionOptions{
Passphrase: "correct-password",
@@ -116,10 +116,10 @@ func TestEncryptDecrypt(t *testing.T) {
if err != nil {
t.Fatalf("Failed to create encryption writer: %v", err)
}
writer.Write(original)
writer.Close()
// Try to decrypt with wrong password
reader, err := NewDecryptionReader(&encrypted, EncryptionOptions{
Passphrase: "wrong-password",
@@ -127,12 +127,12 @@ func TestEncryptDecrypt(t *testing.T) {
if err != nil {
t.Fatalf("Failed to create decryption reader: %v", err)
}
_, err = io.ReadAll(reader)
if err == nil {
t.Error("Expected decryption to fail with wrong password, but it succeeded")
}
t.Logf("✅ Wrong password correctly rejected: %v", err)
})
}
@@ -143,9 +143,9 @@ func TestLargeData(t *testing.T) {
for i := range original {
original[i] = byte(i % 256)
}
var encrypted bytes.Buffer
// Encrypt
writer, err := NewEncryptionWriter(&encrypted, EncryptionOptions{
Passphrase: "test-password",
@@ -153,19 +153,19 @@ func TestLargeData(t *testing.T) {
if err != nil {
t.Fatalf("Failed to create encryption writer: %v", err)
}
if _, err := writer.Write(original); err != nil {
t.Fatalf("Failed to write data: %v", err)
}
if err := writer.Close(); err != nil {
t.Fatalf("Failed to close writer: %v", err)
}
t.Logf("Original size: %d bytes", len(original))
t.Logf("Encrypted size: %d bytes", encrypted.Len())
t.Logf("Overhead: %.2f%%", float64(encrypted.Len()-len(original))/float64(len(original))*100)
// Decrypt
reader, err := NewDecryptionReader(&encrypted, EncryptionOptions{
Passphrase: "test-password",
@@ -173,16 +173,16 @@ func TestLargeData(t *testing.T) {
if err != nil {
t.Fatalf("Failed to create decryption reader: %v", err)
}
decrypted, err := io.ReadAll(reader)
if err != nil {
t.Fatalf("Failed to read decrypted data: %v", err)
}
if !bytes.Equal(decrypted, original) {
t.Errorf("Large data decryption failed")
}
t.Log("✅ Large data encryption/decryption successful")
}
@@ -192,43 +192,43 @@ func TestKeyGeneration(t *testing.T) {
if err != nil {
t.Fatalf("Failed to generate key: %v", err)
}
if len(key1) != KeySize {
t.Errorf("Key size mismatch: expected %d, got %d", KeySize, len(key1))
}
// Generate another key and verify it's different
key2, err := GenerateKey()
if err != nil {
t.Fatalf("Failed to generate second key: %v", err)
}
if bytes.Equal(key1, key2) {
t.Error("Generated keys are identical - randomness broken!")
}
t.Log("✅ Key generation successful")
}
func TestKeyDerivation(t *testing.T) {
passphrase := "my-secret-passphrase"
salt1, _ := GenerateSalt()
// Derive key twice with same salt - should be identical
key1 := DeriveKey(passphrase, salt1)
key2 := DeriveKey(passphrase, salt1)
if !bytes.Equal(key1, key2) {
t.Error("Key derivation not deterministic")
}
// Derive with different salt - should be different
salt2, _ := GenerateSalt()
key3 := DeriveKey(passphrase, salt2)
if bytes.Equal(key1, key3) {
t.Error("Different salts produced same key")
}
t.Log("✅ Key derivation successful")
}

View File

@@ -16,7 +16,7 @@ type Logger interface {
Info(msg string, keysAndValues ...interface{})
Warn(msg string, keysAndValues ...interface{})
Error(msg string, keysAndValues ...interface{})
// Structured logging methods
WithFields(fields map[string]interface{}) Logger
WithField(key string, value interface{}) Logger
@@ -113,7 +113,7 @@ func (l *logger) Error(msg string, args ...any) {
}
func (l *logger) Time(msg string, args ...any) {
// Time logs are always at info level with special formatting
// Time logs are always at info level with special formatting
l.logWithFields(logrus.InfoLevel, "[TIME] "+msg, args...)
}
@@ -225,7 +225,7 @@ type CleanFormatter struct{}
// Format implements logrus.Formatter interface
func (f *CleanFormatter) Format(entry *logrus.Entry) ([]byte, error) {
timestamp := entry.Time.Format("2006-01-02T15:04:05")
// Color codes for different log levels
var levelColor, levelText string
switch entry.Level {
@@ -246,24 +246,24 @@ func (f *CleanFormatter) Format(entry *logrus.Entry) ([]byte, error) {
levelText = "INFO "
}
resetColor := "\033[0m"
// Build the message with perfectly aligned columns
var output strings.Builder
// Column 1: Level (with color, fixed width 5 chars)
output.WriteString(levelColor)
output.WriteString(levelText)
output.WriteString(resetColor)
output.WriteString(" ")
// Column 2: Timestamp (fixed format)
output.WriteString("[")
output.WriteString(timestamp)
output.WriteString("] ")
// Column 3: Message
output.WriteString(entry.Message)
// Append important fields in a clean format (skip internal/redundant fields)
if len(entry.Data) > 0 {
// Only show truly important fields, skip verbose ones
@@ -272,7 +272,7 @@ func (f *CleanFormatter) Format(entry *logrus.Entry) ([]byte, error) {
if k == "elapsed" || k == "operation_id" || k == "step" || k == "timestamp" || k == "message" {
continue
}
// Format duration nicely at the end
if k == "duration" {
if str, ok := v.(string); ok {
@@ -280,14 +280,14 @@ func (f *CleanFormatter) Format(entry *logrus.Entry) ([]byte, error) {
}
continue
}
// Only show critical fields (driver, errors, etc)
if k == "driver" || k == "max_conns" || k == "error" || k == "database" {
output.WriteString(fmt.Sprintf(" %s=%v", k, v))
}
}
}
output.WriteString("\n")
return []byte(output.String()), nil
}

View File

@@ -29,11 +29,11 @@ type BackupMetadata struct {
BaseBackup string `json:"base_backup,omitempty"`
Duration float64 `json:"duration_seconds"`
ExtraInfo map[string]string `json:"extra_info,omitempty"`
// Encryption fields (v2.3+)
Encrypted bool `json:"encrypted"` // Whether backup is encrypted
EncryptionAlgorithm string `json:"encryption_algorithm,omitempty"` // e.g., "aes-256-gcm"
// Incremental backup fields (v2.2+)
Incremental *IncrementalMetadata `json:"incremental,omitempty"` // Only present for incremental backups
}
@@ -50,16 +50,16 @@ type IncrementalMetadata struct {
// ClusterMetadata contains metadata for cluster backups
type ClusterMetadata struct {
Version string `json:"version"`
Timestamp time.Time `json:"timestamp"`
ClusterName string `json:"cluster_name"`
DatabaseType string `json:"database_type"`
Host string `json:"host"`
Port int `json:"port"`
Databases []BackupMetadata `json:"databases"`
TotalSize int64 `json:"total_size_bytes"`
Duration float64 `json:"duration_seconds"`
ExtraInfo map[string]string `json:"extra_info,omitempty"`
Version string `json:"version"`
Timestamp time.Time `json:"timestamp"`
ClusterName string `json:"cluster_name"`
DatabaseType string `json:"database_type"`
Host string `json:"host"`
Port int `json:"port"`
Databases []BackupMetadata `json:"databases"`
TotalSize int64 `json:"total_size_bytes"`
Duration float64 `json:"duration_seconds"`
ExtraInfo map[string]string `json:"extra_info,omitempty"`
}
// CalculateSHA256 computes the SHA-256 checksum of a file
@@ -81,7 +81,7 @@ func CalculateSHA256(filePath string) (string, error) {
// Save writes metadata to a .meta.json file
func (m *BackupMetadata) Save() error {
metaPath := m.BackupFile + ".meta.json"
data, err := json.MarshalIndent(m, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal metadata: %w", err)
@@ -97,7 +97,7 @@ func (m *BackupMetadata) Save() error {
// Load reads metadata from a .meta.json file
func Load(backupFile string) (*BackupMetadata, error) {
metaPath := backupFile + ".meta.json"
data, err := os.ReadFile(metaPath)
if err != nil {
return nil, fmt.Errorf("failed to read metadata file: %w", err)
@@ -114,7 +114,7 @@ func Load(backupFile string) (*BackupMetadata, error) {
// SaveCluster writes cluster metadata to a .meta.json file
func (m *ClusterMetadata) Save(targetFile string) error {
metaPath := targetFile + ".meta.json"
data, err := json.MarshalIndent(m, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal cluster metadata: %w", err)
@@ -130,7 +130,7 @@ func (m *ClusterMetadata) Save(targetFile string) error {
// LoadCluster reads cluster metadata from a .meta.json file
func LoadCluster(targetFile string) (*ClusterMetadata, error) {
metaPath := targetFile + ".meta.json"
data, err := os.ReadFile(metaPath)
if err != nil {
return nil, fmt.Errorf("failed to read cluster metadata file: %w", err)
@@ -156,13 +156,13 @@ func ListBackups(dir string) ([]*BackupMetadata, error) {
for _, metaFile := range matches {
// Extract backup file path (remove .meta.json suffix)
backupFile := metaFile[:len(metaFile)-len(".meta.json")]
meta, err := Load(backupFile)
if err != nil {
// Skip invalid metadata files
continue
}
backups = append(backups, meta)
}

View File

@@ -39,7 +39,7 @@ func NewMetricsCollector(log logger.Logger) *MetricsCollector {
func (mc *MetricsCollector) RecordOperation(operation, database string, start time.Time, sizeBytes int64, success bool, errorCount int) {
duration := time.Since(start)
throughput := calculateThroughput(sizeBytes, duration)
metric := OperationMetrics{
Operation: operation,
Database: database,
@@ -50,11 +50,11 @@ func (mc *MetricsCollector) RecordOperation(operation, database string, start ti
ErrorCount: errorCount,
Success: success,
}
mc.mu.Lock()
mc.metrics = append(mc.metrics, metric)
mc.mu.Unlock()
// Log structured metrics
if mc.logger != nil {
fields := map[string]interface{}{
@@ -67,7 +67,7 @@ func (mc *MetricsCollector) RecordOperation(operation, database string, start ti
"error_count": errorCount,
"success": success,
}
if success {
mc.logger.WithFields(fields).Info("Operation completed successfully")
} else {
@@ -80,7 +80,7 @@ func (mc *MetricsCollector) RecordOperation(operation, database string, start ti
func (mc *MetricsCollector) RecordCompressionRatio(operation, database string, ratio float64) {
mc.mu.Lock()
defer mc.mu.Unlock()
// Find and update the most recent matching operation
for i := len(mc.metrics) - 1; i >= 0; i-- {
if mc.metrics[i].Operation == operation && mc.metrics[i].Database == database {
@@ -94,7 +94,7 @@ func (mc *MetricsCollector) RecordCompressionRatio(operation, database string, r
func (mc *MetricsCollector) GetMetrics() []OperationMetrics {
mc.mu.RLock()
defer mc.mu.RUnlock()
result := make([]OperationMetrics, len(mc.metrics))
copy(result, mc.metrics)
return result
@@ -104,15 +104,15 @@ func (mc *MetricsCollector) GetMetrics() []OperationMetrics {
func (mc *MetricsCollector) GetAverages() map[string]interface{} {
mc.mu.RLock()
defer mc.mu.RUnlock()
if len(mc.metrics) == 0 {
return map[string]interface{}{}
}
var totalDuration time.Duration
var totalSize, totalThroughput float64
var successCount, errorCount int
for _, m := range mc.metrics {
totalDuration += m.Duration
totalSize += float64(m.SizeBytes)
@@ -122,15 +122,15 @@ func (mc *MetricsCollector) GetAverages() map[string]interface{} {
}
errorCount += m.ErrorCount
}
count := len(mc.metrics)
return map[string]interface{}{
"total_operations": count,
"success_rate": float64(successCount) / float64(count) * 100,
"avg_duration_ms": totalDuration.Milliseconds() / int64(count),
"avg_size_mb": totalSize / float64(count) / 1024 / 1024,
"avg_throughput_mbps": totalThroughput / float64(count),
"total_errors": errorCount,
"total_operations": count,
"success_rate": float64(successCount) / float64(count) * 100,
"avg_duration_ms": totalDuration.Milliseconds() / int64(count),
"avg_size_mb": totalSize / float64(count) / 1024 / 1024,
"avg_throughput_mbps": totalThroughput / float64(count),
"total_errors": errorCount,
}
}
@@ -159,4 +159,4 @@ var GlobalMetrics *MetricsCollector
// InitGlobalMetrics initializes the global metrics collector
func InitGlobalMetrics(log logger.Logger) {
GlobalMetrics = NewMetricsCollector(log)
}
}

View File

@@ -24,18 +24,18 @@ func NewRecoveryConfigGenerator(log logger.Logger) *RecoveryConfigGenerator {
// RecoveryConfig holds all recovery configuration parameters
type RecoveryConfig struct {
// Core recovery settings
Target *RecoveryTarget
WALArchiveDir string
Target *RecoveryTarget
WALArchiveDir string
RestoreCommand string
// PostgreSQL version
PostgreSQLVersion int // Major version (12, 13, 14, etc.)
// Additional settings
PrimaryConnInfo string // For standby mode
PrimarySlotName string // Replication slot name
PrimaryConnInfo string // For standby mode
PrimarySlotName string // Replication slot name
RecoveryMinApplyDelay string // Min delay for replay
// Paths
DataDir string // PostgreSQL data directory
}
@@ -61,7 +61,7 @@ func (rcg *RecoveryConfigGenerator) generateModernRecoveryConfig(config *Recover
// Create recovery.signal file (empty file that triggers recovery mode)
recoverySignalPath := filepath.Join(config.DataDir, "recovery.signal")
rcg.log.Info("Creating recovery.signal file", "path", recoverySignalPath)
signalFile, err := os.Create(recoverySignalPath)
if err != nil {
return fmt.Errorf("failed to create recovery.signal: %w", err)
@@ -180,7 +180,7 @@ func (rcg *RecoveryConfigGenerator) generateLegacyRecoveryConfig(config *Recover
func (rcg *RecoveryConfigGenerator) generateRestoreCommand(walArchiveDir string) string {
// The restore_command is executed by PostgreSQL to fetch WAL files
// %f = WAL filename, %p = full path to copy WAL file to
// Try multiple extensions (.gz.enc, .enc, .gz, plain)
// This handles compressed and/or encrypted WAL files
return fmt.Sprintf(`bash -c 'for ext in .gz.enc .enc .gz ""; do [ -f "%s/%%f$ext" ] && { [ -z "$ext" ] && cp "%s/%%f$ext" "%%p" || case "$ext" in *.gz.enc) gpg -d "%s/%%f$ext" | gunzip > "%%p" ;; *.enc) gpg -d "%s/%%f$ext" > "%%p" ;; *.gz) gunzip -c "%s/%%f$ext" > "%%p" ;; esac; exit 0; }; done; exit 1'`,
@@ -232,14 +232,14 @@ func (rcg *RecoveryConfigGenerator) ValidateDataDirectory(dataDir string) error
// DetectPostgreSQLVersion detects the PostgreSQL version from the data directory
func (rcg *RecoveryConfigGenerator) DetectPostgreSQLVersion(dataDir string) (int, error) {
pgVersionPath := filepath.Join(dataDir, "PG_VERSION")
content, err := os.ReadFile(pgVersionPath)
if err != nil {
return 0, fmt.Errorf("failed to read PG_VERSION: %w", err)
}
versionStr := strings.TrimSpace(string(content))
// Parse major version (e.g., "14" or "14.2")
parts := strings.Split(versionStr, ".")
if len(parts) == 0 {

View File

@@ -10,10 +10,10 @@ import (
// RecoveryTarget represents a PostgreSQL recovery target
type RecoveryTarget struct {
Type string // "time", "xid", "lsn", "name", "immediate"
Value string // The target value (timestamp, XID, LSN, or restore point name)
Action string // "promote", "pause", "shutdown"
Timeline string // Timeline to follow ("latest" or timeline ID)
Type string // "time", "xid", "lsn", "name", "immediate"
Value string // The target value (timestamp, XID, LSN, or restore point name)
Action string // "promote", "pause", "shutdown"
Timeline string // Timeline to follow ("latest" or timeline ID)
Inclusive bool // Whether target is inclusive (default: true)
}
@@ -128,13 +128,13 @@ func (rt *RecoveryTarget) validateTime() error {
// Try parsing various timestamp formats
formats := []string{
"2006-01-02 15:04:05", // Standard format
"2006-01-02 15:04:05.999999", // With microseconds
"2006-01-02T15:04:05", // ISO 8601
"2006-01-02T15:04:05Z", // ISO 8601 with UTC
"2006-01-02T15:04:05-07:00", // ISO 8601 with timezone
time.RFC3339, // RFC3339
time.RFC3339Nano, // RFC3339 with nanoseconds
"2006-01-02 15:04:05", // Standard format
"2006-01-02 15:04:05.999999", // With microseconds
"2006-01-02T15:04:05", // ISO 8601
"2006-01-02T15:04:05Z", // ISO 8601 with UTC
"2006-01-02T15:04:05-07:00", // ISO 8601 with timezone
time.RFC3339, // RFC3339
time.RFC3339Nano, // RFC3339 with nanoseconds
}
var parseErr error
@@ -283,24 +283,24 @@ func FormatConfigLine(key, value string) string {
// String returns a human-readable representation of the recovery target
func (rt *RecoveryTarget) String() string {
var sb strings.Builder
sb.WriteString("Recovery Target:\n")
sb.WriteString(fmt.Sprintf(" Type: %s\n", rt.Type))
if rt.Type != TargetTypeImmediate {
sb.WriteString(fmt.Sprintf(" Value: %s\n", rt.Value))
}
sb.WriteString(fmt.Sprintf(" Action: %s\n", rt.Action))
if rt.Timeline != "" {
sb.WriteString(fmt.Sprintf(" Timeline: %s\n", rt.Timeline))
}
if rt.Type != TargetTypeImmediate && rt.Type != TargetTypeName {
sb.WriteString(fmt.Sprintf(" Inclusive: %v\n", rt.Inclusive))
}
return sb.String()
}

View File

@@ -284,7 +284,7 @@ func (ro *RestoreOrchestrator) startPostgreSQL(ctx context.Context, opts *Restor
}
cmd := exec.CommandContext(ctx, pgCtl, "-D", opts.TargetDataDir, "-l", filepath.Join(opts.TargetDataDir, "logfile"), "start")
output, err := cmd.CombinedOutput()
if err != nil {
ro.log.Error("PostgreSQL startup failed", "output", string(output))
@@ -321,18 +321,18 @@ func (ro *RestoreOrchestrator) monitorRecovery(ctx context.Context, opts *Restor
pidFile := filepath.Join(opts.TargetDataDir, "postmaster.pid")
if _, err := os.Stat(pidFile); err == nil {
ro.log.Info("✅ PostgreSQL is running")
// Check if recovery files still exist
recoverySignal := filepath.Join(opts.TargetDataDir, "recovery.signal")
recoveryConf := filepath.Join(opts.TargetDataDir, "recovery.conf")
if _, err := os.Stat(recoverySignal); os.IsNotExist(err) {
if _, err := os.Stat(recoveryConf); os.IsNotExist(err) {
ro.log.Info("✅ Recovery completed - PostgreSQL promoted to primary")
return nil
}
}
ro.log.Info("Recovery in progress...")
} else {
ro.log.Info("PostgreSQL not yet started or crashed")

View File

@@ -17,32 +17,32 @@ type DetailedReporter struct {
// OperationStatus represents the status of a backup/restore operation
type OperationStatus struct {
ID string `json:"id"`
Name string `json:"name"`
Type string `json:"type"` // "backup", "restore", "verify"
Status string `json:"status"` // "running", "completed", "failed"
StartTime time.Time `json:"start_time"`
EndTime *time.Time `json:"end_time,omitempty"`
Duration time.Duration `json:"duration"`
Progress int `json:"progress"` // 0-100
Message string `json:"message"`
Details map[string]string `json:"details"`
Steps []StepStatus `json:"steps"`
BytesTotal int64 `json:"bytes_total"`
BytesDone int64 `json:"bytes_done"`
FilesTotal int `json:"files_total"`
FilesDone int `json:"files_done"`
Errors []string `json:"errors,omitempty"`
ID string `json:"id"`
Name string `json:"name"`
Type string `json:"type"` // "backup", "restore", "verify"
Status string `json:"status"` // "running", "completed", "failed"
StartTime time.Time `json:"start_time"`
EndTime *time.Time `json:"end_time,omitempty"`
Duration time.Duration `json:"duration"`
Progress int `json:"progress"` // 0-100
Message string `json:"message"`
Details map[string]string `json:"details"`
Steps []StepStatus `json:"steps"`
BytesTotal int64 `json:"bytes_total"`
BytesDone int64 `json:"bytes_done"`
FilesTotal int `json:"files_total"`
FilesDone int `json:"files_done"`
Errors []string `json:"errors,omitempty"`
}
// StepStatus represents individual steps within an operation
type StepStatus struct {
Name string `json:"name"`
Status string `json:"status"`
StartTime time.Time `json:"start_time"`
EndTime *time.Time `json:"end_time,omitempty"`
Name string `json:"name"`
Status string `json:"status"`
StartTime time.Time `json:"start_time"`
EndTime *time.Time `json:"end_time,omitempty"`
Duration time.Duration `json:"duration"`
Message string `json:"message"`
Message string `json:"message"`
}
// Logger interface for detailed reporting
@@ -79,7 +79,7 @@ func (dr *DetailedReporter) StartOperation(id, name, opType string) *OperationTr
}
dr.operations = append(dr.operations, operation)
if dr.startTime.IsZero() {
dr.startTime = time.Now()
}
@@ -90,9 +90,9 @@ func (dr *DetailedReporter) StartOperation(id, name, opType string) *OperationTr
}
// Log operation start
dr.logger.Info("Operation started",
"id", id,
"name", name,
dr.logger.Info("Operation started",
"id", id,
"name", name,
"type", opType,
"timestamp", operation.StartTime.Format(time.RFC3339))
@@ -117,7 +117,7 @@ func (ot *OperationTracker) UpdateProgress(progress int, message string) {
if ot.reporter.operations[i].ID == ot.operationID {
ot.reporter.operations[i].Progress = progress
ot.reporter.operations[i].Message = message
// Update visual indicator
if ot.reporter.indicator != nil {
progressMsg := fmt.Sprintf("[%d%%] %s", progress, message)
@@ -150,7 +150,7 @@ func (ot *OperationTracker) AddStep(name, message string) *StepTracker {
for i := range ot.reporter.operations {
if ot.reporter.operations[i].ID == ot.operationID {
ot.reporter.operations[i].Steps = append(ot.reporter.operations[i].Steps, step)
// Log step start
ot.reporter.logger.Info("Step started",
"operation_id", ot.operationID,
@@ -190,7 +190,7 @@ func (ot *OperationTracker) SetFileProgress(filesDone, filesTotal int) {
if ot.reporter.operations[i].ID == ot.operationID {
ot.reporter.operations[i].FilesDone = filesDone
ot.reporter.operations[i].FilesTotal = filesTotal
if filesTotal > 0 {
progress := (filesDone * 100) / filesTotal
ot.reporter.operations[i].Progress = progress
@@ -209,25 +209,25 @@ func (ot *OperationTracker) SetByteProgress(bytesDone, bytesTotal int64) {
if ot.reporter.operations[i].ID == ot.operationID {
ot.reporter.operations[i].BytesDone = bytesDone
ot.reporter.operations[i].BytesTotal = bytesTotal
if bytesTotal > 0 {
progress := int((bytesDone * 100) / bytesTotal)
ot.reporter.operations[i].Progress = progress
// Calculate ETA and speed
elapsed := time.Since(ot.reporter.operations[i].StartTime).Seconds()
if elapsed > 0 && bytesDone > 0 {
speed := float64(bytesDone) / elapsed // bytes/sec
remaining := bytesTotal - bytesDone
eta := time.Duration(float64(remaining)/speed) * time.Second
// Update progress message with ETA and speed
if ot.reporter.indicator != nil {
speedStr := formatSpeed(int64(speed))
etaStr := formatDuration(eta)
progressMsg := fmt.Sprintf("[%d%%] %s / %s (%s/s, ETA: %s)",
progress,
formatBytes(bytesDone),
progressMsg := fmt.Sprintf("[%d%%] %s / %s (%s/s, ETA: %s)",
progress,
formatBytes(bytesDone),
formatBytes(bytesTotal),
speedStr,
etaStr)
@@ -253,7 +253,7 @@ func (ot *OperationTracker) Complete(message string) {
ot.reporter.operations[i].EndTime = &now
ot.reporter.operations[i].Duration = now.Sub(ot.reporter.operations[i].StartTime)
ot.reporter.operations[i].Message = message
// Complete visual indicator
if ot.reporter.indicator != nil {
ot.reporter.indicator.Complete(fmt.Sprintf("✅ %s", message))
@@ -283,7 +283,7 @@ func (ot *OperationTracker) Fail(err error) {
ot.reporter.operations[i].Duration = now.Sub(ot.reporter.operations[i].StartTime)
ot.reporter.operations[i].Message = err.Error()
ot.reporter.operations[i].Errors = append(ot.reporter.operations[i].Errors, err.Error())
// Fail visual indicator
if ot.reporter.indicator != nil {
ot.reporter.indicator.Fail(fmt.Sprintf("❌ %s", err.Error()))
@@ -321,7 +321,7 @@ func (st *StepTracker) Complete(message string) {
st.reporter.operations[i].Steps[j].EndTime = &now
st.reporter.operations[i].Steps[j].Duration = now.Sub(st.reporter.operations[i].Steps[j].StartTime)
st.reporter.operations[i].Steps[j].Message = message
// Log step completion
st.reporter.logger.Info("Step completed",
"operation_id", st.operationID,
@@ -351,7 +351,7 @@ func (st *StepTracker) Fail(err error) {
st.reporter.operations[i].Steps[j].EndTime = &now
st.reporter.operations[i].Steps[j].Duration = now.Sub(st.reporter.operations[i].Steps[j].StartTime)
st.reporter.operations[i].Steps[j].Message = err.Error()
// Log step failure
st.reporter.logger.Error("Step failed",
"operation_id", st.operationID,
@@ -428,8 +428,8 @@ type OperationSummary struct {
func (os *OperationSummary) FormatSummary() string {
return fmt.Sprintf(
"📊 Operations Summary:\n"+
" Total: %d | Completed: %d | Failed: %d | Running: %d\n"+
" Total Duration: %s",
" Total: %d | Completed: %d | Failed: %d | Running: %d\n"+
" Total Duration: %s",
os.TotalOperations,
os.CompletedOperations,
os.FailedOperations,
@@ -461,7 +461,7 @@ func formatBytes(bytes int64) string {
GB = 1024 * MB
TB = 1024 * GB
)
switch {
case bytes >= TB:
return fmt.Sprintf("%.2f TB", float64(bytes)/float64(TB))
@@ -483,7 +483,7 @@ func formatSpeed(bytesPerSec int64) string {
MB = 1024 * KB
GB = 1024 * MB
)
switch {
case bytesPerSec >= GB:
return fmt.Sprintf("%.2f GB", float64(bytesPerSec)/float64(GB))
@@ -494,4 +494,4 @@ func formatSpeed(bytesPerSec int64) string {
default:
return fmt.Sprintf("%d B", bytesPerSec)
}
}
}

View File

@@ -42,11 +42,11 @@ func (e *ETAEstimator) GetETA() time.Duration {
if e.itemsComplete == 0 || e.totalItems == 0 {
return 0
}
elapsed := e.GetElapsed()
avgTimePerItem := elapsed / time.Duration(e.itemsComplete)
remainingItems := e.totalItems - e.itemsComplete
return avgTimePerItem * time.Duration(remainingItems)
}
@@ -83,12 +83,12 @@ func (e *ETAEstimator) GetFullStatus(baseMessage string) string {
// No items to track, just show elapsed
return fmt.Sprintf("%s | Elapsed: %s", baseMessage, e.FormatElapsed())
}
if e.itemsComplete == 0 {
// Just started
return fmt.Sprintf("%s | 0/%d | Starting...", baseMessage, e.totalItems)
}
// Full status with progress and ETA
return fmt.Sprintf("%s | %s | Elapsed: %s | ETA: %s",
baseMessage,
@@ -102,44 +102,44 @@ func FormatDuration(d time.Duration) string {
if d < time.Second {
return "< 1s"
}
hours := int(d.Hours())
minutes := int(d.Minutes()) % 60
seconds := int(d.Seconds()) % 60
if hours > 0 {
if minutes > 0 {
return fmt.Sprintf("%dh %dm", hours, minutes)
}
return fmt.Sprintf("%dh", hours)
}
if minutes > 0 {
if seconds > 5 { // Only show seconds if > 5
return fmt.Sprintf("%dm %ds", minutes, seconds)
}
return fmt.Sprintf("%dm", minutes)
}
return fmt.Sprintf("%ds", seconds)
}
// EstimateSizeBasedDuration estimates duration based on size (fallback when no progress tracking)
func EstimateSizeBasedDuration(sizeBytes int64, cores int) time.Duration {
sizeMB := float64(sizeBytes) / (1024 * 1024)
// Base estimate: ~100MB per minute on average hardware
baseMinutes := sizeMB / 100.0
// Adjust for CPU cores (more cores = faster, but not linear)
// Use square root to represent diminishing returns
if cores > 1 {
speedup := 1.0 + (0.3 * (float64(cores) - 1)) // 30% improvement per core
baseMinutes = baseMinutes / speedup
}
// Add 20% buffer for safety
baseMinutes = baseMinutes * 1.2
return time.Duration(baseMinutes * float64(time.Minute))
}

View File

@@ -7,19 +7,19 @@ import (
func TestNewETAEstimator(t *testing.T) {
estimator := NewETAEstimator("Test Operation", 10)
if estimator.operation != "Test Operation" {
t.Errorf("Expected operation 'Test Operation', got '%s'", estimator.operation)
}
if estimator.totalItems != 10 {
t.Errorf("Expected totalItems 10, got %d", estimator.totalItems)
}
if estimator.itemsComplete != 0 {
t.Errorf("Expected itemsComplete 0, got %d", estimator.itemsComplete)
}
if estimator.startTime.IsZero() {
t.Error("Expected startTime to be set")
}
@@ -27,12 +27,12 @@ func TestNewETAEstimator(t *testing.T) {
func TestUpdateProgress(t *testing.T) {
estimator := NewETAEstimator("Test", 10)
estimator.UpdateProgress(5)
if estimator.itemsComplete != 5 {
t.Errorf("Expected itemsComplete 5, got %d", estimator.itemsComplete)
}
estimator.UpdateProgress(8)
if estimator.itemsComplete != 8 {
t.Errorf("Expected itemsComplete 8, got %d", estimator.itemsComplete)
@@ -41,24 +41,24 @@ func TestUpdateProgress(t *testing.T) {
func TestGetProgress(t *testing.T) {
estimator := NewETAEstimator("Test", 10)
// Test 0% progress
if progress := estimator.GetProgress(); progress != 0 {
t.Errorf("Expected 0%%, got %.2f%%", progress)
}
// Test 50% progress
estimator.UpdateProgress(5)
if progress := estimator.GetProgress(); progress != 50.0 {
t.Errorf("Expected 50%%, got %.2f%%", progress)
}
// Test 100% progress
estimator.UpdateProgress(10)
if progress := estimator.GetProgress(); progress != 100.0 {
t.Errorf("Expected 100%%, got %.2f%%", progress)
}
// Test zero division
zeroEstimator := NewETAEstimator("Test", 0)
if progress := zeroEstimator.GetProgress(); progress != 0 {
@@ -68,10 +68,10 @@ func TestGetProgress(t *testing.T) {
func TestGetElapsed(t *testing.T) {
estimator := NewETAEstimator("Test", 10)
// Wait a bit
time.Sleep(100 * time.Millisecond)
elapsed := estimator.GetElapsed()
if elapsed < 100*time.Millisecond {
t.Errorf("Expected elapsed time >= 100ms, got %v", elapsed)
@@ -80,16 +80,16 @@ func TestGetElapsed(t *testing.T) {
func TestGetETA(t *testing.T) {
estimator := NewETAEstimator("Test", 10)
// No progress yet, ETA should be 0
if eta := estimator.GetETA(); eta != 0 {
t.Errorf("Expected ETA 0 for no progress, got %v", eta)
}
// Simulate 5 items completed in 5 seconds
estimator.startTime = time.Now().Add(-5 * time.Second)
estimator.UpdateProgress(5)
eta := estimator.GetETA()
// Should be approximately 5 seconds (5 items remaining at 1 sec/item)
if eta < 4*time.Second || eta > 6*time.Second {
@@ -99,18 +99,18 @@ func TestGetETA(t *testing.T) {
func TestFormatProgress(t *testing.T) {
estimator := NewETAEstimator("Test", 13)
// Test at 0%
if result := estimator.FormatProgress(); result != "0/13 (0%)" {
t.Errorf("Expected '0/13 (0%%)', got '%s'", result)
}
// Test at 38%
estimator.UpdateProgress(5)
if result := estimator.FormatProgress(); result != "5/13 (38%)" {
t.Errorf("Expected '5/13 (38%%)', got '%s'", result)
}
// Test at 100%
estimator.UpdateProgress(13)
if result := estimator.FormatProgress(); result != "13/13 (100%)" {
@@ -125,16 +125,16 @@ func TestFormatDuration(t *testing.T) {
}{
{500 * time.Millisecond, "< 1s"},
{5 * time.Second, "5s"},
{65 * time.Second, "1m"}, // 5 seconds not shown (<=5)
{125 * time.Second, "2m"}, // 5 seconds not shown (<=5)
{65 * time.Second, "1m"}, // 5 seconds not shown (<=5)
{125 * time.Second, "2m"}, // 5 seconds not shown (<=5)
{3 * time.Minute, "3m"},
{3*time.Minute + 3*time.Second, "3m"}, // < 5 seconds not shown
{3*time.Minute + 10*time.Second, "3m 10s"}, // > 5 seconds shown
{3*time.Minute + 3*time.Second, "3m"}, // < 5 seconds not shown
{3*time.Minute + 10*time.Second, "3m 10s"}, // > 5 seconds shown
{90 * time.Minute, "1h 30m"},
{120 * time.Minute, "2h"},
{150 * time.Minute, "2h 30m"},
}
for _, tt := range tests {
result := FormatDuration(tt.duration)
if result != tt.expected {
@@ -145,16 +145,16 @@ func TestFormatDuration(t *testing.T) {
func TestFormatETA(t *testing.T) {
estimator := NewETAEstimator("Test", 10)
// No progress - should show "calculating..."
if result := estimator.FormatETA(); result != "calculating..." {
t.Errorf("Expected 'calculating...', got '%s'", result)
}
// With progress
estimator.startTime = time.Now().Add(-10 * time.Second)
estimator.UpdateProgress(5)
result := estimator.FormatETA()
if result != "~10s remaining" {
t.Errorf("Expected '~10s remaining', got '%s'", result)
@@ -164,7 +164,7 @@ func TestFormatETA(t *testing.T) {
func TestFormatElapsed(t *testing.T) {
estimator := NewETAEstimator("Test", 10)
estimator.startTime = time.Now().Add(-45 * time.Second)
result := estimator.FormatElapsed()
if result != "45s" {
t.Errorf("Expected '45s', got '%s'", result)
@@ -173,23 +173,23 @@ func TestFormatElapsed(t *testing.T) {
func TestGetFullStatus(t *testing.T) {
estimator := NewETAEstimator("Backing up cluster", 13)
// Just started (0 items)
result := estimator.GetFullStatus("Backing up cluster")
if result != "Backing up cluster | 0/13 | Starting..." {
t.Errorf("Unexpected result for 0 items: '%s'", result)
}
// With progress
estimator.startTime = time.Now().Add(-30 * time.Second)
estimator.UpdateProgress(5)
result = estimator.GetFullStatus("Backing up cluster")
// Should contain all components
if len(result) < 50 { // Reasonable minimum length
t.Errorf("Result too short: '%s'", result)
}
// Check it contains key elements (format may vary slightly)
if !contains(result, "5/13") {
t.Errorf("Result missing progress '5/13': '%s'", result)
@@ -208,7 +208,7 @@ func TestGetFullStatus(t *testing.T) {
func TestGetFullStatusWithZeroItems(t *testing.T) {
estimator := NewETAEstimator("Test Operation", 0)
estimator.startTime = time.Now().Add(-5 * time.Second)
result := estimator.GetFullStatus("Test Operation")
// Should only show elapsed time when no items to track
if !contains(result, "Test Operation") || !contains(result, "Elapsed:") {
@@ -226,13 +226,13 @@ func TestEstimateSizeBasedDuration(t *testing.T) {
if duration < 60*time.Second || duration > 90*time.Second {
t.Errorf("Expected ~1.2 minutes for 100MB/1core, got %v", duration)
}
// Test 100MB with 8 cores (should be faster)
duration8cores := EstimateSizeBasedDuration(100*1024*1024, 8)
if duration8cores >= duration {
t.Errorf("Expected faster with more cores: %v vs %v", duration8cores, duration)
}
// Test larger file
duration1GB := EstimateSizeBasedDuration(1024*1024*1024, 1)
if duration1GB <= duration {
@@ -242,9 +242,8 @@ func TestEstimateSizeBasedDuration(t *testing.T) {
// Helper function
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr ||
len(s) > len(substr) && (
s[:len(substr)] == substr ||
return len(s) >= len(substr) && (s == substr ||
len(s) > len(substr) && (s[:len(substr)] == substr ||
s[len(s)-len(substr):] == substr ||
indexHelper(s, substr) >= 0))
}

View File

@@ -43,11 +43,11 @@ func NewSpinner() *Spinner {
func (s *Spinner) Start(message string) {
s.message = message
s.active = true
go func() {
ticker := time.NewTicker(s.interval)
defer ticker.Stop()
i := 0
lastMessage := ""
for {
@@ -57,12 +57,12 @@ func (s *Spinner) Start(message string) {
case <-ticker.C:
if s.active {
displayMsg := s.message
// Add ETA info if estimator is available
if s.estimator != nil {
displayMsg = s.estimator.GetFullStatus(s.message)
}
currentFrame := fmt.Sprintf("%s %s", s.frames[i%len(s.frames)], displayMsg)
if s.message != lastMessage {
// Print new line for new messages
@@ -130,13 +130,13 @@ func NewDots() *Dots {
func (d *Dots) Start(message string) {
d.message = message
d.active = true
fmt.Fprint(d.writer, message)
go func() {
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
count := 0
for {
select {
@@ -191,13 +191,13 @@ func (d *Dots) SetEstimator(estimator *ETAEstimator) {
// ProgressBar creates a visual progress bar
type ProgressBar struct {
writer io.Writer
message string
total int
current int
width int
active bool
stopCh chan bool
writer io.Writer
message string
total int
current int
width int
active bool
stopCh chan bool
}
// NewProgressBar creates a new progress bar
@@ -265,12 +265,12 @@ func (p *ProgressBar) render() {
if !p.active {
return
}
percent := float64(p.current) / float64(p.total)
filled := int(percent * float64(p.width))
bar := strings.Repeat("█", filled) + strings.Repeat("░", p.width-filled)
fmt.Fprintf(p.writer, "\n%s [%s] %d%%", p.message, bar, int(percent*100))
}
@@ -432,7 +432,7 @@ func NewIndicator(interactive bool, indicatorType string) Indicator {
if !interactive {
return NewLineByLine() // Use line-by-line for non-interactive mode
}
switch indicatorType {
case "spinner":
return NewSpinner()
@@ -457,9 +457,9 @@ func NewNullIndicator() *NullIndicator {
return &NullIndicator{}
}
func (n *NullIndicator) Start(message string) {}
func (n *NullIndicator) Update(message string) {}
func (n *NullIndicator) Complete(message string) {}
func (n *NullIndicator) Fail(message string) {}
func (n *NullIndicator) Stop() {}
func (n *NullIndicator) Start(message string) {}
func (n *NullIndicator) Update(message string) {}
func (n *NullIndicator) Complete(message string) {}
func (n *NullIndicator) Fail(message string) {}
func (n *NullIndicator) Stop() {}
func (n *NullIndicator) SetEstimator(estimator *ETAEstimator) {}

View File

@@ -1,3 +1,4 @@
//go:build openbsd
// +build openbsd
package restore

View File

@@ -1,3 +1,4 @@
//go:build netbsd
// +build netbsd
package restore

View File

@@ -1,3 +1,4 @@
//go:build !windows && !openbsd && !netbsd
// +build !windows,!openbsd,!netbsd
package restore

View File

@@ -1,3 +1,4 @@
//go:build windows
// +build windows
package restore

View File

@@ -358,21 +358,21 @@ func (e *Engine) executeRestoreCommand(ctx context.Context, cmdArgs []string) er
e.log.Warn("Restore completed with ignorable errors", "error_count", errorCount, "last_error", lastError)
return nil // Success despite ignorable errors
}
// Classify error and provide helpful hints
if lastError != "" {
classification := checks.ClassifyError(lastError)
e.log.Error("Restore command failed",
"error", err,
"last_stderr", lastError,
e.log.Error("Restore command failed",
"error", err,
"last_stderr", lastError,
"error_count", errorCount,
"error_type", classification.Type,
"hint", classification.Hint,
"action", classification.Action)
return fmt.Errorf("restore failed: %w (last error: %s, total errors: %d) - %s",
return fmt.Errorf("restore failed: %w (last error: %s, total errors: %d) - %s",
err, lastError, errorCount, classification.Hint)
}
e.log.Error("Restore command failed", "error", err, "last_stderr", lastError, "error_count", errorCount)
return fmt.Errorf("restore failed: %w", err)
}
@@ -440,21 +440,21 @@ func (e *Engine) executeRestoreWithDecompression(ctx context.Context, archivePat
e.log.Warn("Restore with decompression completed with ignorable errors", "error_count", errorCount, "last_error", lastError)
return nil // Success despite ignorable errors
}
// Classify error and provide helpful hints
if lastError != "" {
classification := checks.ClassifyError(lastError)
e.log.Error("Restore with decompression failed",
"error", err,
"last_stderr", lastError,
e.log.Error("Restore with decompression failed",
"error", err,
"last_stderr", lastError,
"error_count", errorCount,
"error_type", classification.Type,
"hint", classification.Hint,
"action", classification.Action)
return fmt.Errorf("restore failed: %w (last error: %s, total errors: %d) - %s",
return fmt.Errorf("restore failed: %w (last error: %s, total errors: %d) - %s",
err, lastError, errorCount, classification.Hint)
}
e.log.Error("Restore with decompression failed", "error", err, "last_stderr", lastError, "error_count", errorCount)
return fmt.Errorf("restore failed: %w", err)
}
@@ -530,20 +530,20 @@ func (e *Engine) RestoreCluster(ctx context.Context, archivePath string) error {
operation.Fail("Invalid cluster archive format")
return fmt.Errorf("not a cluster archive: %s (detected format: %s)", archivePath, format)
}
// Check disk space before starting restore
e.log.Info("Checking disk space for restore")
archiveInfo, err := os.Stat(archivePath)
if err == nil {
spaceCheck := checks.CheckDiskSpaceForRestore(e.cfg.BackupDir, archiveInfo.Size())
if spaceCheck.Critical {
operation.Fail("Insufficient disk space")
return fmt.Errorf("insufficient disk space for restore: %.1f%% used - need at least 4x archive size", spaceCheck.UsedPercent)
}
if spaceCheck.Warning {
e.log.Warn("Low disk space - restore may fail",
e.log.Warn("Low disk space - restore may fail",
"available_gb", float64(spaceCheck.AvailableBytes)/(1024*1024*1024),
"used_percent", spaceCheck.UsedPercent)
}
@@ -638,13 +638,13 @@ func (e *Engine) RestoreCluster(ctx context.Context, archivePath string) error {
// Check for large objects in dump files and adjust parallelism
hasLargeObjects := e.detectLargeObjectsInDumps(dumpsDir, entries)
// Use worker pool for parallel restore
parallelism := e.cfg.ClusterParallelism
if parallelism < 1 {
parallelism = 1 // Ensure at least sequential
}
// Automatically reduce parallelism if large objects detected
if hasLargeObjects && parallelism > 1 {
e.log.Warn("Large objects detected in dump files - reducing parallelism to avoid lock contention",
@@ -731,13 +731,13 @@ func (e *Engine) RestoreCluster(ctx context.Context, archivePath string) error {
mu.Lock()
e.log.Error("Failed to restore database", "name", dbName, "file", dumpFile, "error", restoreErr)
mu.Unlock()
// Check for specific recoverable errors
errMsg := restoreErr.Error()
if strings.Contains(errMsg, "max_locks_per_transaction") {
mu.Lock()
e.log.Warn("Database restore failed due to insufficient locks - this is a PostgreSQL configuration issue",
"database", dbName,
e.log.Warn("Database restore failed due to insufficient locks - this is a PostgreSQL configuration issue",
"database", dbName,
"solution", "increase max_locks_per_transaction in postgresql.conf")
mu.Unlock()
} else if strings.Contains(errMsg, "total errors:") && strings.Contains(errMsg, "2562426") {
@@ -747,7 +747,7 @@ func (e *Engine) RestoreCluster(ctx context.Context, archivePath string) error {
"errors", "2562426")
mu.Unlock()
}
failedDBsMu.Lock()
// Include more context in the error message
failedDBs = append(failedDBs, fmt.Sprintf("%s: restore failed: %v", dbName, restoreErr))
@@ -770,16 +770,16 @@ func (e *Engine) RestoreCluster(ctx context.Context, archivePath string) error {
if failCountFinal > 0 {
failedList := strings.Join(failedDBs, "\n ")
// Log summary
e.log.Info("Cluster restore completed with failures",
"succeeded", successCountFinal,
"failed", failCountFinal,
"total", totalDBs)
e.progress.Fail(fmt.Sprintf("Cluster restore: %d succeeded, %d failed out of %d total", successCountFinal, failCountFinal, totalDBs))
operation.Complete(fmt.Sprintf("Partial restore: %d/%d databases succeeded", successCountFinal, totalDBs))
return fmt.Errorf("cluster restore completed with %d failures:\n %s", failCountFinal, failedList)
}
@@ -1079,48 +1079,48 @@ func (e *Engine) detectLargeObjectsInDumps(dumpsDir string, entries []os.DirEntr
hasLargeObjects := false
checkedCount := 0
maxChecks := 5 // Only check first 5 dumps to avoid slowdown
for _, entry := range entries {
if entry.IsDir() || checkedCount >= maxChecks {
continue
}
dumpFile := filepath.Join(dumpsDir, entry.Name())
// Skip compressed SQL files (can't easily check without decompressing)
if strings.HasSuffix(dumpFile, ".sql.gz") {
continue
}
// Use pg_restore -l to list contents (fast, doesn't restore data)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "pg_restore", "-l", dumpFile)
output, err := cmd.Output()
if err != nil {
// If pg_restore -l fails, it might not be custom format - skip
continue
}
checkedCount++
// Check if output contains "BLOB" or "LARGE OBJECT" entries
outputStr := string(output)
if strings.Contains(outputStr, "BLOB") ||
strings.Contains(outputStr, "LARGE OBJECT") ||
strings.Contains(outputStr, " BLOBS ") {
if strings.Contains(outputStr, "BLOB") ||
strings.Contains(outputStr, "LARGE OBJECT") ||
strings.Contains(outputStr, " BLOBS ") {
e.log.Info("Large objects detected in dump file", "file", entry.Name())
hasLargeObjects = true
// Don't break - log all files with large objects
}
}
if hasLargeObjects {
e.log.Warn("Cluster contains databases with large objects - parallel restore may cause lock contention")
}
return hasLargeObjects
}
@@ -1128,13 +1128,13 @@ func (e *Engine) detectLargeObjectsInDumps(dumpsDir string, entries []os.DirEntr
func (e *Engine) isIgnorableError(errorMsg string) bool {
// Convert to lowercase for case-insensitive matching
lowerMsg := strings.ToLower(errorMsg)
// CRITICAL: Syntax errors are NOT ignorable - indicates corrupted dump
if strings.Contains(lowerMsg, "syntax error") {
e.log.Error("CRITICAL: Syntax error in dump file - dump may be corrupted", "error", errorMsg)
return false
}
// CRITICAL: If error count is extremely high (>100k), dump is likely corrupted
if strings.Contains(errorMsg, "total errors:") {
// Extract error count if present in message
@@ -1149,21 +1149,21 @@ func (e *Engine) isIgnorableError(errorMsg string) bool {
}
}
}
// List of ignorable error patterns (objects that already exist)
ignorablePatterns := []string{
"already exists",
"duplicate key",
"does not exist, skipping", // For DROP IF EXISTS
"no pg_hba.conf entry", // Permission warnings (not fatal)
"no pg_hba.conf entry", // Permission warnings (not fatal)
}
for _, pattern := range ignorablePatterns {
if strings.Contains(lowerMsg, pattern) {
return true
}
}
return false
}

View File

@@ -1,24 +1,24 @@
package restore
import (
"compress/gzip"
"io"
"os"
"strings"
"compress/gzip"
"io"
"os"
"strings"
)
// ArchiveFormat represents the type of backup archive
type ArchiveFormat string
const (
FormatPostgreSQLDump ArchiveFormat = "PostgreSQL Dump (.dump)"
FormatPostgreSQLDumpGz ArchiveFormat = "PostgreSQL Dump Compressed (.dump.gz)"
FormatPostgreSQLSQL ArchiveFormat = "PostgreSQL SQL (.sql)"
FormatPostgreSQLSQLGz ArchiveFormat = "PostgreSQL SQL Compressed (.sql.gz)"
FormatMySQLSQL ArchiveFormat = "MySQL SQL (.sql)"
FormatMySQLSQLGz ArchiveFormat = "MySQL SQL Compressed (.sql.gz)"
FormatClusterTarGz ArchiveFormat = "Cluster Archive (.tar.gz)"
FormatUnknown ArchiveFormat = "Unknown"
FormatPostgreSQLDump ArchiveFormat = "PostgreSQL Dump (.dump)"
FormatPostgreSQLDumpGz ArchiveFormat = "PostgreSQL Dump Compressed (.dump.gz)"
FormatPostgreSQLSQL ArchiveFormat = "PostgreSQL SQL (.sql)"
FormatPostgreSQLSQLGz ArchiveFormat = "PostgreSQL SQL Compressed (.sql.gz)"
FormatMySQLSQL ArchiveFormat = "MySQL SQL (.sql)"
FormatMySQLSQLGz ArchiveFormat = "MySQL SQL Compressed (.sql.gz)"
FormatClusterTarGz ArchiveFormat = "Cluster Archive (.tar.gz)"
FormatUnknown ArchiveFormat = "Unknown"
)
// DetectArchiveFormat detects the format of a backup archive from its filename and content
@@ -37,7 +37,7 @@ func DetectArchiveFormat(filename string) ArchiveFormat {
result := isCustomFormat(filename, true)
// If file doesn't exist or we can't read it, trust the extension
// If file exists and has PGDMP signature, it's custom format
// If file exists but doesn't have signature, it might be SQL named as .dump
// If file exists but doesn't have signature, it might be SQL named as .dump
if result == formatCheckCustom || result == formatCheckFileNotFound {
return FormatPostgreSQLDumpGz
}
@@ -81,9 +81,9 @@ func DetectArchiveFormat(filename string) ArchiveFormat {
type formatCheckResult int
const (
formatCheckFileNotFound formatCheckResult = iota
formatCheckCustom
formatCheckNotCustom
formatCheckFileNotFound formatCheckResult = iota
formatCheckCustom
formatCheckNotCustom
)
// isCustomFormat checks if a file is PostgreSQL custom format (has PGDMP signature)

View File

@@ -242,7 +242,7 @@ func (s *Safety) CheckDiskSpaceAt(archivePath string, checkDir string, multiplie
}
archiveSize := stat.Size()
// Estimate required space (archive size * multiplier for decompression/extraction)
requiredSpace := int64(float64(archiveSize) * multiplier)
@@ -323,12 +323,12 @@ func (s *Safety) checkPostgresDatabaseExists(ctx context.Context, dbName string)
"-d", "postgres",
"-tAc", fmt.Sprintf("SELECT 1 FROM pg_database WHERE datname='%s'", dbName),
}
// Only add -h flag if host is not localhost (to use Unix socket for peer auth)
if s.cfg.Host != "localhost" && s.cfg.Host != "127.0.0.1" && s.cfg.Host != "" {
args = append([]string{"-h", s.cfg.Host}, args...)
}
cmd := exec.CommandContext(ctx, "psql", args...)
// Set password if provided
@@ -351,12 +351,12 @@ func (s *Safety) checkMySQLDatabaseExists(ctx context.Context, dbName string) (b
"-u", s.cfg.User,
"-e", fmt.Sprintf("SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME='%s'", dbName),
}
// Only add -h flag if host is not localhost (to use Unix socket)
if s.cfg.Host != "localhost" && s.cfg.Host != "127.0.0.1" && s.cfg.Host != "" {
args = append([]string{"-h", s.cfg.Host}, args...)
}
cmd := exec.CommandContext(ctx, "mysql", args...)
if s.cfg.Password != "" {
@@ -386,7 +386,7 @@ func (s *Safety) ListUserDatabases(ctx context.Context) ([]string, error) {
func (s *Safety) listPostgresUserDatabases(ctx context.Context) ([]string, error) {
// Query to get non-template databases excluding 'postgres' system DB
query := "SELECT datname FROM pg_database WHERE datistemplate = false AND datname != 'postgres' ORDER BY datname"
args := []string{
"-p", fmt.Sprintf("%d", s.cfg.Port),
"-U", s.cfg.User,
@@ -394,12 +394,12 @@ func (s *Safety) listPostgresUserDatabases(ctx context.Context) ([]string, error
"-tA", // Tuples only, unaligned
"-c", query,
}
// Only add -h flag if host is not localhost (to use Unix socket for peer auth)
if s.cfg.Host != "localhost" && s.cfg.Host != "127.0.0.1" && s.cfg.Host != "" {
args = append([]string{"-h", s.cfg.Host}, args...)
}
cmd := exec.CommandContext(ctx, "psql", args...)
// Set password if provided
@@ -429,19 +429,19 @@ func (s *Safety) listPostgresUserDatabases(ctx context.Context) ([]string, error
func (s *Safety) listMySQLUserDatabases(ctx context.Context) ([]string, error) {
// Exclude system databases
query := "SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME NOT IN ('information_schema', 'mysql', 'performance_schema', 'sys') ORDER BY SCHEMA_NAME"
args := []string{
"-P", fmt.Sprintf("%d", s.cfg.Port),
"-u", s.cfg.User,
"-N", // Skip column names
"-e", query,
}
// Only add -h flag if host is not localhost (to use Unix socket)
if s.cfg.Host != "localhost" && s.cfg.Host != "127.0.0.1" && s.cfg.Host != "" {
args = append([]string{"-h", s.cfg.Host}, args...)
}
cmd := exec.CommandContext(ctx, "mysql", args...)
if s.cfg.Password != "" {

View File

@@ -23,7 +23,7 @@ func TestValidateArchive_FileNotFound(t *testing.T) {
func TestValidateArchive_EmptyFile(t *testing.T) {
tmpDir := t.TempDir()
emptyFile := filepath.Join(tmpDir, "empty.dump")
if err := os.WriteFile(emptyFile, []byte{}, 0644); err != nil {
t.Fatalf("Failed to create empty file: %v", err)
}
@@ -43,7 +43,7 @@ func TestCheckDiskSpace_InsufficientSpace(t *testing.T) {
// Just ensure the function doesn't panic
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.dump")
// Create a small test file
if err := os.WriteFile(testFile, []byte("test"), 0644); err != nil {
t.Fatalf("Failed to create test file: %v", err)

View File

@@ -23,21 +23,21 @@ func ParsePostgreSQLVersion(versionStr string) (*VersionInfo, error) {
// Match patterns like "PostgreSQL 17.7", "PostgreSQL 13.11", "PostgreSQL 10.23"
re := regexp.MustCompile(`PostgreSQL\s+(\d+)\.(\d+)`)
matches := re.FindStringSubmatch(versionStr)
if len(matches) < 3 {
return nil, fmt.Errorf("could not parse PostgreSQL version from: %s", versionStr)
}
major, err := strconv.Atoi(matches[1])
if err != nil {
return nil, fmt.Errorf("invalid major version: %s", matches[1])
}
minor, err := strconv.Atoi(matches[2])
if err != nil {
return nil, fmt.Errorf("invalid minor version: %s", matches[2])
}
return &VersionInfo{
Major: major,
Minor: minor,
@@ -53,24 +53,24 @@ func GetDumpFileVersion(dumpPath string) (*VersionInfo, error) {
if err != nil {
return nil, fmt.Errorf("failed to read dump file metadata: %w (output: %s)", err, string(output))
}
// Look for "Dumped from database version: X.Y.Z" in output
re := regexp.MustCompile(`Dumped from database version:\s+(\d+)\.(\d+)`)
matches := re.FindStringSubmatch(string(output))
if len(matches) < 3 {
// Try alternate format in some dumps
re = regexp.MustCompile(`PostgreSQL database dump.*(\d+)\.(\d+)`)
matches = re.FindStringSubmatch(string(output))
}
if len(matches) < 3 {
return nil, fmt.Errorf("could not find version information in dump file")
}
major, _ := strconv.Atoi(matches[1])
minor, _ := strconv.Atoi(matches[2])
return &VersionInfo{
Major: major,
Minor: minor,
@@ -81,18 +81,18 @@ func GetDumpFileVersion(dumpPath string) (*VersionInfo, error) {
// CheckVersionCompatibility checks if restoring from source version to target version is safe
func CheckVersionCompatibility(sourceVer, targetVer *VersionInfo) *VersionCompatibilityResult {
result := &VersionCompatibilityResult{
Compatible: true,
Compatible: true,
SourceVersion: sourceVer,
TargetVersion: targetVer,
}
// Same major version - always compatible
if sourceVer.Major == targetVer.Major {
result.Level = CompatibilityLevelSafe
result.Message = "Same major version - fully compatible"
return result
}
// Downgrade - not supported
if sourceVer.Major > targetVer.Major {
result.Compatible = false
@@ -101,10 +101,10 @@ func CheckVersionCompatibility(sourceVer, targetVer *VersionInfo) *VersionCompat
result.Warnings = append(result.Warnings, "Database downgrades require pg_dump from the target version")
return result
}
// Upgrade - check how many major versions
versionDiff := targetVer.Major - sourceVer.Major
if versionDiff == 1 {
// One major version upgrade - generally safe
result.Level = CompatibilityLevelSafe
@@ -113,7 +113,7 @@ func CheckVersionCompatibility(sourceVer, targetVer *VersionInfo) *VersionCompat
// 2-3 major versions - should work but review release notes
result.Level = CompatibilityLevelWarning
result.Message = fmt.Sprintf("Upgrading from PostgreSQL %d to %d - supported but review release notes", sourceVer.Major, targetVer.Major)
result.Warnings = append(result.Warnings,
result.Warnings = append(result.Warnings,
fmt.Sprintf("You are jumping %d major versions - some features may have changed", versionDiff))
result.Warnings = append(result.Warnings,
"Review release notes for deprecated features or behavior changes")
@@ -134,13 +134,13 @@ func CheckVersionCompatibility(sourceVer, targetVer *VersionInfo) *VersionCompat
result.Recommendations = append(result.Recommendations,
"Review PostgreSQL release notes for versions "+strconv.Itoa(sourceVer.Major)+" through "+strconv.Itoa(targetVer.Major))
}
// Add general upgrade advice
if versionDiff > 0 {
result.Recommendations = append(result.Recommendations,
"Run ANALYZE on all tables after restore for optimal query performance")
}
return result
}
@@ -189,33 +189,33 @@ func (e *Engine) CheckRestoreVersionCompatibility(ctx context.Context, dumpPath
e.log.Warn("Could not determine dump file version", "error", err)
return nil, nil
}
// Get target database version
targetVerStr, err := e.db.GetVersion(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get target database version: %w", err)
}
targetVer, err := ParsePostgreSQLVersion(targetVerStr)
if err != nil {
return nil, fmt.Errorf("failed to parse target version: %w", err)
}
// Check compatibility
result := CheckVersionCompatibility(dumpVer, targetVer)
// Log the results
e.log.Info("Version compatibility check",
"source", dumpVer.Full,
"target", targetVer.Full,
"level", result.Level.String())
if len(result.Warnings) > 0 {
for _, warning := range result.Warnings {
e.log.Warn(warning)
}
}
return result, nil
}

View File

@@ -19,12 +19,12 @@ type Policy struct {
// CleanupResult contains information about cleanup operations
type CleanupResult struct {
TotalBackups int
TotalBackups int
EligibleForDeletion int
Deleted []string
Kept []string
SpaceFreed int64
Errors []error
Deleted []string
Kept []string
SpaceFreed int64
Errors []error
}
// ApplyPolicy enforces the retention policy on backups in a directory
@@ -63,13 +63,13 @@ func ApplyPolicy(backupDir string, policy Policy) (*CleanupResult, error) {
// Check if backup is older than retention period
if backup.Timestamp.Before(cutoffDate) {
result.EligibleForDeletion++
if policy.DryRun {
result.Deleted = append(result.Deleted, backup.BackupFile)
} else {
// Delete backup file and associated metadata
if err := deleteBackup(backup.BackupFile); err != nil {
result.Errors = append(result.Errors,
result.Errors = append(result.Errors,
fmt.Errorf("failed to delete %s: %w", backup.BackupFile, err))
} else {
result.Deleted = append(result.Deleted, backup.BackupFile)
@@ -204,7 +204,7 @@ func CleanupByPattern(backupDir, pattern string, policy Policy) (*CleanupResult,
if backup.Timestamp.Before(cutoffDate) {
result.EligibleForDeletion++
if policy.DryRun {
result.Deleted = append(result.Deleted, backup.BackupFile)
} else {

View File

@@ -9,18 +9,18 @@ import (
// AuditEvent represents an auditable event
type AuditEvent struct {
Timestamp time.Time
User string
Action string
Resource string
Result string
Details map[string]interface{}
Timestamp time.Time
User string
Action string
Resource string
Result string
Details map[string]interface{}
}
// AuditLogger provides audit logging functionality
type AuditLogger struct {
log logger.Logger
enabled bool
log logger.Logger
enabled bool
}
// NewAuditLogger creates a new audit logger

View File

@@ -42,7 +42,7 @@ func VerifyChecksum(path string, expectedChecksum string) error {
func SaveChecksum(archivePath string, checksum string) error {
checksumPath := archivePath + ".sha256"
content := fmt.Sprintf("%s %s\n", checksum, archivePath)
if err := os.WriteFile(checksumPath, []byte(content), 0644); err != nil {
return fmt.Errorf("failed to save checksum: %w", err)
}
@@ -53,7 +53,7 @@ func SaveChecksum(archivePath string, checksum string) error {
// LoadChecksum loads checksum from a .sha256 file
func LoadChecksum(archivePath string) (string, error) {
checksumPath := archivePath + ".sha256"
data, err := os.ReadFile(checksumPath)
if err != nil {
return "", fmt.Errorf("failed to read checksum file: %w", err)

View File

@@ -49,7 +49,7 @@ func ValidateArchivePath(path string) (string, error) {
// Must have a valid archive extension
ext := strings.ToLower(filepath.Ext(cleaned))
validExtensions := []string{".dump", ".sql", ".gz", ".tar"}
valid := false
for _, validExt := range validExtensions {
if strings.HasSuffix(cleaned, validExt) {

View File

@@ -23,20 +23,20 @@ func NewPrivilegeChecker(log logger.Logger) *PrivilegeChecker {
// CheckAndWarn checks if running with elevated privileges and warns
func (pc *PrivilegeChecker) CheckAndWarn(allowRoot bool) error {
isRoot, user := pc.isRunningAsRoot()
if isRoot {
pc.log.Warn("⚠️ Running with elevated privileges (root/Administrator)")
pc.log.Warn("Security recommendation: Create a dedicated backup user with minimal privileges")
if !allowRoot {
return fmt.Errorf("running as root is not recommended, use --allow-root to override")
}
pc.log.Warn("Proceeding with root privileges (--allow-root specified)")
} else {
pc.log.Debug("Running as non-privileged user", "user", user)
}
return nil
}
@@ -52,7 +52,7 @@ func (pc *PrivilegeChecker) isRunningAsRoot() (bool, string) {
func (pc *PrivilegeChecker) isUnixRoot() (bool, string) {
uid := os.Getuid()
user := GetCurrentUser()
isRoot := uid == 0 || user == "root"
return isRoot, user
}
@@ -62,10 +62,10 @@ func (pc *PrivilegeChecker) isWindowsAdmin() (bool, string) {
// Check if running as Administrator on Windows
// This is a simplified check - full implementation would use Windows API
user := GetCurrentUser()
// Common admin user patterns on Windows
isAdmin := user == "Administrator" || user == "SYSTEM"
return isAdmin, user
}
@@ -89,11 +89,11 @@ func (pc *PrivilegeChecker) GetSecurityRecommendations() []string {
"Regularly rotate database passwords",
"Monitor audit logs for unauthorized access attempts",
}
if runtime.GOOS != "windows" {
recommendations = append(recommendations,
fmt.Sprintf("Run as non-root user: sudo -u %s dbbackup ...", pc.GetRecommendedUser()))
}
return recommendations
}

View File

@@ -1,4 +1,5 @@
// go:build !linux
//go:build !linux
// +build !linux
package security

View File

@@ -1,3 +1,4 @@
//go:build !windows
// +build !windows
package security
@@ -19,7 +20,7 @@ func (rc *ResourceChecker) checkPlatformLimits() (*ResourceLimits, error) {
if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rLimit); err == nil {
limits.MaxOpenFiles = uint64(rLimit.Cur)
rc.log.Debug("Resource limit: max open files", "limit", rLimit.Cur, "max", rLimit.Max)
if rLimit.Cur < 1024 {
rc.log.Warn("⚠️ Low file descriptor limit detected",
"current", rLimit.Cur,

View File

@@ -1,3 +1,4 @@
//go:build windows
// +build windows
package security
@@ -23,5 +24,3 @@ func (rc *ResourceChecker) checkPlatformLimits() (*ResourceLimits, error) {
return limits, nil
}

View File

@@ -46,13 +46,13 @@ func (rp *RetentionPolicy) CleanupOldBackups(backupDir string) (int, int64, erro
}
if len(archives) <= rp.MinBackups {
rp.log.Debug("Keeping all backups (below minimum threshold)",
rp.log.Debug("Keeping all backups (below minimum threshold)",
"count", len(archives), "min_backups", rp.MinBackups)
return 0, 0, nil
}
cutoffTime := time.Now().AddDate(0, 0, -rp.RetentionDays)
// Sort by modification time (oldest first)
sort.Slice(archives, func(i, j int) bool {
return archives[i].ModTime.Before(archives[j].ModTime)
@@ -65,14 +65,14 @@ func (rp *RetentionPolicy) CleanupOldBackups(backupDir string) (int, int64, erro
// Keep minimum number of backups
remaining := len(archives) - i
if remaining <= rp.MinBackups {
rp.log.Debug("Stopped cleanup to maintain minimum backups",
rp.log.Debug("Stopped cleanup to maintain minimum backups",
"remaining", remaining, "min_backups", rp.MinBackups)
break
}
// Delete if older than retention period
if archive.ModTime.Before(cutoffTime) {
rp.log.Info("Removing old backup",
rp.log.Info("Removing old backup",
"file", filepath.Base(archive.Path),
"age_days", int(time.Since(archive.ModTime).Hours()/24),
"size_mb", archive.Size/1024/1024)
@@ -100,7 +100,7 @@ func (rp *RetentionPolicy) CleanupOldBackups(backupDir string) (int, int64, erro
}
if deletedCount > 0 {
rp.log.Info("Cleanup completed",
rp.log.Info("Cleanup completed",
"deleted_backups", deletedCount,
"freed_space_mb", freedSpace/1024/1024,
"retention_days", rp.RetentionDays)
@@ -124,7 +124,7 @@ func (rp *RetentionPolicy) scanBackupArchives(backupDir string) ([]ArchiveInfo,
}
name := entry.Name()
// Skip non-backup files
if !isBackupArchive(name) {
continue
@@ -161,7 +161,7 @@ func isBackupArchive(name string) bool {
// extractDatabaseName extracts database name from archive filename
func extractDatabaseName(filename string) string {
base := filepath.Base(filename)
// Remove extensions
for {
oldBase := base
@@ -170,7 +170,7 @@ func extractDatabaseName(filename string) string {
break
}
}
// Remove timestamp patterns
if len(base) > 20 {
// Typically: db_name_20240101_120000
@@ -184,7 +184,7 @@ func extractDatabaseName(filename string) string {
}
}
}
return base
}

View File

@@ -171,9 +171,9 @@ func (m *Manager) Setup() error {
// Log current swap status
if total, used, free, err := m.GetCurrentSwap(); err == nil {
m.log.Info("Swap status after setup",
"total_mb", total,
"used_mb", used,
m.log.Info("Swap status after setup",
"total_mb", total,
"used_mb", used,
"free_mb", free,
"added_gb", m.sizeGB)
}

View File

@@ -41,13 +41,13 @@ var (
// ArchiveInfo holds information about a backup archive
type ArchiveInfo struct {
Name string
Path string
Format restore.ArchiveFormat
Size int64
Modified time.Time
DatabaseName string
Valid bool
Name string
Path string
Format restore.ArchiveFormat
Size int64
Modified time.Time
DatabaseName string
Valid bool
ValidationMsg string
}
@@ -132,13 +132,13 @@ func loadArchives(cfg *config.Config, log logger.Logger) tea.Cmd {
}
archives = append(archives, ArchiveInfo{
Name: name,
Path: fullPath,
Format: format,
Size: info.Size(),
Modified: info.ModTime(),
DatabaseName: dbName,
Valid: valid,
Name: name,
Path: fullPath,
Format: format,
Size: info.Size(),
Modified: info.ModTime(),
DatabaseName: dbName,
Valid: valid,
ValidationMsg: validationMsg,
})
}
@@ -196,13 +196,13 @@ func (m ArchiveBrowserModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case "enter", " ":
if len(m.archives) > 0 && m.cursor < len(m.archives) {
selected := m.archives[m.cursor]
// Validate selection based on mode
if m.mode == "restore-cluster" && !selected.Format.IsClusterBackup() {
m.message = errorStyle.Render("❌ Please select a cluster backup (.tar.gz)")
return m, nil
}
if m.mode == "restore-single" && selected.Format.IsClusterBackup() {
m.message = errorStyle.Render("❌ Please select a single database backup")
return m, nil
@@ -239,7 +239,7 @@ func (m ArchiveBrowserModel) View() string {
} else if m.mode == "restore-cluster" {
title = "📦 Select Archive to Restore (Cluster)"
}
s.WriteString(titleStyle.Render(title))
s.WriteString("\n\n")

View File

@@ -78,10 +78,10 @@ type backupCompleteMsg struct {
func executeBackupWithTUIProgress(parentCtx context.Context, cfg *config.Config, log logger.Logger, backupType, dbName string, ratio int) tea.Cmd {
return func() tea.Msg {
// Use configurable cluster timeout (minutes) from config; default set in config.New()
// Use parent context to inherit cancellation from TUI
clusterTimeout := time.Duration(cfg.ClusterTimeoutMinutes) * time.Minute
ctx, cancel := context.WithTimeout(parentCtx, clusterTimeout)
// Use configurable cluster timeout (minutes) from config; default set in config.New()
// Use parent context to inherit cancellation from TUI
clusterTimeout := time.Duration(cfg.ClusterTimeoutMinutes) * time.Minute
ctx, cancel := context.WithTimeout(parentCtx, clusterTimeout)
defer cancel()
start := time.Now()
@@ -151,10 +151,10 @@ func (m BackupExecutionModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if !m.done {
// Increment spinner frame for smooth animation
m.spinnerFrame = (m.spinnerFrame + 1) % len(spinnerFrames)
// Update status based on elapsed time to show progress
elapsedSec := int(time.Since(m.startTime).Seconds())
if elapsedSec < 2 {
m.status = "Initializing backup..."
} else if elapsedSec < 5 {
@@ -180,7 +180,7 @@ func (m BackupExecutionModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.status = fmt.Sprintf("Backing up database '%s'...", m.databaseName)
}
}
return m, backupTickCmd()
}
return m, nil
@@ -239,7 +239,7 @@ func (m BackupExecutionModel) View() string {
s.WriteString(fmt.Sprintf(" %s %s\n", spinnerFrames[m.spinnerFrame], m.status))
} else {
s.WriteString(fmt.Sprintf(" %s\n\n", m.status))
if m.err != nil {
s.WriteString(fmt.Sprintf(" ❌ Error: %v\n", m.err))
} else if m.result != "" {

View File

@@ -52,13 +52,13 @@ func (m BackupManagerModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, nil
}
m.archives = msg.archives
// Calculate total size
m.totalSize = 0
for _, archive := range m.archives {
m.totalSize += archive.Size
}
// Get free space (simplified - just show message)
m.message = fmt.Sprintf("Loaded %d archive(s)", len(m.archives))
return m, nil

View File

@@ -84,7 +84,7 @@ func (m DatabaseSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.databases = []string{"Error loading databases"}
} else {
m.databases = msg.databases
// Auto-select database if specified
if m.config.TUIAutoDatabase != "" {
for i, db := range m.databases {
@@ -92,7 +92,7 @@ func (m DatabaseSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.cursor = i
m.selected = db
m.logger.Info("Auto-selected database", "database", db)
// If sample backup, ask for ratio (or auto-use default)
if m.backupType == "sample" {
if m.config.TUIDryRun {
@@ -107,7 +107,7 @@ func (m DatabaseSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
ValidateInt(1, 100))
return inputModel, nil
}
// For single backup, go directly to execution
executor := NewBackupExecution(m.config, m.logger, m.parent, m.ctx, m.backupType, m.selected, 0)
return executor, executor.Init()
@@ -136,7 +136,7 @@ func (m DatabaseSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case "enter":
if !m.loading && m.err == nil && len(m.databases) > 0 {
m.selected = m.databases[m.cursor]
// If sample backup, ask for ratio first
if m.backupType == "sample" {
inputModel := NewInputModel(m.config, m.logger, m,
@@ -146,7 +146,7 @@ func (m DatabaseSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
ValidateInt(1, 100))
return inputModel, nil
}
// For single backup, go directly to execution
executor := NewBackupExecution(m.config, m.logger, m.parent, m.ctx, m.backupType, m.selected, 0)
return executor, executor.Init()

View File

@@ -111,7 +111,7 @@ func (db *DirectoryBrowser) Render() string {
}
var lines []string
// Header
lines = append(lines, fmt.Sprintf(" Current: %s", db.CurrentPath))
lines = append(lines, fmt.Sprintf(" Found %d directories (cursor: %d)", len(db.items), db.cursor))
@@ -121,7 +121,7 @@ func (db *DirectoryBrowser) Render() string {
maxItems := 5 // Show max 5 items to keep it compact
start := 0
end := len(db.items)
if len(db.items) > maxItems {
// Center the cursor in the view
start = db.cursor - maxItems/2
@@ -144,14 +144,14 @@ func (db *DirectoryBrowser) Render() string {
if i == db.cursor {
prefix = " >> "
}
displayName := item
if item == ".." {
displayName = "../ (parent directory)"
} else if item != "[Error reading directory]" {
displayName = item + "/"
}
lines = append(lines, prefix+displayName)
}
@@ -164,4 +164,4 @@ func (db *DirectoryBrowser) Render() string {
lines = append(lines, " ↑/↓: Navigate | Enter/→: Open | ←: Parent | Space: Select | Esc: Cancel")
return strings.Join(lines, "\n")
}
}

View File

@@ -14,12 +14,12 @@ import (
// DirectoryPicker is a simple, fast directory and file picker
type DirectoryPicker struct {
currentPath string
items []FileItem
cursor int
callback func(string)
allowFiles bool // Allow file selection for restore operations
styles DirectoryPickerStyles
currentPath string
items []FileItem
cursor int
callback func(string)
allowFiles bool // Allow file selection for restore operations
styles DirectoryPickerStyles
}
type FileItem struct {
@@ -98,26 +98,26 @@ func (dp *DirectoryPicker) loadItems() {
// Collect directories and optionally files
var dirs []FileItem
var files []FileItem
for _, entry := range entries {
if strings.HasPrefix(entry.Name(), ".") {
continue // Skip hidden files
}
item := FileItem{
Name: entry.Name(),
IsDir: entry.IsDir(),
Path: filepath.Join(dp.currentPath, entry.Name()),
}
if entry.IsDir() {
dirs = append(dirs, item)
} else if dp.allowFiles {
// Only include backup-related files
if strings.HasSuffix(entry.Name(), ".sql") ||
strings.HasSuffix(entry.Name(), ".dump") ||
strings.HasSuffix(entry.Name(), ".gz") ||
strings.HasSuffix(entry.Name(), ".tar") {
if strings.HasSuffix(entry.Name(), ".sql") ||
strings.HasSuffix(entry.Name(), ".dump") ||
strings.HasSuffix(entry.Name(), ".gz") ||
strings.HasSuffix(entry.Name(), ".tar") {
files = append(files, item)
}
}
@@ -242,4 +242,4 @@ func (dp *DirectoryPicker) View() string {
content.WriteString(dp.styles.Help.Render(help))
return dp.styles.Container.Render(content.String())
}
}

View File

@@ -37,14 +37,14 @@ func NewHistoryView(cfg *config.Config, log logger.Logger, parent tea.Model) His
if lastIndex < 0 {
lastIndex = 0
}
// Calculate initial viewport to show the last item
maxVisible := 15
viewOffset := lastIndex - maxVisible + 1
if viewOffset < 0 {
viewOffset = 0
}
return HistoryViewModel{
config: cfg,
logger: log,
@@ -112,7 +112,7 @@ func (m HistoryViewModel) Init() tea.Cmd {
func (m HistoryViewModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
maxVisible := 15 // Show max 15 items at once
switch msg := msg.(type) {
case tea.KeyMsg:
switch msg.String() {
@@ -136,7 +136,7 @@ func (m HistoryViewModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.viewOffset = m.cursor - maxVisible + 1
}
}
case "pgup":
// Page up - jump by maxVisible items
m.cursor -= maxVisible
@@ -147,7 +147,7 @@ func (m HistoryViewModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if m.cursor < m.viewOffset {
m.viewOffset = m.cursor
}
case "pgdown":
// Page down - jump by maxVisible items
m.cursor += maxVisible
@@ -158,12 +158,12 @@ func (m HistoryViewModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if m.cursor >= m.viewOffset+maxVisible {
m.viewOffset = m.cursor - maxVisible + 1
}
case "home", "g":
// Jump to first item
m.cursor = 0
m.viewOffset = 0
case "end", "G":
// Jump to last item
m.cursor = len(m.history) - 1
@@ -187,15 +187,15 @@ func (m HistoryViewModel) View() string {
s.WriteString("📭 No backup history found\n\n")
} else {
maxVisible := 15 // Show max 15 items at once
// Calculate visible range
start := m.viewOffset
end := start + maxVisible
if end > len(m.history) {
end = len(m.history)
}
s.WriteString(fmt.Sprintf("Found %d backup operations (Viewing %d/%d):\n\n",
s.WriteString(fmt.Sprintf("Found %d backup operations (Viewing %d/%d):\n\n",
len(m.history), m.cursor+1, len(m.history)))
// Show scroll indicators
@@ -219,12 +219,12 @@ func (m HistoryViewModel) View() string {
s.WriteString(fmt.Sprintf(" %s\n", line))
}
}
// Show scroll indicator if more entries below
if end < len(m.history) {
s.WriteString(fmt.Sprintf(" ▼ %d more entries below...\n", len(m.history)-end))
}
s.WriteString("\n")
}

View File

@@ -61,7 +61,7 @@ func (m InputModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
}
m.done = true
// If this is from database selector, execute backup with ratio
if selector, ok := m.parent.(DatabaseSelectorModel); ok {
ratio, _ := strconv.Atoi(m.value)

View File

@@ -53,14 +53,14 @@ type dbTypeOption struct {
// MenuModel represents the simple menu state
type MenuModel struct {
choices []string
cursor int
config *config.Config
logger logger.Logger
quitting bool
message string
dbTypes []dbTypeOption
dbTypeCursor int
choices []string
cursor int
config *config.Config
logger logger.Logger
quitting bool
message string
dbTypes []dbTypeOption
dbTypeCursor int
// Background operations
ctx context.Context
@@ -133,7 +133,7 @@ func (m MenuModel) Init() tea.Cmd {
// Auto-select menu option if specified
if m.config.TUIAutoSelect >= 0 && m.config.TUIAutoSelect < len(m.choices) {
m.logger.Info("TUI Auto-select enabled", "option", m.config.TUIAutoSelect, "label", m.choices[m.config.TUIAutoSelect])
// Return command to trigger auto-selection
return func() tea.Msg {
return autoSelectMsg{}
@@ -150,7 +150,7 @@ func (m MenuModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if m.config.TUIAutoSelect >= 0 && m.config.TUIAutoSelect < len(m.choices) {
m.cursor = m.config.TUIAutoSelect
m.logger.Info("Auto-selecting option", "cursor", m.cursor, "choice", m.choices[m.cursor])
// Trigger the selection based on cursor position
switch m.cursor {
case 0: // Single Database Backup
@@ -184,7 +184,7 @@ func (m MenuModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
}
return m, nil
case tea.KeyMsg:
switch msg.String() {
case "ctrl+c", "q":
@@ -192,13 +192,13 @@ func (m MenuModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if m.cancel != nil {
m.cancel()
}
// Clean up any orphaned processes before exit
m.logger.Info("Cleaning up processes before exit")
if err := cleanup.KillOrphanedProcesses(m.logger); err != nil {
m.logger.Warn("Failed to clean up all processes", "error", err)
}
m.quitting = true
return m, tea.Quit

View File

@@ -269,11 +269,11 @@ func (s *SilentOperation) Fail(message string, args ...any) {}
// SilentProgressIndicator implements progress.Indicator but doesn't output anything
type SilentProgressIndicator struct{}
func (s *SilentProgressIndicator) Start(message string) {}
func (s *SilentProgressIndicator) Update(message string) {}
func (s *SilentProgressIndicator) Complete(message string) {}
func (s *SilentProgressIndicator) Fail(message string) {}
func (s *SilentProgressIndicator) Stop() {}
func (s *SilentProgressIndicator) Start(message string) {}
func (s *SilentProgressIndicator) Update(message string) {}
func (s *SilentProgressIndicator) Complete(message string) {}
func (s *SilentProgressIndicator) Fail(message string) {}
func (s *SilentProgressIndicator) Stop() {}
func (s *SilentProgressIndicator) SetEstimator(estimator *progress.ETAEstimator) {}
// RunBackupInTUI runs a backup operation with TUI-compatible progress reporting

View File

@@ -20,54 +20,54 @@ var spinnerFrames = []string{"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "
// RestoreExecutionModel handles restore execution with progress
type RestoreExecutionModel struct {
config *config.Config
logger logger.Logger
parent tea.Model
ctx context.Context
archive ArchiveInfo
targetDB string
cleanFirst bool
createIfMissing bool
restoreType string
config *config.Config
logger logger.Logger
parent tea.Model
ctx context.Context
archive ArchiveInfo
targetDB string
cleanFirst bool
createIfMissing bool
restoreType string
cleanClusterFirst bool // Drop all user databases before cluster restore
existingDBs []string // List of databases to drop
existingDBs []string // List of databases to drop
// Progress tracking
status string
phase string
progress int
details []string
startTime time.Time
spinnerFrame int
status string
phase string
progress int
details []string
startTime time.Time
spinnerFrame int
spinnerFrames []string
// Results
done bool
err error
result string
elapsed time.Duration
done bool
err error
result string
elapsed time.Duration
}
// NewRestoreExecution creates a new restore execution model
func NewRestoreExecution(cfg *config.Config, log logger.Logger, parent tea.Model, ctx context.Context, archive ArchiveInfo, targetDB string, cleanFirst, createIfMissing bool, restoreType string, cleanClusterFirst bool, existingDBs []string) RestoreExecutionModel {
return RestoreExecutionModel{
config: cfg,
logger: log,
parent: parent,
ctx: ctx,
archive: archive,
targetDB: targetDB,
cleanFirst: cleanFirst,
createIfMissing: createIfMissing,
restoreType: restoreType,
config: cfg,
logger: log,
parent: parent,
ctx: ctx,
archive: archive,
targetDB: targetDB,
cleanFirst: cleanFirst,
createIfMissing: createIfMissing,
restoreType: restoreType,
cleanClusterFirst: cleanClusterFirst,
existingDBs: existingDBs,
status: "Initializing...",
phase: "Starting",
startTime: time.Now(),
details: []string{},
spinnerFrames: spinnerFrames, // Use package-level constant
spinnerFrame: 0,
existingDBs: existingDBs,
status: "Initializing...",
phase: "Starting",
startTime: time.Now(),
details: []string{},
spinnerFrames: spinnerFrames, // Use package-level constant
spinnerFrame: 0,
}
}
@@ -123,7 +123,7 @@ func executeRestoreWithTUIProgress(parentCtx context.Context, cfg *config.Config
// STEP 1: Clean cluster if requested (drop all existing user databases)
if restoreType == "restore-cluster" && cleanClusterFirst && len(existingDBs) > 0 {
log.Info("Dropping existing user databases before cluster restore", "count", len(existingDBs))
// Drop databases using command-line psql (no connection required)
// This matches how cluster restore works - uses CLI tools, not database connections
droppedCount := 0
@@ -139,13 +139,13 @@ func executeRestoreWithTUIProgress(parentCtx context.Context, cfg *config.Config
}
dropCancel() // Clean up context
}
log.Info("Cluster cleanup completed", "dropped", droppedCount, "total", len(existingDBs))
}
// STEP 2: Create restore engine with silent progress (no stdout interference with TUI)
engine := restore.NewSilent(cfg, log, dbClient)
// Set up progress callback (but it won't work in goroutine - progress is already sent via logs)
// The TUI will just use spinner animation to show activity
@@ -186,11 +186,11 @@ func (m RestoreExecutionModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if !m.done {
m.spinnerFrame = (m.spinnerFrame + 1) % len(m.spinnerFrames)
m.elapsed = time.Since(m.startTime)
// Update status based on elapsed time to show progress
// This provides visual feedback even though we don't have real-time progress
elapsedSec := int(m.elapsed.Seconds())
if elapsedSec < 2 {
m.status = "Initializing restore..."
m.phase = "Starting"
@@ -222,7 +222,7 @@ func (m RestoreExecutionModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.phase = "Restore"
}
}
return m, restoreTickCmd()
}
return m, nil
@@ -245,7 +245,7 @@ func (m RestoreExecutionModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.err = msg.err
m.result = msg.result
m.elapsed = msg.elapsed
if m.err == nil {
m.status = "Restore completed successfully"
m.phase = "Done"
@@ -311,7 +311,7 @@ func (m RestoreExecutionModel) View() string {
} else {
// Show progress
s.WriteString(fmt.Sprintf("Phase: %s\n", m.phase))
// Show status with rotating spinner (unified indicator for all operations)
spinner := m.spinnerFrames[m.spinnerFrame]
s.WriteString(fmt.Sprintf("Status: %s %s\n", spinner, m.status))
@@ -339,10 +339,10 @@ func (m RestoreExecutionModel) View() string {
func renderProgressBar(percent int) string {
width := 40
filled := (percent * width) / 100
bar := strings.Repeat("█", filled)
empty := strings.Repeat("░", width-filled)
return successStyle.Render(bar) + infoStyle.Render(empty)
}
@@ -370,24 +370,23 @@ func dropDatabaseCLI(ctx context.Context, cfg *config.Config, dbName string) err
"-d", "postgres", // Connect to postgres maintenance DB
"-c", fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbName),
}
// Only add -h flag if host is not localhost (to use Unix socket for peer auth)
if cfg.Host != "localhost" && cfg.Host != "127.0.0.1" && cfg.Host != "" {
args = append([]string{"-h", cfg.Host}, args...)
}
cmd := exec.CommandContext(ctx, "psql", args...)
// Set password if provided
if cfg.Password != "" {
cmd.Env = append(cmd.Environ(), fmt.Sprintf("PGPASSWORD=%s", cfg.Password))
}
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("failed to drop database %s: %w\nOutput: %s", dbName, err, string(output))
}
return nil
}

View File

@@ -43,22 +43,22 @@ type SafetyCheck struct {
// RestorePreviewModel shows restore preview and safety checks
type RestorePreviewModel struct {
config *config.Config
logger logger.Logger
parent tea.Model
ctx context.Context
archive ArchiveInfo
mode string
targetDB string
cleanFirst bool
createIfMissing bool
cleanClusterFirst bool // For cluster restore: drop all user databases first
existingDBCount int // Number of existing user databases
existingDBs []string // List of existing user databases
safetyChecks []SafetyCheck
checking bool
canProceed bool
message string
config *config.Config
logger logger.Logger
parent tea.Model
ctx context.Context
archive ArchiveInfo
mode string
targetDB string
cleanFirst bool
createIfMissing bool
cleanClusterFirst bool // For cluster restore: drop all user databases first
existingDBCount int // Number of existing user databases
existingDBs []string // List of existing user databases
safetyChecks []SafetyCheck
checking bool
canProceed bool
message string
}
// NewRestorePreview creates a new restore preview
@@ -70,16 +70,16 @@ func NewRestorePreview(cfg *config.Config, log logger.Logger, parent tea.Model,
}
return RestorePreviewModel{
config: cfg,
logger: log,
parent: parent,
ctx: ctx,
archive: archive,
mode: mode,
targetDB: targetDB,
cleanFirst: false,
config: cfg,
logger: log,
parent: parent,
ctx: ctx,
archive: archive,
mode: mode,
targetDB: targetDB,
cleanFirst: false,
createIfMissing: true,
checking: true,
checking: true,
safetyChecks: []SafetyCheck{
{Name: "Archive integrity", Status: "pending", Critical: true},
{Name: "Disk space", Status: "pending", Critical: true},
@@ -156,7 +156,7 @@ func runSafetyChecks(cfg *config.Config, log logger.Logger, archive ArchiveInfo,
// 4. Target database check (skip for cluster restores)
existingDBCount := 0
existingDBs := []string{}
if !archive.Format.IsClusterBackup() {
check = SafetyCheck{Name: "Target database", Status: "checking", Critical: false}
exists, err := safety.CheckDatabaseExists(ctx, targetDB)
@@ -174,7 +174,7 @@ func runSafetyChecks(cfg *config.Config, log logger.Logger, archive ArchiveInfo,
} else {
// For cluster restores, detect existing user databases
check = SafetyCheck{Name: "Existing databases", Status: "checking", Critical: false}
// Get list of existing user databases (exclude templates and system DBs)
dbList, err := safety.ListUserDatabases(ctx)
if err != nil {
@@ -183,7 +183,7 @@ func runSafetyChecks(cfg *config.Config, log logger.Logger, archive ArchiveInfo,
} else {
existingDBCount = len(dbList)
existingDBs = dbList
if existingDBCount > 0 {
check.Status = "warning"
check.Message = fmt.Sprintf("Found %d existing user database(s) - can be cleaned before restore", existingDBCount)
@@ -288,13 +288,13 @@ func (m RestorePreviewModel) View() string {
s.WriteString("\n")
s.WriteString(fmt.Sprintf(" Database: %s\n", m.targetDB))
s.WriteString(fmt.Sprintf(" Host: %s:%d\n", m.config.Host, m.config.Port))
cleanIcon := "✗"
if m.cleanFirst {
cleanIcon = "✓"
}
s.WriteString(fmt.Sprintf(" Clean First: %s %v\n", cleanIcon, m.cleanFirst))
createIcon := "✗"
if m.createIfMissing {
createIcon = "✓"
@@ -305,10 +305,10 @@ func (m RestorePreviewModel) View() string {
s.WriteString(archiveHeaderStyle.Render("🎯 Cluster Restore Options"))
s.WriteString("\n")
s.WriteString(fmt.Sprintf(" Host: %s:%d\n", m.config.Host, m.config.Port))
if m.existingDBCount > 0 {
s.WriteString(fmt.Sprintf(" Existing Databases: %d found\n", m.existingDBCount))
// Show first few database names
maxShow := 5
for i, db := range m.existingDBs {
@@ -319,7 +319,7 @@ func (m RestorePreviewModel) View() string {
}
s.WriteString(fmt.Sprintf(" - %s\n", db))
}
cleanIcon := "✗"
cleanStyle := infoStyle
if m.cleanClusterFirst {
@@ -344,7 +344,7 @@ func (m RestorePreviewModel) View() string {
for _, check := range m.safetyChecks {
icon := "○"
style := checkPendingStyle
switch check.Status {
case "passed":
icon = "✓"

View File

@@ -75,7 +75,7 @@ func NewSettingsModel(cfg *config.Config, log logger.Logger, parent tea.Model) S
}
nextIdx := (currentIdx + 1) % len(workloads)
c.CPUWorkloadType = workloads[nextIdx]
// Recalculate Jobs and DumpJobs based on workload type
if c.CPUInfo != nil && c.AutoDetectCores {
switch c.CPUWorkloadType {
@@ -329,7 +329,7 @@ func NewSettingsModel(cfg *config.Config, log logger.Logger, parent tea.Model) S
{
Key: "cloud_access_key",
DisplayName: "Cloud Access Key",
Value: func(c *config.Config) string {
Value: func(c *config.Config) string {
if c.CloudAccessKey != "" {
return "***" + c.CloudAccessKey[len(c.CloudAccessKey)-4:]
}
@@ -624,7 +624,7 @@ func (m SettingsModel) saveSettings() (tea.Model, tea.Cmd) {
// cycleDatabaseType cycles through database type options
func (m SettingsModel) cycleDatabaseType() (tea.Model, tea.Cmd) {
dbTypes := []string{"postgres", "mysql", "mariadb"}
// Find current index
currentIdx := 0
for i, dbType := range dbTypes {
@@ -633,17 +633,17 @@ func (m SettingsModel) cycleDatabaseType() (tea.Model, tea.Cmd) {
break
}
}
// Cycle to next
nextIdx := (currentIdx + 1) % len(dbTypes)
newType := dbTypes[nextIdx]
// Update config
if err := m.config.SetDatabaseType(newType); err != nil {
m.message = errorStyle.Render(fmt.Sprintf("❌ Failed to set database type: %s", err.Error()))
return m, nil
}
m.message = successStyle.Render(fmt.Sprintf("✅ Database type set to %s", m.config.DisplayDatabaseType()))
return m, nil
}
@@ -726,7 +726,7 @@ func (m SettingsModel) View() string {
fmt.Sprintf("Compression: Level %d", m.config.CompressionLevel),
fmt.Sprintf("Jobs: %d parallel, %d dump", m.config.Jobs, m.config.DumpJobs),
}
if m.config.CloudEnabled {
cloudInfo := fmt.Sprintf("Cloud: %s (%s)", m.config.CloudProvider, m.config.CloudBucket)
if m.config.CloudAutoUpload {

View File

@@ -9,14 +9,14 @@ import (
// Result represents the outcome of a verification operation
type Result struct {
Valid bool
BackupFile string
ExpectedSHA256 string
Valid bool
BackupFile string
ExpectedSHA256 string
CalculatedSHA256 string
SizeMatch bool
FileExists bool
MetadataExists bool
Error error
SizeMatch bool
FileExists bool
MetadataExists bool
Error error
}
// Verify checks the integrity of a backup file
@@ -47,7 +47,7 @@ func Verify(backupFile string) (*Result, error) {
// Check size match
if info.Size() != meta.SizeBytes {
result.SizeMatch = false
result.Error = fmt.Errorf("size mismatch: expected %d bytes, got %d bytes",
result.Error = fmt.Errorf("size mismatch: expected %d bytes, got %d bytes",
meta.SizeBytes, info.Size())
return result, nil
}
@@ -64,7 +64,7 @@ func Verify(backupFile string) (*Result, error) {
// Compare checksums
if actualSHA256 != meta.SHA256 {
result.Valid = false
result.Error = fmt.Errorf("checksum mismatch: expected %s, got %s",
result.Error = fmt.Errorf("checksum mismatch: expected %s, got %s",
meta.SHA256, actualSHA256)
return result, nil
}
@@ -77,7 +77,7 @@ func Verify(backupFile string) (*Result, error) {
// VerifyMultiple verifies multiple backup files
func VerifyMultiple(backupFiles []string) ([]*Result, error) {
var results []*Result
for _, file := range backupFiles {
result, err := Verify(file)
if err != nil {
@@ -106,7 +106,7 @@ func QuickCheck(backupFile string) error {
// Check size
if info.Size() != meta.SizeBytes {
return fmt.Errorf("size mismatch: expected %d bytes, got %d bytes",
return fmt.Errorf("size mismatch: expected %d bytes, got %d bytes",
meta.SizeBytes, info.Size())
}

View File

@@ -21,26 +21,26 @@ type Archiver struct {
// ArchiveConfig holds WAL archiving configuration
type ArchiveConfig struct {
ArchiveDir string // Directory to store archived WAL files
CompressWAL bool // Compress WAL files with gzip
EncryptWAL bool // Encrypt WAL files
EncryptionKey []byte // 32-byte key for AES-256-GCM encryption
RetentionDays int // Days to keep WAL archives
VerifyChecksum bool // Verify WAL file checksums
ArchiveDir string // Directory to store archived WAL files
CompressWAL bool // Compress WAL files with gzip
EncryptWAL bool // Encrypt WAL files
EncryptionKey []byte // 32-byte key for AES-256-GCM encryption
RetentionDays int // Days to keep WAL archives
VerifyChecksum bool // Verify WAL file checksums
}
// WALArchiveInfo contains metadata about an archived WAL file
type WALArchiveInfo struct {
WALFileName string `json:"wal_filename"`
ArchivePath string `json:"archive_path"`
OriginalSize int64 `json:"original_size"`
ArchivedSize int64 `json:"archived_size"`
Checksum string `json:"checksum"`
Timeline uint32 `json:"timeline"`
Segment uint64 `json:"segment"`
ArchivedAt time.Time `json:"archived_at"`
Compressed bool `json:"compressed"`
Encrypted bool `json:"encrypted"`
WALFileName string `json:"wal_filename"`
ArchivePath string `json:"archive_path"`
OriginalSize int64 `json:"original_size"`
ArchivedSize int64 `json:"archived_size"`
Checksum string `json:"checksum"`
Timeline uint32 `json:"timeline"`
Segment uint64 `json:"segment"`
ArchivedAt time.Time `json:"archived_at"`
Compressed bool `json:"compressed"`
Encrypted bool `json:"encrypted"`
}
// NewArchiver creates a new WAL archiver
@@ -77,7 +77,7 @@ func (a *Archiver) ArchiveWALFile(ctx context.Context, walFilePath, walFileName
// Process WAL file: compression and/or encryption
var archivePath string
var archivedSize int64
if config.CompressWAL && config.EncryptWAL {
// Compress then encrypt
archivePath, archivedSize, err = a.compressAndEncryptWAL(walFilePath, walFileName, config)
@@ -150,7 +150,7 @@ func (a *Archiver) copyWAL(walFilePath, walFileName string, config ArchiveConfig
// compressWAL compresses a WAL file using gzip
func (a *Archiver) compressWAL(walFilePath, walFileName string, config ArchiveConfig) (string, int64, error) {
archivePath := filepath.Join(config.ArchiveDir, walFileName+".gz")
compressor := NewCompressor(a.log)
compressedSize, err := compressor.CompressWALFile(walFilePath, archivePath, 6) // gzip level 6 (balanced)
if err != nil {
@@ -163,12 +163,12 @@ func (a *Archiver) compressWAL(walFilePath, walFileName string, config ArchiveCo
// encryptWAL encrypts a WAL file
func (a *Archiver) encryptWAL(walFilePath, walFileName string, config ArchiveConfig) (string, int64, error) {
archivePath := filepath.Join(config.ArchiveDir, walFileName+".enc")
encryptor := NewEncryptor(a.log)
encOpts := EncryptionOptions{
Key: config.EncryptionKey,
}
encryptedSize, err := encryptor.EncryptWALFile(walFilePath, archivePath, encOpts)
if err != nil {
return "", 0, fmt.Errorf("WAL encryption failed: %w", err)
@@ -199,7 +199,7 @@ func (a *Archiver) compressAndEncryptWAL(walFilePath, walFileName string, config
encOpts := EncryptionOptions{
Key: config.EncryptionKey,
}
encryptedSize, err := encryptor.EncryptWALFile(tempCompressed, archivePath, encOpts)
if err != nil {
return "", 0, fmt.Errorf("WAL encryption failed: %w", err)
@@ -340,7 +340,7 @@ func (a *Archiver) GetArchiveStats(config ArchiveConfig) (*ArchiveStats, error)
for _, archive := range archives {
stats.TotalSize += archive.ArchivedSize
if archive.Compressed {
stats.CompressedFiles++
}

View File

@@ -11,6 +11,7 @@ import (
"path/filepath"
"dbbackup/internal/logger"
"golang.org/x/crypto/pbkdf2"
)

View File

@@ -23,14 +23,14 @@ type PITRManager struct {
// PITRConfig holds PITR settings
type PITRConfig struct {
Enabled bool
ArchiveMode string // "on", "off", "always"
ArchiveCommand string
ArchiveDir string
WALLevel string // "minimal", "replica", "logical"
MaxWALSenders int
WALKeepSize string // e.g., "1GB"
RestoreCommand string
Enabled bool
ArchiveMode string // "on", "off", "always"
ArchiveCommand string
ArchiveDir string
WALLevel string // "minimal", "replica", "logical"
MaxWALSenders int
WALKeepSize string // e.g., "1GB"
RestoreCommand string
}
// RecoveryTarget specifies the point-in-time to recover to
@@ -87,11 +87,11 @@ func (pm *PITRManager) EnablePITR(ctx context.Context, archiveDir string) error
// Settings to enable PITR
settings := map[string]string{
"wal_level": "replica", // Required for PITR
"archive_mode": "on",
"archive_command": archiveCommand,
"max_wal_senders": "3",
"wal_keep_size": "1GB", // Keep at least 1GB of WAL
"wal_level": "replica", // Required for PITR
"archive_mode": "on",
"archive_command": archiveCommand,
"max_wal_senders": "3",
"wal_keep_size": "1GB", // Keep at least 1GB of WAL
}
// Update postgresql.conf
@@ -156,7 +156,7 @@ func (pm *PITRManager) GetCurrentPITRConfig(ctx context.Context) (*PITRConfig, e
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
// Skip comments and empty lines
if line == "" || strings.HasPrefix(line, "#") {
continue
@@ -226,11 +226,11 @@ func (pm *PITRManager) createRecoverySignal(ctx context.Context, dataDir string,
// Recovery settings go in postgresql.auto.conf (PostgreSQL 12+)
autoConfPath := filepath.Join(dataDir, "postgresql.auto.conf")
// Build recovery settings
var settings []string
settings = append(settings, fmt.Sprintf("restore_command = 'cp %s/%%f %%p'", walArchiveDir))
if target.TargetTime != nil {
settings = append(settings, fmt.Sprintf("recovery_target_time = '%s'", target.TargetTime.Format("2006-01-02 15:04:05")))
} else if target.TargetXID != "" {
@@ -270,11 +270,11 @@ func (pm *PITRManager) createRecoverySignal(ctx context.Context, dataDir string,
// createLegacyRecoveryConf creates recovery.conf for PostgreSQL < 12
func (pm *PITRManager) createLegacyRecoveryConf(dataDir string, target RecoveryTarget, walArchiveDir string) error {
recoveryConfPath := filepath.Join(dataDir, "recovery.conf")
var content strings.Builder
content.WriteString("# Recovery Configuration (created by dbbackup)\n")
content.WriteString(fmt.Sprintf("restore_command = 'cp %s/%%f %%p'\n", walArchiveDir))
if target.TargetTime != nil {
content.WriteString(fmt.Sprintf("recovery_target_time = '%s'\n", target.TargetTime.Format("2006-01-02 15:04:05")))
}

View File

@@ -40,9 +40,9 @@ type TimelineInfo struct {
// TimelineHistory represents the complete timeline branching structure
type TimelineHistory struct {
Timelines []*TimelineInfo // All timelines sorted by ID
Timelines []*TimelineInfo // All timelines sorted by ID
CurrentTimeline uint32 // Current active timeline
TimelineMap map[uint32]*TimelineInfo // Quick lookup by timeline ID
TimelineMap map[uint32]*TimelineInfo // Quick lookup by timeline ID
}
// ParseTimelineHistory parses timeline history from an archive directory
@@ -74,10 +74,10 @@ func (tm *TimelineManager) ParseTimelineHistory(ctx context.Context, archiveDir
// Always add timeline 1 (base timeline) if not present
if _, exists := history.TimelineMap[1]; !exists {
baseTimeline := &TimelineInfo{
TimelineID: 1,
ParentTimeline: 0,
SwitchPoint: "0/0",
Reason: "Base timeline",
TimelineID: 1,
ParentTimeline: 0,
SwitchPoint: "0/0",
Reason: "Base timeline",
FirstWALSegment: 0,
}
history.Timelines = append(history.Timelines, baseTimeline)
@@ -201,7 +201,7 @@ func (tm *TimelineManager) scanWALSegments(archiveDir string, history *TimelineH
// Process each WAL file
for _, walFile := range walFiles {
filename := filepath.Base(walFile)
// Remove extensions
filename = strings.TrimSuffix(filename, ".gz.enc")
filename = strings.TrimSuffix(filename, ".enc")
@@ -255,7 +255,7 @@ func (tm *TimelineManager) ValidateTimelineConsistency(ctx context.Context, hist
parent, exists := history.TimelineMap[tl.ParentTimeline]
if !exists {
return fmt.Errorf("timeline %d references non-existent parent timeline %d",
return fmt.Errorf("timeline %d references non-existent parent timeline %d",
tl.TimelineID, tl.ParentTimeline)
}
@@ -274,29 +274,29 @@ func (tm *TimelineManager) ValidateTimelineConsistency(ctx context.Context, hist
// GetTimelinePath returns the path from timeline 1 to the target timeline
func (tm *TimelineManager) GetTimelinePath(history *TimelineHistory, targetTimeline uint32) ([]*TimelineInfo, error) {
path := make([]*TimelineInfo, 0)
currentTL := targetTimeline
for currentTL > 0 {
tl, exists := history.TimelineMap[currentTL]
if !exists {
return nil, fmt.Errorf("timeline %d not found in history", currentTL)
}
// Prepend to path (we're walking backwards)
path = append([]*TimelineInfo{tl}, path...)
// Move to parent
if currentTL == 1 {
break // Reached base timeline
}
currentTL = tl.ParentTimeline
// Prevent infinite loops
if len(path) > 100 {
return nil, fmt.Errorf("timeline path too long (possible cycle)")
}
}
return path, nil
}
@@ -305,13 +305,13 @@ func (tm *TimelineManager) FindTimelineAtPoint(history *TimelineHistory, targetL
// Start from current timeline and walk backwards
for i := len(history.Timelines) - 1; i >= 0; i-- {
tl := history.Timelines[i]
// Compare LSNs (simplified - in production would need proper LSN comparison)
if tl.SwitchPoint <= targetLSN || tl.SwitchPoint == "0/0" {
return tl.TimelineID, nil
}
}
// Default to timeline 1
return 1, nil
}
@@ -384,23 +384,23 @@ func (tm *TimelineManager) formatTimelineNode(sb *strings.Builder, history *Time
}
sb.WriteString(fmt.Sprintf("%s%s Timeline %d", indent, marker, tl.TimelineID))
if tl.TimelineID == history.CurrentTimeline {
sb.WriteString(" [CURRENT]")
}
if tl.SwitchPoint != "" && tl.SwitchPoint != "0/0" {
sb.WriteString(fmt.Sprintf(" (switched at %s)", tl.SwitchPoint))
}
if tl.FirstWALSegment > 0 {
sb.WriteString(fmt.Sprintf("\n%s WAL segments: %d files", indent, tl.LastWALSegment-tl.FirstWALSegment+1))
}
if tl.Reason != "" {
sb.WriteString(fmt.Sprintf("\n%s Reason: %s", indent, tl.Reason))
}
sb.WriteString("\n")
// Find and format children