diff --git a/cmd/backup_impl.go b/cmd/backup_impl.go index 105bf92..9653e07 100644 --- a/cmd/backup_impl.go +++ b/cmd/backup_impl.go @@ -24,6 +24,20 @@ func runClusterBackup(ctx context.Context) error { return fmt.Errorf("configuration error: %w", err) } + // Check privileges + privChecker := security.NewPrivilegeChecker(log) + if err := privChecker.CheckAndWarn(cfg.AllowRoot); err != nil { + return err + } + + // Check resource limits + if cfg.CheckResources { + resChecker := security.NewResourceChecker(log) + if _, err := resChecker.CheckResourceLimits(); err != nil { + log.Warn("Failed to check resource limits", "error", err) + } + } + log.Info("Starting cluster backup", "host", cfg.Host, "port", cfg.Port, @@ -33,6 +47,13 @@ func runClusterBackup(ctx context.Context) error { user := security.GetCurrentUser() auditLogger.LogBackupStart(user, "all_databases", "cluster") + // Rate limit connection attempts + host := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port) + if err := rateLimiter.CheckAndWait(host); err != nil { + auditLogger.LogBackupFailed(user, "all_databases", err) + return fmt.Errorf("rate limit exceeded: %w", err) + } + // Create database instance db, err := database.New(cfg, log) if err != nil { @@ -43,9 +64,11 @@ func runClusterBackup(ctx context.Context) error { // Connect to database if err := db.Connect(ctx); err != nil { + rateLimiter.RecordFailure(host) auditLogger.LogBackupFailed(user, "all_databases", err) return fmt.Errorf("failed to connect to database: %w", err) } + rateLimiter.RecordSuccess(host) // Create backup engine engine := backup.New(cfg, log, db) @@ -59,6 +82,16 @@ func runClusterBackup(ctx context.Context) error { // Audit log: backup success auditLogger.LogBackupComplete(user, "all_databases", cfg.BackupDir, 0) + // Cleanup old backups if retention policy is enabled + if cfg.RetentionDays > 0 { + retentionPolicy := security.NewRetentionPolicy(cfg.RetentionDays, cfg.MinBackups, log) + if deleted, freed, err := retentionPolicy.CleanupOldBackups(cfg.BackupDir); err != nil { + log.Warn("Failed to cleanup old backups", "error", err) + } else if deleted > 0 { + log.Info("Cleaned up old backups", "deleted", deleted, "freed_mb", freed/1024/1024) + } + } + // Save configuration for future use (unless disabled) if !cfg.NoSaveConfig { localCfg := config.ConfigFromConfig(cfg) @@ -83,6 +116,12 @@ func runSingleBackup(ctx context.Context, databaseName string) error { return fmt.Errorf("configuration error: %w", err) } + // Check privileges + privChecker := security.NewPrivilegeChecker(log) + if err := privChecker.CheckAndWarn(cfg.AllowRoot); err != nil { + return err + } + log.Info("Starting single database backup", "database", databaseName, "db_type", cfg.DatabaseType, @@ -94,6 +133,13 @@ func runSingleBackup(ctx context.Context, databaseName string) error { user := security.GetCurrentUser() auditLogger.LogBackupStart(user, databaseName, "single") + // Rate limit connection attempts + host := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port) + if err := rateLimiter.CheckAndWait(host); err != nil { + auditLogger.LogBackupFailed(user, databaseName, err) + return fmt.Errorf("rate limit exceeded: %w", err) + } + // Create database instance db, err := database.New(cfg, log) if err != nil { @@ -104,9 +150,11 @@ func runSingleBackup(ctx context.Context, databaseName string) error { // Connect to database if err := db.Connect(ctx); err != nil { + rateLimiter.RecordFailure(host) auditLogger.LogBackupFailed(user, databaseName, err) return fmt.Errorf("failed to connect to database: %w", err) } + rateLimiter.RecordSuccess(host) // Verify database exists exists, err := db.DatabaseExists(ctx, databaseName) @@ -132,6 +180,16 @@ func runSingleBackup(ctx context.Context, databaseName string) error { // Audit log: backup success auditLogger.LogBackupComplete(user, databaseName, cfg.BackupDir, 0) + // Cleanup old backups if retention policy is enabled + if cfg.RetentionDays > 0 { + retentionPolicy := security.NewRetentionPolicy(cfg.RetentionDays, cfg.MinBackups, log) + if deleted, freed, err := retentionPolicy.CleanupOldBackups(cfg.BackupDir); err != nil { + log.Warn("Failed to cleanup old backups", "error", err) + } else if deleted > 0 { + log.Info("Cleaned up old backups", "deleted", deleted, "freed_mb", freed/1024/1024) + } + } + // Save configuration for future use (unless disabled) if !cfg.NoSaveConfig { localCfg := config.ConfigFromConfig(cfg) @@ -156,6 +214,12 @@ func runSampleBackup(ctx context.Context, databaseName string) error { return fmt.Errorf("configuration error: %w", err) } + // Check privileges + privChecker := security.NewPrivilegeChecker(log) + if err := privChecker.CheckAndWarn(cfg.AllowRoot); err != nil { + return err + } + // Validate sample parameters if cfg.SampleValue <= 0 { return fmt.Errorf("sample value must be greater than 0") @@ -189,6 +253,13 @@ func runSampleBackup(ctx context.Context, databaseName string) error { user := security.GetCurrentUser() auditLogger.LogBackupStart(user, databaseName, "sample") + // Rate limit connection attempts + host := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port) + if err := rateLimiter.CheckAndWait(host); err != nil { + auditLogger.LogBackupFailed(user, databaseName, err) + return fmt.Errorf("rate limit exceeded: %w", err) + } + // Create database instance db, err := database.New(cfg, log) if err != nil { @@ -199,9 +270,11 @@ func runSampleBackup(ctx context.Context, databaseName string) error { // Connect to database if err := db.Connect(ctx); err != nil { + rateLimiter.RecordFailure(host) auditLogger.LogBackupFailed(user, databaseName, err) return fmt.Errorf("failed to connect to database: %w", err) } + rateLimiter.RecordSuccess(host) // Verify database exists exists, err := db.DatabaseExists(ctx, databaseName) diff --git a/cmd/root.go b/cmd/root.go index 465ff7a..d62fe59 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -14,6 +14,7 @@ var ( cfg *config.Config log logger.Logger auditLogger *security.AuditLogger + rateLimiter *security.RateLimiter ) // rootCmd represents the base command when called without any subcommands @@ -62,6 +63,9 @@ func Execute(ctx context.Context, config *config.Config, logger logger.Logger) e // Initialize audit logger auditLogger = security.NewAuditLogger(logger, true) + + // Initialize rate limiter + rateLimiter = security.NewRateLimiter(config.MaxRetries, logger) // Set version info rootCmd.Version = fmt.Sprintf("%s (built: %s, commit: %s)", @@ -87,6 +91,13 @@ func Execute(ctx context.Context, config *config.Config, logger logger.Logger) e rootCmd.PersistentFlags().IntVar(&cfg.CompressionLevel, "compression", cfg.CompressionLevel, "Compression level (0-9)") rootCmd.PersistentFlags().BoolVar(&cfg.NoSaveConfig, "no-save-config", false, "Don't save configuration after successful operations") rootCmd.PersistentFlags().BoolVar(&cfg.NoLoadConfig, "no-config", false, "Don't load configuration from .dbbackup.conf") + + // Security flags (MEDIUM priority) + rootCmd.PersistentFlags().IntVar(&cfg.RetentionDays, "retention-days", cfg.RetentionDays, "Backup retention period in days (0=disabled)") + rootCmd.PersistentFlags().IntVar(&cfg.MinBackups, "min-backups", cfg.MinBackups, "Minimum number of backups to keep") + rootCmd.PersistentFlags().IntVar(&cfg.MaxRetries, "max-retries", cfg.MaxRetries, "Maximum connection retry attempts") + rootCmd.PersistentFlags().BoolVar(&cfg.AllowRoot, "allow-root", cfg.AllowRoot, "Allow running as root/Administrator") + rootCmd.PersistentFlags().BoolVar(&cfg.CheckResources, "check-resources", cfg.CheckResources, "Check system resource limits") return rootCmd.ExecuteContext(ctx) } diff --git a/internal/config/config.go b/internal/config/config.go index 5f920fc..d884efb 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -68,6 +68,13 @@ type Config struct { SwapFilePath string // Path to temporary swap file SwapFileSizeGB int // Size in GB (0 = disabled) AutoSwap bool // Automatically manage swap for large backups + + // Security options (MEDIUM priority) + RetentionDays int // Backup retention in days (0 = disabled) + MinBackups int // Minimum backups to keep regardless of age + MaxRetries int // Maximum connection retry attempts + AllowRoot bool // Allow running as root/Administrator + CheckResources bool // Check resource limits before operations } // New creates a new configuration with default values @@ -158,6 +165,13 @@ func New() *Config { SwapFilePath: getEnvString("SWAP_FILE_PATH", "/tmp/dbbackup_swap"), SwapFileSizeGB: getEnvInt("SWAP_FILE_SIZE_GB", 0), // 0 = disabled by default AutoSwap: getEnvBool("AUTO_SWAP", false), + + // Security defaults (MEDIUM priority) + RetentionDays: getEnvInt("RETENTION_DAYS", 30), // Keep backups for 30 days + MinBackups: getEnvInt("MIN_BACKUPS", 5), // Keep at least 5 backups + MaxRetries: getEnvInt("MAX_RETRIES", 3), // Maximum 3 retry attempts + AllowRoot: getEnvBool("ALLOW_ROOT", false), // Disallow root by default + CheckResources: getEnvBool("CHECK_RESOURCES", true), // Check resources by default } // Ensure canonical defaults are enforced diff --git a/internal/config/persist.go b/internal/config/persist.go index 732c79b..176129d 100644 --- a/internal/config/persist.go +++ b/internal/config/persist.go @@ -29,6 +29,11 @@ type LocalConfig struct { // Performance settings CPUWorkload string MaxCores int + + // Security settings + RetentionDays int + MinBackups int + MaxRetries int } // LoadLocalConfig loads configuration from .dbbackup.conf in current directory @@ -114,6 +119,21 @@ func LoadLocalConfig() (*LocalConfig, error) { cfg.MaxCores = mc } } + case "security": + switch key { + case "retention_days": + if rd, err := strconv.Atoi(value); err == nil { + cfg.RetentionDays = rd + } + case "min_backups": + if mb, err := strconv.Atoi(value); err == nil { + cfg.MinBackups = mb + } + case "max_retries": + if mr, err := strconv.Atoi(value); err == nil { + cfg.MaxRetries = mr + } + } } } @@ -173,6 +193,19 @@ func SaveLocalConfig(cfg *LocalConfig) error { if cfg.MaxCores != 0 { sb.WriteString(fmt.Sprintf("max_cores = %d\n", cfg.MaxCores)) } + sb.WriteString("\n") + + // Security section + sb.WriteString("[security]\n") + if cfg.RetentionDays != 0 { + sb.WriteString(fmt.Sprintf("retention_days = %d\n", cfg.RetentionDays)) + } + if cfg.MinBackups != 0 { + sb.WriteString(fmt.Sprintf("min_backups = %d\n", cfg.MinBackups)) + } + if cfg.MaxRetries != 0 { + sb.WriteString(fmt.Sprintf("max_retries = %d\n", cfg.MaxRetries)) + } configPath := filepath.Join(".", ConfigFileName) // Use 0600 permissions for security (readable/writable only by owner) @@ -226,22 +259,34 @@ func ApplyLocalConfig(cfg *Config, local *LocalConfig) { if local.MaxCores != 0 { cfg.MaxCores = local.MaxCores } + if cfg.RetentionDays == 30 && local.RetentionDays != 0 { + cfg.RetentionDays = local.RetentionDays + } + if cfg.MinBackups == 5 && local.MinBackups != 0 { + cfg.MinBackups = local.MinBackups + } + if cfg.MaxRetries == 3 && local.MaxRetries != 0 { + cfg.MaxRetries = local.MaxRetries + } } // ConfigFromConfig creates a LocalConfig from a Config func ConfigFromConfig(cfg *Config) *LocalConfig { return &LocalConfig{ - DBType: cfg.DatabaseType, - Host: cfg.Host, - Port: cfg.Port, - User: cfg.User, - Database: cfg.Database, - SSLMode: cfg.SSLMode, - BackupDir: cfg.BackupDir, - Compression: cfg.CompressionLevel, - Jobs: cfg.Jobs, - DumpJobs: cfg.DumpJobs, - CPUWorkload: cfg.CPUWorkloadType, - MaxCores: cfg.MaxCores, + DBType: cfg.DatabaseType, + Host: cfg.Host, + Port: cfg.Port, + User: cfg.User, + Database: cfg.Database, + SSLMode: cfg.SSLMode, + BackupDir: cfg.BackupDir, + Compression: cfg.CompressionLevel, + Jobs: cfg.Jobs, + DumpJobs: cfg.DumpJobs, + CPUWorkload: cfg.CPUWorkloadType, + MaxCores: cfg.MaxCores, + RetentionDays: cfg.RetentionDays, + MinBackups: cfg.MinBackups, + MaxRetries: cfg.MaxRetries, } } diff --git a/internal/security/privileges.go b/internal/security/privileges.go new file mode 100644 index 0000000..aaa1ea7 --- /dev/null +++ b/internal/security/privileges.go @@ -0,0 +1,99 @@ +package security + +import ( + "fmt" + "os" + "runtime" + + "dbbackup/internal/logger" +) + +// PrivilegeChecker checks for elevated privileges +type PrivilegeChecker struct { + log logger.Logger +} + +// NewPrivilegeChecker creates a new privilege checker +func NewPrivilegeChecker(log logger.Logger) *PrivilegeChecker { + return &PrivilegeChecker{ + log: log, + } +} + +// CheckAndWarn checks if running with elevated privileges and warns +func (pc *PrivilegeChecker) CheckAndWarn(allowRoot bool) error { + isRoot, user := pc.isRunningAsRoot() + + if isRoot { + pc.log.Warn("⚠️ Running with elevated privileges (root/Administrator)") + pc.log.Warn("Security recommendation: Create a dedicated backup user with minimal privileges") + + if !allowRoot { + return fmt.Errorf("running as root is not recommended, use --allow-root to override") + } + + pc.log.Warn("Proceeding with root privileges (--allow-root specified)") + } else { + pc.log.Debug("Running as non-privileged user", "user", user) + } + + return nil +} + +// isRunningAsRoot checks if current process has root/admin privileges +func (pc *PrivilegeChecker) isRunningAsRoot() (bool, string) { + if runtime.GOOS == "windows" { + return pc.isWindowsAdmin() + } + return pc.isUnixRoot() +} + +// isUnixRoot checks for root on Unix-like systems +func (pc *PrivilegeChecker) isUnixRoot() (bool, string) { + uid := os.Getuid() + user := GetCurrentUser() + + isRoot := uid == 0 || user == "root" + return isRoot, user +} + +// isWindowsAdmin checks for Administrator on Windows +func (pc *PrivilegeChecker) isWindowsAdmin() (bool, string) { + // Check if running as Administrator on Windows + // This is a simplified check - full implementation would use Windows API + user := GetCurrentUser() + + // Common admin user patterns on Windows + isAdmin := user == "Administrator" || user == "SYSTEM" + + return isAdmin, user +} + +// GetRecommendedUser returns recommended non-privileged username +func (pc *PrivilegeChecker) GetRecommendedUser() string { + if runtime.GOOS == "windows" { + return "BackupUser" + } + return "dbbackup" +} + +// GetSecurityRecommendations returns security best practices +func (pc *PrivilegeChecker) GetSecurityRecommendations() []string { + recommendations := []string{ + "Create a dedicated backup user with minimal database privileges", + "Grant only necessary permissions (SELECT, LOCK TABLES for MySQL)", + "Use connection strings instead of environment variables in production", + "Store credentials in secure credential management systems", + "Enable SSL/TLS for database connections", + "Restrict backup directory permissions (chmod 700)", + "Regularly rotate database passwords", + "Monitor audit logs for unauthorized access attempts", + } + + if runtime.GOOS != "windows" { + recommendations = append(recommendations, + fmt.Sprintf("Run as non-root user: sudo -u %s dbbackup ...", pc.GetRecommendedUser())) + } + + return recommendations +} diff --git a/internal/security/ratelimit.go b/internal/security/ratelimit.go new file mode 100644 index 0000000..d672e79 --- /dev/null +++ b/internal/security/ratelimit.go @@ -0,0 +1,176 @@ +package security + +import ( + "fmt" + "sync" + "time" + + "dbbackup/internal/logger" +) + +// RateLimiter tracks connection attempts and enforces rate limiting +type RateLimiter struct { + attempts map[string]*attemptTracker + mu sync.RWMutex + maxRetries int + baseDelay time.Duration + maxDelay time.Duration + resetInterval time.Duration + log logger.Logger +} + +// attemptTracker tracks connection attempts for a specific host +type attemptTracker struct { + count int + lastAttempt time.Time + nextAllowed time.Time +} + +// NewRateLimiter creates a new rate limiter for connection attempts +func NewRateLimiter(maxRetries int, log logger.Logger) *RateLimiter { + return &RateLimiter{ + attempts: make(map[string]*attemptTracker), + maxRetries: maxRetries, + baseDelay: 1 * time.Second, + maxDelay: 60 * time.Second, + resetInterval: 5 * time.Minute, + log: log, + } +} + +// CheckAndWait checks if connection is allowed and waits if rate limited +// Returns error if max retries exceeded +func (rl *RateLimiter) CheckAndWait(host string) error { + rl.mu.Lock() + defer rl.mu.Unlock() + + now := time.Now() + tracker, exists := rl.attempts[host] + + if !exists { + // First attempt, allow immediately + rl.attempts[host] = &attemptTracker{ + count: 1, + lastAttempt: now, + nextAllowed: now, + } + return nil + } + + // Reset counter if enough time has passed + if now.Sub(tracker.lastAttempt) > rl.resetInterval { + rl.log.Debug("Resetting rate limit counter", "host", host) + tracker.count = 1 + tracker.lastAttempt = now + tracker.nextAllowed = now + return nil + } + + // Check if max retries exceeded + if tracker.count >= rl.maxRetries { + return fmt.Errorf("max connection retries (%d) exceeded for host %s, try again in %v", + rl.maxRetries, host, rl.resetInterval) + } + + // Calculate exponential backoff delay + delay := rl.calculateDelay(tracker.count) + tracker.nextAllowed = tracker.lastAttempt.Add(delay) + + // Wait if necessary + if now.Before(tracker.nextAllowed) { + waitTime := tracker.nextAllowed.Sub(now) + rl.log.Info("Rate limiting connection attempt", + "host", host, + "attempt", tracker.count, + "wait_seconds", int(waitTime.Seconds())) + + rl.mu.Unlock() + time.Sleep(waitTime) + rl.mu.Lock() + } + + // Update tracker + tracker.count++ + tracker.lastAttempt = time.Now() + + return nil +} + +// RecordSuccess resets the attempt counter for successful connections +func (rl *RateLimiter) RecordSuccess(host string) { + rl.mu.Lock() + defer rl.mu.Unlock() + + if tracker, exists := rl.attempts[host]; exists { + rl.log.Debug("Connection successful, resetting rate limit", "host", host) + tracker.count = 0 + tracker.lastAttempt = time.Now() + tracker.nextAllowed = time.Now() + } +} + +// RecordFailure increments the failure counter +func (rl *RateLimiter) RecordFailure(host string) { + rl.mu.Lock() + defer rl.mu.Unlock() + + now := time.Now() + tracker, exists := rl.attempts[host] + + if !exists { + rl.attempts[host] = &attemptTracker{ + count: 1, + lastAttempt: now, + nextAllowed: now.Add(rl.baseDelay), + } + return + } + + tracker.count++ + tracker.lastAttempt = now + tracker.nextAllowed = now.Add(rl.calculateDelay(tracker.count)) + + rl.log.Warn("Connection failed", + "host", host, + "attempt", tracker.count, + "max_retries", rl.maxRetries) +} + +// calculateDelay calculates exponential backoff delay +func (rl *RateLimiter) calculateDelay(attempt int) time.Duration { + // Exponential backoff: 1s, 2s, 4s, 8s, 16s, 32s, max 60s + delay := rl.baseDelay * time.Duration(1< rl.maxDelay { + delay = rl.maxDelay + } + return delay +} + +// GetStatus returns current rate limit status for a host +func (rl *RateLimiter) GetStatus(host string) (attempts int, nextAllowed time.Time, isLimited bool) { + rl.mu.RLock() + defer rl.mu.RUnlock() + + tracker, exists := rl.attempts[host] + if !exists { + return 0, time.Now(), false + } + + now := time.Now() + isLimited = now.Before(tracker.nextAllowed) + + return tracker.count, tracker.nextAllowed, isLimited +} + +// Cleanup removes old entries from rate limiter +func (rl *RateLimiter) Cleanup() { + rl.mu.Lock() + defer rl.mu.Unlock() + + now := time.Now() + for host, tracker := range rl.attempts { + if now.Sub(tracker.lastAttempt) > rl.resetInterval*2 { + delete(rl.attempts, host) + } + } +} diff --git a/internal/security/resources.go b/internal/security/resources.go new file mode 100644 index 0000000..a5e154b --- /dev/null +++ b/internal/security/resources.go @@ -0,0 +1,169 @@ +package security + +import ( + "fmt" + "runtime" + "syscall" + + "dbbackup/internal/logger" +) + +// ResourceChecker checks system resource limits +type ResourceChecker struct { + log logger.Logger +} + +// NewResourceChecker creates a new resource checker +func NewResourceChecker(log logger.Logger) *ResourceChecker { + return &ResourceChecker{ + log: log, + } +} + +// ResourceLimits holds system resource limit information +type ResourceLimits struct { + MaxOpenFiles uint64 + MaxProcesses uint64 + MaxMemory uint64 + MaxAddressSpace uint64 + Available bool + Platform string +} + +// CheckResourceLimits checks and reports system resource limits +func (rc *ResourceChecker) CheckResourceLimits() (*ResourceLimits, error) { + if runtime.GOOS == "windows" { + return rc.checkWindowsLimits() + } + return rc.checkUnixLimits() +} + +// checkUnixLimits checks resource limits on Unix-like systems +func (rc *ResourceChecker) checkUnixLimits() (*ResourceLimits, error) { + limits := &ResourceLimits{ + Available: true, + Platform: runtime.GOOS, + } + + // Check max open files (RLIMIT_NOFILE) + var rLimit syscall.Rlimit + if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rLimit); err == nil { + limits.MaxOpenFiles = rLimit.Cur + rc.log.Debug("Resource limit: max open files", "limit", rLimit.Cur, "max", rLimit.Max) + + if rLimit.Cur < 1024 { + rc.log.Warn("⚠️ Low file descriptor limit detected", + "current", rLimit.Cur, + "recommended", 4096, + "hint", "Increase with: ulimit -n 4096") + } + } + + // Check max processes (RLIMIT_NPROC) - Linux/BSD only + if runtime.GOOS == "linux" || runtime.GOOS == "freebsd" || runtime.GOOS == "openbsd" { + // RLIMIT_NPROC may not be available on all platforms + const RLIMIT_NPROC = 6 // Linux value + if err := syscall.Getrlimit(RLIMIT_NPROC, &rLimit); err == nil { + limits.MaxProcesses = rLimit.Cur + rc.log.Debug("Resource limit: max processes", "limit", rLimit.Cur) + } + } + + // Check max memory (RLIMIT_AS - address space) + if err := syscall.Getrlimit(syscall.RLIMIT_AS, &rLimit); err == nil { + limits.MaxAddressSpace = rLimit.Cur + // Check if unlimited (max value indicates unlimited) + if rLimit.Cur < ^uint64(0)-1024 { + rc.log.Debug("Resource limit: max address space", "limit_mb", rLimit.Cur/1024/1024) + } + } + + // Check available memory + var memStats runtime.MemStats + runtime.ReadMemStats(&memStats) + limits.MaxMemory = memStats.Sys + + rc.log.Debug("Memory stats", + "alloc_mb", memStats.Alloc/1024/1024, + "sys_mb", memStats.Sys/1024/1024, + "num_gc", memStats.NumGC) + + return limits, nil +} + +// checkWindowsLimits checks resource limits on Windows +func (rc *ResourceChecker) checkWindowsLimits() (*ResourceLimits, error) { + limits := &ResourceLimits{ + Available: true, + Platform: "windows", + MaxOpenFiles: 2048, // Windows default + } + + // Get memory stats + var memStats runtime.MemStats + runtime.ReadMemStats(&memStats) + limits.MaxMemory = memStats.Sys + + rc.log.Debug("Windows memory stats", + "alloc_mb", memStats.Alloc/1024/1024, + "sys_mb", memStats.Sys/1024/1024) + + return limits, nil +} + +// ValidateResourcesForBackup validates resources are sufficient for backup operation +func (rc *ResourceChecker) ValidateResourcesForBackup(estimatedSize int64) error { + limits, err := rc.CheckResourceLimits() + if err != nil { + return fmt.Errorf("failed to check resource limits: %w", err) + } + + var warnings []string + + // Check file descriptor limit on Unix + if runtime.GOOS != "windows" && limits.MaxOpenFiles < 1024 { + warnings = append(warnings, + fmt.Sprintf("Low file descriptor limit (%d), recommended: 4096+", limits.MaxOpenFiles)) + } + + // Check memory (warn if backup size might exceed available memory) + estimatedMemory := estimatedSize / 10 // Rough estimate: 10% of backup size + var memStats runtime.MemStats + runtime.ReadMemStats(&memStats) + availableMemory := memStats.Sys - memStats.Alloc + + if estimatedMemory > int64(availableMemory) { + warnings = append(warnings, + fmt.Sprintf("Backup may require more memory than available (estimated: %dMB, available: %dMB)", + estimatedMemory/1024/1024, availableMemory/1024/1024)) + } + + if len(warnings) > 0 { + for _, warning := range warnings { + rc.log.Warn("⚠️ Resource constraint: " + warning) + } + rc.log.Info("Continuing backup operation (warnings are informational)") + } + + return nil +} + +// GetResourceRecommendations returns recommendations for resource limits +func (rc *ResourceChecker) GetResourceRecommendations() []string { + if runtime.GOOS == "windows" { + return []string{ + "Ensure sufficient disk space (3-4x backup size)", + "Monitor memory usage during large backups", + "Close unnecessary applications before backup", + } + } + + return []string{ + "Set file descriptor limit: ulimit -n 4096", + "Set max processes: ulimit -u 4096", + "Monitor disk space: df -h", + "Check memory: free -h", + "For large backups, consider increasing limits in /etc/security/limits.conf", + "Example limits.conf entry: dbbackup soft nofile 8192", + } +} diff --git a/internal/security/retention.go b/internal/security/retention.go new file mode 100644 index 0000000..06a4da7 --- /dev/null +++ b/internal/security/retention.go @@ -0,0 +1,197 @@ +package security + +import ( + "fmt" + "os" + "path/filepath" + "sort" + "time" + + "dbbackup/internal/logger" +) + +// RetentionPolicy defines backup retention rules +type RetentionPolicy struct { + RetentionDays int + MinBackups int // Minimum backups to keep regardless of age + log logger.Logger +} + +// NewRetentionPolicy creates a new retention policy +func NewRetentionPolicy(retentionDays, minBackups int, log logger.Logger) *RetentionPolicy { + return &RetentionPolicy{ + RetentionDays: retentionDays, + MinBackups: minBackups, + log: log, + } +} + +// ArchiveInfo holds information about a backup archive +type ArchiveInfo struct { + Path string + ModTime time.Time + Size int64 + Database string +} + +// CleanupOldBackups removes backups older than retention period +func (rp *RetentionPolicy) CleanupOldBackups(backupDir string) (int, int64, error) { + if rp.RetentionDays <= 0 { + return 0, 0, nil // Retention disabled + } + + archives, err := rp.scanBackupArchives(backupDir) + if err != nil { + return 0, 0, fmt.Errorf("failed to scan backup directory: %w", err) + } + + if len(archives) <= rp.MinBackups { + rp.log.Debug("Keeping all backups (below minimum threshold)", + "count", len(archives), "min_backups", rp.MinBackups) + return 0, 0, nil + } + + cutoffTime := time.Now().AddDate(0, 0, -rp.RetentionDays) + + // Sort by modification time (oldest first) + sort.Slice(archives, func(i, j int) bool { + return archives[i].ModTime.Before(archives[j].ModTime) + }) + + var deletedCount int + var freedSpace int64 + + for i, archive := range archives { + // Keep minimum number of backups + remaining := len(archives) - i + if remaining <= rp.MinBackups { + rp.log.Debug("Stopped cleanup to maintain minimum backups", + "remaining", remaining, "min_backups", rp.MinBackups) + break + } + + // Delete if older than retention period + if archive.ModTime.Before(cutoffTime) { + rp.log.Info("Removing old backup", + "file", filepath.Base(archive.Path), + "age_days", int(time.Since(archive.ModTime).Hours()/24), + "size_mb", archive.Size/1024/1024) + + if err := os.Remove(archive.Path); err != nil { + rp.log.Warn("Failed to remove old backup", "file", archive.Path, "error", err) + continue + } + + // Also remove checksum file if exists + checksumPath := archive.Path + ".sha256" + if _, err := os.Stat(checksumPath); err == nil { + os.Remove(checksumPath) + } + + // Also remove metadata file if exists + metadataPath := archive.Path + ".meta" + if _, err := os.Stat(metadataPath); err == nil { + os.Remove(metadataPath) + } + + deletedCount++ + freedSpace += archive.Size + } + } + + if deletedCount > 0 { + rp.log.Info("Cleanup completed", + "deleted_backups", deletedCount, + "freed_space_mb", freedSpace/1024/1024, + "retention_days", rp.RetentionDays) + } + + return deletedCount, freedSpace, nil +} + +// scanBackupArchives scans directory for backup archives +func (rp *RetentionPolicy) scanBackupArchives(backupDir string) ([]ArchiveInfo, error) { + var archives []ArchiveInfo + + entries, err := os.ReadDir(backupDir) + if err != nil { + return nil, err + } + + for _, entry := range entries { + if entry.IsDir() { + continue + } + + name := entry.Name() + + // Skip non-backup files + if !isBackupArchive(name) { + continue + } + + path := filepath.Join(backupDir, name) + info, err := entry.Info() + if err != nil { + rp.log.Warn("Failed to get file info", "file", name, "error", err) + continue + } + + archives = append(archives, ArchiveInfo{ + Path: path, + ModTime: info.ModTime(), + Size: info.Size(), + Database: extractDatabaseName(name), + }) + } + + return archives, nil +} + +// isBackupArchive checks if filename is a backup archive +func isBackupArchive(name string) bool { + return (filepath.Ext(name) == ".dump" || + filepath.Ext(name) == ".sql" || + filepath.Ext(name) == ".gz" || + filepath.Ext(name) == ".tar") && + name != ".sha256" && + name != ".meta" +} + +// extractDatabaseName extracts database name from archive filename +func extractDatabaseName(filename string) string { + base := filepath.Base(filename) + + // Remove extensions + for { + oldBase := base + base = removeExtension(base) + if base == oldBase { + break + } + } + + // Remove timestamp patterns + if len(base) > 20 { + // Typically: db_name_20240101_120000 + underscoreCount := 0 + for i := len(base) - 1; i >= 0; i-- { + if base[i] == '_' { + underscoreCount++ + if underscoreCount >= 2 { + return base[:i] + } + } + } + } + + return base +} + +// removeExtension removes one extension from filename +func removeExtension(name string) string { + if ext := filepath.Ext(name); ext != "" { + return name[:len(name)-len(ext)] + } + return name +}