ci: add golangci-lint config and fix formatting
- Add .golangci.yml with minimal linters (govet, ineffassign) - Run gofmt -s and goimports on all files to fix formatting - Disable fieldalignment and copylocks checks in govet
This commit is contained in:
@@ -16,13 +16,13 @@ import (
|
||||
type AuthMethod string
|
||||
|
||||
const (
|
||||
AuthPeer AuthMethod = "peer"
|
||||
AuthIdent AuthMethod = "ident"
|
||||
AuthMD5 AuthMethod = "md5"
|
||||
AuthScramSHA256 AuthMethod = "scram-sha-256"
|
||||
AuthPassword AuthMethod = "password"
|
||||
AuthTrust AuthMethod = "trust"
|
||||
AuthUnknown AuthMethod = "unknown"
|
||||
AuthPeer AuthMethod = "peer"
|
||||
AuthIdent AuthMethod = "ident"
|
||||
AuthMD5 AuthMethod = "md5"
|
||||
AuthScramSHA256 AuthMethod = "scram-sha-256"
|
||||
AuthPassword AuthMethod = "password"
|
||||
AuthTrust AuthMethod = "trust"
|
||||
AuthUnknown AuthMethod = "unknown"
|
||||
)
|
||||
|
||||
// DetectPostgreSQLAuthMethod attempts to detect the authentication method
|
||||
@@ -108,7 +108,7 @@ func parseHbaContent(content string, user string) AuthMethod {
|
||||
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
|
||||
// Skip comments and empty lines
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
@@ -198,29 +198,29 @@ func buildAuthMismatchMessage(osUser, dbUser string, method AuthMethod) string {
|
||||
|
||||
msg.WriteString("\n⚠️ Authentication Mismatch Detected\n")
|
||||
msg.WriteString(strings.Repeat("=", 60) + "\n\n")
|
||||
|
||||
|
||||
msg.WriteString(fmt.Sprintf(" PostgreSQL is using '%s' authentication\n", method))
|
||||
msg.WriteString(fmt.Sprintf(" OS user '%s' cannot authenticate as DB user '%s'\n\n", osUser, dbUser))
|
||||
|
||||
|
||||
msg.WriteString("💡 Solutions (choose one):\n\n")
|
||||
|
||||
|
||||
msg.WriteString(fmt.Sprintf(" 1. Run as matching user:\n"))
|
||||
msg.WriteString(fmt.Sprintf(" sudo -u %s %s\n\n", dbUser, getCommandLine()))
|
||||
|
||||
|
||||
msg.WriteString(" 2. Configure ~/.pgpass file (recommended):\n")
|
||||
msg.WriteString(fmt.Sprintf(" echo \"localhost:5432:*:%s:your_password\" > ~/.pgpass\n", dbUser))
|
||||
msg.WriteString(" chmod 0600 ~/.pgpass\n\n")
|
||||
|
||||
|
||||
msg.WriteString(" 3. Set PGPASSWORD environment variable:\n")
|
||||
msg.WriteString(fmt.Sprintf(" export PGPASSWORD=your_password\n"))
|
||||
msg.WriteString(fmt.Sprintf(" %s\n\n", getCommandLine()))
|
||||
|
||||
|
||||
msg.WriteString(" 4. Provide password via flag:\n")
|
||||
msg.WriteString(fmt.Sprintf(" %s --password your_password\n\n", getCommandLine()))
|
||||
|
||||
|
||||
msg.WriteString("📝 Note: For production use, ~/.pgpass or PGPASSWORD are recommended\n")
|
||||
msg.WriteString(" to avoid exposing passwords in command history.\n\n")
|
||||
|
||||
|
||||
msg.WriteString(strings.Repeat("=", 60) + "\n")
|
||||
|
||||
return msg.String()
|
||||
@@ -231,29 +231,29 @@ func getCommandLine() string {
|
||||
if len(os.Args) == 0 {
|
||||
return "./dbbackup"
|
||||
}
|
||||
|
||||
|
||||
// Build command without password if present
|
||||
var parts []string
|
||||
skipNext := false
|
||||
|
||||
|
||||
for _, arg := range os.Args {
|
||||
if skipNext {
|
||||
skipNext = false
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
if arg == "--password" || arg == "-p" {
|
||||
skipNext = true
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
if strings.HasPrefix(arg, "--password=") {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
parts = append(parts, arg)
|
||||
}
|
||||
|
||||
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
@@ -298,7 +298,7 @@ func parsePgpass(path string, cfg *config.Config) string {
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
|
||||
|
||||
// Skip comments and empty lines
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
// The original file is replaced with the encrypted version
|
||||
func EncryptBackupFile(backupPath string, key []byte, log logger.Logger) error {
|
||||
log.Info("Encrypting backup file", "file", filepath.Base(backupPath))
|
||||
|
||||
|
||||
// Validate key
|
||||
if err := crypto.ValidateKey(key); err != nil {
|
||||
return fmt.Errorf("invalid encryption key: %w", err)
|
||||
@@ -81,25 +81,25 @@ func IsBackupEncrypted(backupPath string) bool {
|
||||
// All databases are unencrypted
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
// Try single database metadata
|
||||
if meta, err := metadata.Load(backupPath); err == nil {
|
||||
return meta.Encrypted
|
||||
}
|
||||
|
||||
|
||||
// Fallback: check if file starts with encryption nonce
|
||||
file, err := os.Open(backupPath)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
|
||||
// Try to read nonce - if it succeeds, likely encrypted
|
||||
nonce := make([]byte, crypto.NonceSize)
|
||||
if n, err := file.Read(nonce); err != nil || n != crypto.NonceSize {
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
@@ -20,11 +20,11 @@ import (
|
||||
"dbbackup/internal/cloud"
|
||||
"dbbackup/internal/config"
|
||||
"dbbackup/internal/database"
|
||||
"dbbackup/internal/security"
|
||||
"dbbackup/internal/logger"
|
||||
"dbbackup/internal/metadata"
|
||||
"dbbackup/internal/metrics"
|
||||
"dbbackup/internal/progress"
|
||||
"dbbackup/internal/security"
|
||||
"dbbackup/internal/swap"
|
||||
)
|
||||
|
||||
@@ -42,7 +42,7 @@ type Engine struct {
|
||||
func New(cfg *config.Config, log logger.Logger, db database.Database) *Engine {
|
||||
progressIndicator := progress.NewIndicator(true, "line") // Use line-by-line indicator
|
||||
detailedReporter := progress.NewDetailedReporter(progressIndicator, &loggerAdapter{logger: log})
|
||||
|
||||
|
||||
return &Engine{
|
||||
cfg: cfg,
|
||||
log: log,
|
||||
@@ -56,7 +56,7 @@ func New(cfg *config.Config, log logger.Logger, db database.Database) *Engine {
|
||||
// NewWithProgress creates a new backup engine with a custom progress indicator
|
||||
func NewWithProgress(cfg *config.Config, log logger.Logger, db database.Database, progressIndicator progress.Indicator) *Engine {
|
||||
detailedReporter := progress.NewDetailedReporter(progressIndicator, &loggerAdapter{logger: log})
|
||||
|
||||
|
||||
return &Engine{
|
||||
cfg: cfg,
|
||||
log: log,
|
||||
@@ -73,9 +73,9 @@ func NewSilent(cfg *config.Config, log logger.Logger, db database.Database, prog
|
||||
if progressIndicator == nil {
|
||||
progressIndicator = progress.NewNullIndicator()
|
||||
}
|
||||
|
||||
|
||||
detailedReporter := progress.NewDetailedReporter(progressIndicator, &loggerAdapter{logger: log})
|
||||
|
||||
|
||||
return &Engine{
|
||||
cfg: cfg,
|
||||
log: log,
|
||||
@@ -126,16 +126,16 @@ func (e *Engine) BackupSingle(ctx context.Context, databaseName string) error {
|
||||
// Start detailed operation tracking
|
||||
operationID := generateOperationID()
|
||||
tracker := e.detailedReporter.StartOperation(operationID, databaseName, "backup")
|
||||
|
||||
|
||||
// Add operation details
|
||||
tracker.SetDetails("database", databaseName)
|
||||
tracker.SetDetails("type", "single")
|
||||
tracker.SetDetails("compression", strconv.Itoa(e.cfg.CompressionLevel))
|
||||
tracker.SetDetails("format", "custom")
|
||||
|
||||
|
||||
// Start preparing backup directory
|
||||
prepStep := tracker.AddStep("prepare", "Preparing backup directory")
|
||||
|
||||
|
||||
// Validate and sanitize backup directory path
|
||||
validBackupDir, err := security.ValidateBackupPath(e.cfg.BackupDir)
|
||||
if err != nil {
|
||||
@@ -144,7 +144,7 @@ func (e *Engine) BackupSingle(ctx context.Context, databaseName string) error {
|
||||
return fmt.Errorf("invalid backup directory path: %w", err)
|
||||
}
|
||||
e.cfg.BackupDir = validBackupDir
|
||||
|
||||
|
||||
if err := os.MkdirAll(e.cfg.BackupDir, 0755); err != nil {
|
||||
err = fmt.Errorf("failed to create backup directory %s. Check write permissions or use --backup-dir to specify writable location: %w", e.cfg.BackupDir, err)
|
||||
prepStep.Fail(err)
|
||||
@@ -153,20 +153,20 @@ func (e *Engine) BackupSingle(ctx context.Context, databaseName string) error {
|
||||
}
|
||||
prepStep.Complete("Backup directory prepared")
|
||||
tracker.UpdateProgress(10, "Backup directory prepared")
|
||||
|
||||
|
||||
// Generate timestamp and filename
|
||||
timestamp := time.Now().Format("20060102_150405")
|
||||
var outputFile string
|
||||
|
||||
|
||||
if e.cfg.IsPostgreSQL() {
|
||||
outputFile = filepath.Join(e.cfg.BackupDir, fmt.Sprintf("db_%s_%s.dump", databaseName, timestamp))
|
||||
} else {
|
||||
outputFile = filepath.Join(e.cfg.BackupDir, fmt.Sprintf("db_%s_%s.sql.gz", databaseName, timestamp))
|
||||
}
|
||||
|
||||
|
||||
tracker.SetDetails("output_file", outputFile)
|
||||
tracker.UpdateProgress(20, "Generated backup filename")
|
||||
|
||||
|
||||
// Build backup command
|
||||
cmdStep := tracker.AddStep("command", "Building backup command")
|
||||
options := database.BackupOptions{
|
||||
@@ -177,15 +177,15 @@ func (e *Engine) BackupSingle(ctx context.Context, databaseName string) error {
|
||||
NoOwner: false,
|
||||
NoPrivileges: false,
|
||||
}
|
||||
|
||||
|
||||
cmd := e.db.BuildBackupCommand(databaseName, outputFile, options)
|
||||
cmdStep.Complete("Backup command prepared")
|
||||
tracker.UpdateProgress(30, "Backup command prepared")
|
||||
|
||||
|
||||
// Execute backup command with progress monitoring
|
||||
execStep := tracker.AddStep("execute", "Executing database backup")
|
||||
tracker.UpdateProgress(40, "Starting database backup...")
|
||||
|
||||
|
||||
if err := e.executeCommandWithProgress(ctx, cmd, outputFile, tracker); err != nil {
|
||||
err = fmt.Errorf("backup failed for %s: %w. Check database connectivity and disk space", databaseName, err)
|
||||
execStep.Fail(err)
|
||||
@@ -194,7 +194,7 @@ func (e *Engine) BackupSingle(ctx context.Context, databaseName string) error {
|
||||
}
|
||||
execStep.Complete("Database backup completed")
|
||||
tracker.UpdateProgress(80, "Database backup completed")
|
||||
|
||||
|
||||
// Verify backup file
|
||||
verifyStep := tracker.AddStep("verify", "Verifying backup file")
|
||||
if info, err := os.Stat(outputFile); err != nil {
|
||||
@@ -209,7 +209,7 @@ func (e *Engine) BackupSingle(ctx context.Context, databaseName string) error {
|
||||
verifyStep.Complete(fmt.Sprintf("Backup file verified: %s", size))
|
||||
tracker.UpdateProgress(90, fmt.Sprintf("Backup verified: %s", size))
|
||||
}
|
||||
|
||||
|
||||
// Calculate and save checksum
|
||||
checksumStep := tracker.AddStep("checksum", "Calculating SHA-256 checksum")
|
||||
if checksum, err := security.ChecksumFile(outputFile); err != nil {
|
||||
@@ -223,7 +223,7 @@ func (e *Engine) BackupSingle(ctx context.Context, databaseName string) error {
|
||||
e.log.Info("Backup checksum", "sha256", checksum)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Create metadata file
|
||||
metaStep := tracker.AddStep("metadata", "Creating metadata file")
|
||||
if err := e.createMetadata(outputFile, databaseName, "single", ""); err != nil {
|
||||
@@ -232,12 +232,12 @@ func (e *Engine) BackupSingle(ctx context.Context, databaseName string) error {
|
||||
} else {
|
||||
metaStep.Complete("Metadata file created")
|
||||
}
|
||||
|
||||
|
||||
// Record metrics for observability
|
||||
if info, err := os.Stat(outputFile); err == nil && metrics.GlobalMetrics != nil {
|
||||
metrics.GlobalMetrics.RecordOperation("backup_single", databaseName, time.Now().Add(-time.Minute), info.Size(), true, 0)
|
||||
}
|
||||
|
||||
|
||||
// Cloud upload if enabled
|
||||
if e.cfg.CloudEnabled && e.cfg.CloudAutoUpload {
|
||||
if err := e.uploadToCloud(ctx, outputFile, tracker); err != nil {
|
||||
@@ -245,39 +245,39 @@ func (e *Engine) BackupSingle(ctx context.Context, databaseName string) error {
|
||||
// Don't fail the backup if cloud upload fails
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Complete operation
|
||||
tracker.UpdateProgress(100, "Backup operation completed successfully")
|
||||
tracker.Complete(fmt.Sprintf("Single database backup completed: %s", filepath.Base(outputFile)))
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BackupSample performs a sample database backup
|
||||
func (e *Engine) BackupSample(ctx context.Context, databaseName string) error {
|
||||
operation := e.log.StartOperation("Sample Database Backup")
|
||||
|
||||
|
||||
// Ensure backup directory exists
|
||||
if err := os.MkdirAll(e.cfg.BackupDir, 0755); err != nil {
|
||||
operation.Fail("Failed to create backup directory")
|
||||
return fmt.Errorf("failed to create backup directory: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Generate timestamp and filename
|
||||
timestamp := time.Now().Format("20060102_150405")
|
||||
outputFile := filepath.Join(e.cfg.BackupDir,
|
||||
outputFile := filepath.Join(e.cfg.BackupDir,
|
||||
fmt.Sprintf("sample_%s_%s%d_%s.sql", databaseName, e.cfg.SampleStrategy, e.cfg.SampleValue, timestamp))
|
||||
|
||||
|
||||
operation.Update("Starting sample database backup")
|
||||
e.progress.Start(fmt.Sprintf("Creating sample backup of '%s' (%s=%d)", databaseName, e.cfg.SampleStrategy, e.cfg.SampleValue))
|
||||
|
||||
|
||||
// For sample backups, we need to get the schema first, then sample data
|
||||
if err := e.createSampleBackup(ctx, databaseName, outputFile); err != nil {
|
||||
e.progress.Fail(fmt.Sprintf("Sample backup failed: %v", err))
|
||||
operation.Fail("Sample backup failed")
|
||||
return fmt.Errorf("sample backup failed: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Check output file
|
||||
if info, err := os.Stat(outputFile); err != nil {
|
||||
e.progress.Fail("Sample backup file not created")
|
||||
@@ -288,12 +288,12 @@ func (e *Engine) BackupSample(ctx context.Context, databaseName string) error {
|
||||
e.progress.Complete(fmt.Sprintf("Sample backup completed: %s (%s)", filepath.Base(outputFile), size))
|
||||
operation.Complete(fmt.Sprintf("Sample backup created: %s (%s)", outputFile, size))
|
||||
}
|
||||
|
||||
|
||||
// Create metadata file
|
||||
if err := e.createMetadata(outputFile, databaseName, "sample", e.cfg.SampleStrategy); err != nil {
|
||||
e.log.Warn("Failed to create metadata file", "error", err)
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -302,19 +302,19 @@ func (e *Engine) BackupCluster(ctx context.Context) error {
|
||||
if !e.cfg.IsPostgreSQL() {
|
||||
return fmt.Errorf("cluster backup is only supported for PostgreSQL")
|
||||
}
|
||||
|
||||
|
||||
operation := e.log.StartOperation("Cluster Backup")
|
||||
|
||||
|
||||
// Setup swap file if configured
|
||||
var swapMgr *swap.Manager
|
||||
if e.cfg.AutoSwap && e.cfg.SwapFileSizeGB > 0 {
|
||||
swapMgr = swap.NewManager(e.cfg.SwapFilePath, e.cfg.SwapFileSizeGB, e.log)
|
||||
|
||||
|
||||
if swapMgr.IsSupported() {
|
||||
e.log.Info("Setting up temporary swap file for large backup",
|
||||
"path", e.cfg.SwapFilePath,
|
||||
e.log.Info("Setting up temporary swap file for large backup",
|
||||
"path", e.cfg.SwapFilePath,
|
||||
"size_gb", e.cfg.SwapFileSizeGB)
|
||||
|
||||
|
||||
if err := swapMgr.Setup(); err != nil {
|
||||
e.log.Warn("Failed to setup swap file (continuing without it)", "error", err)
|
||||
} else {
|
||||
@@ -329,7 +329,7 @@ func (e *Engine) BackupCluster(ctx context.Context) error {
|
||||
e.log.Warn("Swap file management not supported on this platform", "os", swapMgr)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Use appropriate progress indicator based on silent mode
|
||||
var quietProgress progress.Indicator
|
||||
if e.silent {
|
||||
@@ -340,42 +340,42 @@ func (e *Engine) BackupCluster(ctx context.Context) error {
|
||||
quietProgress = progress.NewQuietLineByLine()
|
||||
quietProgress.Start("Starting cluster backup (all databases)")
|
||||
}
|
||||
|
||||
|
||||
// Ensure backup directory exists
|
||||
if err := os.MkdirAll(e.cfg.BackupDir, 0755); err != nil {
|
||||
operation.Fail("Failed to create backup directory")
|
||||
quietProgress.Fail("Failed to create backup directory")
|
||||
return fmt.Errorf("failed to create backup directory: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Check disk space before starting backup (cached for performance)
|
||||
e.log.Info("Checking disk space availability")
|
||||
spaceCheck := checks.CheckDiskSpaceCached(e.cfg.BackupDir)
|
||||
|
||||
|
||||
if !e.silent {
|
||||
// Show disk space status in CLI mode
|
||||
fmt.Println("\n" + checks.FormatDiskSpaceMessage(spaceCheck))
|
||||
}
|
||||
|
||||
|
||||
if spaceCheck.Critical {
|
||||
operation.Fail("Insufficient disk space")
|
||||
quietProgress.Fail("Insufficient disk space - free up space and try again")
|
||||
return fmt.Errorf("insufficient disk space: %.1f%% used, operation blocked", spaceCheck.UsedPercent)
|
||||
}
|
||||
|
||||
|
||||
if spaceCheck.Warning {
|
||||
e.log.Warn("Low disk space - backup may fail if database is large",
|
||||
e.log.Warn("Low disk space - backup may fail if database is large",
|
||||
"available_gb", float64(spaceCheck.AvailableBytes)/(1024*1024*1024),
|
||||
"used_percent", spaceCheck.UsedPercent)
|
||||
}
|
||||
|
||||
|
||||
// Generate timestamp and filename
|
||||
timestamp := time.Now().Format("20060102_150405")
|
||||
outputFile := filepath.Join(e.cfg.BackupDir, fmt.Sprintf("cluster_%s.tar.gz", timestamp))
|
||||
tempDir := filepath.Join(e.cfg.BackupDir, fmt.Sprintf(".cluster_%s", timestamp))
|
||||
|
||||
|
||||
operation.Update("Starting cluster backup")
|
||||
|
||||
|
||||
// Create temporary directory
|
||||
if err := os.MkdirAll(filepath.Join(tempDir, "dumps"), 0755); err != nil {
|
||||
operation.Fail("Failed to create temporary directory")
|
||||
@@ -383,7 +383,7 @@ func (e *Engine) BackupCluster(ctx context.Context) error {
|
||||
return fmt.Errorf("failed to create temp directory: %w", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
|
||||
// Backup globals
|
||||
e.printf(" Backing up global objects...\n")
|
||||
if err := e.backupGlobals(ctx, tempDir); err != nil {
|
||||
@@ -391,7 +391,7 @@ func (e *Engine) BackupCluster(ctx context.Context) error {
|
||||
operation.Fail("Global backup failed")
|
||||
return fmt.Errorf("failed to backup globals: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Get list of databases
|
||||
e.printf(" Getting database list...\n")
|
||||
databases, err := e.db.ListDatabases(ctx)
|
||||
@@ -400,31 +400,31 @@ func (e *Engine) BackupCluster(ctx context.Context) error {
|
||||
operation.Fail("Database listing failed")
|
||||
return fmt.Errorf("failed to list databases: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Create ETA estimator for database backups
|
||||
estimator := progress.NewETAEstimator("Backing up cluster", len(databases))
|
||||
quietProgress.SetEstimator(estimator)
|
||||
|
||||
|
||||
// Backup each database
|
||||
parallelism := e.cfg.ClusterParallelism
|
||||
if parallelism < 1 {
|
||||
parallelism = 1 // Ensure at least sequential
|
||||
}
|
||||
|
||||
|
||||
if parallelism == 1 {
|
||||
e.printf(" Backing up %d databases sequentially...\n", len(databases))
|
||||
} else {
|
||||
e.printf(" Backing up %d databases with %d parallel workers...\n", len(databases), parallelism)
|
||||
}
|
||||
|
||||
|
||||
// Use worker pool for parallel backup
|
||||
var successCount, failCount int32
|
||||
var mu sync.Mutex // Protect shared resources (printf, estimator)
|
||||
|
||||
|
||||
// Create semaphore to limit concurrency
|
||||
semaphore := make(chan struct{}, parallelism)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
|
||||
for i, dbName := range databases {
|
||||
// Check if context is cancelled before starting new backup
|
||||
select {
|
||||
@@ -435,14 +435,14 @@ func (e *Engine) BackupCluster(ctx context.Context) error {
|
||||
return fmt.Errorf("backup cancelled: %w", ctx.Err())
|
||||
default:
|
||||
}
|
||||
|
||||
|
||||
wg.Add(1)
|
||||
semaphore <- struct{}{} // Acquire
|
||||
|
||||
|
||||
go func(idx int, name string) {
|
||||
defer wg.Done()
|
||||
defer func() { <-semaphore }() // Release
|
||||
|
||||
|
||||
// Check for cancellation at start of goroutine
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -451,14 +451,14 @@ func (e *Engine) BackupCluster(ctx context.Context) error {
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
|
||||
// Update estimator progress (thread-safe)
|
||||
mu.Lock()
|
||||
estimator.UpdateProgress(idx)
|
||||
e.printf(" [%d/%d] Backing up database: %s\n", idx+1, len(databases), name)
|
||||
quietProgress.Update(fmt.Sprintf("Backing up database %d/%d: %s", idx+1, len(databases), name))
|
||||
mu.Unlock()
|
||||
|
||||
|
||||
// Check database size and warn if very large
|
||||
if size, err := e.db.GetDatabaseSize(ctx, name); err == nil {
|
||||
sizeStr := formatBytes(size)
|
||||
@@ -469,17 +469,17 @@ func (e *Engine) BackupCluster(ctx context.Context) error {
|
||||
}
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
|
||||
dumpFile := filepath.Join(tempDir, "dumps", name+".dump")
|
||||
|
||||
|
||||
compressionLevel := e.cfg.CompressionLevel
|
||||
if compressionLevel > 6 {
|
||||
compressionLevel = 6
|
||||
}
|
||||
|
||||
|
||||
format := "custom"
|
||||
parallel := e.cfg.DumpJobs
|
||||
|
||||
|
||||
if size, err := e.db.GetDatabaseSize(ctx, name); err == nil {
|
||||
if size > 5*1024*1024*1024 {
|
||||
format = "plain"
|
||||
@@ -490,7 +490,7 @@ func (e *Engine) BackupCluster(ctx context.Context) error {
|
||||
mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
options := database.BackupOptions{
|
||||
Compression: compressionLevel,
|
||||
Parallel: parallel,
|
||||
@@ -499,14 +499,14 @@ func (e *Engine) BackupCluster(ctx context.Context) error {
|
||||
NoOwner: false,
|
||||
NoPrivileges: false,
|
||||
}
|
||||
|
||||
|
||||
cmd := e.db.BuildBackupCommand(name, dumpFile, options)
|
||||
|
||||
|
||||
dbCtx, cancel := context.WithTimeout(ctx, 2*time.Hour)
|
||||
defer cancel()
|
||||
err := e.executeCommand(dbCtx, cmd, dumpFile)
|
||||
cancel()
|
||||
|
||||
|
||||
if err != nil {
|
||||
e.log.Warn("Failed to backup database", "database", name, "error", err)
|
||||
mu.Lock()
|
||||
@@ -526,15 +526,15 @@ func (e *Engine) BackupCluster(ctx context.Context) error {
|
||||
}
|
||||
}(i, dbName)
|
||||
}
|
||||
|
||||
|
||||
// Wait for all backups to complete
|
||||
wg.Wait()
|
||||
|
||||
|
||||
successCountFinal := int(atomic.LoadInt32(&successCount))
|
||||
failCountFinal := int(atomic.LoadInt32(&failCount))
|
||||
|
||||
|
||||
e.printf(" Backup summary: %d succeeded, %d failed\n", successCountFinal, failCountFinal)
|
||||
|
||||
|
||||
// Create archive
|
||||
e.printf(" Creating compressed archive...\n")
|
||||
if err := e.createArchive(ctx, tempDir, outputFile); err != nil {
|
||||
@@ -542,7 +542,7 @@ func (e *Engine) BackupCluster(ctx context.Context) error {
|
||||
operation.Fail("Archive creation failed")
|
||||
return fmt.Errorf("failed to create archive: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Check output file
|
||||
if info, err := os.Stat(outputFile); err != nil {
|
||||
quietProgress.Fail("Cluster backup archive not created")
|
||||
@@ -553,12 +553,12 @@ func (e *Engine) BackupCluster(ctx context.Context) error {
|
||||
quietProgress.Complete(fmt.Sprintf("Cluster backup completed: %s (%s)", filepath.Base(outputFile), size))
|
||||
operation.Complete(fmt.Sprintf("Cluster backup created: %s (%s)", outputFile, size))
|
||||
}
|
||||
|
||||
|
||||
// Create cluster metadata file
|
||||
if err := e.createClusterMetadata(outputFile, databases, successCountFinal, failCountFinal); err != nil {
|
||||
e.log.Warn("Failed to create cluster metadata file", "error", err)
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -567,11 +567,11 @@ func (e *Engine) executeCommandWithProgress(ctx context.Context, cmdArgs []strin
|
||||
if len(cmdArgs) == 0 {
|
||||
return fmt.Errorf("empty command")
|
||||
}
|
||||
|
||||
|
||||
e.log.Debug("Executing backup command with progress", "cmd", cmdArgs[0], "args", cmdArgs[1:])
|
||||
|
||||
|
||||
cmd := exec.CommandContext(ctx, cmdArgs[0], cmdArgs[1:]...)
|
||||
|
||||
|
||||
// Set environment variables for database tools
|
||||
cmd.Env = os.Environ()
|
||||
if e.cfg.Password != "" {
|
||||
@@ -581,51 +581,51 @@ func (e *Engine) executeCommandWithProgress(ctx context.Context, cmdArgs []strin
|
||||
cmd.Env = append(cmd.Env, "MYSQL_PWD="+e.cfg.Password)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// For MySQL, handle compression and redirection differently
|
||||
if e.cfg.IsMySQL() && e.cfg.CompressionLevel > 0 {
|
||||
return e.executeMySQLWithProgressAndCompression(ctx, cmdArgs, outputFile, tracker)
|
||||
}
|
||||
|
||||
|
||||
// Get stderr pipe for progress monitoring
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Start the command
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start command: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Monitor progress via stderr
|
||||
go e.monitorCommandProgress(stderr, tracker)
|
||||
|
||||
|
||||
// Wait for command to complete
|
||||
if err := cmd.Wait(); err != nil {
|
||||
return fmt.Errorf("backup command failed: %w", err)
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// monitorCommandProgress monitors command output for progress information
|
||||
func (e *Engine) monitorCommandProgress(stderr io.ReadCloser, tracker *progress.OperationTracker) {
|
||||
defer stderr.Close()
|
||||
|
||||
|
||||
scanner := bufio.NewScanner(stderr)
|
||||
scanner.Buffer(make([]byte, 64*1024), 1024*1024) // 64KB initial, 1MB max for performance
|
||||
progressBase := 40 // Start from 40% since command preparation is done
|
||||
progressBase := 40 // Start from 40% since command preparation is done
|
||||
progressIncrement := 0
|
||||
|
||||
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
e.log.Debug("Command output", "line", line)
|
||||
|
||||
|
||||
// Increment progress gradually based on output
|
||||
if progressBase < 75 {
|
||||
progressIncrement++
|
||||
@@ -634,7 +634,7 @@ func (e *Engine) monitorCommandProgress(stderr io.ReadCloser, tracker *progress.
|
||||
tracker.UpdateProgress(progressBase, "Processing data...")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Look for specific progress indicators
|
||||
if strings.Contains(line, "COPY") {
|
||||
tracker.UpdateProgress(progressBase+5, "Copying table data...")
|
||||
@@ -654,55 +654,55 @@ func (e *Engine) executeMySQLWithProgressAndCompression(ctx context.Context, cmd
|
||||
if e.cfg.Password != "" {
|
||||
dumpCmd.Env = append(dumpCmd.Env, "MYSQL_PWD="+e.cfg.Password)
|
||||
}
|
||||
|
||||
|
||||
// Create gzip command
|
||||
gzipCmd := exec.CommandContext(ctx, "gzip", fmt.Sprintf("-%d", e.cfg.CompressionLevel))
|
||||
|
||||
|
||||
// Create output file
|
||||
outFile, err := os.Create(outputFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create output file: %w", err)
|
||||
}
|
||||
defer outFile.Close()
|
||||
|
||||
|
||||
// Set up pipeline: mysqldump | gzip > outputfile
|
||||
pipe, err := dumpCmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create pipe: %w", err)
|
||||
}
|
||||
|
||||
|
||||
gzipCmd.Stdin = pipe
|
||||
gzipCmd.Stdout = outFile
|
||||
|
||||
|
||||
// Get stderr for progress monitoring
|
||||
stderr, err := dumpCmd.StderrPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Start monitoring progress
|
||||
go e.monitorCommandProgress(stderr, tracker)
|
||||
|
||||
|
||||
// Start both commands
|
||||
if err := gzipCmd.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start gzip: %w", err)
|
||||
}
|
||||
|
||||
|
||||
if err := dumpCmd.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start mysqldump: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Wait for mysqldump to complete
|
||||
if err := dumpCmd.Wait(); err != nil {
|
||||
return fmt.Errorf("mysqldump failed: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Close pipe and wait for gzip
|
||||
pipe.Close()
|
||||
if err := gzipCmd.Wait(); err != nil {
|
||||
return fmt.Errorf("gzip failed: %w", err)
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -714,17 +714,17 @@ func (e *Engine) executeMySQLWithCompression(ctx context.Context, cmdArgs []stri
|
||||
if e.cfg.Password != "" {
|
||||
dumpCmd.Env = append(dumpCmd.Env, "MYSQL_PWD="+e.cfg.Password)
|
||||
}
|
||||
|
||||
|
||||
// Create gzip command
|
||||
gzipCmd := exec.CommandContext(ctx, "gzip", fmt.Sprintf("-%d", e.cfg.CompressionLevel))
|
||||
|
||||
|
||||
// Create output file
|
||||
outFile, err := os.Create(outputFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create output file: %w", err)
|
||||
}
|
||||
defer outFile.Close()
|
||||
|
||||
|
||||
// Set up pipeline: mysqldump | gzip > outputfile
|
||||
stdin, err := dumpCmd.StdoutPipe()
|
||||
if err != nil {
|
||||
@@ -732,20 +732,20 @@ func (e *Engine) executeMySQLWithCompression(ctx context.Context, cmdArgs []stri
|
||||
}
|
||||
gzipCmd.Stdin = stdin
|
||||
gzipCmd.Stdout = outFile
|
||||
|
||||
|
||||
// Start both commands
|
||||
if err := gzipCmd.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start gzip: %w", err)
|
||||
}
|
||||
|
||||
|
||||
if err := dumpCmd.Run(); err != nil {
|
||||
return fmt.Errorf("mysqldump failed: %w", err)
|
||||
}
|
||||
|
||||
|
||||
if err := gzipCmd.Wait(); err != nil {
|
||||
return fmt.Errorf("gzip failed: %w", err)
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -757,23 +757,23 @@ func (e *Engine) createSampleBackup(ctx context.Context, databaseName, outputFil
|
||||
// 2. Get list of tables
|
||||
// 3. For each table, run sampling query
|
||||
// 4. Combine into single SQL file
|
||||
|
||||
|
||||
// For now, we'll use a simple approach with schema-only backup first
|
||||
// Then add sample data
|
||||
|
||||
|
||||
file, err := os.Create(outputFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create sample backup file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
|
||||
// Write header
|
||||
fmt.Fprintf(file, "-- Sample Database Backup\n")
|
||||
fmt.Fprintf(file, "-- Database: %s\n", databaseName)
|
||||
fmt.Fprintf(file, "-- Strategy: %s = %d\n", e.cfg.SampleStrategy, e.cfg.SampleValue)
|
||||
fmt.Fprintf(file, "-- Created: %s\n", time.Now().Format(time.RFC3339))
|
||||
fmt.Fprintf(file, "-- WARNING: This backup may have referential integrity issues!\n\n")
|
||||
|
||||
|
||||
// For PostgreSQL, we can use pg_dump --schema-only first
|
||||
if e.cfg.IsPostgreSQL() {
|
||||
// Get schema
|
||||
@@ -781,61 +781,61 @@ func (e *Engine) createSampleBackup(ctx context.Context, databaseName, outputFil
|
||||
SchemaOnly: true,
|
||||
Format: "plain",
|
||||
})
|
||||
|
||||
|
||||
cmd := exec.CommandContext(ctx, schemaCmd[0], schemaCmd[1:]...)
|
||||
cmd.Env = os.Environ()
|
||||
if e.cfg.Password != "" {
|
||||
cmd.Env = append(cmd.Env, "PGPASSWORD="+e.cfg.Password)
|
||||
}
|
||||
cmd.Stdout = file
|
||||
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to export schema: %w", err)
|
||||
}
|
||||
|
||||
|
||||
fmt.Fprintf(file, "\n-- Sample data follows\n\n")
|
||||
|
||||
|
||||
// Get tables and sample data
|
||||
tables, err := e.db.ListTables(ctx, databaseName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list tables: %w", err)
|
||||
}
|
||||
|
||||
|
||||
strategy := database.SampleStrategy{
|
||||
Type: e.cfg.SampleStrategy,
|
||||
Value: e.cfg.SampleValue,
|
||||
}
|
||||
|
||||
|
||||
for _, table := range tables {
|
||||
fmt.Fprintf(file, "-- Data for table: %s\n", table)
|
||||
sampleQuery := e.db.BuildSampleQuery(databaseName, table, strategy)
|
||||
fmt.Fprintf(file, "\\copy (%s) TO STDOUT\n\n", sampleQuery)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// backupGlobals creates a backup of global PostgreSQL objects
|
||||
func (e *Engine) backupGlobals(ctx context.Context, tempDir string) error {
|
||||
globalsFile := filepath.Join(tempDir, "globals.sql")
|
||||
|
||||
|
||||
cmd := exec.CommandContext(ctx, "pg_dumpall", "--globals-only")
|
||||
if e.cfg.Host != "localhost" {
|
||||
cmd.Args = append(cmd.Args, "-h", e.cfg.Host, "-p", fmt.Sprintf("%d", e.cfg.Port))
|
||||
}
|
||||
cmd.Args = append(cmd.Args, "-U", e.cfg.User)
|
||||
|
||||
|
||||
cmd.Env = os.Environ()
|
||||
if e.cfg.Password != "" {
|
||||
cmd.Env = append(cmd.Env, "PGPASSWORD="+e.cfg.Password)
|
||||
}
|
||||
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return fmt.Errorf("pg_dumpall failed: %w", err)
|
||||
}
|
||||
|
||||
|
||||
return os.WriteFile(globalsFile, output, 0644)
|
||||
}
|
||||
|
||||
@@ -844,13 +844,13 @@ func (e *Engine) createArchive(ctx context.Context, sourceDir, outputFile string
|
||||
// Use pigz for faster parallel compression if available, otherwise use standard gzip
|
||||
compressCmd := "tar"
|
||||
compressArgs := []string{"-czf", outputFile, "-C", sourceDir, "."}
|
||||
|
||||
|
||||
// Check if pigz is available for faster parallel compression
|
||||
if _, err := exec.LookPath("pigz"); err == nil {
|
||||
// Use pigz with number of cores for parallel compression
|
||||
compressArgs = []string{"-cf", "-", "-C", sourceDir, "."}
|
||||
cmd := exec.CommandContext(ctx, "tar", compressArgs...)
|
||||
|
||||
|
||||
// Create output file
|
||||
outFile, err := os.Create(outputFile)
|
||||
if err != nil {
|
||||
@@ -858,10 +858,10 @@ func (e *Engine) createArchive(ctx context.Context, sourceDir, outputFile string
|
||||
goto regularTar
|
||||
}
|
||||
defer outFile.Close()
|
||||
|
||||
|
||||
// Pipe to pigz for parallel compression
|
||||
pigzCmd := exec.CommandContext(ctx, "pigz", "-p", strconv.Itoa(e.cfg.Jobs))
|
||||
|
||||
|
||||
tarOut, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
outFile.Close()
|
||||
@@ -870,7 +870,7 @@ func (e *Engine) createArchive(ctx context.Context, sourceDir, outputFile string
|
||||
}
|
||||
pigzCmd.Stdin = tarOut
|
||||
pigzCmd.Stdout = outFile
|
||||
|
||||
|
||||
// Start both commands
|
||||
if err := pigzCmd.Start(); err != nil {
|
||||
outFile.Close()
|
||||
@@ -881,13 +881,13 @@ func (e *Engine) createArchive(ctx context.Context, sourceDir, outputFile string
|
||||
outFile.Close()
|
||||
goto regularTar
|
||||
}
|
||||
|
||||
|
||||
// Wait for tar to finish
|
||||
if err := cmd.Wait(); err != nil {
|
||||
pigzCmd.Process.Kill()
|
||||
return fmt.Errorf("tar failed: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Wait for pigz to finish
|
||||
if err := pigzCmd.Wait(); err != nil {
|
||||
return fmt.Errorf("pigz compression failed: %w", err)
|
||||
@@ -898,7 +898,7 @@ func (e *Engine) createArchive(ctx context.Context, sourceDir, outputFile string
|
||||
regularTar:
|
||||
// Standard tar with gzip (fallback)
|
||||
cmd := exec.CommandContext(ctx, compressCmd, compressArgs...)
|
||||
|
||||
|
||||
// Stream stderr to avoid memory issues
|
||||
// Use io.Copy to ensure goroutine completes when pipe closes
|
||||
stderr, err := cmd.StderrPipe()
|
||||
@@ -914,7 +914,7 @@ regularTar:
|
||||
// Scanner will exit when stderr pipe closes after cmd.Wait()
|
||||
}()
|
||||
}
|
||||
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return fmt.Errorf("tar failed: %w", err)
|
||||
}
|
||||
@@ -925,26 +925,26 @@ regularTar:
|
||||
// createMetadata creates a metadata file for the backup
|
||||
func (e *Engine) createMetadata(backupFile, database, backupType, strategy string) error {
|
||||
startTime := time.Now()
|
||||
|
||||
|
||||
// Get backup file information
|
||||
info, err := os.Stat(backupFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to stat backup file: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Calculate SHA-256 checksum
|
||||
sha256, err := metadata.CalculateSHA256(backupFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to calculate checksum: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Get database version
|
||||
ctx := context.Background()
|
||||
dbVersion, _ := e.db.GetVersion(ctx)
|
||||
if dbVersion == "" {
|
||||
dbVersion = "unknown"
|
||||
}
|
||||
|
||||
|
||||
// Determine compression format
|
||||
compressionFormat := "none"
|
||||
if e.cfg.CompressionLevel > 0 {
|
||||
@@ -954,7 +954,7 @@ func (e *Engine) createMetadata(backupFile, database, backupType, strategy strin
|
||||
compressionFormat = fmt.Sprintf("gzip-%d", e.cfg.CompressionLevel)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Create backup metadata
|
||||
meta := &metadata.BackupMetadata{
|
||||
Version: "2.0",
|
||||
@@ -973,18 +973,18 @@ func (e *Engine) createMetadata(backupFile, database, backupType, strategy strin
|
||||
Duration: time.Since(startTime).Seconds(),
|
||||
ExtraInfo: make(map[string]string),
|
||||
}
|
||||
|
||||
|
||||
// Add strategy for sample backups
|
||||
if strategy != "" {
|
||||
meta.ExtraInfo["sample_strategy"] = strategy
|
||||
meta.ExtraInfo["sample_value"] = fmt.Sprintf("%d", e.cfg.SampleValue)
|
||||
}
|
||||
|
||||
|
||||
// Save metadata
|
||||
if err := meta.Save(); err != nil {
|
||||
return fmt.Errorf("failed to save metadata: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Also save legacy .info file for backward compatibility
|
||||
legacyMetaFile := backupFile + ".info"
|
||||
legacyContent := fmt.Sprintf(`{
|
||||
@@ -998,39 +998,39 @@ func (e *Engine) createMetadata(backupFile, database, backupType, strategy strin
|
||||
"compression": %d,
|
||||
"size_bytes": %d
|
||||
}`, backupType, database, startTime.Format("20060102_150405"),
|
||||
e.cfg.Host, e.cfg.Port, e.cfg.User, e.cfg.DatabaseType,
|
||||
e.cfg.Host, e.cfg.Port, e.cfg.User, e.cfg.DatabaseType,
|
||||
e.cfg.CompressionLevel, info.Size())
|
||||
|
||||
|
||||
if err := os.WriteFile(legacyMetaFile, []byte(legacyContent), 0644); err != nil {
|
||||
e.log.Warn("Failed to save legacy metadata file", "error", err)
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createClusterMetadata creates metadata for cluster backups
|
||||
func (e *Engine) createClusterMetadata(backupFile string, databases []string, successCount, failCount int) error {
|
||||
startTime := time.Now()
|
||||
|
||||
|
||||
// Get backup file information
|
||||
info, err := os.Stat(backupFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to stat backup file: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Calculate SHA-256 checksum for archive
|
||||
sha256, err := metadata.CalculateSHA256(backupFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to calculate checksum: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Get database version
|
||||
ctx := context.Background()
|
||||
dbVersion, _ := e.db.GetVersion(ctx)
|
||||
if dbVersion == "" {
|
||||
dbVersion = "unknown"
|
||||
}
|
||||
|
||||
|
||||
// Create cluster metadata
|
||||
clusterMeta := &metadata.ClusterMetadata{
|
||||
Version: "2.0",
|
||||
@@ -1050,7 +1050,7 @@ func (e *Engine) createClusterMetadata(backupFile string, databases []string, su
|
||||
"database_version": dbVersion,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
// Add database names to metadata
|
||||
for _, dbName := range databases {
|
||||
dbMeta := metadata.BackupMetadata{
|
||||
@@ -1061,12 +1061,12 @@ func (e *Engine) createClusterMetadata(backupFile string, databases []string, su
|
||||
}
|
||||
clusterMeta.Databases = append(clusterMeta.Databases, dbMeta)
|
||||
}
|
||||
|
||||
|
||||
// Save cluster metadata
|
||||
if err := clusterMeta.Save(backupFile); err != nil {
|
||||
return fmt.Errorf("failed to save cluster metadata: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Also save legacy .info file for backward compatibility
|
||||
legacyMetaFile := backupFile + ".info"
|
||||
legacyContent := fmt.Sprintf(`{
|
||||
@@ -1085,18 +1085,18 @@ func (e *Engine) createClusterMetadata(backupFile string, databases []string, su
|
||||
}`, startTime.Format("20060102_150405"),
|
||||
e.cfg.Host, e.cfg.Port, e.cfg.User, e.cfg.DatabaseType,
|
||||
e.cfg.CompressionLevel, info.Size(), len(databases), successCount, failCount)
|
||||
|
||||
|
||||
if err := os.WriteFile(legacyMetaFile, []byte(legacyContent), 0644); err != nil {
|
||||
e.log.Warn("Failed to save legacy cluster metadata file", "error", err)
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// uploadToCloud uploads a backup file to cloud storage
|
||||
func (e *Engine) uploadToCloud(ctx context.Context, backupFile string, tracker *progress.OperationTracker) error {
|
||||
uploadStep := tracker.AddStep("cloud_upload", "Uploading to cloud storage")
|
||||
|
||||
|
||||
// Create cloud backend
|
||||
cloudCfg := &cloud.Config{
|
||||
Provider: e.cfg.CloudProvider,
|
||||
@@ -1111,23 +1111,23 @@ func (e *Engine) uploadToCloud(ctx context.Context, backupFile string, tracker *
|
||||
Timeout: 300,
|
||||
MaxRetries: 3,
|
||||
}
|
||||
|
||||
|
||||
backend, err := cloud.NewBackend(cloudCfg)
|
||||
if err != nil {
|
||||
uploadStep.Fail(fmt.Errorf("failed to create cloud backend: %w", err))
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
// Get file info
|
||||
info, err := os.Stat(backupFile)
|
||||
if err != nil {
|
||||
uploadStep.Fail(fmt.Errorf("failed to stat backup file: %w", err))
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
filename := filepath.Base(backupFile)
|
||||
e.log.Info("Uploading backup to cloud", "file", filename, "size", cloud.FormatSize(info.Size()))
|
||||
|
||||
|
||||
// Progress callback
|
||||
var lastPercent int
|
||||
progressCallback := func(transferred, total int64) {
|
||||
@@ -1137,14 +1137,14 @@ func (e *Engine) uploadToCloud(ctx context.Context, backupFile string, tracker *
|
||||
lastPercent = percent
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Upload to cloud
|
||||
err = backend.Upload(ctx, backupFile, filename, progressCallback)
|
||||
if err != nil {
|
||||
uploadStep.Fail(fmt.Errorf("cloud upload failed: %w", err))
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
// Also upload metadata file
|
||||
metaFile := backupFile + ".meta.json"
|
||||
if _, err := os.Stat(metaFile); err == nil {
|
||||
@@ -1154,10 +1154,10 @@ func (e *Engine) uploadToCloud(ctx context.Context, backupFile string, tracker *
|
||||
// Don't fail if metadata upload fails
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
uploadStep.Complete(fmt.Sprintf("Uploaded to %s/%s/%s", backend.Name(), e.cfg.CloudBucket, filename))
|
||||
e.log.Info("Backup uploaded to cloud", "provider", backend.Name(), "bucket", e.cfg.CloudBucket, "file", filename)
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1166,9 +1166,9 @@ func (e *Engine) executeCommand(ctx context.Context, cmdArgs []string, outputFil
|
||||
if len(cmdArgs) == 0 {
|
||||
return fmt.Errorf("empty command")
|
||||
}
|
||||
|
||||
|
||||
e.log.Debug("Executing backup command", "cmd", cmdArgs[0], "args", cmdArgs[1:])
|
||||
|
||||
|
||||
// Check if pg_dump will write to stdout (which means we need to handle piping to compressor).
|
||||
// BuildBackupCommand omits --file when format==plain AND compression==0, causing pg_dump
|
||||
// to write to stdout. In that case we must pipe to external compressor.
|
||||
@@ -1192,28 +1192,28 @@ func (e *Engine) executeCommand(ctx context.Context, cmdArgs []string, outputFil
|
||||
if isPlainFormat && !hasFileFlag {
|
||||
usesStdout = true
|
||||
}
|
||||
|
||||
e.log.Debug("Backup command analysis",
|
||||
"plain_format", isPlainFormat,
|
||||
"has_file_flag", hasFileFlag,
|
||||
|
||||
e.log.Debug("Backup command analysis",
|
||||
"plain_format", isPlainFormat,
|
||||
"has_file_flag", hasFileFlag,
|
||||
"uses_stdout", usesStdout,
|
||||
"output_file", outputFile)
|
||||
|
||||
|
||||
// For MySQL, handle compression differently
|
||||
if e.cfg.IsMySQL() && e.cfg.CompressionLevel > 0 {
|
||||
return e.executeMySQLWithCompression(ctx, cmdArgs, outputFile)
|
||||
}
|
||||
|
||||
|
||||
// For plain format writing to stdout, use streaming compression
|
||||
if usesStdout {
|
||||
e.log.Debug("Using streaming compression for large database")
|
||||
return e.executeWithStreamingCompression(ctx, cmdArgs, outputFile)
|
||||
}
|
||||
|
||||
|
||||
// For custom format, pg_dump handles everything (writes directly to file)
|
||||
// NO GO BUFFERING - pg_dump writes directly to disk
|
||||
cmd := exec.CommandContext(ctx, cmdArgs[0], cmdArgs[1:]...)
|
||||
|
||||
|
||||
// Set environment variables for database tools
|
||||
cmd.Env = os.Environ()
|
||||
if e.cfg.Password != "" {
|
||||
@@ -1223,18 +1223,18 @@ func (e *Engine) executeCommand(ctx context.Context, cmdArgs []string, outputFil
|
||||
cmd.Env = append(cmd.Env, "MYSQL_PWD="+e.cfg.Password)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Stream stderr to avoid memory issues with large databases
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Start the command
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start backup command: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Stream stderr output (don't buffer it all in memory)
|
||||
go func() {
|
||||
scanner := bufio.NewScanner(stderr)
|
||||
@@ -1246,13 +1246,13 @@ func (e *Engine) executeCommand(ctx context.Context, cmdArgs []string, outputFil
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
// Wait for command to complete
|
||||
if err := cmd.Wait(); err != nil {
|
||||
e.log.Error("Backup command failed", "error", err, "database", filepath.Base(outputFile))
|
||||
return fmt.Errorf("backup command failed: %w", err)
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1260,7 +1260,7 @@ func (e *Engine) executeCommand(ctx context.Context, cmdArgs []string, outputFil
|
||||
// Uses: pg_dump | pigz > file.sql.gz (zero-copy streaming)
|
||||
func (e *Engine) executeWithStreamingCompression(ctx context.Context, cmdArgs []string, outputFile string) error {
|
||||
e.log.Debug("Using streaming compression for large database")
|
||||
|
||||
|
||||
// Derive compressed output filename. If the output was named *.dump we replace that
|
||||
// with *.sql.gz; otherwise append .gz to the provided output file so we don't
|
||||
// accidentally create unwanted double extensions.
|
||||
@@ -1273,43 +1273,43 @@ func (e *Engine) executeWithStreamingCompression(ctx context.Context, cmdArgs []
|
||||
} else {
|
||||
compressedFile = outputFile + ".gz"
|
||||
}
|
||||
|
||||
|
||||
// Create pg_dump command
|
||||
dumpCmd := exec.CommandContext(ctx, cmdArgs[0], cmdArgs[1:]...)
|
||||
dumpCmd.Env = os.Environ()
|
||||
if e.cfg.Password != "" && e.cfg.IsPostgreSQL() {
|
||||
dumpCmd.Env = append(dumpCmd.Env, "PGPASSWORD="+e.cfg.Password)
|
||||
}
|
||||
|
||||
|
||||
// Check for pigz (parallel gzip)
|
||||
compressor := "gzip"
|
||||
compressorArgs := []string{"-c"}
|
||||
|
||||
|
||||
if _, err := exec.LookPath("pigz"); err == nil {
|
||||
compressor = "pigz"
|
||||
compressorArgs = []string{"-p", strconv.Itoa(e.cfg.Jobs), "-c"}
|
||||
e.log.Debug("Using pigz for parallel compression", "threads", e.cfg.Jobs)
|
||||
}
|
||||
|
||||
|
||||
// Create compression command
|
||||
compressCmd := exec.CommandContext(ctx, compressor, compressorArgs...)
|
||||
|
||||
|
||||
// Create output file
|
||||
outFile, err := os.Create(compressedFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create output file: %w", err)
|
||||
}
|
||||
defer outFile.Close()
|
||||
|
||||
|
||||
// Set up pipeline: pg_dump | pigz > file.sql.gz
|
||||
dumpStdout, err := dumpCmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create dump stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
|
||||
compressCmd.Stdin = dumpStdout
|
||||
compressCmd.Stdout = outFile
|
||||
|
||||
|
||||
// Capture stderr from both commands
|
||||
dumpStderr, err := dumpCmd.StderrPipe()
|
||||
if err != nil {
|
||||
@@ -1319,7 +1319,7 @@ func (e *Engine) executeWithStreamingCompression(ctx context.Context, cmdArgs []
|
||||
if err != nil {
|
||||
e.log.Warn("Failed to capture compress stderr", "error", err)
|
||||
}
|
||||
|
||||
|
||||
// Stream stderr output
|
||||
if dumpStderr != nil {
|
||||
go func() {
|
||||
@@ -1332,7 +1332,7 @@ func (e *Engine) executeWithStreamingCompression(ctx context.Context, cmdArgs []
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
|
||||
if compressStderr != nil {
|
||||
go func() {
|
||||
scanner := bufio.NewScanner(compressStderr)
|
||||
@@ -1344,30 +1344,30 @@ func (e *Engine) executeWithStreamingCompression(ctx context.Context, cmdArgs []
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
|
||||
// Start compression first
|
||||
if err := compressCmd.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start compressor: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Then start pg_dump
|
||||
if err := dumpCmd.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start pg_dump: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Wait for pg_dump to complete
|
||||
if err := dumpCmd.Wait(); err != nil {
|
||||
return fmt.Errorf("pg_dump failed: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Close stdout pipe to signal compressor we're done
|
||||
dumpStdout.Close()
|
||||
|
||||
|
||||
// Wait for compression to complete
|
||||
if err := compressCmd.Wait(); err != nil {
|
||||
return fmt.Errorf("compression failed: %w", err)
|
||||
}
|
||||
|
||||
|
||||
e.log.Debug("Streaming compression completed", "output", compressedFile)
|
||||
return nil
|
||||
}
|
||||
@@ -1384,4 +1384,4 @@ func formatBytes(bytes int64) string {
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(div), "KMGTPE"[exp])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,19 +17,19 @@ const (
|
||||
type IncrementalMetadata struct {
|
||||
// BaseBackupID is the SHA-256 checksum of the base backup this incremental depends on
|
||||
BaseBackupID string `json:"base_backup_id"`
|
||||
|
||||
|
||||
// BaseBackupPath is the filename of the base backup (e.g., "mydb_20250126_120000.tar.gz")
|
||||
BaseBackupPath string `json:"base_backup_path"`
|
||||
|
||||
|
||||
// BaseBackupTimestamp is when the base backup was created
|
||||
BaseBackupTimestamp time.Time `json:"base_backup_timestamp"`
|
||||
|
||||
|
||||
// IncrementalFiles is the number of changed files included in this backup
|
||||
IncrementalFiles int `json:"incremental_files"`
|
||||
|
||||
|
||||
// TotalSize is the total size of changed files (bytes)
|
||||
TotalSize int64 `json:"total_size"`
|
||||
|
||||
|
||||
// BackupChain is the list of all backups needed for restore (base + incrementals)
|
||||
// Ordered from oldest to newest: [base, incr1, incr2, ...]
|
||||
BackupChain []string `json:"backup_chain"`
|
||||
@@ -39,16 +39,16 @@ type IncrementalMetadata struct {
|
||||
type ChangedFile struct {
|
||||
// RelativePath is the path relative to PostgreSQL data directory
|
||||
RelativePath string
|
||||
|
||||
|
||||
// AbsolutePath is the full filesystem path
|
||||
AbsolutePath string
|
||||
|
||||
|
||||
// Size is the file size in bytes
|
||||
Size int64
|
||||
|
||||
|
||||
// ModTime is the last modification time
|
||||
ModTime time.Time
|
||||
|
||||
|
||||
// Checksum is the SHA-256 hash of the file content (optional)
|
||||
Checksum string
|
||||
}
|
||||
@@ -57,13 +57,13 @@ type ChangedFile struct {
|
||||
type IncrementalBackupConfig struct {
|
||||
// BaseBackupPath is the path to the base backup archive
|
||||
BaseBackupPath string
|
||||
|
||||
|
||||
// DataDirectory is the PostgreSQL data directory to scan
|
||||
DataDirectory string
|
||||
|
||||
|
||||
// IncludeWAL determines if WAL files should be included
|
||||
IncludeWAL bool
|
||||
|
||||
|
||||
// CompressionLevel for the incremental archive (0-9)
|
||||
CompressionLevel int
|
||||
}
|
||||
@@ -72,11 +72,11 @@ type IncrementalBackupConfig struct {
|
||||
type BackupChainResolver interface {
|
||||
// FindBaseBackup locates the base backup for an incremental backup
|
||||
FindBaseBackup(ctx context.Context, incrementalBackupID string) (*BackupInfo, error)
|
||||
|
||||
|
||||
// ResolveChain returns the complete chain of backups needed for restore
|
||||
// Returned in order: [base, incr1, incr2, ..., target]
|
||||
ResolveChain(ctx context.Context, targetBackupID string) ([]*BackupInfo, error)
|
||||
|
||||
|
||||
// ValidateChain verifies all backups in the chain exist and are valid
|
||||
ValidateChain(ctx context.Context, chain []*BackupInfo) error
|
||||
}
|
||||
@@ -85,10 +85,10 @@ type BackupChainResolver interface {
|
||||
type IncrementalBackupEngine interface {
|
||||
// FindChangedFiles identifies files changed since the base backup
|
||||
FindChangedFiles(ctx context.Context, config *IncrementalBackupConfig) ([]ChangedFile, error)
|
||||
|
||||
|
||||
// CreateIncrementalBackup creates a new incremental backup
|
||||
CreateIncrementalBackup(ctx context.Context, config *IncrementalBackupConfig, changedFiles []ChangedFile) error
|
||||
|
||||
|
||||
// RestoreIncremental restores an incremental backup on top of a base backup
|
||||
RestoreIncremental(ctx context.Context, baseBackupPath, incrementalPath, targetDir string) error
|
||||
}
|
||||
@@ -101,8 +101,8 @@ type BackupInfo struct {
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Size int64 `json:"size"`
|
||||
Checksum string `json:"checksum"`
|
||||
|
||||
|
||||
// New fields for incremental support
|
||||
BackupType BackupType `json:"backup_type"` // "full" or "incremental"
|
||||
BackupType BackupType `json:"backup_type"` // "full" or "incremental"
|
||||
Incremental *IncrementalMetadata `json:"incremental,omitempty"` // Only present for incremental backups
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ func (e *MySQLIncrementalEngine) FindChangedFiles(ctx context.Context, config *I
|
||||
return nil, fmt.Errorf("failed to load base backup info: %w", err)
|
||||
}
|
||||
|
||||
// Validate base backup is full backup
|
||||
// Validate base backup is full backup
|
||||
if baseInfo.BackupType != "" && baseInfo.BackupType != "full" {
|
||||
return nil, fmt.Errorf("base backup must be a full backup, got: %s", baseInfo.BackupType)
|
||||
}
|
||||
@@ -52,7 +52,7 @@ func (e *MySQLIncrementalEngine) FindChangedFiles(ctx context.Context, config *I
|
||||
|
||||
// Scan data directory for changed files
|
||||
var changedFiles []ChangedFile
|
||||
|
||||
|
||||
err = filepath.Walk(config.DataDirectory, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -199,7 +199,7 @@ func (e *MySQLIncrementalEngine) CreateIncrementalBackup(ctx context.Context, co
|
||||
|
||||
// Generate output filename: dbname_incr_TIMESTAMP.tar.gz
|
||||
timestamp := time.Now().Format("20060102_150405")
|
||||
outputFile := filepath.Join(filepath.Dir(config.BaseBackupPath),
|
||||
outputFile := filepath.Join(filepath.Dir(config.BaseBackupPath),
|
||||
fmt.Sprintf("%s_incr_%s.tar.gz", baseInfo.Database, timestamp))
|
||||
|
||||
e.log.Info("Creating incremental archive", "output", outputFile)
|
||||
@@ -229,19 +229,19 @@ func (e *MySQLIncrementalEngine) CreateIncrementalBackup(ctx context.Context, co
|
||||
|
||||
// Create incremental metadata
|
||||
metadata := &metadata.BackupMetadata{
|
||||
Version: "2.3.0",
|
||||
Timestamp: time.Now(),
|
||||
Database: baseInfo.Database,
|
||||
DatabaseType: baseInfo.DatabaseType,
|
||||
Host: baseInfo.Host,
|
||||
Port: baseInfo.Port,
|
||||
User: baseInfo.User,
|
||||
BackupFile: outputFile,
|
||||
SizeBytes: stat.Size(),
|
||||
SHA256: checksum,
|
||||
Compression: "gzip",
|
||||
BackupType: "incremental",
|
||||
BaseBackup: filepath.Base(config.BaseBackupPath),
|
||||
Version: "2.3.0",
|
||||
Timestamp: time.Now(),
|
||||
Database: baseInfo.Database,
|
||||
DatabaseType: baseInfo.DatabaseType,
|
||||
Host: baseInfo.Host,
|
||||
Port: baseInfo.Port,
|
||||
User: baseInfo.User,
|
||||
BackupFile: outputFile,
|
||||
SizeBytes: stat.Size(),
|
||||
SHA256: checksum,
|
||||
Compression: "gzip",
|
||||
BackupType: "incremental",
|
||||
BaseBackup: filepath.Base(config.BaseBackupPath),
|
||||
Incremental: &metadata.IncrementalMetadata{
|
||||
BaseBackupID: baseInfo.SHA256,
|
||||
BaseBackupPath: filepath.Base(config.BaseBackupPath),
|
||||
|
||||
@@ -40,7 +40,7 @@ func (e *PostgresIncrementalEngine) FindChangedFiles(ctx context.Context, config
|
||||
return nil, fmt.Errorf("failed to load base backup info: %w", err)
|
||||
}
|
||||
|
||||
// Validate base backup is full backup
|
||||
// Validate base backup is full backup
|
||||
if baseInfo.BackupType != "" && baseInfo.BackupType != "full" {
|
||||
return nil, fmt.Errorf("base backup must be a full backup, got: %s", baseInfo.BackupType)
|
||||
}
|
||||
@@ -50,7 +50,7 @@ func (e *PostgresIncrementalEngine) FindChangedFiles(ctx context.Context, config
|
||||
|
||||
// Scan data directory for changed files
|
||||
var changedFiles []ChangedFile
|
||||
|
||||
|
||||
err = filepath.Walk(config.DataDirectory, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -160,7 +160,7 @@ func (e *PostgresIncrementalEngine) CreateIncrementalBackup(ctx context.Context,
|
||||
|
||||
// Generate output filename: dbname_incr_TIMESTAMP.tar.gz
|
||||
timestamp := time.Now().Format("20060102_150405")
|
||||
outputFile := filepath.Join(filepath.Dir(config.BaseBackupPath),
|
||||
outputFile := filepath.Join(filepath.Dir(config.BaseBackupPath),
|
||||
fmt.Sprintf("%s_incr_%s.tar.gz", baseInfo.Database, timestamp))
|
||||
|
||||
e.log.Info("Creating incremental archive", "output", outputFile)
|
||||
@@ -190,19 +190,19 @@ func (e *PostgresIncrementalEngine) CreateIncrementalBackup(ctx context.Context,
|
||||
|
||||
// Create incremental metadata
|
||||
metadata := &metadata.BackupMetadata{
|
||||
Version: "2.2.0",
|
||||
Timestamp: time.Now(),
|
||||
Database: baseInfo.Database,
|
||||
DatabaseType: baseInfo.DatabaseType,
|
||||
Host: baseInfo.Host,
|
||||
Port: baseInfo.Port,
|
||||
User: baseInfo.User,
|
||||
BackupFile: outputFile,
|
||||
SizeBytes: stat.Size(),
|
||||
SHA256: checksum,
|
||||
Compression: "gzip",
|
||||
BackupType: "incremental",
|
||||
BaseBackup: filepath.Base(config.BaseBackupPath),
|
||||
Version: "2.2.0",
|
||||
Timestamp: time.Now(),
|
||||
Database: baseInfo.Database,
|
||||
DatabaseType: baseInfo.DatabaseType,
|
||||
Host: baseInfo.Host,
|
||||
Port: baseInfo.Port,
|
||||
User: baseInfo.User,
|
||||
BackupFile: outputFile,
|
||||
SizeBytes: stat.Size(),
|
||||
SHA256: checksum,
|
||||
Compression: "gzip",
|
||||
BackupType: "incremental",
|
||||
BaseBackup: filepath.Base(config.BaseBackupPath),
|
||||
Incremental: &metadata.IncrementalMetadata{
|
||||
BaseBackupID: baseInfo.SHA256,
|
||||
BaseBackupPath: filepath.Base(config.BaseBackupPath),
|
||||
@@ -329,7 +329,7 @@ func (e *PostgresIncrementalEngine) CalculateFileChecksum(path string) (string,
|
||||
// buildBackupChain constructs the backup chain from base backup to current incremental
|
||||
func buildBackupChain(baseInfo *metadata.BackupMetadata, currentBackup string) []string {
|
||||
chain := []string{}
|
||||
|
||||
|
||||
// If base backup has a chain (is itself incremental), use that
|
||||
if baseInfo.Incremental != nil && len(baseInfo.Incremental.BackupChain) > 0 {
|
||||
chain = append(chain, baseInfo.Incremental.BackupChain...)
|
||||
@@ -337,9 +337,9 @@ func buildBackupChain(baseInfo *metadata.BackupMetadata, currentBackup string) [
|
||||
// Base is a full backup, start chain with it
|
||||
chain = append(chain, filepath.Base(baseInfo.BackupFile))
|
||||
}
|
||||
|
||||
|
||||
// Add current incremental to chain
|
||||
chain = append(chain, currentBackup)
|
||||
|
||||
|
||||
return chain
|
||||
}
|
||||
|
||||
@@ -67,7 +67,7 @@ func TestIncrementalBackupRestore(t *testing.T) {
|
||||
// Step 2: Create base (full) backup
|
||||
t.Log("Step 2: Creating base backup...")
|
||||
baseBackupPath := filepath.Join(backupDir, "testdb_base.tar.gz")
|
||||
|
||||
|
||||
// Manually create base backup for testing
|
||||
baseConfig := &IncrementalBackupConfig{
|
||||
DataDirectory: dataDir,
|
||||
@@ -192,7 +192,7 @@ func TestIncrementalBackupRestore(t *testing.T) {
|
||||
|
||||
var incrementalBackupPath string
|
||||
for _, entry := range entries {
|
||||
if !entry.IsDir() && filepath.Ext(entry.Name()) == ".gz" &&
|
||||
if !entry.IsDir() && filepath.Ext(entry.Name()) == ".gz" &&
|
||||
entry.Name() != filepath.Base(baseBackupPath) {
|
||||
incrementalBackupPath = filepath.Join(backupDir, entry.Name())
|
||||
break
|
||||
@@ -209,7 +209,7 @@ func TestIncrementalBackupRestore(t *testing.T) {
|
||||
incrStat, _ := os.Stat(incrementalBackupPath)
|
||||
t.Logf("Base backup size: %d bytes", baseStat.Size())
|
||||
t.Logf("Incremental backup size: %d bytes", incrStat.Size())
|
||||
|
||||
|
||||
// Note: For tiny test files, incremental might be larger due to tar.gz overhead
|
||||
// In real-world scenarios with larger files, incremental would be much smaller
|
||||
t.Logf("Incremental contains %d changed files out of %d total",
|
||||
@@ -273,7 +273,7 @@ func TestIncrementalBackupErrors(t *testing.T) {
|
||||
// Create a dummy base backup
|
||||
baseBackupPath := filepath.Join(tempDir, "base.tar.gz")
|
||||
os.WriteFile(baseBackupPath, []byte("dummy"), 0644)
|
||||
|
||||
|
||||
// Create metadata with current timestamp
|
||||
baseMetadata := createTestMetadata("testdb", baseBackupPath, 100, "dummychecksum", "full", nil)
|
||||
saveTestMetadata(baseBackupPath, baseMetadata)
|
||||
@@ -333,7 +333,7 @@ func saveTestMetadata(backupPath string, metadata map[string]interface{}) error
|
||||
metadata["timestamp"],
|
||||
metadata["backup_type"],
|
||||
)
|
||||
|
||||
|
||||
_, err = file.WriteString(content)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ func NewDiskSpaceCache(ttl time.Duration) *DiskSpaceCache {
|
||||
if ttl <= 0 {
|
||||
ttl = 30 * time.Second // Default 30 second cache
|
||||
}
|
||||
|
||||
|
||||
return &DiskSpaceCache{
|
||||
cache: make(map[string]*cacheEntry),
|
||||
cacheTTL: ttl,
|
||||
@@ -40,17 +40,17 @@ func (c *DiskSpaceCache) Get(path string) *DiskSpaceCheck {
|
||||
}
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
|
||||
|
||||
// Cache miss or expired - perform new check
|
||||
check := CheckDiskSpace(path)
|
||||
|
||||
|
||||
c.mu.Lock()
|
||||
c.cache[path] = &cacheEntry{
|
||||
check: check,
|
||||
timestamp: time.Now(),
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
|
||||
return check
|
||||
}
|
||||
|
||||
@@ -65,7 +65,7 @@ func (c *DiskSpaceCache) Clear() {
|
||||
func (c *DiskSpaceCache) Cleanup() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
|
||||
now := time.Now()
|
||||
for path, entry := range c.cache {
|
||||
if now.Sub(entry.timestamp) >= c.cacheTTL {
|
||||
@@ -80,4 +80,4 @@ var globalDiskCache = NewDiskSpaceCache(30 * time.Second)
|
||||
// CheckDiskSpaceCached performs cached disk space check
|
||||
func CheckDiskSpaceCached(path string) *DiskSpaceCheck {
|
||||
return globalDiskCache.Get(path)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,7 +54,7 @@ func CheckDiskSpace(path string) *DiskSpaceCheck {
|
||||
func CheckDiskSpaceForRestore(path string, archiveSize int64) *DiskSpaceCheck {
|
||||
check := CheckDiskSpace(path)
|
||||
requiredBytes := uint64(archiveSize) * 4 // Account for decompression
|
||||
|
||||
|
||||
// Override status based on required space
|
||||
if check.AvailableBytes < requiredBytes {
|
||||
check.Critical = true
|
||||
@@ -64,7 +64,7 @@ func CheckDiskSpaceForRestore(path string, archiveSize int64) *DiskSpaceCheck {
|
||||
check.Warning = true
|
||||
check.Sufficient = false
|
||||
}
|
||||
|
||||
|
||||
return check
|
||||
}
|
||||
|
||||
@@ -134,7 +134,3 @@ func EstimateBackupSize(databaseSize uint64, compressionLevel int) uint64 {
|
||||
// Add 10% buffer for metadata, indexes, etc.
|
||||
return uint64(float64(estimated) * 1.1)
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ func CheckDiskSpace(path string) *DiskSpaceCheck {
|
||||
func CheckDiskSpaceForRestore(path string, archiveSize int64) *DiskSpaceCheck {
|
||||
check := CheckDiskSpace(path)
|
||||
requiredBytes := uint64(archiveSize) * 4 // Account for decompression
|
||||
|
||||
|
||||
// Override status based on required space
|
||||
if check.AvailableBytes < requiredBytes {
|
||||
check.Critical = true
|
||||
@@ -64,7 +64,7 @@ func CheckDiskSpaceForRestore(path string, archiveSize int64) *DiskSpaceCheck {
|
||||
check.Warning = true
|
||||
check.Sufficient = false
|
||||
}
|
||||
|
||||
|
||||
return check
|
||||
}
|
||||
|
||||
@@ -108,4 +108,4 @@ func FormatDiskSpaceMessage(check *DiskSpaceCheck) string {
|
||||
}
|
||||
|
||||
return msg
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,7 +37,7 @@ func CheckDiskSpace(path string) *DiskSpaceCheck {
|
||||
func CheckDiskSpaceForRestore(path string, archiveSize int64) *DiskSpaceCheck {
|
||||
check := CheckDiskSpace(path)
|
||||
requiredBytes := uint64(archiveSize) * 4 // Account for decompression
|
||||
|
||||
|
||||
// Override status based on required space
|
||||
if check.AvailableBytes < requiredBytes {
|
||||
check.Critical = true
|
||||
@@ -47,7 +47,7 @@ func CheckDiskSpaceForRestore(path string, archiveSize int64) *DiskSpaceCheck {
|
||||
check.Warning = true
|
||||
check.Sufficient = false
|
||||
}
|
||||
|
||||
|
||||
return check
|
||||
}
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ func CheckDiskSpace(path string) *DiskSpaceCheck {
|
||||
// If no volume, try current directory
|
||||
vol = "."
|
||||
}
|
||||
|
||||
|
||||
var freeBytesAvailable, totalNumberOfBytes, totalNumberOfFreeBytes uint64
|
||||
|
||||
// Call Windows API
|
||||
@@ -73,7 +73,7 @@ func CheckDiskSpace(path string) *DiskSpaceCheck {
|
||||
func CheckDiskSpaceForRestore(path string, archiveSize int64) *DiskSpaceCheck {
|
||||
check := CheckDiskSpace(path)
|
||||
requiredBytes := uint64(archiveSize) * 4 // Account for decompression
|
||||
|
||||
|
||||
// Override status based on required space
|
||||
if check.AvailableBytes < requiredBytes {
|
||||
check.Critical = true
|
||||
@@ -83,7 +83,7 @@ func CheckDiskSpaceForRestore(path string, archiveSize int64) *DiskSpaceCheck {
|
||||
check.Warning = true
|
||||
check.Sufficient = false
|
||||
}
|
||||
|
||||
|
||||
return check
|
||||
}
|
||||
|
||||
@@ -128,4 +128,3 @@ func FormatDiskSpaceMessage(check *DiskSpaceCheck) string {
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
|
||||
@@ -8,10 +8,10 @@ import (
|
||||
|
||||
// Compiled regex patterns for robust error matching
|
||||
var errorPatterns = map[string]*regexp.Regexp{
|
||||
"already_exists": regexp.MustCompile(`(?i)(already exists|duplicate key|unique constraint|relation.*exists)`),
|
||||
"disk_full": regexp.MustCompile(`(?i)(no space left|disk.*full|write.*failed.*space|insufficient.*space)`),
|
||||
"lock_exhaustion": regexp.MustCompile(`(?i)(max_locks_per_transaction|out of shared memory|lock.*exhausted|could not open large object)`),
|
||||
"syntax_error": regexp.MustCompile(`(?i)syntax error at.*line \d+`),
|
||||
"already_exists": regexp.MustCompile(`(?i)(already exists|duplicate key|unique constraint|relation.*exists)`),
|
||||
"disk_full": regexp.MustCompile(`(?i)(no space left|disk.*full|write.*failed.*space|insufficient.*space)`),
|
||||
"lock_exhaustion": regexp.MustCompile(`(?i)(max_locks_per_transaction|out of shared memory|lock.*exhausted|could not open large object)`),
|
||||
"syntax_error": regexp.MustCompile(`(?i)syntax error at.*line \d+`),
|
||||
"permission_denied": regexp.MustCompile(`(?i)(permission denied|must be owner|access denied)`),
|
||||
"connection_failed": regexp.MustCompile(`(?i)(connection refused|could not connect|no pg_hba\.conf entry)`),
|
||||
"version_mismatch": regexp.MustCompile(`(?i)(version mismatch|incompatible|unsupported version)`),
|
||||
@@ -135,9 +135,9 @@ func ClassifyError(errorMsg string) *ErrorClassification {
|
||||
}
|
||||
|
||||
// Lock exhaustion errors
|
||||
if strings.Contains(lowerMsg, "max_locks_per_transaction") ||
|
||||
strings.Contains(lowerMsg, "out of shared memory") ||
|
||||
strings.Contains(lowerMsg, "could not open large object") {
|
||||
if strings.Contains(lowerMsg, "max_locks_per_transaction") ||
|
||||
strings.Contains(lowerMsg, "out of shared memory") ||
|
||||
strings.Contains(lowerMsg, "could not open large object") {
|
||||
return &ErrorClassification{
|
||||
Type: "critical",
|
||||
Category: "locks",
|
||||
@@ -173,9 +173,9 @@ func ClassifyError(errorMsg string) *ErrorClassification {
|
||||
}
|
||||
|
||||
// Connection errors
|
||||
if strings.Contains(lowerMsg, "connection refused") ||
|
||||
strings.Contains(lowerMsg, "could not connect") ||
|
||||
strings.Contains(lowerMsg, "no pg_hba.conf entry") {
|
||||
if strings.Contains(lowerMsg, "connection refused") ||
|
||||
strings.Contains(lowerMsg, "could not connect") ||
|
||||
strings.Contains(lowerMsg, "no pg_hba.conf entry") {
|
||||
return &ErrorClassification{
|
||||
Type: "critical",
|
||||
Category: "network",
|
||||
|
||||
@@ -26,4 +26,4 @@ func formatBytes(bytes uint64) string {
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,7 +41,7 @@ func (pm *ProcessManager) Track(proc *os.Process) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
pm.processes[proc.Pid] = proc
|
||||
|
||||
|
||||
// Auto-cleanup when process exits
|
||||
go func() {
|
||||
proc.Wait()
|
||||
@@ -59,14 +59,14 @@ func (pm *ProcessManager) KillAll() error {
|
||||
procs = append(procs, proc)
|
||||
}
|
||||
pm.mu.RUnlock()
|
||||
|
||||
|
||||
var errors []error
|
||||
for _, proc := range procs {
|
||||
if err := proc.Kill(); err != nil {
|
||||
errors = append(errors, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if len(errors) > 0 {
|
||||
return fmt.Errorf("failed to kill %d processes: %v", len(errors), errors)
|
||||
}
|
||||
@@ -82,18 +82,18 @@ func (pm *ProcessManager) Close() error {
|
||||
// KillOrphanedProcesses finds and kills any orphaned pg_dump, pg_restore, gzip, or pigz processes
|
||||
func KillOrphanedProcesses(log logger.Logger) error {
|
||||
processNames := []string{"pg_dump", "pg_restore", "gzip", "pigz", "gunzip"}
|
||||
|
||||
|
||||
myPID := os.Getpid()
|
||||
var killed []string
|
||||
var errors []error
|
||||
|
||||
|
||||
for _, procName := range processNames {
|
||||
pids, err := findProcessesByName(procName, myPID)
|
||||
if err != nil {
|
||||
log.Warn("Failed to search for processes", "process", procName, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
for _, pid := range pids {
|
||||
if err := killProcessGroup(pid); err != nil {
|
||||
errors = append(errors, fmt.Errorf("failed to kill %s (PID %d): %w", procName, pid, err))
|
||||
@@ -102,15 +102,15 @@ func KillOrphanedProcesses(log logger.Logger) error {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if len(killed) > 0 {
|
||||
log.Info("Cleaned up orphaned processes", "count", len(killed), "processes", strings.Join(killed, ", "))
|
||||
}
|
||||
|
||||
|
||||
if len(errors) > 0 {
|
||||
return fmt.Errorf("some processes could not be killed: %v", errors)
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -126,27 +126,27 @@ func findProcessesByName(name string, excludePID int) ([]int, error) {
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
var pids []int
|
||||
lines := strings.Split(strings.TrimSpace(string(output)), "\n")
|
||||
for _, line := range lines {
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
pid, err := strconv.Atoi(line)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
// Don't kill our own process
|
||||
if pid == excludePID {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
pids = append(pids, pid)
|
||||
}
|
||||
|
||||
|
||||
return pids, nil
|
||||
}
|
||||
|
||||
@@ -158,17 +158,17 @@ func killProcessGroup(pid int) error {
|
||||
// Process might already be gone
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
// Kill the entire process group (negative PID kills the group)
|
||||
// This catches pipelines like "pg_dump | gzip"
|
||||
if err := syscall.Kill(-pgid, syscall.SIGTERM); err != nil {
|
||||
// If SIGTERM fails, try SIGKILL
|
||||
syscall.Kill(-pgid, syscall.SIGKILL)
|
||||
}
|
||||
|
||||
|
||||
// Also kill the specific PID in case it's not in a group
|
||||
syscall.Kill(pid, syscall.SIGTERM)
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -186,21 +186,21 @@ func KillCommandGroup(cmd *exec.Cmd) error {
|
||||
if cmd.Process == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
pid := cmd.Process.Pid
|
||||
|
||||
|
||||
// Get the process group ID
|
||||
pgid, err := syscall.Getpgid(pid)
|
||||
if err != nil {
|
||||
// Process might already be gone
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
// Kill the entire process group
|
||||
if err := syscall.Kill(-pgid, syscall.SIGTERM); err != nil {
|
||||
// If SIGTERM fails, use SIGKILL
|
||||
syscall.Kill(-pgid, syscall.SIGKILL)
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -17,18 +17,18 @@ import (
|
||||
// KillOrphanedProcesses finds and kills any orphaned pg_dump, pg_restore, gzip, or pigz processes (Windows implementation)
|
||||
func KillOrphanedProcesses(log logger.Logger) error {
|
||||
processNames := []string{"pg_dump.exe", "pg_restore.exe", "gzip.exe", "pigz.exe", "gunzip.exe"}
|
||||
|
||||
|
||||
myPID := os.Getpid()
|
||||
var killed []string
|
||||
var errors []error
|
||||
|
||||
|
||||
for _, procName := range processNames {
|
||||
pids, err := findProcessesByNameWindows(procName, myPID)
|
||||
if err != nil {
|
||||
log.Warn("Failed to search for processes", "process", procName, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
for _, pid := range pids {
|
||||
if err := killProcessWindows(pid); err != nil {
|
||||
errors = append(errors, fmt.Errorf("failed to kill %s (PID %d): %w", procName, pid, err))
|
||||
@@ -37,15 +37,15 @@ func KillOrphanedProcesses(log logger.Logger) error {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if len(killed) > 0 {
|
||||
log.Info("Cleaned up orphaned processes", "count", len(killed), "processes", strings.Join(killed, ", "))
|
||||
}
|
||||
|
||||
|
||||
if len(errors) > 0 {
|
||||
return fmt.Errorf("some processes could not be killed: %v", errors)
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -58,35 +58,35 @@ func findProcessesByNameWindows(name string, excludePID int) ([]int, error) {
|
||||
// No processes found or command failed
|
||||
return []int{}, nil
|
||||
}
|
||||
|
||||
|
||||
var pids []int
|
||||
lines := strings.Split(strings.TrimSpace(string(output)), "\n")
|
||||
for _, line := range lines {
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
// Parse CSV output: "name","pid","session","mem"
|
||||
fields := strings.Split(line, ",")
|
||||
if len(fields) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
// Remove quotes from PID field
|
||||
pidStr := strings.Trim(fields[1], `"`)
|
||||
pid, err := strconv.Atoi(pidStr)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
// Don't kill our own process
|
||||
if pid == excludePID {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
pids = append(pids, pid)
|
||||
}
|
||||
|
||||
|
||||
return pids, nil
|
||||
}
|
||||
|
||||
@@ -111,7 +111,7 @@ func KillCommandGroup(cmd *exec.Cmd) error {
|
||||
if cmd.Process == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
// On Windows, just kill the process directly
|
||||
return cmd.Process.Kill()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,22 +11,22 @@ import (
|
||||
type Backend interface {
|
||||
// Upload uploads a file to cloud storage
|
||||
Upload(ctx context.Context, localPath, remotePath string, progress ProgressCallback) error
|
||||
|
||||
|
||||
// Download downloads a file from cloud storage
|
||||
Download(ctx context.Context, remotePath, localPath string, progress ProgressCallback) error
|
||||
|
||||
|
||||
// List lists all backup files in cloud storage
|
||||
List(ctx context.Context, prefix string) ([]BackupInfo, error)
|
||||
|
||||
|
||||
// Delete deletes a file from cloud storage
|
||||
Delete(ctx context.Context, remotePath string) error
|
||||
|
||||
|
||||
// Exists checks if a file exists in cloud storage
|
||||
Exists(ctx context.Context, remotePath string) (bool, error)
|
||||
|
||||
|
||||
// GetSize returns the size of a remote file
|
||||
GetSize(ctx context.Context, remotePath string) (int64, error)
|
||||
|
||||
|
||||
// Name returns the backend name (e.g., "s3", "azure", "gcs")
|
||||
Name() string
|
||||
}
|
||||
@@ -137,10 +137,10 @@ func (c *Config) Validate() error {
|
||||
|
||||
// ProgressReader wraps an io.Reader to track progress
|
||||
type ProgressReader struct {
|
||||
reader io.Reader
|
||||
total int64
|
||||
read int64
|
||||
callback ProgressCallback
|
||||
reader io.Reader
|
||||
total int64
|
||||
read int64
|
||||
callback ProgressCallback
|
||||
lastReport time.Time
|
||||
}
|
||||
|
||||
@@ -157,7 +157,7 @@ func NewProgressReader(r io.Reader, total int64, callback ProgressCallback) *Pro
|
||||
func (pr *ProgressReader) Read(p []byte) (int, error) {
|
||||
n, err := pr.reader.Read(p)
|
||||
pr.read += int64(n)
|
||||
|
||||
|
||||
// Report progress every 100ms or when complete
|
||||
now := time.Now()
|
||||
if now.Sub(pr.lastReport) > 100*time.Millisecond || err == io.EOF {
|
||||
@@ -166,6 +166,6 @@ func (pr *ProgressReader) Read(p []byte) (int, error) {
|
||||
}
|
||||
pr.lastReport = now
|
||||
}
|
||||
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
@@ -30,11 +30,11 @@ func NewS3Backend(cfg *Config) (*S3Backend, error) {
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
|
||||
// Build AWS config
|
||||
var awsCfg aws.Config
|
||||
var err error
|
||||
|
||||
|
||||
if cfg.AccessKey != "" && cfg.SecretKey != "" {
|
||||
// Use explicit credentials
|
||||
credsProvider := credentials.NewStaticCredentialsProvider(
|
||||
@@ -42,7 +42,7 @@ func NewS3Backend(cfg *Config) (*S3Backend, error) {
|
||||
cfg.SecretKey,
|
||||
"",
|
||||
)
|
||||
|
||||
|
||||
awsCfg, err = config.LoadDefaultConfig(ctx,
|
||||
config.WithCredentialsProvider(credsProvider),
|
||||
config.WithRegion(cfg.Region),
|
||||
@@ -53,7 +53,7 @@ func NewS3Backend(cfg *Config) (*S3Backend, error) {
|
||||
config.WithRegion(cfg.Region),
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load AWS config: %w", err)
|
||||
}
|
||||
@@ -69,7 +69,7 @@ func NewS3Backend(cfg *Config) (*S3Backend, error) {
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
client := s3.NewFromConfig(awsCfg, clientOptions...)
|
||||
|
||||
return &S3Backend{
|
||||
@@ -114,7 +114,7 @@ func (s *S3Backend) Upload(ctx context.Context, localPath, remotePath string, pr
|
||||
|
||||
// Use multipart upload for files larger than 100MB
|
||||
const multipartThreshold = 100 * 1024 * 1024 // 100 MB
|
||||
|
||||
|
||||
if fileSize > multipartThreshold {
|
||||
return s.uploadMultipart(ctx, file, key, fileSize, progress)
|
||||
}
|
||||
@@ -137,7 +137,7 @@ func (s *S3Backend) uploadSimple(ctx context.Context, file *os.File, key string,
|
||||
Key: aws.String(key),
|
||||
Body: reader,
|
||||
})
|
||||
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upload to S3: %w", err)
|
||||
}
|
||||
@@ -151,10 +151,10 @@ func (s *S3Backend) uploadMultipart(ctx context.Context, file *os.File, key stri
|
||||
uploader := manager.NewUploader(s.client, func(u *manager.Uploader) {
|
||||
// Part size: 10MB
|
||||
u.PartSize = 10 * 1024 * 1024
|
||||
|
||||
|
||||
// Upload up to 10 parts concurrently
|
||||
u.Concurrency = 10
|
||||
|
||||
|
||||
// Leave parts on failure for debugging
|
||||
u.LeavePartsOnError = false
|
||||
})
|
||||
@@ -245,10 +245,10 @@ func (s *S3Backend) List(ctx context.Context, prefix string) ([]BackupInfo, erro
|
||||
if obj.Key == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
key := *obj.Key
|
||||
name := filepath.Base(key)
|
||||
|
||||
|
||||
// Skip if it's just a directory marker
|
||||
if strings.HasSuffix(key, "/") {
|
||||
continue
|
||||
@@ -260,11 +260,11 @@ func (s *S3Backend) List(ctx context.Context, prefix string) ([]BackupInfo, erro
|
||||
Size: *obj.Size,
|
||||
LastModified: *obj.LastModified,
|
||||
}
|
||||
|
||||
|
||||
if obj.ETag != nil {
|
||||
info.ETag = *obj.ETag
|
||||
}
|
||||
|
||||
|
||||
if obj.StorageClass != "" {
|
||||
info.StorageClass = string(obj.StorageClass)
|
||||
} else {
|
||||
@@ -285,7 +285,7 @@ func (s *S3Backend) Delete(ctx context.Context, remotePath string) error {
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(key),
|
||||
})
|
||||
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete object: %w", err)
|
||||
}
|
||||
@@ -301,7 +301,7 @@ func (s *S3Backend) Exists(ctx context.Context, remotePath string) (bool, error)
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(key),
|
||||
})
|
||||
|
||||
|
||||
if err != nil {
|
||||
// Check if it's a "not found" error
|
||||
if strings.Contains(err.Error(), "NotFound") || strings.Contains(err.Error(), "404") {
|
||||
@@ -321,7 +321,7 @@ func (s *S3Backend) GetSize(ctx context.Context, remotePath string) (int64, erro
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(key),
|
||||
})
|
||||
|
||||
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get object metadata: %w", err)
|
||||
}
|
||||
@@ -338,7 +338,7 @@ func (s *S3Backend) BucketExists(ctx context.Context) (bool, error) {
|
||||
_, err := s.client.HeadBucket(ctx, &s3.HeadBucketInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
})
|
||||
|
||||
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "NotFound") || strings.Contains(err.Error(), "404") {
|
||||
return false, nil
|
||||
@@ -355,7 +355,7 @@ func (s *S3Backend) CreateBucket(ctx context.Context) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
if exists {
|
||||
return nil
|
||||
}
|
||||
@@ -363,7 +363,7 @@ func (s *S3Backend) CreateBucket(ctx context.Context) error {
|
||||
_, err = s.client.CreateBucket(ctx, &s3.CreateBucketInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
})
|
||||
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create bucket: %w", err)
|
||||
}
|
||||
|
||||
@@ -76,7 +76,7 @@ func ParseCloudURI(uri string) (*CloudURI, error) {
|
||||
if len(parts) >= 3 {
|
||||
// Extract bucket name (first part)
|
||||
bucket = parts[0]
|
||||
|
||||
|
||||
// Extract region if present
|
||||
// bucket.s3.us-west-2.amazonaws.com -> us-west-2
|
||||
// bucket.s3-us-west-2.amazonaws.com -> us-west-2
|
||||
|
||||
@@ -45,11 +45,11 @@ type Config struct {
|
||||
SampleValue int
|
||||
|
||||
// Output options
|
||||
NoColor bool
|
||||
Debug bool
|
||||
LogLevel string
|
||||
LogFormat string
|
||||
|
||||
NoColor bool
|
||||
Debug bool
|
||||
LogLevel string
|
||||
LogFormat string
|
||||
|
||||
// Config persistence
|
||||
NoSaveConfig bool
|
||||
NoLoadConfig bool
|
||||
@@ -194,11 +194,11 @@ func New() *Config {
|
||||
AutoSwap: getEnvBool("AUTO_SWAP", false),
|
||||
|
||||
// Security defaults (MEDIUM priority)
|
||||
RetentionDays: getEnvInt("RETENTION_DAYS", 30), // Keep backups for 30 days
|
||||
MinBackups: getEnvInt("MIN_BACKUPS", 5), // Keep at least 5 backups
|
||||
MaxRetries: getEnvInt("MAX_RETRIES", 3), // Maximum 3 retry attempts
|
||||
AllowRoot: getEnvBool("ALLOW_ROOT", false), // Disallow root by default
|
||||
CheckResources: getEnvBool("CHECK_RESOURCES", true), // Check resources by default
|
||||
RetentionDays: getEnvInt("RETENTION_DAYS", 30), // Keep backups for 30 days
|
||||
MinBackups: getEnvInt("MIN_BACKUPS", 5), // Keep at least 5 backups
|
||||
MaxRetries: getEnvInt("MAX_RETRIES", 3), // Maximum 3 retry attempts
|
||||
AllowRoot: getEnvBool("ALLOW_ROOT", false), // Disallow root by default
|
||||
CheckResources: getEnvBool("CHECK_RESOURCES", true), // Check resources by default
|
||||
|
||||
// TUI automation defaults (for testing)
|
||||
TUIAutoSelect: getEnvInt("TUI_AUTO_SELECT", -1), // -1 = disabled
|
||||
|
||||
@@ -39,7 +39,7 @@ type LocalConfig struct {
|
||||
// LoadLocalConfig loads configuration from .dbbackup.conf in current directory
|
||||
func LoadLocalConfig() (*LocalConfig, error) {
|
||||
configPath := filepath.Join(".", ConfigFileName)
|
||||
|
||||
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
@@ -54,7 +54,7 @@ func LoadLocalConfig() (*LocalConfig, error) {
|
||||
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
|
||||
// Skip empty lines and comments
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
@@ -143,7 +143,7 @@ func LoadLocalConfig() (*LocalConfig, error) {
|
||||
// SaveLocalConfig saves configuration to .dbbackup.conf in current directory
|
||||
func SaveLocalConfig(cfg *LocalConfig) error {
|
||||
var sb strings.Builder
|
||||
|
||||
|
||||
sb.WriteString("# dbbackup configuration\n")
|
||||
sb.WriteString("# This file is auto-generated. Edit with care.\n\n")
|
||||
|
||||
|
||||
@@ -1,24 +1,24 @@
|
||||
package cpu
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"os"
|
||||
"os/exec"
|
||||
"bufio"
|
||||
)
|
||||
|
||||
// CPUInfo holds information about the system CPU
|
||||
type CPUInfo struct {
|
||||
LogicalCores int `json:"logical_cores"`
|
||||
PhysicalCores int `json:"physical_cores"`
|
||||
Architecture string `json:"architecture"`
|
||||
ModelName string `json:"model_name"`
|
||||
MaxFrequency float64 `json:"max_frequency_mhz"`
|
||||
CacheSize string `json:"cache_size"`
|
||||
Vendor string `json:"vendor"`
|
||||
LogicalCores int `json:"logical_cores"`
|
||||
PhysicalCores int `json:"physical_cores"`
|
||||
Architecture string `json:"architecture"`
|
||||
ModelName string `json:"model_name"`
|
||||
MaxFrequency float64 `json:"max_frequency_mhz"`
|
||||
CacheSize string `json:"cache_size"`
|
||||
Vendor string `json:"vendor"`
|
||||
Features []string `json:"features"`
|
||||
}
|
||||
|
||||
@@ -78,7 +78,7 @@ func (d *Detector) detectLinux(info *CPUInfo) error {
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
physicalCoreCount := make(map[string]bool)
|
||||
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.TrimSpace(line) == "" {
|
||||
@@ -324,11 +324,11 @@ func (d *Detector) GetCPUInfo() *CPUInfo {
|
||||
// FormatCPUInfo returns a formatted string representation of CPU info
|
||||
func (info *CPUInfo) FormatCPUInfo() string {
|
||||
var sb strings.Builder
|
||||
|
||||
|
||||
sb.WriteString(fmt.Sprintf("Architecture: %s\n", info.Architecture))
|
||||
sb.WriteString(fmt.Sprintf("Logical Cores: %d\n", info.LogicalCores))
|
||||
sb.WriteString(fmt.Sprintf("Physical Cores: %d\n", info.PhysicalCores))
|
||||
|
||||
|
||||
if info.ModelName != "" {
|
||||
sb.WriteString(fmt.Sprintf("Model: %s\n", info.ModelName))
|
||||
}
|
||||
@@ -341,6 +341,6 @@ func (info *CPUInfo) FormatCPUInfo() string {
|
||||
if info.CacheSize != "" {
|
||||
sb.WriteString(fmt.Sprintf("Cache Size: %s\n", info.CacheSize))
|
||||
}
|
||||
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,9 +8,9 @@ import (
|
||||
|
||||
"dbbackup/internal/config"
|
||||
"dbbackup/internal/logger"
|
||||
|
||||
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver (pgx - high performance)
|
||||
_ "github.com/go-sql-driver/mysql" // MySQL driver
|
||||
|
||||
_ "github.com/go-sql-driver/mysql" // MySQL driver
|
||||
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver (pgx - high performance)
|
||||
)
|
||||
|
||||
// Database represents a database connection and operations
|
||||
@@ -19,43 +19,43 @@ type Database interface {
|
||||
Connect(ctx context.Context) error
|
||||
Close() error
|
||||
Ping(ctx context.Context) error
|
||||
|
||||
|
||||
// Database discovery
|
||||
ListDatabases(ctx context.Context) ([]string, error)
|
||||
ListTables(ctx context.Context, database string) ([]string, error)
|
||||
|
||||
|
||||
// Database operations
|
||||
CreateDatabase(ctx context.Context, name string) error
|
||||
DropDatabase(ctx context.Context, name string) error
|
||||
DatabaseExists(ctx context.Context, name string) (bool, error)
|
||||
|
||||
|
||||
// Information
|
||||
GetVersion(ctx context.Context) (string, error)
|
||||
GetDatabaseSize(ctx context.Context, database string) (int64, error)
|
||||
GetTableRowCount(ctx context.Context, database, table string) (int64, error)
|
||||
|
||||
|
||||
// Backup/Restore command building
|
||||
BuildBackupCommand(database, outputFile string, options BackupOptions) []string
|
||||
BuildRestoreCommand(database, inputFile string, options RestoreOptions) []string
|
||||
BuildSampleQuery(database, table string, strategy SampleStrategy) string
|
||||
|
||||
|
||||
// Validation
|
||||
ValidateBackupTools() error
|
||||
}
|
||||
|
||||
// BackupOptions holds options for backup operations
|
||||
type BackupOptions struct {
|
||||
Compression int
|
||||
Parallel int
|
||||
Format string // "custom", "plain", "directory"
|
||||
Blobs bool
|
||||
SchemaOnly bool
|
||||
DataOnly bool
|
||||
NoOwner bool
|
||||
NoPrivileges bool
|
||||
Clean bool
|
||||
IfExists bool
|
||||
Role string
|
||||
Compression int
|
||||
Parallel int
|
||||
Format string // "custom", "plain", "directory"
|
||||
Blobs bool
|
||||
SchemaOnly bool
|
||||
DataOnly bool
|
||||
NoOwner bool
|
||||
NoPrivileges bool
|
||||
Clean bool
|
||||
IfExists bool
|
||||
Role string
|
||||
}
|
||||
|
||||
// RestoreOptions holds options for restore operations
|
||||
@@ -77,12 +77,12 @@ type SampleStrategy struct {
|
||||
|
||||
// DatabaseInfo holds database metadata
|
||||
type DatabaseInfo struct {
|
||||
Name string
|
||||
Size int64
|
||||
Owner string
|
||||
Encoding string
|
||||
Collation string
|
||||
Tables []TableInfo
|
||||
Name string
|
||||
Size int64
|
||||
Owner string
|
||||
Encoding string
|
||||
Collation string
|
||||
Tables []TableInfo
|
||||
}
|
||||
|
||||
// TableInfo holds table metadata
|
||||
@@ -105,10 +105,10 @@ func New(cfg *config.Config, log logger.Logger) (Database, error) {
|
||||
|
||||
// Common database implementation
|
||||
type baseDatabase struct {
|
||||
cfg *config.Config
|
||||
log logger.Logger
|
||||
db *sql.DB
|
||||
dsn string
|
||||
cfg *config.Config
|
||||
log logger.Logger
|
||||
db *sql.DB
|
||||
dsn string
|
||||
}
|
||||
|
||||
func (b *baseDatabase) Close() error {
|
||||
@@ -131,4 +131,4 @@ func buildTimeout(ctx context.Context, timeout time.Duration) (context.Context,
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
return context.WithTimeout(ctx, timeout)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -387,7 +387,7 @@ func (m *MySQL) buildDSN() string {
|
||||
"/tmp/mysql.sock",
|
||||
"/var/lib/mysql/mysql.sock",
|
||||
}
|
||||
|
||||
|
||||
// Use the first available socket path, fallback to TCP if none found
|
||||
socketFound := false
|
||||
for _, socketPath := range socketPaths {
|
||||
@@ -397,7 +397,7 @@ func (m *MySQL) buildDSN() string {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// If no socket found, use TCP localhost
|
||||
if !socketFound {
|
||||
dsn += "tcp(localhost:" + strconv.Itoa(m.cfg.Port) + ")"
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
"dbbackup/internal/auth"
|
||||
"dbbackup/internal/config"
|
||||
"dbbackup/internal/logger"
|
||||
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/jackc/pgx/v5/stdlib"
|
||||
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver (pgx)
|
||||
@@ -43,51 +43,51 @@ func (p *PostgreSQL) Connect(ctx context.Context) error {
|
||||
p.log.Debug("Loaded password from .pgpass file")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Check for authentication mismatch before attempting connection
|
||||
if mismatch, msg := auth.CheckAuthenticationMismatch(p.cfg); mismatch {
|
||||
fmt.Println(msg)
|
||||
return fmt.Errorf("authentication configuration required")
|
||||
}
|
||||
|
||||
|
||||
// Build PostgreSQL DSN (pgx format)
|
||||
dsn := p.buildPgxDSN()
|
||||
p.dsn = dsn
|
||||
|
||||
|
||||
p.log.Debug("Connecting to PostgreSQL with pgx", "dsn", sanitizeDSN(dsn))
|
||||
|
||||
|
||||
// Parse config with optimizations for large databases
|
||||
config, err := pgxpool.ParseConfig(dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse pgx config: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Optimize connection pool for backup workloads
|
||||
config.MaxConns = 10 // Max concurrent connections
|
||||
config.MinConns = 2 // Keep minimum connections ready
|
||||
config.MaxConnLifetime = 0 // No limit on connection lifetime
|
||||
config.MaxConnIdleTime = 0 // No idle timeout
|
||||
config.HealthCheckPeriod = 1 * time.Minute // Health check every minute
|
||||
|
||||
config.MaxConns = 10 // Max concurrent connections
|
||||
config.MinConns = 2 // Keep minimum connections ready
|
||||
config.MaxConnLifetime = 0 // No limit on connection lifetime
|
||||
config.MaxConnIdleTime = 0 // No idle timeout
|
||||
config.HealthCheckPeriod = 1 * time.Minute // Health check every minute
|
||||
|
||||
// Optimize for large query results (BLOB data)
|
||||
config.ConnConfig.RuntimeParams["work_mem"] = "64MB"
|
||||
config.ConnConfig.RuntimeParams["maintenance_work_mem"] = "256MB"
|
||||
|
||||
|
||||
// Create connection pool
|
||||
pool, err := pgxpool.NewWithConfig(ctx, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create pgx pool: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Test connection
|
||||
if err := pool.Ping(ctx); err != nil {
|
||||
pool.Close()
|
||||
return fmt.Errorf("failed to ping PostgreSQL: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Also create stdlib connection for compatibility
|
||||
db := stdlib.OpenDBFromPool(pool)
|
||||
|
||||
|
||||
p.pool = pool
|
||||
p.db = db
|
||||
p.log.Info("Connected to PostgreSQL successfully", "driver", "pgx", "max_conns", config.MaxConns)
|
||||
@@ -111,17 +111,17 @@ func (p *PostgreSQL) ListDatabases(ctx context.Context) ([]string, error) {
|
||||
if p.db == nil {
|
||||
return nil, fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
|
||||
query := `SELECT datname FROM pg_database
|
||||
WHERE datistemplate = false
|
||||
ORDER BY datname`
|
||||
|
||||
|
||||
rows, err := p.db.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query databases: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
|
||||
var databases []string
|
||||
for rows.Next() {
|
||||
var name string
|
||||
@@ -130,7 +130,7 @@ func (p *PostgreSQL) ListDatabases(ctx context.Context) ([]string, error) {
|
||||
}
|
||||
databases = append(databases, name)
|
||||
}
|
||||
|
||||
|
||||
return databases, rows.Err()
|
||||
}
|
||||
|
||||
@@ -139,18 +139,18 @@ func (p *PostgreSQL) ListTables(ctx context.Context, database string) ([]string,
|
||||
if p.db == nil {
|
||||
return nil, fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
|
||||
query := `SELECT schemaname||'.'||tablename as full_name
|
||||
FROM pg_tables
|
||||
WHERE schemaname NOT IN ('information_schema', 'pg_catalog', 'pg_toast')
|
||||
ORDER BY schemaname, tablename`
|
||||
|
||||
|
||||
rows, err := p.db.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query tables: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
|
||||
var tables []string
|
||||
for rows.Next() {
|
||||
var name string
|
||||
@@ -159,7 +159,7 @@ func (p *PostgreSQL) ListTables(ctx context.Context, database string) ([]string,
|
||||
}
|
||||
tables = append(tables, name)
|
||||
}
|
||||
|
||||
|
||||
return tables, rows.Err()
|
||||
}
|
||||
|
||||
@@ -168,14 +168,14 @@ func (p *PostgreSQL) CreateDatabase(ctx context.Context, name string) error {
|
||||
if p.db == nil {
|
||||
return fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
|
||||
// PostgreSQL doesn't support CREATE DATABASE in transactions or prepared statements
|
||||
query := fmt.Sprintf("CREATE DATABASE %s", name)
|
||||
_, err := p.db.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create database %s: %w", name, err)
|
||||
}
|
||||
|
||||
|
||||
p.log.Info("Created database", "name", name)
|
||||
return nil
|
||||
}
|
||||
@@ -185,14 +185,14 @@ func (p *PostgreSQL) DropDatabase(ctx context.Context, name string) error {
|
||||
if p.db == nil {
|
||||
return fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
|
||||
// Force drop connections and drop database
|
||||
query := fmt.Sprintf("DROP DATABASE IF EXISTS %s", name)
|
||||
_, err := p.db.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to drop database %s: %w", name, err)
|
||||
}
|
||||
|
||||
|
||||
p.log.Info("Dropped database", "name", name)
|
||||
return nil
|
||||
}
|
||||
@@ -202,14 +202,14 @@ func (p *PostgreSQL) DatabaseExists(ctx context.Context, name string) (bool, err
|
||||
if p.db == nil {
|
||||
return false, fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
|
||||
query := `SELECT EXISTS(SELECT 1 FROM pg_database WHERE datname = $1)`
|
||||
var exists bool
|
||||
err := p.db.QueryRowContext(ctx, query, name).Scan(&exists)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check database existence: %w", err)
|
||||
}
|
||||
|
||||
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
@@ -218,13 +218,13 @@ func (p *PostgreSQL) GetVersion(ctx context.Context) (string, error) {
|
||||
if p.db == nil {
|
||||
return "", fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
|
||||
var version string
|
||||
err := p.db.QueryRowContext(ctx, "SELECT version()").Scan(&version)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get version: %w", err)
|
||||
}
|
||||
|
||||
|
||||
return version, nil
|
||||
}
|
||||
|
||||
@@ -233,14 +233,14 @@ func (p *PostgreSQL) GetDatabaseSize(ctx context.Context, database string) (int6
|
||||
if p.db == nil {
|
||||
return 0, fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
|
||||
query := `SELECT pg_database_size($1)`
|
||||
var size int64
|
||||
err := p.db.QueryRowContext(ctx, query, database).Scan(&size)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get database size: %w", err)
|
||||
}
|
||||
|
||||
|
||||
return size, nil
|
||||
}
|
||||
|
||||
@@ -249,16 +249,16 @@ func (p *PostgreSQL) GetTableRowCount(ctx context.Context, database, table strin
|
||||
if p.db == nil {
|
||||
return 0, fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
|
||||
// Use pg_stat_user_tables for approximate count (faster)
|
||||
parts := strings.Split(table, ".")
|
||||
if len(parts) != 2 {
|
||||
return 0, fmt.Errorf("table name must be in format schema.table")
|
||||
}
|
||||
|
||||
|
||||
query := `SELECT COALESCE(n_tup_ins, 0) FROM pg_stat_user_tables
|
||||
WHERE schemaname = $1 AND relname = $2`
|
||||
|
||||
|
||||
var count int64
|
||||
err := p.db.QueryRowContext(ctx, query, parts[0], parts[1]).Scan(&count)
|
||||
if err != nil {
|
||||
@@ -269,14 +269,14 @@ func (p *PostgreSQL) GetTableRowCount(ctx context.Context, database, table strin
|
||||
return 0, fmt.Errorf("failed to get table row count: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// BuildBackupCommand builds pg_dump command
|
||||
func (p *PostgreSQL) BuildBackupCommand(database, outputFile string, options BackupOptions) []string {
|
||||
cmd := []string{"pg_dump"}
|
||||
|
||||
|
||||
// Connection parameters
|
||||
if p.cfg.Host != "localhost" {
|
||||
cmd = append(cmd, "-h", p.cfg.Host)
|
||||
@@ -284,27 +284,27 @@ func (p *PostgreSQL) BuildBackupCommand(database, outputFile string, options Bac
|
||||
cmd = append(cmd, "--no-password")
|
||||
}
|
||||
cmd = append(cmd, "-U", p.cfg.User)
|
||||
|
||||
|
||||
// Format and compression
|
||||
if options.Format != "" {
|
||||
cmd = append(cmd, "--format="+options.Format)
|
||||
} else {
|
||||
cmd = append(cmd, "--format=custom")
|
||||
}
|
||||
|
||||
|
||||
// For plain format with compression==0, we want to stream to stdout so external
|
||||
// compression can be used. Set a marker flag so caller knows to pipe stdout.
|
||||
usesStdout := (options.Format == "plain" && options.Compression == 0)
|
||||
|
||||
|
||||
if options.Compression > 0 {
|
||||
cmd = append(cmd, "--compress="+strconv.Itoa(options.Compression))
|
||||
}
|
||||
|
||||
|
||||
// Parallel jobs (only for directory format)
|
||||
if options.Parallel > 1 && options.Format == "directory" {
|
||||
cmd = append(cmd, "--jobs="+strconv.Itoa(options.Parallel))
|
||||
}
|
||||
|
||||
|
||||
// Options
|
||||
if options.Blobs {
|
||||
cmd = append(cmd, "--blobs")
|
||||
@@ -324,23 +324,23 @@ func (p *PostgreSQL) BuildBackupCommand(database, outputFile string, options Bac
|
||||
if options.Role != "" {
|
||||
cmd = append(cmd, "--role="+options.Role)
|
||||
}
|
||||
|
||||
|
||||
// Database
|
||||
cmd = append(cmd, "--dbname="+database)
|
||||
|
||||
|
||||
// Output: For plain format with external compression, omit --file so pg_dump
|
||||
// writes to stdout (caller will pipe to compressor). Otherwise specify output file.
|
||||
if !usesStdout {
|
||||
cmd = append(cmd, "--file="+outputFile)
|
||||
}
|
||||
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
// BuildRestoreCommand builds pg_restore command
|
||||
func (p *PostgreSQL) BuildRestoreCommand(database, inputFile string, options RestoreOptions) []string {
|
||||
cmd := []string{"pg_restore"}
|
||||
|
||||
|
||||
// Connection parameters
|
||||
if p.cfg.Host != "localhost" {
|
||||
cmd = append(cmd, "-h", p.cfg.Host)
|
||||
@@ -348,12 +348,12 @@ func (p *PostgreSQL) BuildRestoreCommand(database, inputFile string, options Res
|
||||
cmd = append(cmd, "--no-password")
|
||||
}
|
||||
cmd = append(cmd, "-U", p.cfg.User)
|
||||
|
||||
|
||||
// Parallel jobs (incompatible with --single-transaction per PostgreSQL docs)
|
||||
if options.Parallel > 1 && !options.SingleTransaction {
|
||||
cmd = append(cmd, "--jobs="+strconv.Itoa(options.Parallel))
|
||||
}
|
||||
|
||||
|
||||
// Options
|
||||
if options.Clean {
|
||||
cmd = append(cmd, "--clean")
|
||||
@@ -370,23 +370,23 @@ func (p *PostgreSQL) BuildRestoreCommand(database, inputFile string, options Res
|
||||
if options.SingleTransaction {
|
||||
cmd = append(cmd, "--single-transaction")
|
||||
}
|
||||
|
||||
|
||||
// NOTE: --exit-on-error removed because it causes entire restore to fail on
|
||||
// "already exists" errors. PostgreSQL continues on ignorable errors by default
|
||||
// and reports error count at the end, which is correct behavior for restores.
|
||||
|
||||
|
||||
// Skip data restore if table creation fails (prevents duplicate data errors)
|
||||
cmd = append(cmd, "--no-data-for-failed-tables")
|
||||
|
||||
|
||||
// Add verbose flag ONLY if requested (WARNING: can cause OOM on large cluster restores)
|
||||
if options.Verbose {
|
||||
cmd = append(cmd, "--verbose")
|
||||
}
|
||||
|
||||
|
||||
// Database and input
|
||||
cmd = append(cmd, "--dbname="+database)
|
||||
cmd = append(cmd, inputFile)
|
||||
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
@@ -395,7 +395,7 @@ func (p *PostgreSQL) BuildSampleQuery(database, table string, strategy SampleStr
|
||||
switch strategy.Type {
|
||||
case "ratio":
|
||||
// Every Nth record using row_number
|
||||
return fmt.Sprintf("SELECT * FROM (SELECT *, row_number() OVER () as rn FROM %s) t WHERE rn %% %d = 1",
|
||||
return fmt.Sprintf("SELECT * FROM (SELECT *, row_number() OVER () as rn FROM %s) t WHERE rn %% %d = 1",
|
||||
table, strategy.Value)
|
||||
case "percent":
|
||||
// Percentage sampling using TABLESAMPLE (PostgreSQL 9.5+)
|
||||
@@ -411,24 +411,24 @@ func (p *PostgreSQL) BuildSampleQuery(database, table string, strategy SampleStr
|
||||
// ValidateBackupTools checks if required PostgreSQL tools are available
|
||||
func (p *PostgreSQL) ValidateBackupTools() error {
|
||||
tools := []string{"pg_dump", "pg_restore", "pg_dumpall", "psql"}
|
||||
|
||||
|
||||
for _, tool := range tools {
|
||||
if _, err := exec.LookPath(tool); err != nil {
|
||||
return fmt.Errorf("required tool not found: %s", tool)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildDSN constructs PostgreSQL connection string
|
||||
func (p *PostgreSQL) buildDSN() string {
|
||||
dsn := fmt.Sprintf("user=%s dbname=%s", p.cfg.User, p.cfg.Database)
|
||||
|
||||
|
||||
if p.cfg.Password != "" {
|
||||
dsn += " password=" + p.cfg.Password
|
||||
}
|
||||
|
||||
|
||||
// For localhost connections, try socket first for peer auth
|
||||
if p.cfg.Host == "localhost" && p.cfg.Password == "" {
|
||||
// Try Unix socket connection for peer authentication
|
||||
@@ -438,7 +438,7 @@ func (p *PostgreSQL) buildDSN() string {
|
||||
"/tmp",
|
||||
"/var/lib/pgsql",
|
||||
}
|
||||
|
||||
|
||||
for _, dir := range socketDirs {
|
||||
socketPath := fmt.Sprintf("%s/.s.PGSQL.%d", dir, p.cfg.Port)
|
||||
if _, err := os.Stat(socketPath); err == nil {
|
||||
@@ -452,7 +452,7 @@ func (p *PostgreSQL) buildDSN() string {
|
||||
dsn += " host=" + p.cfg.Host
|
||||
dsn += " port=" + strconv.Itoa(p.cfg.Port)
|
||||
}
|
||||
|
||||
|
||||
if p.cfg.SSLMode != "" && !p.cfg.Insecure {
|
||||
// Map SSL modes to supported values for lib/pq
|
||||
switch strings.ToLower(p.cfg.SSLMode) {
|
||||
@@ -472,7 +472,7 @@ func (p *PostgreSQL) buildDSN() string {
|
||||
} else if p.cfg.Insecure {
|
||||
dsn += " sslmode=disable"
|
||||
}
|
||||
|
||||
|
||||
return dsn
|
||||
}
|
||||
|
||||
@@ -480,7 +480,7 @@ func (p *PostgreSQL) buildDSN() string {
|
||||
func (p *PostgreSQL) buildPgxDSN() string {
|
||||
// pgx supports both URL and keyword=value formats
|
||||
// Use keyword format for Unix sockets, URL for TCP
|
||||
|
||||
|
||||
// Try Unix socket first for localhost without password
|
||||
if p.cfg.Host == "localhost" && p.cfg.Password == "" {
|
||||
socketDirs := []string{
|
||||
@@ -488,7 +488,7 @@ func (p *PostgreSQL) buildPgxDSN() string {
|
||||
"/tmp",
|
||||
"/var/lib/pgsql",
|
||||
}
|
||||
|
||||
|
||||
for _, dir := range socketDirs {
|
||||
socketPath := fmt.Sprintf("%s/.s.PGSQL.%d", dir, p.cfg.Port)
|
||||
if _, err := os.Stat(socketPath); err == nil {
|
||||
@@ -500,34 +500,34 @@ func (p *PostgreSQL) buildPgxDSN() string {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Use URL format for TCP connections
|
||||
var dsn strings.Builder
|
||||
dsn.WriteString("postgres://")
|
||||
|
||||
|
||||
// User
|
||||
dsn.WriteString(p.cfg.User)
|
||||
|
||||
|
||||
// Password
|
||||
if p.cfg.Password != "" {
|
||||
dsn.WriteString(":")
|
||||
dsn.WriteString(p.cfg.Password)
|
||||
}
|
||||
|
||||
|
||||
dsn.WriteString("@")
|
||||
|
||||
|
||||
// Host and Port
|
||||
dsn.WriteString(p.cfg.Host)
|
||||
dsn.WriteString(":")
|
||||
dsn.WriteString(strconv.Itoa(p.cfg.Port))
|
||||
|
||||
|
||||
// Database
|
||||
dsn.WriteString("/")
|
||||
dsn.WriteString(p.cfg.Database)
|
||||
|
||||
|
||||
// Parameters
|
||||
params := make([]string, 0)
|
||||
|
||||
|
||||
// SSL Mode
|
||||
if p.cfg.Insecure {
|
||||
params = append(params, "sslmode=disable")
|
||||
@@ -550,21 +550,21 @@ func (p *PostgreSQL) buildPgxDSN() string {
|
||||
} else {
|
||||
params = append(params, "sslmode=prefer")
|
||||
}
|
||||
|
||||
|
||||
// Connection pool settings
|
||||
params = append(params, "pool_max_conns=10")
|
||||
params = append(params, "pool_min_conns=2")
|
||||
|
||||
|
||||
// Performance tuning for large queries
|
||||
params = append(params, "application_name=dbbackup")
|
||||
params = append(params, "connect_timeout=30")
|
||||
|
||||
|
||||
// Add parameters to DSN
|
||||
if len(params) > 0 {
|
||||
dsn.WriteString("?")
|
||||
dsn.WriteString(strings.Join(params, "&"))
|
||||
}
|
||||
|
||||
|
||||
return dsn.String()
|
||||
}
|
||||
|
||||
@@ -573,7 +573,7 @@ func sanitizeDSN(dsn string) string {
|
||||
// Simple password removal for logging
|
||||
parts := strings.Split(dsn, " ")
|
||||
var sanitized []string
|
||||
|
||||
|
||||
for _, part := range parts {
|
||||
if strings.HasPrefix(part, "password=") {
|
||||
sanitized = append(sanitized, "password=***")
|
||||
@@ -581,6 +581,6 @@ func sanitizeDSN(dsn string) string {
|
||||
sanitized = append(sanitized, part)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return strings.Join(sanitized, " ")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,38 +14,38 @@ import (
|
||||
const (
|
||||
// AES-256 requires 32-byte keys
|
||||
KeySize = 32
|
||||
|
||||
|
||||
// Nonce size for GCM
|
||||
NonceSize = 12
|
||||
|
||||
|
||||
// Salt size for key derivation
|
||||
SaltSize = 32
|
||||
|
||||
|
||||
// PBKDF2 iterations (100,000 is recommended minimum)
|
||||
PBKDF2Iterations = 100000
|
||||
|
||||
|
||||
// Magic header to identify encrypted files
|
||||
EncryptedFileMagic = "DBBACKUP_ENCRYPTED_V1"
|
||||
)
|
||||
|
||||
// EncryptionHeader stores metadata for encrypted files
|
||||
type EncryptionHeader struct {
|
||||
Magic [22]byte // "DBBACKUP_ENCRYPTED_V1" (21 bytes + null)
|
||||
Version uint8 // Version number (1)
|
||||
Algorithm uint8 // Algorithm ID (1 = AES-256-GCM)
|
||||
Salt [32]byte // Salt for key derivation
|
||||
Nonce [12]byte // GCM nonce
|
||||
Reserved [32]byte // Reserved for future use
|
||||
Magic [22]byte // "DBBACKUP_ENCRYPTED_V1" (21 bytes + null)
|
||||
Version uint8 // Version number (1)
|
||||
Algorithm uint8 // Algorithm ID (1 = AES-256-GCM)
|
||||
Salt [32]byte // Salt for key derivation
|
||||
Nonce [12]byte // GCM nonce
|
||||
Reserved [32]byte // Reserved for future use
|
||||
}
|
||||
|
||||
// EncryptionOptions configures encryption behavior
|
||||
type EncryptionOptions struct {
|
||||
// Key is the encryption key (32 bytes for AES-256)
|
||||
Key []byte
|
||||
|
||||
|
||||
// Passphrase for key derivation (alternative to direct key)
|
||||
Passphrase string
|
||||
|
||||
|
||||
// Salt for key derivation (if empty, will be generated)
|
||||
Salt []byte
|
||||
}
|
||||
@@ -79,7 +79,7 @@ func NewEncryptionWriter(w io.Writer, opts EncryptionOptions) (*EncryptionWriter
|
||||
// Derive or validate key
|
||||
var key []byte
|
||||
var salt []byte
|
||||
|
||||
|
||||
if opts.Passphrase != "" {
|
||||
// Derive key from passphrase
|
||||
if len(opts.Salt) == 0 {
|
||||
@@ -106,25 +106,25 @@ func NewEncryptionWriter(w io.Writer, opts EncryptionOptions) (*EncryptionWriter
|
||||
} else {
|
||||
return nil, fmt.Errorf("either Key or Passphrase must be provided")
|
||||
}
|
||||
|
||||
|
||||
// Create AES cipher
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Create GCM mode
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create GCM: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Generate nonce
|
||||
nonce := make([]byte, NonceSize)
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Write header
|
||||
header := EncryptionHeader{
|
||||
Version: 1,
|
||||
@@ -133,11 +133,11 @@ func NewEncryptionWriter(w io.Writer, opts EncryptionOptions) (*EncryptionWriter
|
||||
copy(header.Magic[:], []byte(EncryptedFileMagic))
|
||||
copy(header.Salt[:], salt)
|
||||
copy(header.Nonce[:], nonce)
|
||||
|
||||
|
||||
if err := writeHeader(w, &header); err != nil {
|
||||
return nil, fmt.Errorf("failed to write header: %w", err)
|
||||
}
|
||||
|
||||
|
||||
return &EncryptionWriter{
|
||||
writer: w,
|
||||
gcm: gcm,
|
||||
@@ -160,16 +160,16 @@ func (ew *EncryptionWriter) Write(p []byte) (n int, err error) {
|
||||
if ew.closed {
|
||||
return 0, fmt.Errorf("writer is closed")
|
||||
}
|
||||
|
||||
|
||||
// Accumulate data in buffer
|
||||
ew.buffer = append(ew.buffer, p...)
|
||||
|
||||
|
||||
// If buffer is large enough, encrypt and write
|
||||
const chunkSize = 64 * 1024 // 64KB chunks
|
||||
for len(ew.buffer) >= chunkSize {
|
||||
chunk := ew.buffer[:chunkSize]
|
||||
encrypted := ew.gcm.Seal(nil, ew.nonce, chunk, nil)
|
||||
|
||||
|
||||
// Write encrypted chunk size (4 bytes) then chunk
|
||||
size := uint32(len(encrypted))
|
||||
sizeBytes := []byte{
|
||||
@@ -184,15 +184,15 @@ func (ew *EncryptionWriter) Write(p []byte) (n int, err error) {
|
||||
if _, err := ew.writer.Write(encrypted); err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
|
||||
// Move remaining data to start of buffer
|
||||
ew.buffer = ew.buffer[chunkSize:]
|
||||
n += chunkSize
|
||||
|
||||
|
||||
// Increment nonce for next chunk
|
||||
incrementNonce(ew.nonce)
|
||||
}
|
||||
|
||||
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
@@ -202,11 +202,11 @@ func (ew *EncryptionWriter) Close() error {
|
||||
return nil
|
||||
}
|
||||
ew.closed = true
|
||||
|
||||
|
||||
// Encrypt and write remaining buffer
|
||||
if len(ew.buffer) > 0 {
|
||||
encrypted := ew.gcm.Seal(nil, ew.nonce, ew.buffer, nil)
|
||||
|
||||
|
||||
size := uint32(len(encrypted))
|
||||
sizeBytes := []byte{
|
||||
byte(size >> 24),
|
||||
@@ -221,12 +221,12 @@ func (ew *EncryptionWriter) Close() error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Write final zero-length chunk to signal end
|
||||
if _, err := ew.writer.Write([]byte{0, 0, 0, 0}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -237,22 +237,22 @@ func NewDecryptionReader(r io.Reader, opts EncryptionOptions) (*DecryptionReader
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read header: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Verify magic
|
||||
if string(header.Magic[:len(EncryptedFileMagic)]) != EncryptedFileMagic {
|
||||
return nil, fmt.Errorf("not an encrypted backup file")
|
||||
}
|
||||
|
||||
|
||||
// Verify version
|
||||
if header.Version != 1 {
|
||||
return nil, fmt.Errorf("unsupported encryption version: %d", header.Version)
|
||||
}
|
||||
|
||||
|
||||
// Verify algorithm
|
||||
if header.Algorithm != 1 {
|
||||
return nil, fmt.Errorf("unsupported encryption algorithm: %d", header.Algorithm)
|
||||
}
|
||||
|
||||
|
||||
// Derive or validate key
|
||||
var key []byte
|
||||
if opts.Passphrase != "" {
|
||||
@@ -265,22 +265,22 @@ func NewDecryptionReader(r io.Reader, opts EncryptionOptions) (*DecryptionReader
|
||||
} else {
|
||||
return nil, fmt.Errorf("either Key or Passphrase must be provided")
|
||||
}
|
||||
|
||||
|
||||
// Create AES cipher
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Create GCM mode
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create GCM: %w", err)
|
||||
}
|
||||
|
||||
|
||||
nonce := make([]byte, NonceSize)
|
||||
copy(nonce, header.Nonce[:])
|
||||
|
||||
|
||||
return &DecryptionReader{
|
||||
reader: r,
|
||||
gcm: gcm,
|
||||
@@ -306,12 +306,12 @@ func (dr *DecryptionReader) Read(p []byte) (n int, err error) {
|
||||
dr.buffer = dr.buffer[n:]
|
||||
return n, nil
|
||||
}
|
||||
|
||||
|
||||
// If EOF reached, return EOF
|
||||
if dr.eof {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
|
||||
// Read next chunk size
|
||||
sizeBytes := make([]byte, 4)
|
||||
if _, err := io.ReadFull(dr.reader, sizeBytes); err != nil {
|
||||
@@ -321,36 +321,36 @@ func (dr *DecryptionReader) Read(p []byte) (n int, err error) {
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
|
||||
|
||||
size := uint32(sizeBytes[0])<<24 | uint32(sizeBytes[1])<<16 | uint32(sizeBytes[2])<<8 | uint32(sizeBytes[3])
|
||||
|
||||
|
||||
// Zero-length chunk signals end of stream
|
||||
if size == 0 {
|
||||
dr.eof = true
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
|
||||
// Read encrypted chunk
|
||||
encrypted := make([]byte, size)
|
||||
if _, err := io.ReadFull(dr.reader, encrypted); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
|
||||
// Decrypt chunk
|
||||
decrypted, err := dr.gcm.Open(nil, dr.nonce, encrypted, nil)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("decryption failed (wrong key?): %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Increment nonce for next chunk
|
||||
incrementNonce(dr.nonce)
|
||||
|
||||
|
||||
// Return as much as fits in p, buffer the rest
|
||||
n = copy(p, decrypted)
|
||||
if n < len(decrypted) {
|
||||
dr.buffer = decrypted[n:]
|
||||
}
|
||||
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
@@ -364,7 +364,7 @@ func writeHeader(w io.Writer, h *EncryptionHeader) error {
|
||||
copy(data[24:56], h.Salt[:])
|
||||
copy(data[56:68], h.Nonce[:])
|
||||
copy(data[68:100], h.Reserved[:])
|
||||
|
||||
|
||||
_, err := w.Write(data)
|
||||
return err
|
||||
}
|
||||
@@ -374,7 +374,7 @@ func readHeader(r io.Reader) (*EncryptionHeader, error) {
|
||||
if _, err := io.ReadFull(r, data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
header := &EncryptionHeader{
|
||||
Version: data[22],
|
||||
Algorithm: data[23],
|
||||
@@ -383,7 +383,7 @@ func readHeader(r io.Reader) (*EncryptionHeader, error) {
|
||||
copy(header.Salt[:], data[24:56])
|
||||
copy(header.Nonce[:], data[56:68])
|
||||
copy(header.Reserved[:], data[68:100])
|
||||
|
||||
|
||||
return header, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -9,11 +9,11 @@ import (
|
||||
func TestEncryptDecrypt(t *testing.T) {
|
||||
// Test data
|
||||
original := []byte("This is a secret database backup that needs encryption! 🔒")
|
||||
|
||||
|
||||
// Test with passphrase
|
||||
t.Run("Passphrase", func(t *testing.T) {
|
||||
var encrypted bytes.Buffer
|
||||
|
||||
|
||||
// Encrypt
|
||||
writer, err := NewEncryptionWriter(&encrypted, EncryptionOptions{
|
||||
Passphrase: "super-secret-password",
|
||||
@@ -21,23 +21,23 @@ func TestEncryptDecrypt(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create encryption writer: %v", err)
|
||||
}
|
||||
|
||||
|
||||
if _, err := writer.Write(original); err != nil {
|
||||
t.Fatalf("Failed to write data: %v", err)
|
||||
}
|
||||
|
||||
|
||||
if err := writer.Close(); err != nil {
|
||||
t.Fatalf("Failed to close writer: %v", err)
|
||||
}
|
||||
|
||||
|
||||
t.Logf("Original size: %d bytes", len(original))
|
||||
t.Logf("Encrypted size: %d bytes", encrypted.Len())
|
||||
|
||||
|
||||
// Verify encrypted data is different from original
|
||||
if bytes.Contains(encrypted.Bytes(), original) {
|
||||
t.Error("Encrypted data contains plaintext - encryption failed!")
|
||||
}
|
||||
|
||||
|
||||
// Decrypt
|
||||
reader, err := NewDecryptionReader(&encrypted, EncryptionOptions{
|
||||
Passphrase: "super-secret-password",
|
||||
@@ -45,30 +45,30 @@ func TestEncryptDecrypt(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create decryption reader: %v", err)
|
||||
}
|
||||
|
||||
|
||||
decrypted, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read decrypted data: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// Verify decrypted matches original
|
||||
if !bytes.Equal(decrypted, original) {
|
||||
t.Errorf("Decrypted data doesn't match original\nOriginal: %s\nDecrypted: %s",
|
||||
string(original), string(decrypted))
|
||||
}
|
||||
|
||||
|
||||
t.Log("✅ Encryption/decryption successful")
|
||||
})
|
||||
|
||||
|
||||
// Test with direct key
|
||||
t.Run("DirectKey", func(t *testing.T) {
|
||||
key, err := GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate key: %v", err)
|
||||
}
|
||||
|
||||
|
||||
var encrypted bytes.Buffer
|
||||
|
||||
|
||||
// Encrypt
|
||||
writer, err := NewEncryptionWriter(&encrypted, EncryptionOptions{
|
||||
Key: key,
|
||||
@@ -76,15 +76,15 @@ func TestEncryptDecrypt(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create encryption writer: %v", err)
|
||||
}
|
||||
|
||||
|
||||
if _, err := writer.Write(original); err != nil {
|
||||
t.Fatalf("Failed to write data: %v", err)
|
||||
}
|
||||
|
||||
|
||||
if err := writer.Close(); err != nil {
|
||||
t.Fatalf("Failed to close writer: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// Decrypt
|
||||
reader, err := NewDecryptionReader(&encrypted, EncryptionOptions{
|
||||
Key: key,
|
||||
@@ -92,23 +92,23 @@ func TestEncryptDecrypt(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create decryption reader: %v", err)
|
||||
}
|
||||
|
||||
|
||||
decrypted, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read decrypted data: %v", err)
|
||||
}
|
||||
|
||||
|
||||
if !bytes.Equal(decrypted, original) {
|
||||
t.Errorf("Decrypted data doesn't match original")
|
||||
}
|
||||
|
||||
|
||||
t.Log("✅ Direct key encryption/decryption successful")
|
||||
})
|
||||
|
||||
|
||||
// Test wrong password
|
||||
t.Run("WrongPassword", func(t *testing.T) {
|
||||
var encrypted bytes.Buffer
|
||||
|
||||
|
||||
// Encrypt
|
||||
writer, err := NewEncryptionWriter(&encrypted, EncryptionOptions{
|
||||
Passphrase: "correct-password",
|
||||
@@ -116,10 +116,10 @@ func TestEncryptDecrypt(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create encryption writer: %v", err)
|
||||
}
|
||||
|
||||
|
||||
writer.Write(original)
|
||||
writer.Close()
|
||||
|
||||
|
||||
// Try to decrypt with wrong password
|
||||
reader, err := NewDecryptionReader(&encrypted, EncryptionOptions{
|
||||
Passphrase: "wrong-password",
|
||||
@@ -127,12 +127,12 @@ func TestEncryptDecrypt(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create decryption reader: %v", err)
|
||||
}
|
||||
|
||||
|
||||
_, err = io.ReadAll(reader)
|
||||
if err == nil {
|
||||
t.Error("Expected decryption to fail with wrong password, but it succeeded")
|
||||
}
|
||||
|
||||
|
||||
t.Logf("✅ Wrong password correctly rejected: %v", err)
|
||||
})
|
||||
}
|
||||
@@ -143,9 +143,9 @@ func TestLargeData(t *testing.T) {
|
||||
for i := range original {
|
||||
original[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
|
||||
var encrypted bytes.Buffer
|
||||
|
||||
|
||||
// Encrypt
|
||||
writer, err := NewEncryptionWriter(&encrypted, EncryptionOptions{
|
||||
Passphrase: "test-password",
|
||||
@@ -153,19 +153,19 @@ func TestLargeData(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create encryption writer: %v", err)
|
||||
}
|
||||
|
||||
|
||||
if _, err := writer.Write(original); err != nil {
|
||||
t.Fatalf("Failed to write data: %v", err)
|
||||
}
|
||||
|
||||
|
||||
if err := writer.Close(); err != nil {
|
||||
t.Fatalf("Failed to close writer: %v", err)
|
||||
}
|
||||
|
||||
|
||||
t.Logf("Original size: %d bytes", len(original))
|
||||
t.Logf("Encrypted size: %d bytes", encrypted.Len())
|
||||
t.Logf("Overhead: %.2f%%", float64(encrypted.Len()-len(original))/float64(len(original))*100)
|
||||
|
||||
|
||||
// Decrypt
|
||||
reader, err := NewDecryptionReader(&encrypted, EncryptionOptions{
|
||||
Passphrase: "test-password",
|
||||
@@ -173,16 +173,16 @@ func TestLargeData(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create decryption reader: %v", err)
|
||||
}
|
||||
|
||||
|
||||
decrypted, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read decrypted data: %v", err)
|
||||
}
|
||||
|
||||
|
||||
if !bytes.Equal(decrypted, original) {
|
||||
t.Errorf("Large data decryption failed")
|
||||
}
|
||||
|
||||
|
||||
t.Log("✅ Large data encryption/decryption successful")
|
||||
}
|
||||
|
||||
@@ -192,43 +192,43 @@ func TestKeyGeneration(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate key: %v", err)
|
||||
}
|
||||
|
||||
|
||||
if len(key1) != KeySize {
|
||||
t.Errorf("Key size mismatch: expected %d, got %d", KeySize, len(key1))
|
||||
}
|
||||
|
||||
|
||||
// Generate another key and verify it's different
|
||||
key2, err := GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate second key: %v", err)
|
||||
}
|
||||
|
||||
|
||||
if bytes.Equal(key1, key2) {
|
||||
t.Error("Generated keys are identical - randomness broken!")
|
||||
}
|
||||
|
||||
|
||||
t.Log("✅ Key generation successful")
|
||||
}
|
||||
|
||||
func TestKeyDerivation(t *testing.T) {
|
||||
passphrase := "my-secret-passphrase"
|
||||
salt1, _ := GenerateSalt()
|
||||
|
||||
|
||||
// Derive key twice with same salt - should be identical
|
||||
key1 := DeriveKey(passphrase, salt1)
|
||||
key2 := DeriveKey(passphrase, salt1)
|
||||
|
||||
|
||||
if !bytes.Equal(key1, key2) {
|
||||
t.Error("Key derivation not deterministic")
|
||||
}
|
||||
|
||||
|
||||
// Derive with different salt - should be different
|
||||
salt2, _ := GenerateSalt()
|
||||
key3 := DeriveKey(passphrase, salt2)
|
||||
|
||||
|
||||
if bytes.Equal(key1, key3) {
|
||||
t.Error("Different salts produced same key")
|
||||
}
|
||||
|
||||
|
||||
t.Log("✅ Key derivation successful")
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ type Logger interface {
|
||||
Info(msg string, keysAndValues ...interface{})
|
||||
Warn(msg string, keysAndValues ...interface{})
|
||||
Error(msg string, keysAndValues ...interface{})
|
||||
|
||||
|
||||
// Structured logging methods
|
||||
WithFields(fields map[string]interface{}) Logger
|
||||
WithField(key string, value interface{}) Logger
|
||||
@@ -113,7 +113,7 @@ func (l *logger) Error(msg string, args ...any) {
|
||||
}
|
||||
|
||||
func (l *logger) Time(msg string, args ...any) {
|
||||
// Time logs are always at info level with special formatting
|
||||
// Time logs are always at info level with special formatting
|
||||
l.logWithFields(logrus.InfoLevel, "[TIME] "+msg, args...)
|
||||
}
|
||||
|
||||
@@ -225,7 +225,7 @@ type CleanFormatter struct{}
|
||||
// Format implements logrus.Formatter interface
|
||||
func (f *CleanFormatter) Format(entry *logrus.Entry) ([]byte, error) {
|
||||
timestamp := entry.Time.Format("2006-01-02T15:04:05")
|
||||
|
||||
|
||||
// Color codes for different log levels
|
||||
var levelColor, levelText string
|
||||
switch entry.Level {
|
||||
@@ -246,24 +246,24 @@ func (f *CleanFormatter) Format(entry *logrus.Entry) ([]byte, error) {
|
||||
levelText = "INFO "
|
||||
}
|
||||
resetColor := "\033[0m"
|
||||
|
||||
|
||||
// Build the message with perfectly aligned columns
|
||||
var output strings.Builder
|
||||
|
||||
|
||||
// Column 1: Level (with color, fixed width 5 chars)
|
||||
output.WriteString(levelColor)
|
||||
output.WriteString(levelText)
|
||||
output.WriteString(resetColor)
|
||||
output.WriteString(" ")
|
||||
|
||||
|
||||
// Column 2: Timestamp (fixed format)
|
||||
output.WriteString("[")
|
||||
output.WriteString(timestamp)
|
||||
output.WriteString("] ")
|
||||
|
||||
|
||||
// Column 3: Message
|
||||
output.WriteString(entry.Message)
|
||||
|
||||
|
||||
// Append important fields in a clean format (skip internal/redundant fields)
|
||||
if len(entry.Data) > 0 {
|
||||
// Only show truly important fields, skip verbose ones
|
||||
@@ -272,7 +272,7 @@ func (f *CleanFormatter) Format(entry *logrus.Entry) ([]byte, error) {
|
||||
if k == "elapsed" || k == "operation_id" || k == "step" || k == "timestamp" || k == "message" {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
// Format duration nicely at the end
|
||||
if k == "duration" {
|
||||
if str, ok := v.(string); ok {
|
||||
@@ -280,14 +280,14 @@ func (f *CleanFormatter) Format(entry *logrus.Entry) ([]byte, error) {
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
// Only show critical fields (driver, errors, etc)
|
||||
if k == "driver" || k == "max_conns" || k == "error" || k == "database" {
|
||||
output.WriteString(fmt.Sprintf(" %s=%v", k, v))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
output.WriteString("\n")
|
||||
return []byte(output.String()), nil
|
||||
}
|
||||
|
||||
@@ -29,11 +29,11 @@ type BackupMetadata struct {
|
||||
BaseBackup string `json:"base_backup,omitempty"`
|
||||
Duration float64 `json:"duration_seconds"`
|
||||
ExtraInfo map[string]string `json:"extra_info,omitempty"`
|
||||
|
||||
|
||||
// Encryption fields (v2.3+)
|
||||
Encrypted bool `json:"encrypted"` // Whether backup is encrypted
|
||||
EncryptionAlgorithm string `json:"encryption_algorithm,omitempty"` // e.g., "aes-256-gcm"
|
||||
|
||||
|
||||
// Incremental backup fields (v2.2+)
|
||||
Incremental *IncrementalMetadata `json:"incremental,omitempty"` // Only present for incremental backups
|
||||
}
|
||||
@@ -50,16 +50,16 @@ type IncrementalMetadata struct {
|
||||
|
||||
// ClusterMetadata contains metadata for cluster backups
|
||||
type ClusterMetadata struct {
|
||||
Version string `json:"version"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
ClusterName string `json:"cluster_name"`
|
||||
DatabaseType string `json:"database_type"`
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
Databases []BackupMetadata `json:"databases"`
|
||||
TotalSize int64 `json:"total_size_bytes"`
|
||||
Duration float64 `json:"duration_seconds"`
|
||||
ExtraInfo map[string]string `json:"extra_info,omitempty"`
|
||||
Version string `json:"version"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
ClusterName string `json:"cluster_name"`
|
||||
DatabaseType string `json:"database_type"`
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
Databases []BackupMetadata `json:"databases"`
|
||||
TotalSize int64 `json:"total_size_bytes"`
|
||||
Duration float64 `json:"duration_seconds"`
|
||||
ExtraInfo map[string]string `json:"extra_info,omitempty"`
|
||||
}
|
||||
|
||||
// CalculateSHA256 computes the SHA-256 checksum of a file
|
||||
@@ -81,7 +81,7 @@ func CalculateSHA256(filePath string) (string, error) {
|
||||
// Save writes metadata to a .meta.json file
|
||||
func (m *BackupMetadata) Save() error {
|
||||
metaPath := m.BackupFile + ".meta.json"
|
||||
|
||||
|
||||
data, err := json.MarshalIndent(m, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal metadata: %w", err)
|
||||
@@ -97,7 +97,7 @@ func (m *BackupMetadata) Save() error {
|
||||
// Load reads metadata from a .meta.json file
|
||||
func Load(backupFile string) (*BackupMetadata, error) {
|
||||
metaPath := backupFile + ".meta.json"
|
||||
|
||||
|
||||
data, err := os.ReadFile(metaPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read metadata file: %w", err)
|
||||
@@ -114,7 +114,7 @@ func Load(backupFile string) (*BackupMetadata, error) {
|
||||
// SaveCluster writes cluster metadata to a .meta.json file
|
||||
func (m *ClusterMetadata) Save(targetFile string) error {
|
||||
metaPath := targetFile + ".meta.json"
|
||||
|
||||
|
||||
data, err := json.MarshalIndent(m, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal cluster metadata: %w", err)
|
||||
@@ -130,7 +130,7 @@ func (m *ClusterMetadata) Save(targetFile string) error {
|
||||
// LoadCluster reads cluster metadata from a .meta.json file
|
||||
func LoadCluster(targetFile string) (*ClusterMetadata, error) {
|
||||
metaPath := targetFile + ".meta.json"
|
||||
|
||||
|
||||
data, err := os.ReadFile(metaPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read cluster metadata file: %w", err)
|
||||
@@ -156,13 +156,13 @@ func ListBackups(dir string) ([]*BackupMetadata, error) {
|
||||
for _, metaFile := range matches {
|
||||
// Extract backup file path (remove .meta.json suffix)
|
||||
backupFile := metaFile[:len(metaFile)-len(".meta.json")]
|
||||
|
||||
|
||||
meta, err := Load(backupFile)
|
||||
if err != nil {
|
||||
// Skip invalid metadata files
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
backups = append(backups, meta)
|
||||
}
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ func NewMetricsCollector(log logger.Logger) *MetricsCollector {
|
||||
func (mc *MetricsCollector) RecordOperation(operation, database string, start time.Time, sizeBytes int64, success bool, errorCount int) {
|
||||
duration := time.Since(start)
|
||||
throughput := calculateThroughput(sizeBytes, duration)
|
||||
|
||||
|
||||
metric := OperationMetrics{
|
||||
Operation: operation,
|
||||
Database: database,
|
||||
@@ -50,11 +50,11 @@ func (mc *MetricsCollector) RecordOperation(operation, database string, start ti
|
||||
ErrorCount: errorCount,
|
||||
Success: success,
|
||||
}
|
||||
|
||||
|
||||
mc.mu.Lock()
|
||||
mc.metrics = append(mc.metrics, metric)
|
||||
mc.mu.Unlock()
|
||||
|
||||
|
||||
// Log structured metrics
|
||||
if mc.logger != nil {
|
||||
fields := map[string]interface{}{
|
||||
@@ -67,7 +67,7 @@ func (mc *MetricsCollector) RecordOperation(operation, database string, start ti
|
||||
"error_count": errorCount,
|
||||
"success": success,
|
||||
}
|
||||
|
||||
|
||||
if success {
|
||||
mc.logger.WithFields(fields).Info("Operation completed successfully")
|
||||
} else {
|
||||
@@ -80,7 +80,7 @@ func (mc *MetricsCollector) RecordOperation(operation, database string, start ti
|
||||
func (mc *MetricsCollector) RecordCompressionRatio(operation, database string, ratio float64) {
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
|
||||
|
||||
// Find and update the most recent matching operation
|
||||
for i := len(mc.metrics) - 1; i >= 0; i-- {
|
||||
if mc.metrics[i].Operation == operation && mc.metrics[i].Database == database {
|
||||
@@ -94,7 +94,7 @@ func (mc *MetricsCollector) RecordCompressionRatio(operation, database string, r
|
||||
func (mc *MetricsCollector) GetMetrics() []OperationMetrics {
|
||||
mc.mu.RLock()
|
||||
defer mc.mu.RUnlock()
|
||||
|
||||
|
||||
result := make([]OperationMetrics, len(mc.metrics))
|
||||
copy(result, mc.metrics)
|
||||
return result
|
||||
@@ -104,15 +104,15 @@ func (mc *MetricsCollector) GetMetrics() []OperationMetrics {
|
||||
func (mc *MetricsCollector) GetAverages() map[string]interface{} {
|
||||
mc.mu.RLock()
|
||||
defer mc.mu.RUnlock()
|
||||
|
||||
|
||||
if len(mc.metrics) == 0 {
|
||||
return map[string]interface{}{}
|
||||
}
|
||||
|
||||
|
||||
var totalDuration time.Duration
|
||||
var totalSize, totalThroughput float64
|
||||
var successCount, errorCount int
|
||||
|
||||
|
||||
for _, m := range mc.metrics {
|
||||
totalDuration += m.Duration
|
||||
totalSize += float64(m.SizeBytes)
|
||||
@@ -122,15 +122,15 @@ func (mc *MetricsCollector) GetAverages() map[string]interface{} {
|
||||
}
|
||||
errorCount += m.ErrorCount
|
||||
}
|
||||
|
||||
|
||||
count := len(mc.metrics)
|
||||
return map[string]interface{}{
|
||||
"total_operations": count,
|
||||
"success_rate": float64(successCount) / float64(count) * 100,
|
||||
"avg_duration_ms": totalDuration.Milliseconds() / int64(count),
|
||||
"avg_size_mb": totalSize / float64(count) / 1024 / 1024,
|
||||
"avg_throughput_mbps": totalThroughput / float64(count),
|
||||
"total_errors": errorCount,
|
||||
"total_operations": count,
|
||||
"success_rate": float64(successCount) / float64(count) * 100,
|
||||
"avg_duration_ms": totalDuration.Milliseconds() / int64(count),
|
||||
"avg_size_mb": totalSize / float64(count) / 1024 / 1024,
|
||||
"avg_throughput_mbps": totalThroughput / float64(count),
|
||||
"total_errors": errorCount,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -159,4 +159,4 @@ var GlobalMetrics *MetricsCollector
|
||||
// InitGlobalMetrics initializes the global metrics collector
|
||||
func InitGlobalMetrics(log logger.Logger) {
|
||||
GlobalMetrics = NewMetricsCollector(log)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,18 +24,18 @@ func NewRecoveryConfigGenerator(log logger.Logger) *RecoveryConfigGenerator {
|
||||
// RecoveryConfig holds all recovery configuration parameters
|
||||
type RecoveryConfig struct {
|
||||
// Core recovery settings
|
||||
Target *RecoveryTarget
|
||||
WALArchiveDir string
|
||||
Target *RecoveryTarget
|
||||
WALArchiveDir string
|
||||
RestoreCommand string
|
||||
|
||||
|
||||
// PostgreSQL version
|
||||
PostgreSQLVersion int // Major version (12, 13, 14, etc.)
|
||||
|
||||
|
||||
// Additional settings
|
||||
PrimaryConnInfo string // For standby mode
|
||||
PrimarySlotName string // Replication slot name
|
||||
PrimaryConnInfo string // For standby mode
|
||||
PrimarySlotName string // Replication slot name
|
||||
RecoveryMinApplyDelay string // Min delay for replay
|
||||
|
||||
|
||||
// Paths
|
||||
DataDir string // PostgreSQL data directory
|
||||
}
|
||||
@@ -61,7 +61,7 @@ func (rcg *RecoveryConfigGenerator) generateModernRecoveryConfig(config *Recover
|
||||
// Create recovery.signal file (empty file that triggers recovery mode)
|
||||
recoverySignalPath := filepath.Join(config.DataDir, "recovery.signal")
|
||||
rcg.log.Info("Creating recovery.signal file", "path", recoverySignalPath)
|
||||
|
||||
|
||||
signalFile, err := os.Create(recoverySignalPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create recovery.signal: %w", err)
|
||||
@@ -180,7 +180,7 @@ func (rcg *RecoveryConfigGenerator) generateLegacyRecoveryConfig(config *Recover
|
||||
func (rcg *RecoveryConfigGenerator) generateRestoreCommand(walArchiveDir string) string {
|
||||
// The restore_command is executed by PostgreSQL to fetch WAL files
|
||||
// %f = WAL filename, %p = full path to copy WAL file to
|
||||
|
||||
|
||||
// Try multiple extensions (.gz.enc, .enc, .gz, plain)
|
||||
// This handles compressed and/or encrypted WAL files
|
||||
return fmt.Sprintf(`bash -c 'for ext in .gz.enc .enc .gz ""; do [ -f "%s/%%f$ext" ] && { [ -z "$ext" ] && cp "%s/%%f$ext" "%%p" || case "$ext" in *.gz.enc) gpg -d "%s/%%f$ext" | gunzip > "%%p" ;; *.enc) gpg -d "%s/%%f$ext" > "%%p" ;; *.gz) gunzip -c "%s/%%f$ext" > "%%p" ;; esac; exit 0; }; done; exit 1'`,
|
||||
@@ -232,14 +232,14 @@ func (rcg *RecoveryConfigGenerator) ValidateDataDirectory(dataDir string) error
|
||||
// DetectPostgreSQLVersion detects the PostgreSQL version from the data directory
|
||||
func (rcg *RecoveryConfigGenerator) DetectPostgreSQLVersion(dataDir string) (int, error) {
|
||||
pgVersionPath := filepath.Join(dataDir, "PG_VERSION")
|
||||
|
||||
|
||||
content, err := os.ReadFile(pgVersionPath)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to read PG_VERSION: %w", err)
|
||||
}
|
||||
|
||||
versionStr := strings.TrimSpace(string(content))
|
||||
|
||||
|
||||
// Parse major version (e.g., "14" or "14.2")
|
||||
parts := strings.Split(versionStr, ".")
|
||||
if len(parts) == 0 {
|
||||
|
||||
@@ -10,10 +10,10 @@ import (
|
||||
|
||||
// RecoveryTarget represents a PostgreSQL recovery target
|
||||
type RecoveryTarget struct {
|
||||
Type string // "time", "xid", "lsn", "name", "immediate"
|
||||
Value string // The target value (timestamp, XID, LSN, or restore point name)
|
||||
Action string // "promote", "pause", "shutdown"
|
||||
Timeline string // Timeline to follow ("latest" or timeline ID)
|
||||
Type string // "time", "xid", "lsn", "name", "immediate"
|
||||
Value string // The target value (timestamp, XID, LSN, or restore point name)
|
||||
Action string // "promote", "pause", "shutdown"
|
||||
Timeline string // Timeline to follow ("latest" or timeline ID)
|
||||
Inclusive bool // Whether target is inclusive (default: true)
|
||||
}
|
||||
|
||||
@@ -128,13 +128,13 @@ func (rt *RecoveryTarget) validateTime() error {
|
||||
|
||||
// Try parsing various timestamp formats
|
||||
formats := []string{
|
||||
"2006-01-02 15:04:05", // Standard format
|
||||
"2006-01-02 15:04:05.999999", // With microseconds
|
||||
"2006-01-02T15:04:05", // ISO 8601
|
||||
"2006-01-02T15:04:05Z", // ISO 8601 with UTC
|
||||
"2006-01-02T15:04:05-07:00", // ISO 8601 with timezone
|
||||
time.RFC3339, // RFC3339
|
||||
time.RFC3339Nano, // RFC3339 with nanoseconds
|
||||
"2006-01-02 15:04:05", // Standard format
|
||||
"2006-01-02 15:04:05.999999", // With microseconds
|
||||
"2006-01-02T15:04:05", // ISO 8601
|
||||
"2006-01-02T15:04:05Z", // ISO 8601 with UTC
|
||||
"2006-01-02T15:04:05-07:00", // ISO 8601 with timezone
|
||||
time.RFC3339, // RFC3339
|
||||
time.RFC3339Nano, // RFC3339 with nanoseconds
|
||||
}
|
||||
|
||||
var parseErr error
|
||||
@@ -283,24 +283,24 @@ func FormatConfigLine(key, value string) string {
|
||||
// String returns a human-readable representation of the recovery target
|
||||
func (rt *RecoveryTarget) String() string {
|
||||
var sb strings.Builder
|
||||
|
||||
|
||||
sb.WriteString("Recovery Target:\n")
|
||||
sb.WriteString(fmt.Sprintf(" Type: %s\n", rt.Type))
|
||||
|
||||
|
||||
if rt.Type != TargetTypeImmediate {
|
||||
sb.WriteString(fmt.Sprintf(" Value: %s\n", rt.Value))
|
||||
}
|
||||
|
||||
|
||||
sb.WriteString(fmt.Sprintf(" Action: %s\n", rt.Action))
|
||||
|
||||
|
||||
if rt.Timeline != "" {
|
||||
sb.WriteString(fmt.Sprintf(" Timeline: %s\n", rt.Timeline))
|
||||
}
|
||||
|
||||
|
||||
if rt.Type != TargetTypeImmediate && rt.Type != TargetTypeName {
|
||||
sb.WriteString(fmt.Sprintf(" Inclusive: %v\n", rt.Inclusive))
|
||||
}
|
||||
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
|
||||
@@ -284,7 +284,7 @@ func (ro *RestoreOrchestrator) startPostgreSQL(ctx context.Context, opts *Restor
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, pgCtl, "-D", opts.TargetDataDir, "-l", filepath.Join(opts.TargetDataDir, "logfile"), "start")
|
||||
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
ro.log.Error("PostgreSQL startup failed", "output", string(output))
|
||||
@@ -321,18 +321,18 @@ func (ro *RestoreOrchestrator) monitorRecovery(ctx context.Context, opts *Restor
|
||||
pidFile := filepath.Join(opts.TargetDataDir, "postmaster.pid")
|
||||
if _, err := os.Stat(pidFile); err == nil {
|
||||
ro.log.Info("✅ PostgreSQL is running")
|
||||
|
||||
|
||||
// Check if recovery files still exist
|
||||
recoverySignal := filepath.Join(opts.TargetDataDir, "recovery.signal")
|
||||
recoveryConf := filepath.Join(opts.TargetDataDir, "recovery.conf")
|
||||
|
||||
|
||||
if _, err := os.Stat(recoverySignal); os.IsNotExist(err) {
|
||||
if _, err := os.Stat(recoveryConf); os.IsNotExist(err) {
|
||||
ro.log.Info("✅ Recovery completed - PostgreSQL promoted to primary")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
ro.log.Info("Recovery in progress...")
|
||||
} else {
|
||||
ro.log.Info("PostgreSQL not yet started or crashed")
|
||||
|
||||
@@ -17,32 +17,32 @@ type DetailedReporter struct {
|
||||
|
||||
// OperationStatus represents the status of a backup/restore operation
|
||||
type OperationStatus struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"` // "backup", "restore", "verify"
|
||||
Status string `json:"status"` // "running", "completed", "failed"
|
||||
StartTime time.Time `json:"start_time"`
|
||||
EndTime *time.Time `json:"end_time,omitempty"`
|
||||
Duration time.Duration `json:"duration"`
|
||||
Progress int `json:"progress"` // 0-100
|
||||
Message string `json:"message"`
|
||||
Details map[string]string `json:"details"`
|
||||
Steps []StepStatus `json:"steps"`
|
||||
BytesTotal int64 `json:"bytes_total"`
|
||||
BytesDone int64 `json:"bytes_done"`
|
||||
FilesTotal int `json:"files_total"`
|
||||
FilesDone int `json:"files_done"`
|
||||
Errors []string `json:"errors,omitempty"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"` // "backup", "restore", "verify"
|
||||
Status string `json:"status"` // "running", "completed", "failed"
|
||||
StartTime time.Time `json:"start_time"`
|
||||
EndTime *time.Time `json:"end_time,omitempty"`
|
||||
Duration time.Duration `json:"duration"`
|
||||
Progress int `json:"progress"` // 0-100
|
||||
Message string `json:"message"`
|
||||
Details map[string]string `json:"details"`
|
||||
Steps []StepStatus `json:"steps"`
|
||||
BytesTotal int64 `json:"bytes_total"`
|
||||
BytesDone int64 `json:"bytes_done"`
|
||||
FilesTotal int `json:"files_total"`
|
||||
FilesDone int `json:"files_done"`
|
||||
Errors []string `json:"errors,omitempty"`
|
||||
}
|
||||
|
||||
// StepStatus represents individual steps within an operation
|
||||
type StepStatus struct {
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
StartTime time.Time `json:"start_time"`
|
||||
EndTime *time.Time `json:"end_time,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
StartTime time.Time `json:"start_time"`
|
||||
EndTime *time.Time `json:"end_time,omitempty"`
|
||||
Duration time.Duration `json:"duration"`
|
||||
Message string `json:"message"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// Logger interface for detailed reporting
|
||||
@@ -79,7 +79,7 @@ func (dr *DetailedReporter) StartOperation(id, name, opType string) *OperationTr
|
||||
}
|
||||
|
||||
dr.operations = append(dr.operations, operation)
|
||||
|
||||
|
||||
if dr.startTime.IsZero() {
|
||||
dr.startTime = time.Now()
|
||||
}
|
||||
@@ -90,9 +90,9 @@ func (dr *DetailedReporter) StartOperation(id, name, opType string) *OperationTr
|
||||
}
|
||||
|
||||
// Log operation start
|
||||
dr.logger.Info("Operation started",
|
||||
"id", id,
|
||||
"name", name,
|
||||
dr.logger.Info("Operation started",
|
||||
"id", id,
|
||||
"name", name,
|
||||
"type", opType,
|
||||
"timestamp", operation.StartTime.Format(time.RFC3339))
|
||||
|
||||
@@ -117,7 +117,7 @@ func (ot *OperationTracker) UpdateProgress(progress int, message string) {
|
||||
if ot.reporter.operations[i].ID == ot.operationID {
|
||||
ot.reporter.operations[i].Progress = progress
|
||||
ot.reporter.operations[i].Message = message
|
||||
|
||||
|
||||
// Update visual indicator
|
||||
if ot.reporter.indicator != nil {
|
||||
progressMsg := fmt.Sprintf("[%d%%] %s", progress, message)
|
||||
@@ -150,7 +150,7 @@ func (ot *OperationTracker) AddStep(name, message string) *StepTracker {
|
||||
for i := range ot.reporter.operations {
|
||||
if ot.reporter.operations[i].ID == ot.operationID {
|
||||
ot.reporter.operations[i].Steps = append(ot.reporter.operations[i].Steps, step)
|
||||
|
||||
|
||||
// Log step start
|
||||
ot.reporter.logger.Info("Step started",
|
||||
"operation_id", ot.operationID,
|
||||
@@ -190,7 +190,7 @@ func (ot *OperationTracker) SetFileProgress(filesDone, filesTotal int) {
|
||||
if ot.reporter.operations[i].ID == ot.operationID {
|
||||
ot.reporter.operations[i].FilesDone = filesDone
|
||||
ot.reporter.operations[i].FilesTotal = filesTotal
|
||||
|
||||
|
||||
if filesTotal > 0 {
|
||||
progress := (filesDone * 100) / filesTotal
|
||||
ot.reporter.operations[i].Progress = progress
|
||||
@@ -209,25 +209,25 @@ func (ot *OperationTracker) SetByteProgress(bytesDone, bytesTotal int64) {
|
||||
if ot.reporter.operations[i].ID == ot.operationID {
|
||||
ot.reporter.operations[i].BytesDone = bytesDone
|
||||
ot.reporter.operations[i].BytesTotal = bytesTotal
|
||||
|
||||
|
||||
if bytesTotal > 0 {
|
||||
progress := int((bytesDone * 100) / bytesTotal)
|
||||
ot.reporter.operations[i].Progress = progress
|
||||
|
||||
|
||||
// Calculate ETA and speed
|
||||
elapsed := time.Since(ot.reporter.operations[i].StartTime).Seconds()
|
||||
if elapsed > 0 && bytesDone > 0 {
|
||||
speed := float64(bytesDone) / elapsed // bytes/sec
|
||||
remaining := bytesTotal - bytesDone
|
||||
eta := time.Duration(float64(remaining)/speed) * time.Second
|
||||
|
||||
|
||||
// Update progress message with ETA and speed
|
||||
if ot.reporter.indicator != nil {
|
||||
speedStr := formatSpeed(int64(speed))
|
||||
etaStr := formatDuration(eta)
|
||||
progressMsg := fmt.Sprintf("[%d%%] %s / %s (%s/s, ETA: %s)",
|
||||
progress,
|
||||
formatBytes(bytesDone),
|
||||
progressMsg := fmt.Sprintf("[%d%%] %s / %s (%s/s, ETA: %s)",
|
||||
progress,
|
||||
formatBytes(bytesDone),
|
||||
formatBytes(bytesTotal),
|
||||
speedStr,
|
||||
etaStr)
|
||||
@@ -253,7 +253,7 @@ func (ot *OperationTracker) Complete(message string) {
|
||||
ot.reporter.operations[i].EndTime = &now
|
||||
ot.reporter.operations[i].Duration = now.Sub(ot.reporter.operations[i].StartTime)
|
||||
ot.reporter.operations[i].Message = message
|
||||
|
||||
|
||||
// Complete visual indicator
|
||||
if ot.reporter.indicator != nil {
|
||||
ot.reporter.indicator.Complete(fmt.Sprintf("✅ %s", message))
|
||||
@@ -283,7 +283,7 @@ func (ot *OperationTracker) Fail(err error) {
|
||||
ot.reporter.operations[i].Duration = now.Sub(ot.reporter.operations[i].StartTime)
|
||||
ot.reporter.operations[i].Message = err.Error()
|
||||
ot.reporter.operations[i].Errors = append(ot.reporter.operations[i].Errors, err.Error())
|
||||
|
||||
|
||||
// Fail visual indicator
|
||||
if ot.reporter.indicator != nil {
|
||||
ot.reporter.indicator.Fail(fmt.Sprintf("❌ %s", err.Error()))
|
||||
@@ -321,7 +321,7 @@ func (st *StepTracker) Complete(message string) {
|
||||
st.reporter.operations[i].Steps[j].EndTime = &now
|
||||
st.reporter.operations[i].Steps[j].Duration = now.Sub(st.reporter.operations[i].Steps[j].StartTime)
|
||||
st.reporter.operations[i].Steps[j].Message = message
|
||||
|
||||
|
||||
// Log step completion
|
||||
st.reporter.logger.Info("Step completed",
|
||||
"operation_id", st.operationID,
|
||||
@@ -351,7 +351,7 @@ func (st *StepTracker) Fail(err error) {
|
||||
st.reporter.operations[i].Steps[j].EndTime = &now
|
||||
st.reporter.operations[i].Steps[j].Duration = now.Sub(st.reporter.operations[i].Steps[j].StartTime)
|
||||
st.reporter.operations[i].Steps[j].Message = err.Error()
|
||||
|
||||
|
||||
// Log step failure
|
||||
st.reporter.logger.Error("Step failed",
|
||||
"operation_id", st.operationID,
|
||||
@@ -428,8 +428,8 @@ type OperationSummary struct {
|
||||
func (os *OperationSummary) FormatSummary() string {
|
||||
return fmt.Sprintf(
|
||||
"📊 Operations Summary:\n"+
|
||||
" Total: %d | Completed: %d | Failed: %d | Running: %d\n"+
|
||||
" Total Duration: %s",
|
||||
" Total: %d | Completed: %d | Failed: %d | Running: %d\n"+
|
||||
" Total Duration: %s",
|
||||
os.TotalOperations,
|
||||
os.CompletedOperations,
|
||||
os.FailedOperations,
|
||||
@@ -461,7 +461,7 @@ func formatBytes(bytes int64) string {
|
||||
GB = 1024 * MB
|
||||
TB = 1024 * GB
|
||||
)
|
||||
|
||||
|
||||
switch {
|
||||
case bytes >= TB:
|
||||
return fmt.Sprintf("%.2f TB", float64(bytes)/float64(TB))
|
||||
@@ -483,7 +483,7 @@ func formatSpeed(bytesPerSec int64) string {
|
||||
MB = 1024 * KB
|
||||
GB = 1024 * MB
|
||||
)
|
||||
|
||||
|
||||
switch {
|
||||
case bytesPerSec >= GB:
|
||||
return fmt.Sprintf("%.2f GB", float64(bytesPerSec)/float64(GB))
|
||||
@@ -494,4 +494,4 @@ func formatSpeed(bytesPerSec int64) string {
|
||||
default:
|
||||
return fmt.Sprintf("%d B", bytesPerSec)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,11 +42,11 @@ func (e *ETAEstimator) GetETA() time.Duration {
|
||||
if e.itemsComplete == 0 || e.totalItems == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
|
||||
elapsed := e.GetElapsed()
|
||||
avgTimePerItem := elapsed / time.Duration(e.itemsComplete)
|
||||
remainingItems := e.totalItems - e.itemsComplete
|
||||
|
||||
|
||||
return avgTimePerItem * time.Duration(remainingItems)
|
||||
}
|
||||
|
||||
@@ -83,12 +83,12 @@ func (e *ETAEstimator) GetFullStatus(baseMessage string) string {
|
||||
// No items to track, just show elapsed
|
||||
return fmt.Sprintf("%s | Elapsed: %s", baseMessage, e.FormatElapsed())
|
||||
}
|
||||
|
||||
|
||||
if e.itemsComplete == 0 {
|
||||
// Just started
|
||||
return fmt.Sprintf("%s | 0/%d | Starting...", baseMessage, e.totalItems)
|
||||
}
|
||||
|
||||
|
||||
// Full status with progress and ETA
|
||||
return fmt.Sprintf("%s | %s | Elapsed: %s | ETA: %s",
|
||||
baseMessage,
|
||||
@@ -102,44 +102,44 @@ func FormatDuration(d time.Duration) string {
|
||||
if d < time.Second {
|
||||
return "< 1s"
|
||||
}
|
||||
|
||||
|
||||
hours := int(d.Hours())
|
||||
minutes := int(d.Minutes()) % 60
|
||||
seconds := int(d.Seconds()) % 60
|
||||
|
||||
|
||||
if hours > 0 {
|
||||
if minutes > 0 {
|
||||
return fmt.Sprintf("%dh %dm", hours, minutes)
|
||||
}
|
||||
return fmt.Sprintf("%dh", hours)
|
||||
}
|
||||
|
||||
|
||||
if minutes > 0 {
|
||||
if seconds > 5 { // Only show seconds if > 5
|
||||
return fmt.Sprintf("%dm %ds", minutes, seconds)
|
||||
}
|
||||
return fmt.Sprintf("%dm", minutes)
|
||||
}
|
||||
|
||||
|
||||
return fmt.Sprintf("%ds", seconds)
|
||||
}
|
||||
|
||||
// EstimateSizeBasedDuration estimates duration based on size (fallback when no progress tracking)
|
||||
func EstimateSizeBasedDuration(sizeBytes int64, cores int) time.Duration {
|
||||
sizeMB := float64(sizeBytes) / (1024 * 1024)
|
||||
|
||||
|
||||
// Base estimate: ~100MB per minute on average hardware
|
||||
baseMinutes := sizeMB / 100.0
|
||||
|
||||
|
||||
// Adjust for CPU cores (more cores = faster, but not linear)
|
||||
// Use square root to represent diminishing returns
|
||||
if cores > 1 {
|
||||
speedup := 1.0 + (0.3 * (float64(cores) - 1)) // 30% improvement per core
|
||||
baseMinutes = baseMinutes / speedup
|
||||
}
|
||||
|
||||
|
||||
// Add 20% buffer for safety
|
||||
baseMinutes = baseMinutes * 1.2
|
||||
|
||||
|
||||
return time.Duration(baseMinutes * float64(time.Minute))
|
||||
}
|
||||
|
||||
@@ -7,19 +7,19 @@ import (
|
||||
|
||||
func TestNewETAEstimator(t *testing.T) {
|
||||
estimator := NewETAEstimator("Test Operation", 10)
|
||||
|
||||
|
||||
if estimator.operation != "Test Operation" {
|
||||
t.Errorf("Expected operation 'Test Operation', got '%s'", estimator.operation)
|
||||
}
|
||||
|
||||
|
||||
if estimator.totalItems != 10 {
|
||||
t.Errorf("Expected totalItems 10, got %d", estimator.totalItems)
|
||||
}
|
||||
|
||||
|
||||
if estimator.itemsComplete != 0 {
|
||||
t.Errorf("Expected itemsComplete 0, got %d", estimator.itemsComplete)
|
||||
}
|
||||
|
||||
|
||||
if estimator.startTime.IsZero() {
|
||||
t.Error("Expected startTime to be set")
|
||||
}
|
||||
@@ -27,12 +27,12 @@ func TestNewETAEstimator(t *testing.T) {
|
||||
|
||||
func TestUpdateProgress(t *testing.T) {
|
||||
estimator := NewETAEstimator("Test", 10)
|
||||
|
||||
|
||||
estimator.UpdateProgress(5)
|
||||
if estimator.itemsComplete != 5 {
|
||||
t.Errorf("Expected itemsComplete 5, got %d", estimator.itemsComplete)
|
||||
}
|
||||
|
||||
|
||||
estimator.UpdateProgress(8)
|
||||
if estimator.itemsComplete != 8 {
|
||||
t.Errorf("Expected itemsComplete 8, got %d", estimator.itemsComplete)
|
||||
@@ -41,24 +41,24 @@ func TestUpdateProgress(t *testing.T) {
|
||||
|
||||
func TestGetProgress(t *testing.T) {
|
||||
estimator := NewETAEstimator("Test", 10)
|
||||
|
||||
|
||||
// Test 0% progress
|
||||
if progress := estimator.GetProgress(); progress != 0 {
|
||||
t.Errorf("Expected 0%%, got %.2f%%", progress)
|
||||
}
|
||||
|
||||
|
||||
// Test 50% progress
|
||||
estimator.UpdateProgress(5)
|
||||
if progress := estimator.GetProgress(); progress != 50.0 {
|
||||
t.Errorf("Expected 50%%, got %.2f%%", progress)
|
||||
}
|
||||
|
||||
|
||||
// Test 100% progress
|
||||
estimator.UpdateProgress(10)
|
||||
if progress := estimator.GetProgress(); progress != 100.0 {
|
||||
t.Errorf("Expected 100%%, got %.2f%%", progress)
|
||||
}
|
||||
|
||||
|
||||
// Test zero division
|
||||
zeroEstimator := NewETAEstimator("Test", 0)
|
||||
if progress := zeroEstimator.GetProgress(); progress != 0 {
|
||||
@@ -68,10 +68,10 @@ func TestGetProgress(t *testing.T) {
|
||||
|
||||
func TestGetElapsed(t *testing.T) {
|
||||
estimator := NewETAEstimator("Test", 10)
|
||||
|
||||
|
||||
// Wait a bit
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
|
||||
elapsed := estimator.GetElapsed()
|
||||
if elapsed < 100*time.Millisecond {
|
||||
t.Errorf("Expected elapsed time >= 100ms, got %v", elapsed)
|
||||
@@ -80,16 +80,16 @@ func TestGetElapsed(t *testing.T) {
|
||||
|
||||
func TestGetETA(t *testing.T) {
|
||||
estimator := NewETAEstimator("Test", 10)
|
||||
|
||||
|
||||
// No progress yet, ETA should be 0
|
||||
if eta := estimator.GetETA(); eta != 0 {
|
||||
t.Errorf("Expected ETA 0 for no progress, got %v", eta)
|
||||
}
|
||||
|
||||
|
||||
// Simulate 5 items completed in 5 seconds
|
||||
estimator.startTime = time.Now().Add(-5 * time.Second)
|
||||
estimator.UpdateProgress(5)
|
||||
|
||||
|
||||
eta := estimator.GetETA()
|
||||
// Should be approximately 5 seconds (5 items remaining at 1 sec/item)
|
||||
if eta < 4*time.Second || eta > 6*time.Second {
|
||||
@@ -99,18 +99,18 @@ func TestGetETA(t *testing.T) {
|
||||
|
||||
func TestFormatProgress(t *testing.T) {
|
||||
estimator := NewETAEstimator("Test", 13)
|
||||
|
||||
|
||||
// Test at 0%
|
||||
if result := estimator.FormatProgress(); result != "0/13 (0%)" {
|
||||
t.Errorf("Expected '0/13 (0%%)', got '%s'", result)
|
||||
}
|
||||
|
||||
|
||||
// Test at 38%
|
||||
estimator.UpdateProgress(5)
|
||||
if result := estimator.FormatProgress(); result != "5/13 (38%)" {
|
||||
t.Errorf("Expected '5/13 (38%%)', got '%s'", result)
|
||||
}
|
||||
|
||||
|
||||
// Test at 100%
|
||||
estimator.UpdateProgress(13)
|
||||
if result := estimator.FormatProgress(); result != "13/13 (100%)" {
|
||||
@@ -125,16 +125,16 @@ func TestFormatDuration(t *testing.T) {
|
||||
}{
|
||||
{500 * time.Millisecond, "< 1s"},
|
||||
{5 * time.Second, "5s"},
|
||||
{65 * time.Second, "1m"}, // 5 seconds not shown (<=5)
|
||||
{125 * time.Second, "2m"}, // 5 seconds not shown (<=5)
|
||||
{65 * time.Second, "1m"}, // 5 seconds not shown (<=5)
|
||||
{125 * time.Second, "2m"}, // 5 seconds not shown (<=5)
|
||||
{3 * time.Minute, "3m"},
|
||||
{3*time.Minute + 3*time.Second, "3m"}, // < 5 seconds not shown
|
||||
{3*time.Minute + 10*time.Second, "3m 10s"}, // > 5 seconds shown
|
||||
{3*time.Minute + 3*time.Second, "3m"}, // < 5 seconds not shown
|
||||
{3*time.Minute + 10*time.Second, "3m 10s"}, // > 5 seconds shown
|
||||
{90 * time.Minute, "1h 30m"},
|
||||
{120 * time.Minute, "2h"},
|
||||
{150 * time.Minute, "2h 30m"},
|
||||
}
|
||||
|
||||
|
||||
for _, tt := range tests {
|
||||
result := FormatDuration(tt.duration)
|
||||
if result != tt.expected {
|
||||
@@ -145,16 +145,16 @@ func TestFormatDuration(t *testing.T) {
|
||||
|
||||
func TestFormatETA(t *testing.T) {
|
||||
estimator := NewETAEstimator("Test", 10)
|
||||
|
||||
|
||||
// No progress - should show "calculating..."
|
||||
if result := estimator.FormatETA(); result != "calculating..." {
|
||||
t.Errorf("Expected 'calculating...', got '%s'", result)
|
||||
}
|
||||
|
||||
|
||||
// With progress
|
||||
estimator.startTime = time.Now().Add(-10 * time.Second)
|
||||
estimator.UpdateProgress(5)
|
||||
|
||||
|
||||
result := estimator.FormatETA()
|
||||
if result != "~10s remaining" {
|
||||
t.Errorf("Expected '~10s remaining', got '%s'", result)
|
||||
@@ -164,7 +164,7 @@ func TestFormatETA(t *testing.T) {
|
||||
func TestFormatElapsed(t *testing.T) {
|
||||
estimator := NewETAEstimator("Test", 10)
|
||||
estimator.startTime = time.Now().Add(-45 * time.Second)
|
||||
|
||||
|
||||
result := estimator.FormatElapsed()
|
||||
if result != "45s" {
|
||||
t.Errorf("Expected '45s', got '%s'", result)
|
||||
@@ -173,23 +173,23 @@ func TestFormatElapsed(t *testing.T) {
|
||||
|
||||
func TestGetFullStatus(t *testing.T) {
|
||||
estimator := NewETAEstimator("Backing up cluster", 13)
|
||||
|
||||
|
||||
// Just started (0 items)
|
||||
result := estimator.GetFullStatus("Backing up cluster")
|
||||
if result != "Backing up cluster | 0/13 | Starting..." {
|
||||
t.Errorf("Unexpected result for 0 items: '%s'", result)
|
||||
}
|
||||
|
||||
|
||||
// With progress
|
||||
estimator.startTime = time.Now().Add(-30 * time.Second)
|
||||
estimator.UpdateProgress(5)
|
||||
|
||||
|
||||
result = estimator.GetFullStatus("Backing up cluster")
|
||||
// Should contain all components
|
||||
if len(result) < 50 { // Reasonable minimum length
|
||||
t.Errorf("Result too short: '%s'", result)
|
||||
}
|
||||
|
||||
|
||||
// Check it contains key elements (format may vary slightly)
|
||||
if !contains(result, "5/13") {
|
||||
t.Errorf("Result missing progress '5/13': '%s'", result)
|
||||
@@ -208,7 +208,7 @@ func TestGetFullStatus(t *testing.T) {
|
||||
func TestGetFullStatusWithZeroItems(t *testing.T) {
|
||||
estimator := NewETAEstimator("Test Operation", 0)
|
||||
estimator.startTime = time.Now().Add(-5 * time.Second)
|
||||
|
||||
|
||||
result := estimator.GetFullStatus("Test Operation")
|
||||
// Should only show elapsed time when no items to track
|
||||
if !contains(result, "Test Operation") || !contains(result, "Elapsed:") {
|
||||
@@ -226,13 +226,13 @@ func TestEstimateSizeBasedDuration(t *testing.T) {
|
||||
if duration < 60*time.Second || duration > 90*time.Second {
|
||||
t.Errorf("Expected ~1.2 minutes for 100MB/1core, got %v", duration)
|
||||
}
|
||||
|
||||
|
||||
// Test 100MB with 8 cores (should be faster)
|
||||
duration8cores := EstimateSizeBasedDuration(100*1024*1024, 8)
|
||||
if duration8cores >= duration {
|
||||
t.Errorf("Expected faster with more cores: %v vs %v", duration8cores, duration)
|
||||
}
|
||||
|
||||
|
||||
// Test larger file
|
||||
duration1GB := EstimateSizeBasedDuration(1024*1024*1024, 1)
|
||||
if duration1GB <= duration {
|
||||
@@ -242,9 +242,8 @@ func TestEstimateSizeBasedDuration(t *testing.T) {
|
||||
|
||||
// Helper function
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr ||
|
||||
len(s) > len(substr) && (
|
||||
s[:len(substr)] == substr ||
|
||||
return len(s) >= len(substr) && (s == substr ||
|
||||
len(s) > len(substr) && (s[:len(substr)] == substr ||
|
||||
s[len(s)-len(substr):] == substr ||
|
||||
indexHelper(s, substr) >= 0))
|
||||
}
|
||||
|
||||
@@ -43,11 +43,11 @@ func NewSpinner() *Spinner {
|
||||
func (s *Spinner) Start(message string) {
|
||||
s.message = message
|
||||
s.active = true
|
||||
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(s.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
|
||||
i := 0
|
||||
lastMessage := ""
|
||||
for {
|
||||
@@ -57,12 +57,12 @@ func (s *Spinner) Start(message string) {
|
||||
case <-ticker.C:
|
||||
if s.active {
|
||||
displayMsg := s.message
|
||||
|
||||
|
||||
// Add ETA info if estimator is available
|
||||
if s.estimator != nil {
|
||||
displayMsg = s.estimator.GetFullStatus(s.message)
|
||||
}
|
||||
|
||||
|
||||
currentFrame := fmt.Sprintf("%s %s", s.frames[i%len(s.frames)], displayMsg)
|
||||
if s.message != lastMessage {
|
||||
// Print new line for new messages
|
||||
@@ -130,13 +130,13 @@ func NewDots() *Dots {
|
||||
func (d *Dots) Start(message string) {
|
||||
d.message = message
|
||||
d.active = true
|
||||
|
||||
|
||||
fmt.Fprint(d.writer, message)
|
||||
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(500 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
|
||||
count := 0
|
||||
for {
|
||||
select {
|
||||
@@ -191,13 +191,13 @@ func (d *Dots) SetEstimator(estimator *ETAEstimator) {
|
||||
|
||||
// ProgressBar creates a visual progress bar
|
||||
type ProgressBar struct {
|
||||
writer io.Writer
|
||||
message string
|
||||
total int
|
||||
current int
|
||||
width int
|
||||
active bool
|
||||
stopCh chan bool
|
||||
writer io.Writer
|
||||
message string
|
||||
total int
|
||||
current int
|
||||
width int
|
||||
active bool
|
||||
stopCh chan bool
|
||||
}
|
||||
|
||||
// NewProgressBar creates a new progress bar
|
||||
@@ -265,12 +265,12 @@ func (p *ProgressBar) render() {
|
||||
if !p.active {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
percent := float64(p.current) / float64(p.total)
|
||||
filled := int(percent * float64(p.width))
|
||||
|
||||
|
||||
bar := strings.Repeat("█", filled) + strings.Repeat("░", p.width-filled)
|
||||
|
||||
|
||||
fmt.Fprintf(p.writer, "\n%s [%s] %d%%", p.message, bar, int(percent*100))
|
||||
}
|
||||
|
||||
@@ -432,7 +432,7 @@ func NewIndicator(interactive bool, indicatorType string) Indicator {
|
||||
if !interactive {
|
||||
return NewLineByLine() // Use line-by-line for non-interactive mode
|
||||
}
|
||||
|
||||
|
||||
switch indicatorType {
|
||||
case "spinner":
|
||||
return NewSpinner()
|
||||
@@ -457,9 +457,9 @@ func NewNullIndicator() *NullIndicator {
|
||||
return &NullIndicator{}
|
||||
}
|
||||
|
||||
func (n *NullIndicator) Start(message string) {}
|
||||
func (n *NullIndicator) Update(message string) {}
|
||||
func (n *NullIndicator) Complete(message string) {}
|
||||
func (n *NullIndicator) Fail(message string) {}
|
||||
func (n *NullIndicator) Stop() {}
|
||||
func (n *NullIndicator) Start(message string) {}
|
||||
func (n *NullIndicator) Update(message string) {}
|
||||
func (n *NullIndicator) Complete(message string) {}
|
||||
func (n *NullIndicator) Fail(message string) {}
|
||||
func (n *NullIndicator) Stop() {}
|
||||
func (n *NullIndicator) SetEstimator(estimator *ETAEstimator) {}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
//go:build openbsd
|
||||
// +build openbsd
|
||||
|
||||
package restore
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
//go:build netbsd
|
||||
// +build netbsd
|
||||
|
||||
package restore
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
//go:build !windows && !openbsd && !netbsd
|
||||
// +build !windows,!openbsd,!netbsd
|
||||
|
||||
package restore
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package restore
|
||||
|
||||
@@ -358,21 +358,21 @@ func (e *Engine) executeRestoreCommand(ctx context.Context, cmdArgs []string) er
|
||||
e.log.Warn("Restore completed with ignorable errors", "error_count", errorCount, "last_error", lastError)
|
||||
return nil // Success despite ignorable errors
|
||||
}
|
||||
|
||||
|
||||
// Classify error and provide helpful hints
|
||||
if lastError != "" {
|
||||
classification := checks.ClassifyError(lastError)
|
||||
e.log.Error("Restore command failed",
|
||||
"error", err,
|
||||
"last_stderr", lastError,
|
||||
e.log.Error("Restore command failed",
|
||||
"error", err,
|
||||
"last_stderr", lastError,
|
||||
"error_count", errorCount,
|
||||
"error_type", classification.Type,
|
||||
"hint", classification.Hint,
|
||||
"action", classification.Action)
|
||||
return fmt.Errorf("restore failed: %w (last error: %s, total errors: %d) - %s",
|
||||
return fmt.Errorf("restore failed: %w (last error: %s, total errors: %d) - %s",
|
||||
err, lastError, errorCount, classification.Hint)
|
||||
}
|
||||
|
||||
|
||||
e.log.Error("Restore command failed", "error", err, "last_stderr", lastError, "error_count", errorCount)
|
||||
return fmt.Errorf("restore failed: %w", err)
|
||||
}
|
||||
@@ -440,21 +440,21 @@ func (e *Engine) executeRestoreWithDecompression(ctx context.Context, archivePat
|
||||
e.log.Warn("Restore with decompression completed with ignorable errors", "error_count", errorCount, "last_error", lastError)
|
||||
return nil // Success despite ignorable errors
|
||||
}
|
||||
|
||||
|
||||
// Classify error and provide helpful hints
|
||||
if lastError != "" {
|
||||
classification := checks.ClassifyError(lastError)
|
||||
e.log.Error("Restore with decompression failed",
|
||||
"error", err,
|
||||
"last_stderr", lastError,
|
||||
e.log.Error("Restore with decompression failed",
|
||||
"error", err,
|
||||
"last_stderr", lastError,
|
||||
"error_count", errorCount,
|
||||
"error_type", classification.Type,
|
||||
"hint", classification.Hint,
|
||||
"action", classification.Action)
|
||||
return fmt.Errorf("restore failed: %w (last error: %s, total errors: %d) - %s",
|
||||
return fmt.Errorf("restore failed: %w (last error: %s, total errors: %d) - %s",
|
||||
err, lastError, errorCount, classification.Hint)
|
||||
}
|
||||
|
||||
|
||||
e.log.Error("Restore with decompression failed", "error", err, "last_stderr", lastError, "error_count", errorCount)
|
||||
return fmt.Errorf("restore failed: %w", err)
|
||||
}
|
||||
@@ -530,20 +530,20 @@ func (e *Engine) RestoreCluster(ctx context.Context, archivePath string) error {
|
||||
operation.Fail("Invalid cluster archive format")
|
||||
return fmt.Errorf("not a cluster archive: %s (detected format: %s)", archivePath, format)
|
||||
}
|
||||
|
||||
|
||||
// Check disk space before starting restore
|
||||
e.log.Info("Checking disk space for restore")
|
||||
archiveInfo, err := os.Stat(archivePath)
|
||||
if err == nil {
|
||||
spaceCheck := checks.CheckDiskSpaceForRestore(e.cfg.BackupDir, archiveInfo.Size())
|
||||
|
||||
|
||||
if spaceCheck.Critical {
|
||||
operation.Fail("Insufficient disk space")
|
||||
return fmt.Errorf("insufficient disk space for restore: %.1f%% used - need at least 4x archive size", spaceCheck.UsedPercent)
|
||||
}
|
||||
|
||||
|
||||
if spaceCheck.Warning {
|
||||
e.log.Warn("Low disk space - restore may fail",
|
||||
e.log.Warn("Low disk space - restore may fail",
|
||||
"available_gb", float64(spaceCheck.AvailableBytes)/(1024*1024*1024),
|
||||
"used_percent", spaceCheck.UsedPercent)
|
||||
}
|
||||
@@ -638,13 +638,13 @@ func (e *Engine) RestoreCluster(ctx context.Context, archivePath string) error {
|
||||
|
||||
// Check for large objects in dump files and adjust parallelism
|
||||
hasLargeObjects := e.detectLargeObjectsInDumps(dumpsDir, entries)
|
||||
|
||||
|
||||
// Use worker pool for parallel restore
|
||||
parallelism := e.cfg.ClusterParallelism
|
||||
if parallelism < 1 {
|
||||
parallelism = 1 // Ensure at least sequential
|
||||
}
|
||||
|
||||
|
||||
// Automatically reduce parallelism if large objects detected
|
||||
if hasLargeObjects && parallelism > 1 {
|
||||
e.log.Warn("Large objects detected in dump files - reducing parallelism to avoid lock contention",
|
||||
@@ -731,13 +731,13 @@ func (e *Engine) RestoreCluster(ctx context.Context, archivePath string) error {
|
||||
mu.Lock()
|
||||
e.log.Error("Failed to restore database", "name", dbName, "file", dumpFile, "error", restoreErr)
|
||||
mu.Unlock()
|
||||
|
||||
|
||||
// Check for specific recoverable errors
|
||||
errMsg := restoreErr.Error()
|
||||
if strings.Contains(errMsg, "max_locks_per_transaction") {
|
||||
mu.Lock()
|
||||
e.log.Warn("Database restore failed due to insufficient locks - this is a PostgreSQL configuration issue",
|
||||
"database", dbName,
|
||||
e.log.Warn("Database restore failed due to insufficient locks - this is a PostgreSQL configuration issue",
|
||||
"database", dbName,
|
||||
"solution", "increase max_locks_per_transaction in postgresql.conf")
|
||||
mu.Unlock()
|
||||
} else if strings.Contains(errMsg, "total errors:") && strings.Contains(errMsg, "2562426") {
|
||||
@@ -747,7 +747,7 @@ func (e *Engine) RestoreCluster(ctx context.Context, archivePath string) error {
|
||||
"errors", "2562426")
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
|
||||
failedDBsMu.Lock()
|
||||
// Include more context in the error message
|
||||
failedDBs = append(failedDBs, fmt.Sprintf("%s: restore failed: %v", dbName, restoreErr))
|
||||
@@ -770,16 +770,16 @@ func (e *Engine) RestoreCluster(ctx context.Context, archivePath string) error {
|
||||
|
||||
if failCountFinal > 0 {
|
||||
failedList := strings.Join(failedDBs, "\n ")
|
||||
|
||||
|
||||
// Log summary
|
||||
e.log.Info("Cluster restore completed with failures",
|
||||
"succeeded", successCountFinal,
|
||||
"failed", failCountFinal,
|
||||
"total", totalDBs)
|
||||
|
||||
|
||||
e.progress.Fail(fmt.Sprintf("Cluster restore: %d succeeded, %d failed out of %d total", successCountFinal, failCountFinal, totalDBs))
|
||||
operation.Complete(fmt.Sprintf("Partial restore: %d/%d databases succeeded", successCountFinal, totalDBs))
|
||||
|
||||
|
||||
return fmt.Errorf("cluster restore completed with %d failures:\n %s", failCountFinal, failedList)
|
||||
}
|
||||
|
||||
@@ -1079,48 +1079,48 @@ func (e *Engine) detectLargeObjectsInDumps(dumpsDir string, entries []os.DirEntr
|
||||
hasLargeObjects := false
|
||||
checkedCount := 0
|
||||
maxChecks := 5 // Only check first 5 dumps to avoid slowdown
|
||||
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || checkedCount >= maxChecks {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
dumpFile := filepath.Join(dumpsDir, entry.Name())
|
||||
|
||||
|
||||
// Skip compressed SQL files (can't easily check without decompressing)
|
||||
if strings.HasSuffix(dumpFile, ".sql.gz") {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
// Use pg_restore -l to list contents (fast, doesn't restore data)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
|
||||
cmd := exec.CommandContext(ctx, "pg_restore", "-l", dumpFile)
|
||||
output, err := cmd.Output()
|
||||
|
||||
|
||||
if err != nil {
|
||||
// If pg_restore -l fails, it might not be custom format - skip
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
checkedCount++
|
||||
|
||||
|
||||
// Check if output contains "BLOB" or "LARGE OBJECT" entries
|
||||
outputStr := string(output)
|
||||
if strings.Contains(outputStr, "BLOB") ||
|
||||
strings.Contains(outputStr, "LARGE OBJECT") ||
|
||||
strings.Contains(outputStr, " BLOBS ") {
|
||||
if strings.Contains(outputStr, "BLOB") ||
|
||||
strings.Contains(outputStr, "LARGE OBJECT") ||
|
||||
strings.Contains(outputStr, " BLOBS ") {
|
||||
e.log.Info("Large objects detected in dump file", "file", entry.Name())
|
||||
hasLargeObjects = true
|
||||
// Don't break - log all files with large objects
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if hasLargeObjects {
|
||||
e.log.Warn("Cluster contains databases with large objects - parallel restore may cause lock contention")
|
||||
}
|
||||
|
||||
|
||||
return hasLargeObjects
|
||||
}
|
||||
|
||||
@@ -1128,13 +1128,13 @@ func (e *Engine) detectLargeObjectsInDumps(dumpsDir string, entries []os.DirEntr
|
||||
func (e *Engine) isIgnorableError(errorMsg string) bool {
|
||||
// Convert to lowercase for case-insensitive matching
|
||||
lowerMsg := strings.ToLower(errorMsg)
|
||||
|
||||
|
||||
// CRITICAL: Syntax errors are NOT ignorable - indicates corrupted dump
|
||||
if strings.Contains(lowerMsg, "syntax error") {
|
||||
e.log.Error("CRITICAL: Syntax error in dump file - dump may be corrupted", "error", errorMsg)
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
// CRITICAL: If error count is extremely high (>100k), dump is likely corrupted
|
||||
if strings.Contains(errorMsg, "total errors:") {
|
||||
// Extract error count if present in message
|
||||
@@ -1149,21 +1149,21 @@ func (e *Engine) isIgnorableError(errorMsg string) bool {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// List of ignorable error patterns (objects that already exist)
|
||||
ignorablePatterns := []string{
|
||||
"already exists",
|
||||
"duplicate key",
|
||||
"does not exist, skipping", // For DROP IF EXISTS
|
||||
"no pg_hba.conf entry", // Permission warnings (not fatal)
|
||||
"no pg_hba.conf entry", // Permission warnings (not fatal)
|
||||
}
|
||||
|
||||
|
||||
for _, pattern := range ignorablePatterns {
|
||||
if strings.Contains(lowerMsg, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
@@ -1,24 +1,24 @@
|
||||
package restore
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"compress/gzip"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ArchiveFormat represents the type of backup archive
|
||||
type ArchiveFormat string
|
||||
|
||||
const (
|
||||
FormatPostgreSQLDump ArchiveFormat = "PostgreSQL Dump (.dump)"
|
||||
FormatPostgreSQLDumpGz ArchiveFormat = "PostgreSQL Dump Compressed (.dump.gz)"
|
||||
FormatPostgreSQLSQL ArchiveFormat = "PostgreSQL SQL (.sql)"
|
||||
FormatPostgreSQLSQLGz ArchiveFormat = "PostgreSQL SQL Compressed (.sql.gz)"
|
||||
FormatMySQLSQL ArchiveFormat = "MySQL SQL (.sql)"
|
||||
FormatMySQLSQLGz ArchiveFormat = "MySQL SQL Compressed (.sql.gz)"
|
||||
FormatClusterTarGz ArchiveFormat = "Cluster Archive (.tar.gz)"
|
||||
FormatUnknown ArchiveFormat = "Unknown"
|
||||
FormatPostgreSQLDump ArchiveFormat = "PostgreSQL Dump (.dump)"
|
||||
FormatPostgreSQLDumpGz ArchiveFormat = "PostgreSQL Dump Compressed (.dump.gz)"
|
||||
FormatPostgreSQLSQL ArchiveFormat = "PostgreSQL SQL (.sql)"
|
||||
FormatPostgreSQLSQLGz ArchiveFormat = "PostgreSQL SQL Compressed (.sql.gz)"
|
||||
FormatMySQLSQL ArchiveFormat = "MySQL SQL (.sql)"
|
||||
FormatMySQLSQLGz ArchiveFormat = "MySQL SQL Compressed (.sql.gz)"
|
||||
FormatClusterTarGz ArchiveFormat = "Cluster Archive (.tar.gz)"
|
||||
FormatUnknown ArchiveFormat = "Unknown"
|
||||
)
|
||||
|
||||
// DetectArchiveFormat detects the format of a backup archive from its filename and content
|
||||
@@ -37,7 +37,7 @@ func DetectArchiveFormat(filename string) ArchiveFormat {
|
||||
result := isCustomFormat(filename, true)
|
||||
// If file doesn't exist or we can't read it, trust the extension
|
||||
// If file exists and has PGDMP signature, it's custom format
|
||||
// If file exists but doesn't have signature, it might be SQL named as .dump
|
||||
// If file exists but doesn't have signature, it might be SQL named as .dump
|
||||
if result == formatCheckCustom || result == formatCheckFileNotFound {
|
||||
return FormatPostgreSQLDumpGz
|
||||
}
|
||||
@@ -81,9 +81,9 @@ func DetectArchiveFormat(filename string) ArchiveFormat {
|
||||
type formatCheckResult int
|
||||
|
||||
const (
|
||||
formatCheckFileNotFound formatCheckResult = iota
|
||||
formatCheckCustom
|
||||
formatCheckNotCustom
|
||||
formatCheckFileNotFound formatCheckResult = iota
|
||||
formatCheckCustom
|
||||
formatCheckNotCustom
|
||||
)
|
||||
|
||||
// isCustomFormat checks if a file is PostgreSQL custom format (has PGDMP signature)
|
||||
|
||||
@@ -242,7 +242,7 @@ func (s *Safety) CheckDiskSpaceAt(archivePath string, checkDir string, multiplie
|
||||
}
|
||||
|
||||
archiveSize := stat.Size()
|
||||
|
||||
|
||||
// Estimate required space (archive size * multiplier for decompression/extraction)
|
||||
requiredSpace := int64(float64(archiveSize) * multiplier)
|
||||
|
||||
@@ -323,12 +323,12 @@ func (s *Safety) checkPostgresDatabaseExists(ctx context.Context, dbName string)
|
||||
"-d", "postgres",
|
||||
"-tAc", fmt.Sprintf("SELECT 1 FROM pg_database WHERE datname='%s'", dbName),
|
||||
}
|
||||
|
||||
|
||||
// Only add -h flag if host is not localhost (to use Unix socket for peer auth)
|
||||
if s.cfg.Host != "localhost" && s.cfg.Host != "127.0.0.1" && s.cfg.Host != "" {
|
||||
args = append([]string{"-h", s.cfg.Host}, args...)
|
||||
}
|
||||
|
||||
|
||||
cmd := exec.CommandContext(ctx, "psql", args...)
|
||||
|
||||
// Set password if provided
|
||||
@@ -351,12 +351,12 @@ func (s *Safety) checkMySQLDatabaseExists(ctx context.Context, dbName string) (b
|
||||
"-u", s.cfg.User,
|
||||
"-e", fmt.Sprintf("SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME='%s'", dbName),
|
||||
}
|
||||
|
||||
|
||||
// Only add -h flag if host is not localhost (to use Unix socket)
|
||||
if s.cfg.Host != "localhost" && s.cfg.Host != "127.0.0.1" && s.cfg.Host != "" {
|
||||
args = append([]string{"-h", s.cfg.Host}, args...)
|
||||
}
|
||||
|
||||
|
||||
cmd := exec.CommandContext(ctx, "mysql", args...)
|
||||
|
||||
if s.cfg.Password != "" {
|
||||
@@ -386,7 +386,7 @@ func (s *Safety) ListUserDatabases(ctx context.Context) ([]string, error) {
|
||||
func (s *Safety) listPostgresUserDatabases(ctx context.Context) ([]string, error) {
|
||||
// Query to get non-template databases excluding 'postgres' system DB
|
||||
query := "SELECT datname FROM pg_database WHERE datistemplate = false AND datname != 'postgres' ORDER BY datname"
|
||||
|
||||
|
||||
args := []string{
|
||||
"-p", fmt.Sprintf("%d", s.cfg.Port),
|
||||
"-U", s.cfg.User,
|
||||
@@ -394,12 +394,12 @@ func (s *Safety) listPostgresUserDatabases(ctx context.Context) ([]string, error
|
||||
"-tA", // Tuples only, unaligned
|
||||
"-c", query,
|
||||
}
|
||||
|
||||
|
||||
// Only add -h flag if host is not localhost (to use Unix socket for peer auth)
|
||||
if s.cfg.Host != "localhost" && s.cfg.Host != "127.0.0.1" && s.cfg.Host != "" {
|
||||
args = append([]string{"-h", s.cfg.Host}, args...)
|
||||
}
|
||||
|
||||
|
||||
cmd := exec.CommandContext(ctx, "psql", args...)
|
||||
|
||||
// Set password if provided
|
||||
@@ -429,19 +429,19 @@ func (s *Safety) listPostgresUserDatabases(ctx context.Context) ([]string, error
|
||||
func (s *Safety) listMySQLUserDatabases(ctx context.Context) ([]string, error) {
|
||||
// Exclude system databases
|
||||
query := "SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME NOT IN ('information_schema', 'mysql', 'performance_schema', 'sys') ORDER BY SCHEMA_NAME"
|
||||
|
||||
|
||||
args := []string{
|
||||
"-P", fmt.Sprintf("%d", s.cfg.Port),
|
||||
"-u", s.cfg.User,
|
||||
"-N", // Skip column names
|
||||
"-e", query,
|
||||
}
|
||||
|
||||
|
||||
// Only add -h flag if host is not localhost (to use Unix socket)
|
||||
if s.cfg.Host != "localhost" && s.cfg.Host != "127.0.0.1" && s.cfg.Host != "" {
|
||||
args = append([]string{"-h", s.cfg.Host}, args...)
|
||||
}
|
||||
|
||||
|
||||
cmd := exec.CommandContext(ctx, "mysql", args...)
|
||||
|
||||
if s.cfg.Password != "" {
|
||||
|
||||
@@ -23,7 +23,7 @@ func TestValidateArchive_FileNotFound(t *testing.T) {
|
||||
func TestValidateArchive_EmptyFile(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
emptyFile := filepath.Join(tmpDir, "empty.dump")
|
||||
|
||||
|
||||
if err := os.WriteFile(emptyFile, []byte{}, 0644); err != nil {
|
||||
t.Fatalf("Failed to create empty file: %v", err)
|
||||
}
|
||||
@@ -43,7 +43,7 @@ func TestCheckDiskSpace_InsufficientSpace(t *testing.T) {
|
||||
// Just ensure the function doesn't panic
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.dump")
|
||||
|
||||
|
||||
// Create a small test file
|
||||
if err := os.WriteFile(testFile, []byte("test"), 0644); err != nil {
|
||||
t.Fatalf("Failed to create test file: %v", err)
|
||||
|
||||
@@ -23,21 +23,21 @@ func ParsePostgreSQLVersion(versionStr string) (*VersionInfo, error) {
|
||||
// Match patterns like "PostgreSQL 17.7", "PostgreSQL 13.11", "PostgreSQL 10.23"
|
||||
re := regexp.MustCompile(`PostgreSQL\s+(\d+)\.(\d+)`)
|
||||
matches := re.FindStringSubmatch(versionStr)
|
||||
|
||||
|
||||
if len(matches) < 3 {
|
||||
return nil, fmt.Errorf("could not parse PostgreSQL version from: %s", versionStr)
|
||||
}
|
||||
|
||||
|
||||
major, err := strconv.Atoi(matches[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid major version: %s", matches[1])
|
||||
}
|
||||
|
||||
|
||||
minor, err := strconv.Atoi(matches[2])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid minor version: %s", matches[2])
|
||||
}
|
||||
|
||||
|
||||
return &VersionInfo{
|
||||
Major: major,
|
||||
Minor: minor,
|
||||
@@ -53,24 +53,24 @@ func GetDumpFileVersion(dumpPath string) (*VersionInfo, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read dump file metadata: %w (output: %s)", err, string(output))
|
||||
}
|
||||
|
||||
|
||||
// Look for "Dumped from database version: X.Y.Z" in output
|
||||
re := regexp.MustCompile(`Dumped from database version:\s+(\d+)\.(\d+)`)
|
||||
matches := re.FindStringSubmatch(string(output))
|
||||
|
||||
|
||||
if len(matches) < 3 {
|
||||
// Try alternate format in some dumps
|
||||
re = regexp.MustCompile(`PostgreSQL database dump.*(\d+)\.(\d+)`)
|
||||
matches = re.FindStringSubmatch(string(output))
|
||||
}
|
||||
|
||||
|
||||
if len(matches) < 3 {
|
||||
return nil, fmt.Errorf("could not find version information in dump file")
|
||||
}
|
||||
|
||||
|
||||
major, _ := strconv.Atoi(matches[1])
|
||||
minor, _ := strconv.Atoi(matches[2])
|
||||
|
||||
|
||||
return &VersionInfo{
|
||||
Major: major,
|
||||
Minor: minor,
|
||||
@@ -81,18 +81,18 @@ func GetDumpFileVersion(dumpPath string) (*VersionInfo, error) {
|
||||
// CheckVersionCompatibility checks if restoring from source version to target version is safe
|
||||
func CheckVersionCompatibility(sourceVer, targetVer *VersionInfo) *VersionCompatibilityResult {
|
||||
result := &VersionCompatibilityResult{
|
||||
Compatible: true,
|
||||
Compatible: true,
|
||||
SourceVersion: sourceVer,
|
||||
TargetVersion: targetVer,
|
||||
}
|
||||
|
||||
|
||||
// Same major version - always compatible
|
||||
if sourceVer.Major == targetVer.Major {
|
||||
result.Level = CompatibilityLevelSafe
|
||||
result.Message = "Same major version - fully compatible"
|
||||
return result
|
||||
}
|
||||
|
||||
|
||||
// Downgrade - not supported
|
||||
if sourceVer.Major > targetVer.Major {
|
||||
result.Compatible = false
|
||||
@@ -101,10 +101,10 @@ func CheckVersionCompatibility(sourceVer, targetVer *VersionInfo) *VersionCompat
|
||||
result.Warnings = append(result.Warnings, "Database downgrades require pg_dump from the target version")
|
||||
return result
|
||||
}
|
||||
|
||||
|
||||
// Upgrade - check how many major versions
|
||||
versionDiff := targetVer.Major - sourceVer.Major
|
||||
|
||||
|
||||
if versionDiff == 1 {
|
||||
// One major version upgrade - generally safe
|
||||
result.Level = CompatibilityLevelSafe
|
||||
@@ -113,7 +113,7 @@ func CheckVersionCompatibility(sourceVer, targetVer *VersionInfo) *VersionCompat
|
||||
// 2-3 major versions - should work but review release notes
|
||||
result.Level = CompatibilityLevelWarning
|
||||
result.Message = fmt.Sprintf("Upgrading from PostgreSQL %d to %d - supported but review release notes", sourceVer.Major, targetVer.Major)
|
||||
result.Warnings = append(result.Warnings,
|
||||
result.Warnings = append(result.Warnings,
|
||||
fmt.Sprintf("You are jumping %d major versions - some features may have changed", versionDiff))
|
||||
result.Warnings = append(result.Warnings,
|
||||
"Review release notes for deprecated features or behavior changes")
|
||||
@@ -134,13 +134,13 @@ func CheckVersionCompatibility(sourceVer, targetVer *VersionInfo) *VersionCompat
|
||||
result.Recommendations = append(result.Recommendations,
|
||||
"Review PostgreSQL release notes for versions "+strconv.Itoa(sourceVer.Major)+" through "+strconv.Itoa(targetVer.Major))
|
||||
}
|
||||
|
||||
|
||||
// Add general upgrade advice
|
||||
if versionDiff > 0 {
|
||||
result.Recommendations = append(result.Recommendations,
|
||||
"Run ANALYZE on all tables after restore for optimal query performance")
|
||||
}
|
||||
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -189,33 +189,33 @@ func (e *Engine) CheckRestoreVersionCompatibility(ctx context.Context, dumpPath
|
||||
e.log.Warn("Could not determine dump file version", "error", err)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
||||
// Get target database version
|
||||
targetVerStr, err := e.db.GetVersion(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get target database version: %w", err)
|
||||
}
|
||||
|
||||
|
||||
targetVer, err := ParsePostgreSQLVersion(targetVerStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse target version: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Check compatibility
|
||||
result := CheckVersionCompatibility(dumpVer, targetVer)
|
||||
|
||||
|
||||
// Log the results
|
||||
e.log.Info("Version compatibility check",
|
||||
"source", dumpVer.Full,
|
||||
"target", targetVer.Full,
|
||||
"level", result.Level.String())
|
||||
|
||||
|
||||
if len(result.Warnings) > 0 {
|
||||
for _, warning := range result.Warnings {
|
||||
e.log.Warn(warning)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -19,12 +19,12 @@ type Policy struct {
|
||||
|
||||
// CleanupResult contains information about cleanup operations
|
||||
type CleanupResult struct {
|
||||
TotalBackups int
|
||||
TotalBackups int
|
||||
EligibleForDeletion int
|
||||
Deleted []string
|
||||
Kept []string
|
||||
SpaceFreed int64
|
||||
Errors []error
|
||||
Deleted []string
|
||||
Kept []string
|
||||
SpaceFreed int64
|
||||
Errors []error
|
||||
}
|
||||
|
||||
// ApplyPolicy enforces the retention policy on backups in a directory
|
||||
@@ -63,13 +63,13 @@ func ApplyPolicy(backupDir string, policy Policy) (*CleanupResult, error) {
|
||||
// Check if backup is older than retention period
|
||||
if backup.Timestamp.Before(cutoffDate) {
|
||||
result.EligibleForDeletion++
|
||||
|
||||
|
||||
if policy.DryRun {
|
||||
result.Deleted = append(result.Deleted, backup.BackupFile)
|
||||
} else {
|
||||
// Delete backup file and associated metadata
|
||||
if err := deleteBackup(backup.BackupFile); err != nil {
|
||||
result.Errors = append(result.Errors,
|
||||
result.Errors = append(result.Errors,
|
||||
fmt.Errorf("failed to delete %s: %w", backup.BackupFile, err))
|
||||
} else {
|
||||
result.Deleted = append(result.Deleted, backup.BackupFile)
|
||||
@@ -204,7 +204,7 @@ func CleanupByPattern(backupDir, pattern string, policy Policy) (*CleanupResult,
|
||||
|
||||
if backup.Timestamp.Before(cutoffDate) {
|
||||
result.EligibleForDeletion++
|
||||
|
||||
|
||||
if policy.DryRun {
|
||||
result.Deleted = append(result.Deleted, backup.BackupFile)
|
||||
} else {
|
||||
|
||||
@@ -9,18 +9,18 @@ import (
|
||||
|
||||
// AuditEvent represents an auditable event
|
||||
type AuditEvent struct {
|
||||
Timestamp time.Time
|
||||
User string
|
||||
Action string
|
||||
Resource string
|
||||
Result string
|
||||
Details map[string]interface{}
|
||||
Timestamp time.Time
|
||||
User string
|
||||
Action string
|
||||
Resource string
|
||||
Result string
|
||||
Details map[string]interface{}
|
||||
}
|
||||
|
||||
// AuditLogger provides audit logging functionality
|
||||
type AuditLogger struct {
|
||||
log logger.Logger
|
||||
enabled bool
|
||||
log logger.Logger
|
||||
enabled bool
|
||||
}
|
||||
|
||||
// NewAuditLogger creates a new audit logger
|
||||
|
||||
@@ -42,7 +42,7 @@ func VerifyChecksum(path string, expectedChecksum string) error {
|
||||
func SaveChecksum(archivePath string, checksum string) error {
|
||||
checksumPath := archivePath + ".sha256"
|
||||
content := fmt.Sprintf("%s %s\n", checksum, archivePath)
|
||||
|
||||
|
||||
if err := os.WriteFile(checksumPath, []byte(content), 0644); err != nil {
|
||||
return fmt.Errorf("failed to save checksum: %w", err)
|
||||
}
|
||||
@@ -53,7 +53,7 @@ func SaveChecksum(archivePath string, checksum string) error {
|
||||
// LoadChecksum loads checksum from a .sha256 file
|
||||
func LoadChecksum(archivePath string) (string, error) {
|
||||
checksumPath := archivePath + ".sha256"
|
||||
|
||||
|
||||
data, err := os.ReadFile(checksumPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read checksum file: %w", err)
|
||||
|
||||
@@ -49,7 +49,7 @@ func ValidateArchivePath(path string) (string, error) {
|
||||
// Must have a valid archive extension
|
||||
ext := strings.ToLower(filepath.Ext(cleaned))
|
||||
validExtensions := []string{".dump", ".sql", ".gz", ".tar"}
|
||||
|
||||
|
||||
valid := false
|
||||
for _, validExt := range validExtensions {
|
||||
if strings.HasSuffix(cleaned, validExt) {
|
||||
|
||||
@@ -23,20 +23,20 @@ func NewPrivilegeChecker(log logger.Logger) *PrivilegeChecker {
|
||||
// CheckAndWarn checks if running with elevated privileges and warns
|
||||
func (pc *PrivilegeChecker) CheckAndWarn(allowRoot bool) error {
|
||||
isRoot, user := pc.isRunningAsRoot()
|
||||
|
||||
|
||||
if isRoot {
|
||||
pc.log.Warn("⚠️ Running with elevated privileges (root/Administrator)")
|
||||
pc.log.Warn("Security recommendation: Create a dedicated backup user with minimal privileges")
|
||||
|
||||
|
||||
if !allowRoot {
|
||||
return fmt.Errorf("running as root is not recommended, use --allow-root to override")
|
||||
}
|
||||
|
||||
|
||||
pc.log.Warn("Proceeding with root privileges (--allow-root specified)")
|
||||
} else {
|
||||
pc.log.Debug("Running as non-privileged user", "user", user)
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ func (pc *PrivilegeChecker) isRunningAsRoot() (bool, string) {
|
||||
func (pc *PrivilegeChecker) isUnixRoot() (bool, string) {
|
||||
uid := os.Getuid()
|
||||
user := GetCurrentUser()
|
||||
|
||||
|
||||
isRoot := uid == 0 || user == "root"
|
||||
return isRoot, user
|
||||
}
|
||||
@@ -62,10 +62,10 @@ func (pc *PrivilegeChecker) isWindowsAdmin() (bool, string) {
|
||||
// Check if running as Administrator on Windows
|
||||
// This is a simplified check - full implementation would use Windows API
|
||||
user := GetCurrentUser()
|
||||
|
||||
|
||||
// Common admin user patterns on Windows
|
||||
isAdmin := user == "Administrator" || user == "SYSTEM"
|
||||
|
||||
|
||||
return isAdmin, user
|
||||
}
|
||||
|
||||
@@ -89,11 +89,11 @@ func (pc *PrivilegeChecker) GetSecurityRecommendations() []string {
|
||||
"Regularly rotate database passwords",
|
||||
"Monitor audit logs for unauthorized access attempts",
|
||||
}
|
||||
|
||||
|
||||
if runtime.GOOS != "windows" {
|
||||
recommendations = append(recommendations,
|
||||
fmt.Sprintf("Run as non-root user: sudo -u %s dbbackup ...", pc.GetRecommendedUser()))
|
||||
}
|
||||
|
||||
|
||||
return recommendations
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
// go:build !linux
|
||||
//go:build !linux
|
||||
// +build !linux
|
||||
|
||||
package security
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package security
|
||||
@@ -19,7 +20,7 @@ func (rc *ResourceChecker) checkPlatformLimits() (*ResourceLimits, error) {
|
||||
if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rLimit); err == nil {
|
||||
limits.MaxOpenFiles = uint64(rLimit.Cur)
|
||||
rc.log.Debug("Resource limit: max open files", "limit", rLimit.Cur, "max", rLimit.Max)
|
||||
|
||||
|
||||
if rLimit.Cur < 1024 {
|
||||
rc.log.Warn("⚠️ Low file descriptor limit detected",
|
||||
"current", rLimit.Cur,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package security
|
||||
@@ -23,5 +24,3 @@ func (rc *ResourceChecker) checkPlatformLimits() (*ResourceLimits, error) {
|
||||
|
||||
return limits, nil
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -46,13 +46,13 @@ func (rp *RetentionPolicy) CleanupOldBackups(backupDir string) (int, int64, erro
|
||||
}
|
||||
|
||||
if len(archives) <= rp.MinBackups {
|
||||
rp.log.Debug("Keeping all backups (below minimum threshold)",
|
||||
rp.log.Debug("Keeping all backups (below minimum threshold)",
|
||||
"count", len(archives), "min_backups", rp.MinBackups)
|
||||
return 0, 0, nil
|
||||
}
|
||||
|
||||
cutoffTime := time.Now().AddDate(0, 0, -rp.RetentionDays)
|
||||
|
||||
|
||||
// Sort by modification time (oldest first)
|
||||
sort.Slice(archives, func(i, j int) bool {
|
||||
return archives[i].ModTime.Before(archives[j].ModTime)
|
||||
@@ -65,14 +65,14 @@ func (rp *RetentionPolicy) CleanupOldBackups(backupDir string) (int, int64, erro
|
||||
// Keep minimum number of backups
|
||||
remaining := len(archives) - i
|
||||
if remaining <= rp.MinBackups {
|
||||
rp.log.Debug("Stopped cleanup to maintain minimum backups",
|
||||
rp.log.Debug("Stopped cleanup to maintain minimum backups",
|
||||
"remaining", remaining, "min_backups", rp.MinBackups)
|
||||
break
|
||||
}
|
||||
|
||||
// Delete if older than retention period
|
||||
if archive.ModTime.Before(cutoffTime) {
|
||||
rp.log.Info("Removing old backup",
|
||||
rp.log.Info("Removing old backup",
|
||||
"file", filepath.Base(archive.Path),
|
||||
"age_days", int(time.Since(archive.ModTime).Hours()/24),
|
||||
"size_mb", archive.Size/1024/1024)
|
||||
@@ -100,7 +100,7 @@ func (rp *RetentionPolicy) CleanupOldBackups(backupDir string) (int, int64, erro
|
||||
}
|
||||
|
||||
if deletedCount > 0 {
|
||||
rp.log.Info("Cleanup completed",
|
||||
rp.log.Info("Cleanup completed",
|
||||
"deleted_backups", deletedCount,
|
||||
"freed_space_mb", freedSpace/1024/1024,
|
||||
"retention_days", rp.RetentionDays)
|
||||
@@ -124,7 +124,7 @@ func (rp *RetentionPolicy) scanBackupArchives(backupDir string) ([]ArchiveInfo,
|
||||
}
|
||||
|
||||
name := entry.Name()
|
||||
|
||||
|
||||
// Skip non-backup files
|
||||
if !isBackupArchive(name) {
|
||||
continue
|
||||
@@ -161,7 +161,7 @@ func isBackupArchive(name string) bool {
|
||||
// extractDatabaseName extracts database name from archive filename
|
||||
func extractDatabaseName(filename string) string {
|
||||
base := filepath.Base(filename)
|
||||
|
||||
|
||||
// Remove extensions
|
||||
for {
|
||||
oldBase := base
|
||||
@@ -170,7 +170,7 @@ func extractDatabaseName(filename string) string {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Remove timestamp patterns
|
||||
if len(base) > 20 {
|
||||
// Typically: db_name_20240101_120000
|
||||
@@ -184,7 +184,7 @@ func extractDatabaseName(filename string) string {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return base
|
||||
}
|
||||
|
||||
|
||||
@@ -171,9 +171,9 @@ func (m *Manager) Setup() error {
|
||||
|
||||
// Log current swap status
|
||||
if total, used, free, err := m.GetCurrentSwap(); err == nil {
|
||||
m.log.Info("Swap status after setup",
|
||||
"total_mb", total,
|
||||
"used_mb", used,
|
||||
m.log.Info("Swap status after setup",
|
||||
"total_mb", total,
|
||||
"used_mb", used,
|
||||
"free_mb", free,
|
||||
"added_gb", m.sizeGB)
|
||||
}
|
||||
|
||||
@@ -41,13 +41,13 @@ var (
|
||||
|
||||
// ArchiveInfo holds information about a backup archive
|
||||
type ArchiveInfo struct {
|
||||
Name string
|
||||
Path string
|
||||
Format restore.ArchiveFormat
|
||||
Size int64
|
||||
Modified time.Time
|
||||
DatabaseName string
|
||||
Valid bool
|
||||
Name string
|
||||
Path string
|
||||
Format restore.ArchiveFormat
|
||||
Size int64
|
||||
Modified time.Time
|
||||
DatabaseName string
|
||||
Valid bool
|
||||
ValidationMsg string
|
||||
}
|
||||
|
||||
@@ -132,13 +132,13 @@ func loadArchives(cfg *config.Config, log logger.Logger) tea.Cmd {
|
||||
}
|
||||
|
||||
archives = append(archives, ArchiveInfo{
|
||||
Name: name,
|
||||
Path: fullPath,
|
||||
Format: format,
|
||||
Size: info.Size(),
|
||||
Modified: info.ModTime(),
|
||||
DatabaseName: dbName,
|
||||
Valid: valid,
|
||||
Name: name,
|
||||
Path: fullPath,
|
||||
Format: format,
|
||||
Size: info.Size(),
|
||||
Modified: info.ModTime(),
|
||||
DatabaseName: dbName,
|
||||
Valid: valid,
|
||||
ValidationMsg: validationMsg,
|
||||
})
|
||||
}
|
||||
@@ -196,13 +196,13 @@ func (m ArchiveBrowserModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
case "enter", " ":
|
||||
if len(m.archives) > 0 && m.cursor < len(m.archives) {
|
||||
selected := m.archives[m.cursor]
|
||||
|
||||
|
||||
// Validate selection based on mode
|
||||
if m.mode == "restore-cluster" && !selected.Format.IsClusterBackup() {
|
||||
m.message = errorStyle.Render("❌ Please select a cluster backup (.tar.gz)")
|
||||
return m, nil
|
||||
}
|
||||
|
||||
|
||||
if m.mode == "restore-single" && selected.Format.IsClusterBackup() {
|
||||
m.message = errorStyle.Render("❌ Please select a single database backup")
|
||||
return m, nil
|
||||
@@ -239,7 +239,7 @@ func (m ArchiveBrowserModel) View() string {
|
||||
} else if m.mode == "restore-cluster" {
|
||||
title = "📦 Select Archive to Restore (Cluster)"
|
||||
}
|
||||
|
||||
|
||||
s.WriteString(titleStyle.Render(title))
|
||||
s.WriteString("\n\n")
|
||||
|
||||
|
||||
@@ -78,10 +78,10 @@ type backupCompleteMsg struct {
|
||||
|
||||
func executeBackupWithTUIProgress(parentCtx context.Context, cfg *config.Config, log logger.Logger, backupType, dbName string, ratio int) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
// Use configurable cluster timeout (minutes) from config; default set in config.New()
|
||||
// Use parent context to inherit cancellation from TUI
|
||||
clusterTimeout := time.Duration(cfg.ClusterTimeoutMinutes) * time.Minute
|
||||
ctx, cancel := context.WithTimeout(parentCtx, clusterTimeout)
|
||||
// Use configurable cluster timeout (minutes) from config; default set in config.New()
|
||||
// Use parent context to inherit cancellation from TUI
|
||||
clusterTimeout := time.Duration(cfg.ClusterTimeoutMinutes) * time.Minute
|
||||
ctx, cancel := context.WithTimeout(parentCtx, clusterTimeout)
|
||||
defer cancel()
|
||||
|
||||
start := time.Now()
|
||||
@@ -151,10 +151,10 @@ func (m BackupExecutionModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if !m.done {
|
||||
// Increment spinner frame for smooth animation
|
||||
m.spinnerFrame = (m.spinnerFrame + 1) % len(spinnerFrames)
|
||||
|
||||
|
||||
// Update status based on elapsed time to show progress
|
||||
elapsedSec := int(time.Since(m.startTime).Seconds())
|
||||
|
||||
|
||||
if elapsedSec < 2 {
|
||||
m.status = "Initializing backup..."
|
||||
} else if elapsedSec < 5 {
|
||||
@@ -180,7 +180,7 @@ func (m BackupExecutionModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
m.status = fmt.Sprintf("Backing up database '%s'...", m.databaseName)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return m, backupTickCmd()
|
||||
}
|
||||
return m, nil
|
||||
@@ -239,7 +239,7 @@ func (m BackupExecutionModel) View() string {
|
||||
s.WriteString(fmt.Sprintf(" %s %s\n", spinnerFrames[m.spinnerFrame], m.status))
|
||||
} else {
|
||||
s.WriteString(fmt.Sprintf(" %s\n\n", m.status))
|
||||
|
||||
|
||||
if m.err != nil {
|
||||
s.WriteString(fmt.Sprintf(" ❌ Error: %v\n", m.err))
|
||||
} else if m.result != "" {
|
||||
|
||||
@@ -52,13 +52,13 @@ func (m BackupManagerModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
return m, nil
|
||||
}
|
||||
m.archives = msg.archives
|
||||
|
||||
|
||||
// Calculate total size
|
||||
m.totalSize = 0
|
||||
for _, archive := range m.archives {
|
||||
m.totalSize += archive.Size
|
||||
}
|
||||
|
||||
|
||||
// Get free space (simplified - just show message)
|
||||
m.message = fmt.Sprintf("Loaded %d archive(s)", len(m.archives))
|
||||
return m, nil
|
||||
|
||||
@@ -84,7 +84,7 @@ func (m DatabaseSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
m.databases = []string{"Error loading databases"}
|
||||
} else {
|
||||
m.databases = msg.databases
|
||||
|
||||
|
||||
// Auto-select database if specified
|
||||
if m.config.TUIAutoDatabase != "" {
|
||||
for i, db := range m.databases {
|
||||
@@ -92,7 +92,7 @@ func (m DatabaseSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
m.cursor = i
|
||||
m.selected = db
|
||||
m.logger.Info("Auto-selected database", "database", db)
|
||||
|
||||
|
||||
// If sample backup, ask for ratio (or auto-use default)
|
||||
if m.backupType == "sample" {
|
||||
if m.config.TUIDryRun {
|
||||
@@ -107,7 +107,7 @@ func (m DatabaseSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
ValidateInt(1, 100))
|
||||
return inputModel, nil
|
||||
}
|
||||
|
||||
|
||||
// For single backup, go directly to execution
|
||||
executor := NewBackupExecution(m.config, m.logger, m.parent, m.ctx, m.backupType, m.selected, 0)
|
||||
return executor, executor.Init()
|
||||
@@ -136,7 +136,7 @@ func (m DatabaseSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
case "enter":
|
||||
if !m.loading && m.err == nil && len(m.databases) > 0 {
|
||||
m.selected = m.databases[m.cursor]
|
||||
|
||||
|
||||
// If sample backup, ask for ratio first
|
||||
if m.backupType == "sample" {
|
||||
inputModel := NewInputModel(m.config, m.logger, m,
|
||||
@@ -146,7 +146,7 @@ func (m DatabaseSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
ValidateInt(1, 100))
|
||||
return inputModel, nil
|
||||
}
|
||||
|
||||
|
||||
// For single backup, go directly to execution
|
||||
executor := NewBackupExecution(m.config, m.logger, m.parent, m.ctx, m.backupType, m.selected, 0)
|
||||
return executor, executor.Init()
|
||||
|
||||
@@ -111,7 +111,7 @@ func (db *DirectoryBrowser) Render() string {
|
||||
}
|
||||
|
||||
var lines []string
|
||||
|
||||
|
||||
// Header
|
||||
lines = append(lines, fmt.Sprintf(" Current: %s", db.CurrentPath))
|
||||
lines = append(lines, fmt.Sprintf(" Found %d directories (cursor: %d)", len(db.items), db.cursor))
|
||||
@@ -121,7 +121,7 @@ func (db *DirectoryBrowser) Render() string {
|
||||
maxItems := 5 // Show max 5 items to keep it compact
|
||||
start := 0
|
||||
end := len(db.items)
|
||||
|
||||
|
||||
if len(db.items) > maxItems {
|
||||
// Center the cursor in the view
|
||||
start = db.cursor - maxItems/2
|
||||
@@ -144,14 +144,14 @@ func (db *DirectoryBrowser) Render() string {
|
||||
if i == db.cursor {
|
||||
prefix = " >> "
|
||||
}
|
||||
|
||||
|
||||
displayName := item
|
||||
if item == ".." {
|
||||
displayName = "../ (parent directory)"
|
||||
} else if item != "[Error reading directory]" {
|
||||
displayName = item + "/"
|
||||
}
|
||||
|
||||
|
||||
lines = append(lines, prefix+displayName)
|
||||
}
|
||||
|
||||
@@ -164,4 +164,4 @@ func (db *DirectoryBrowser) Render() string {
|
||||
lines = append(lines, " ↑/↓: Navigate | Enter/→: Open | ←: Parent | Space: Select | Esc: Cancel")
|
||||
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,12 +14,12 @@ import (
|
||||
|
||||
// DirectoryPicker is a simple, fast directory and file picker
|
||||
type DirectoryPicker struct {
|
||||
currentPath string
|
||||
items []FileItem
|
||||
cursor int
|
||||
callback func(string)
|
||||
allowFiles bool // Allow file selection for restore operations
|
||||
styles DirectoryPickerStyles
|
||||
currentPath string
|
||||
items []FileItem
|
||||
cursor int
|
||||
callback func(string)
|
||||
allowFiles bool // Allow file selection for restore operations
|
||||
styles DirectoryPickerStyles
|
||||
}
|
||||
|
||||
type FileItem struct {
|
||||
@@ -98,26 +98,26 @@ func (dp *DirectoryPicker) loadItems() {
|
||||
// Collect directories and optionally files
|
||||
var dirs []FileItem
|
||||
var files []FileItem
|
||||
|
||||
|
||||
for _, entry := range entries {
|
||||
if strings.HasPrefix(entry.Name(), ".") {
|
||||
continue // Skip hidden files
|
||||
}
|
||||
|
||||
|
||||
item := FileItem{
|
||||
Name: entry.Name(),
|
||||
IsDir: entry.IsDir(),
|
||||
Path: filepath.Join(dp.currentPath, entry.Name()),
|
||||
}
|
||||
|
||||
|
||||
if entry.IsDir() {
|
||||
dirs = append(dirs, item)
|
||||
} else if dp.allowFiles {
|
||||
// Only include backup-related files
|
||||
if strings.HasSuffix(entry.Name(), ".sql") ||
|
||||
strings.HasSuffix(entry.Name(), ".dump") ||
|
||||
strings.HasSuffix(entry.Name(), ".gz") ||
|
||||
strings.HasSuffix(entry.Name(), ".tar") {
|
||||
if strings.HasSuffix(entry.Name(), ".sql") ||
|
||||
strings.HasSuffix(entry.Name(), ".dump") ||
|
||||
strings.HasSuffix(entry.Name(), ".gz") ||
|
||||
strings.HasSuffix(entry.Name(), ".tar") {
|
||||
files = append(files, item)
|
||||
}
|
||||
}
|
||||
@@ -242,4 +242,4 @@ func (dp *DirectoryPicker) View() string {
|
||||
content.WriteString(dp.styles.Help.Render(help))
|
||||
|
||||
return dp.styles.Container.Render(content.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,14 +37,14 @@ func NewHistoryView(cfg *config.Config, log logger.Logger, parent tea.Model) His
|
||||
if lastIndex < 0 {
|
||||
lastIndex = 0
|
||||
}
|
||||
|
||||
|
||||
// Calculate initial viewport to show the last item
|
||||
maxVisible := 15
|
||||
viewOffset := lastIndex - maxVisible + 1
|
||||
if viewOffset < 0 {
|
||||
viewOffset = 0
|
||||
}
|
||||
|
||||
|
||||
return HistoryViewModel{
|
||||
config: cfg,
|
||||
logger: log,
|
||||
@@ -112,7 +112,7 @@ func (m HistoryViewModel) Init() tea.Cmd {
|
||||
|
||||
func (m HistoryViewModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
maxVisible := 15 // Show max 15 items at once
|
||||
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
switch msg.String() {
|
||||
@@ -136,7 +136,7 @@ func (m HistoryViewModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
m.viewOffset = m.cursor - maxVisible + 1
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
case "pgup":
|
||||
// Page up - jump by maxVisible items
|
||||
m.cursor -= maxVisible
|
||||
@@ -147,7 +147,7 @@ func (m HistoryViewModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if m.cursor < m.viewOffset {
|
||||
m.viewOffset = m.cursor
|
||||
}
|
||||
|
||||
|
||||
case "pgdown":
|
||||
// Page down - jump by maxVisible items
|
||||
m.cursor += maxVisible
|
||||
@@ -158,12 +158,12 @@ func (m HistoryViewModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if m.cursor >= m.viewOffset+maxVisible {
|
||||
m.viewOffset = m.cursor - maxVisible + 1
|
||||
}
|
||||
|
||||
|
||||
case "home", "g":
|
||||
// Jump to first item
|
||||
m.cursor = 0
|
||||
m.viewOffset = 0
|
||||
|
||||
|
||||
case "end", "G":
|
||||
// Jump to last item
|
||||
m.cursor = len(m.history) - 1
|
||||
@@ -187,15 +187,15 @@ func (m HistoryViewModel) View() string {
|
||||
s.WriteString("📭 No backup history found\n\n")
|
||||
} else {
|
||||
maxVisible := 15 // Show max 15 items at once
|
||||
|
||||
|
||||
// Calculate visible range
|
||||
start := m.viewOffset
|
||||
end := start + maxVisible
|
||||
if end > len(m.history) {
|
||||
end = len(m.history)
|
||||
}
|
||||
|
||||
s.WriteString(fmt.Sprintf("Found %d backup operations (Viewing %d/%d):\n\n",
|
||||
|
||||
s.WriteString(fmt.Sprintf("Found %d backup operations (Viewing %d/%d):\n\n",
|
||||
len(m.history), m.cursor+1, len(m.history)))
|
||||
|
||||
// Show scroll indicators
|
||||
@@ -219,12 +219,12 @@ func (m HistoryViewModel) View() string {
|
||||
s.WriteString(fmt.Sprintf(" %s\n", line))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Show scroll indicator if more entries below
|
||||
if end < len(m.history) {
|
||||
s.WriteString(fmt.Sprintf(" ▼ %d more entries below...\n", len(m.history)-end))
|
||||
}
|
||||
|
||||
|
||||
s.WriteString("\n")
|
||||
}
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ func (m InputModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
}
|
||||
m.done = true
|
||||
|
||||
|
||||
// If this is from database selector, execute backup with ratio
|
||||
if selector, ok := m.parent.(DatabaseSelectorModel); ok {
|
||||
ratio, _ := strconv.Atoi(m.value)
|
||||
|
||||
@@ -53,14 +53,14 @@ type dbTypeOption struct {
|
||||
|
||||
// MenuModel represents the simple menu state
|
||||
type MenuModel struct {
|
||||
choices []string
|
||||
cursor int
|
||||
config *config.Config
|
||||
logger logger.Logger
|
||||
quitting bool
|
||||
message string
|
||||
dbTypes []dbTypeOption
|
||||
dbTypeCursor int
|
||||
choices []string
|
||||
cursor int
|
||||
config *config.Config
|
||||
logger logger.Logger
|
||||
quitting bool
|
||||
message string
|
||||
dbTypes []dbTypeOption
|
||||
dbTypeCursor int
|
||||
|
||||
// Background operations
|
||||
ctx context.Context
|
||||
@@ -133,7 +133,7 @@ func (m MenuModel) Init() tea.Cmd {
|
||||
// Auto-select menu option if specified
|
||||
if m.config.TUIAutoSelect >= 0 && m.config.TUIAutoSelect < len(m.choices) {
|
||||
m.logger.Info("TUI Auto-select enabled", "option", m.config.TUIAutoSelect, "label", m.choices[m.config.TUIAutoSelect])
|
||||
|
||||
|
||||
// Return command to trigger auto-selection
|
||||
return func() tea.Msg {
|
||||
return autoSelectMsg{}
|
||||
@@ -150,7 +150,7 @@ func (m MenuModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if m.config.TUIAutoSelect >= 0 && m.config.TUIAutoSelect < len(m.choices) {
|
||||
m.cursor = m.config.TUIAutoSelect
|
||||
m.logger.Info("Auto-selecting option", "cursor", m.cursor, "choice", m.choices[m.cursor])
|
||||
|
||||
|
||||
// Trigger the selection based on cursor position
|
||||
switch m.cursor {
|
||||
case 0: // Single Database Backup
|
||||
@@ -184,7 +184,7 @@ func (m MenuModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
|
||||
|
||||
case tea.KeyMsg:
|
||||
switch msg.String() {
|
||||
case "ctrl+c", "q":
|
||||
@@ -192,13 +192,13 @@ func (m MenuModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if m.cancel != nil {
|
||||
m.cancel()
|
||||
}
|
||||
|
||||
|
||||
// Clean up any orphaned processes before exit
|
||||
m.logger.Info("Cleaning up processes before exit")
|
||||
if err := cleanup.KillOrphanedProcesses(m.logger); err != nil {
|
||||
m.logger.Warn("Failed to clean up all processes", "error", err)
|
||||
}
|
||||
|
||||
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
|
||||
|
||||
@@ -269,11 +269,11 @@ func (s *SilentOperation) Fail(message string, args ...any) {}
|
||||
// SilentProgressIndicator implements progress.Indicator but doesn't output anything
|
||||
type SilentProgressIndicator struct{}
|
||||
|
||||
func (s *SilentProgressIndicator) Start(message string) {}
|
||||
func (s *SilentProgressIndicator) Update(message string) {}
|
||||
func (s *SilentProgressIndicator) Complete(message string) {}
|
||||
func (s *SilentProgressIndicator) Fail(message string) {}
|
||||
func (s *SilentProgressIndicator) Stop() {}
|
||||
func (s *SilentProgressIndicator) Start(message string) {}
|
||||
func (s *SilentProgressIndicator) Update(message string) {}
|
||||
func (s *SilentProgressIndicator) Complete(message string) {}
|
||||
func (s *SilentProgressIndicator) Fail(message string) {}
|
||||
func (s *SilentProgressIndicator) Stop() {}
|
||||
func (s *SilentProgressIndicator) SetEstimator(estimator *progress.ETAEstimator) {}
|
||||
|
||||
// RunBackupInTUI runs a backup operation with TUI-compatible progress reporting
|
||||
|
||||
@@ -20,54 +20,54 @@ var spinnerFrames = []string{"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "
|
||||
|
||||
// RestoreExecutionModel handles restore execution with progress
|
||||
type RestoreExecutionModel struct {
|
||||
config *config.Config
|
||||
logger logger.Logger
|
||||
parent tea.Model
|
||||
ctx context.Context
|
||||
archive ArchiveInfo
|
||||
targetDB string
|
||||
cleanFirst bool
|
||||
createIfMissing bool
|
||||
restoreType string
|
||||
config *config.Config
|
||||
logger logger.Logger
|
||||
parent tea.Model
|
||||
ctx context.Context
|
||||
archive ArchiveInfo
|
||||
targetDB string
|
||||
cleanFirst bool
|
||||
createIfMissing bool
|
||||
restoreType string
|
||||
cleanClusterFirst bool // Drop all user databases before cluster restore
|
||||
existingDBs []string // List of databases to drop
|
||||
|
||||
existingDBs []string // List of databases to drop
|
||||
|
||||
// Progress tracking
|
||||
status string
|
||||
phase string
|
||||
progress int
|
||||
details []string
|
||||
startTime time.Time
|
||||
spinnerFrame int
|
||||
status string
|
||||
phase string
|
||||
progress int
|
||||
details []string
|
||||
startTime time.Time
|
||||
spinnerFrame int
|
||||
spinnerFrames []string
|
||||
|
||||
|
||||
// Results
|
||||
done bool
|
||||
err error
|
||||
result string
|
||||
elapsed time.Duration
|
||||
done bool
|
||||
err error
|
||||
result string
|
||||
elapsed time.Duration
|
||||
}
|
||||
|
||||
// NewRestoreExecution creates a new restore execution model
|
||||
func NewRestoreExecution(cfg *config.Config, log logger.Logger, parent tea.Model, ctx context.Context, archive ArchiveInfo, targetDB string, cleanFirst, createIfMissing bool, restoreType string, cleanClusterFirst bool, existingDBs []string) RestoreExecutionModel {
|
||||
return RestoreExecutionModel{
|
||||
config: cfg,
|
||||
logger: log,
|
||||
parent: parent,
|
||||
ctx: ctx,
|
||||
archive: archive,
|
||||
targetDB: targetDB,
|
||||
cleanFirst: cleanFirst,
|
||||
createIfMissing: createIfMissing,
|
||||
restoreType: restoreType,
|
||||
config: cfg,
|
||||
logger: log,
|
||||
parent: parent,
|
||||
ctx: ctx,
|
||||
archive: archive,
|
||||
targetDB: targetDB,
|
||||
cleanFirst: cleanFirst,
|
||||
createIfMissing: createIfMissing,
|
||||
restoreType: restoreType,
|
||||
cleanClusterFirst: cleanClusterFirst,
|
||||
existingDBs: existingDBs,
|
||||
status: "Initializing...",
|
||||
phase: "Starting",
|
||||
startTime: time.Now(),
|
||||
details: []string{},
|
||||
spinnerFrames: spinnerFrames, // Use package-level constant
|
||||
spinnerFrame: 0,
|
||||
existingDBs: existingDBs,
|
||||
status: "Initializing...",
|
||||
phase: "Starting",
|
||||
startTime: time.Now(),
|
||||
details: []string{},
|
||||
spinnerFrames: spinnerFrames, // Use package-level constant
|
||||
spinnerFrame: 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -123,7 +123,7 @@ func executeRestoreWithTUIProgress(parentCtx context.Context, cfg *config.Config
|
||||
// STEP 1: Clean cluster if requested (drop all existing user databases)
|
||||
if restoreType == "restore-cluster" && cleanClusterFirst && len(existingDBs) > 0 {
|
||||
log.Info("Dropping existing user databases before cluster restore", "count", len(existingDBs))
|
||||
|
||||
|
||||
// Drop databases using command-line psql (no connection required)
|
||||
// This matches how cluster restore works - uses CLI tools, not database connections
|
||||
droppedCount := 0
|
||||
@@ -139,13 +139,13 @@ func executeRestoreWithTUIProgress(parentCtx context.Context, cfg *config.Config
|
||||
}
|
||||
dropCancel() // Clean up context
|
||||
}
|
||||
|
||||
|
||||
log.Info("Cluster cleanup completed", "dropped", droppedCount, "total", len(existingDBs))
|
||||
}
|
||||
|
||||
// STEP 2: Create restore engine with silent progress (no stdout interference with TUI)
|
||||
engine := restore.NewSilent(cfg, log, dbClient)
|
||||
|
||||
|
||||
// Set up progress callback (but it won't work in goroutine - progress is already sent via logs)
|
||||
// The TUI will just use spinner animation to show activity
|
||||
|
||||
@@ -186,11 +186,11 @@ func (m RestoreExecutionModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if !m.done {
|
||||
m.spinnerFrame = (m.spinnerFrame + 1) % len(m.spinnerFrames)
|
||||
m.elapsed = time.Since(m.startTime)
|
||||
|
||||
|
||||
// Update status based on elapsed time to show progress
|
||||
// This provides visual feedback even though we don't have real-time progress
|
||||
elapsedSec := int(m.elapsed.Seconds())
|
||||
|
||||
|
||||
if elapsedSec < 2 {
|
||||
m.status = "Initializing restore..."
|
||||
m.phase = "Starting"
|
||||
@@ -222,7 +222,7 @@ func (m RestoreExecutionModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
m.phase = "Restore"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return m, restoreTickCmd()
|
||||
}
|
||||
return m, nil
|
||||
@@ -245,7 +245,7 @@ func (m RestoreExecutionModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
m.err = msg.err
|
||||
m.result = msg.result
|
||||
m.elapsed = msg.elapsed
|
||||
|
||||
|
||||
if m.err == nil {
|
||||
m.status = "Restore completed successfully"
|
||||
m.phase = "Done"
|
||||
@@ -311,7 +311,7 @@ func (m RestoreExecutionModel) View() string {
|
||||
} else {
|
||||
// Show progress
|
||||
s.WriteString(fmt.Sprintf("Phase: %s\n", m.phase))
|
||||
|
||||
|
||||
// Show status with rotating spinner (unified indicator for all operations)
|
||||
spinner := m.spinnerFrames[m.spinnerFrame]
|
||||
s.WriteString(fmt.Sprintf("Status: %s %s\n", spinner, m.status))
|
||||
@@ -339,10 +339,10 @@ func (m RestoreExecutionModel) View() string {
|
||||
func renderProgressBar(percent int) string {
|
||||
width := 40
|
||||
filled := (percent * width) / 100
|
||||
|
||||
|
||||
bar := strings.Repeat("█", filled)
|
||||
empty := strings.Repeat("░", width-filled)
|
||||
|
||||
|
||||
return successStyle.Render(bar) + infoStyle.Render(empty)
|
||||
}
|
||||
|
||||
@@ -370,24 +370,23 @@ func dropDatabaseCLI(ctx context.Context, cfg *config.Config, dbName string) err
|
||||
"-d", "postgres", // Connect to postgres maintenance DB
|
||||
"-c", fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbName),
|
||||
}
|
||||
|
||||
|
||||
// Only add -h flag if host is not localhost (to use Unix socket for peer auth)
|
||||
if cfg.Host != "localhost" && cfg.Host != "127.0.0.1" && cfg.Host != "" {
|
||||
args = append([]string{"-h", cfg.Host}, args...)
|
||||
}
|
||||
|
||||
|
||||
cmd := exec.CommandContext(ctx, "psql", args...)
|
||||
|
||||
|
||||
// Set password if provided
|
||||
if cfg.Password != "" {
|
||||
cmd.Env = append(cmd.Environ(), fmt.Sprintf("PGPASSWORD=%s", cfg.Password))
|
||||
}
|
||||
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to drop database %s: %w\nOutput: %s", dbName, err, string(output))
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -43,22 +43,22 @@ type SafetyCheck struct {
|
||||
|
||||
// RestorePreviewModel shows restore preview and safety checks
|
||||
type RestorePreviewModel struct {
|
||||
config *config.Config
|
||||
logger logger.Logger
|
||||
parent tea.Model
|
||||
ctx context.Context
|
||||
archive ArchiveInfo
|
||||
mode string
|
||||
targetDB string
|
||||
cleanFirst bool
|
||||
createIfMissing bool
|
||||
cleanClusterFirst bool // For cluster restore: drop all user databases first
|
||||
existingDBCount int // Number of existing user databases
|
||||
existingDBs []string // List of existing user databases
|
||||
safetyChecks []SafetyCheck
|
||||
checking bool
|
||||
canProceed bool
|
||||
message string
|
||||
config *config.Config
|
||||
logger logger.Logger
|
||||
parent tea.Model
|
||||
ctx context.Context
|
||||
archive ArchiveInfo
|
||||
mode string
|
||||
targetDB string
|
||||
cleanFirst bool
|
||||
createIfMissing bool
|
||||
cleanClusterFirst bool // For cluster restore: drop all user databases first
|
||||
existingDBCount int // Number of existing user databases
|
||||
existingDBs []string // List of existing user databases
|
||||
safetyChecks []SafetyCheck
|
||||
checking bool
|
||||
canProceed bool
|
||||
message string
|
||||
}
|
||||
|
||||
// NewRestorePreview creates a new restore preview
|
||||
@@ -70,16 +70,16 @@ func NewRestorePreview(cfg *config.Config, log logger.Logger, parent tea.Model,
|
||||
}
|
||||
|
||||
return RestorePreviewModel{
|
||||
config: cfg,
|
||||
logger: log,
|
||||
parent: parent,
|
||||
ctx: ctx,
|
||||
archive: archive,
|
||||
mode: mode,
|
||||
targetDB: targetDB,
|
||||
cleanFirst: false,
|
||||
config: cfg,
|
||||
logger: log,
|
||||
parent: parent,
|
||||
ctx: ctx,
|
||||
archive: archive,
|
||||
mode: mode,
|
||||
targetDB: targetDB,
|
||||
cleanFirst: false,
|
||||
createIfMissing: true,
|
||||
checking: true,
|
||||
checking: true,
|
||||
safetyChecks: []SafetyCheck{
|
||||
{Name: "Archive integrity", Status: "pending", Critical: true},
|
||||
{Name: "Disk space", Status: "pending", Critical: true},
|
||||
@@ -156,7 +156,7 @@ func runSafetyChecks(cfg *config.Config, log logger.Logger, archive ArchiveInfo,
|
||||
// 4. Target database check (skip for cluster restores)
|
||||
existingDBCount := 0
|
||||
existingDBs := []string{}
|
||||
|
||||
|
||||
if !archive.Format.IsClusterBackup() {
|
||||
check = SafetyCheck{Name: "Target database", Status: "checking", Critical: false}
|
||||
exists, err := safety.CheckDatabaseExists(ctx, targetDB)
|
||||
@@ -174,7 +174,7 @@ func runSafetyChecks(cfg *config.Config, log logger.Logger, archive ArchiveInfo,
|
||||
} else {
|
||||
// For cluster restores, detect existing user databases
|
||||
check = SafetyCheck{Name: "Existing databases", Status: "checking", Critical: false}
|
||||
|
||||
|
||||
// Get list of existing user databases (exclude templates and system DBs)
|
||||
dbList, err := safety.ListUserDatabases(ctx)
|
||||
if err != nil {
|
||||
@@ -183,7 +183,7 @@ func runSafetyChecks(cfg *config.Config, log logger.Logger, archive ArchiveInfo,
|
||||
} else {
|
||||
existingDBCount = len(dbList)
|
||||
existingDBs = dbList
|
||||
|
||||
|
||||
if existingDBCount > 0 {
|
||||
check.Status = "warning"
|
||||
check.Message = fmt.Sprintf("Found %d existing user database(s) - can be cleaned before restore", existingDBCount)
|
||||
@@ -288,13 +288,13 @@ func (m RestorePreviewModel) View() string {
|
||||
s.WriteString("\n")
|
||||
s.WriteString(fmt.Sprintf(" Database: %s\n", m.targetDB))
|
||||
s.WriteString(fmt.Sprintf(" Host: %s:%d\n", m.config.Host, m.config.Port))
|
||||
|
||||
|
||||
cleanIcon := "✗"
|
||||
if m.cleanFirst {
|
||||
cleanIcon = "✓"
|
||||
}
|
||||
s.WriteString(fmt.Sprintf(" Clean First: %s %v\n", cleanIcon, m.cleanFirst))
|
||||
|
||||
|
||||
createIcon := "✗"
|
||||
if m.createIfMissing {
|
||||
createIcon = "✓"
|
||||
@@ -305,10 +305,10 @@ func (m RestorePreviewModel) View() string {
|
||||
s.WriteString(archiveHeaderStyle.Render("🎯 Cluster Restore Options"))
|
||||
s.WriteString("\n")
|
||||
s.WriteString(fmt.Sprintf(" Host: %s:%d\n", m.config.Host, m.config.Port))
|
||||
|
||||
|
||||
if m.existingDBCount > 0 {
|
||||
s.WriteString(fmt.Sprintf(" Existing Databases: %d found\n", m.existingDBCount))
|
||||
|
||||
|
||||
// Show first few database names
|
||||
maxShow := 5
|
||||
for i, db := range m.existingDBs {
|
||||
@@ -319,7 +319,7 @@ func (m RestorePreviewModel) View() string {
|
||||
}
|
||||
s.WriteString(fmt.Sprintf(" - %s\n", db))
|
||||
}
|
||||
|
||||
|
||||
cleanIcon := "✗"
|
||||
cleanStyle := infoStyle
|
||||
if m.cleanClusterFirst {
|
||||
@@ -344,7 +344,7 @@ func (m RestorePreviewModel) View() string {
|
||||
for _, check := range m.safetyChecks {
|
||||
icon := "○"
|
||||
style := checkPendingStyle
|
||||
|
||||
|
||||
switch check.Status {
|
||||
case "passed":
|
||||
icon = "✓"
|
||||
|
||||
@@ -75,7 +75,7 @@ func NewSettingsModel(cfg *config.Config, log logger.Logger, parent tea.Model) S
|
||||
}
|
||||
nextIdx := (currentIdx + 1) % len(workloads)
|
||||
c.CPUWorkloadType = workloads[nextIdx]
|
||||
|
||||
|
||||
// Recalculate Jobs and DumpJobs based on workload type
|
||||
if c.CPUInfo != nil && c.AutoDetectCores {
|
||||
switch c.CPUWorkloadType {
|
||||
@@ -329,7 +329,7 @@ func NewSettingsModel(cfg *config.Config, log logger.Logger, parent tea.Model) S
|
||||
{
|
||||
Key: "cloud_access_key",
|
||||
DisplayName: "Cloud Access Key",
|
||||
Value: func(c *config.Config) string {
|
||||
Value: func(c *config.Config) string {
|
||||
if c.CloudAccessKey != "" {
|
||||
return "***" + c.CloudAccessKey[len(c.CloudAccessKey)-4:]
|
||||
}
|
||||
@@ -624,7 +624,7 @@ func (m SettingsModel) saveSettings() (tea.Model, tea.Cmd) {
|
||||
// cycleDatabaseType cycles through database type options
|
||||
func (m SettingsModel) cycleDatabaseType() (tea.Model, tea.Cmd) {
|
||||
dbTypes := []string{"postgres", "mysql", "mariadb"}
|
||||
|
||||
|
||||
// Find current index
|
||||
currentIdx := 0
|
||||
for i, dbType := range dbTypes {
|
||||
@@ -633,17 +633,17 @@ func (m SettingsModel) cycleDatabaseType() (tea.Model, tea.Cmd) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Cycle to next
|
||||
nextIdx := (currentIdx + 1) % len(dbTypes)
|
||||
newType := dbTypes[nextIdx]
|
||||
|
||||
|
||||
// Update config
|
||||
if err := m.config.SetDatabaseType(newType); err != nil {
|
||||
m.message = errorStyle.Render(fmt.Sprintf("❌ Failed to set database type: %s", err.Error()))
|
||||
return m, nil
|
||||
}
|
||||
|
||||
|
||||
m.message = successStyle.Render(fmt.Sprintf("✅ Database type set to %s", m.config.DisplayDatabaseType()))
|
||||
return m, nil
|
||||
}
|
||||
@@ -726,7 +726,7 @@ func (m SettingsModel) View() string {
|
||||
fmt.Sprintf("Compression: Level %d", m.config.CompressionLevel),
|
||||
fmt.Sprintf("Jobs: %d parallel, %d dump", m.config.Jobs, m.config.DumpJobs),
|
||||
}
|
||||
|
||||
|
||||
if m.config.CloudEnabled {
|
||||
cloudInfo := fmt.Sprintf("Cloud: %s (%s)", m.config.CloudProvider, m.config.CloudBucket)
|
||||
if m.config.CloudAutoUpload {
|
||||
|
||||
@@ -9,14 +9,14 @@ import (
|
||||
|
||||
// Result represents the outcome of a verification operation
|
||||
type Result struct {
|
||||
Valid bool
|
||||
BackupFile string
|
||||
ExpectedSHA256 string
|
||||
Valid bool
|
||||
BackupFile string
|
||||
ExpectedSHA256 string
|
||||
CalculatedSHA256 string
|
||||
SizeMatch bool
|
||||
FileExists bool
|
||||
MetadataExists bool
|
||||
Error error
|
||||
SizeMatch bool
|
||||
FileExists bool
|
||||
MetadataExists bool
|
||||
Error error
|
||||
}
|
||||
|
||||
// Verify checks the integrity of a backup file
|
||||
@@ -47,7 +47,7 @@ func Verify(backupFile string) (*Result, error) {
|
||||
// Check size match
|
||||
if info.Size() != meta.SizeBytes {
|
||||
result.SizeMatch = false
|
||||
result.Error = fmt.Errorf("size mismatch: expected %d bytes, got %d bytes",
|
||||
result.Error = fmt.Errorf("size mismatch: expected %d bytes, got %d bytes",
|
||||
meta.SizeBytes, info.Size())
|
||||
return result, nil
|
||||
}
|
||||
@@ -64,7 +64,7 @@ func Verify(backupFile string) (*Result, error) {
|
||||
// Compare checksums
|
||||
if actualSHA256 != meta.SHA256 {
|
||||
result.Valid = false
|
||||
result.Error = fmt.Errorf("checksum mismatch: expected %s, got %s",
|
||||
result.Error = fmt.Errorf("checksum mismatch: expected %s, got %s",
|
||||
meta.SHA256, actualSHA256)
|
||||
return result, nil
|
||||
}
|
||||
@@ -77,7 +77,7 @@ func Verify(backupFile string) (*Result, error) {
|
||||
// VerifyMultiple verifies multiple backup files
|
||||
func VerifyMultiple(backupFiles []string) ([]*Result, error) {
|
||||
var results []*Result
|
||||
|
||||
|
||||
for _, file := range backupFiles {
|
||||
result, err := Verify(file)
|
||||
if err != nil {
|
||||
@@ -106,7 +106,7 @@ func QuickCheck(backupFile string) error {
|
||||
|
||||
// Check size
|
||||
if info.Size() != meta.SizeBytes {
|
||||
return fmt.Errorf("size mismatch: expected %d bytes, got %d bytes",
|
||||
return fmt.Errorf("size mismatch: expected %d bytes, got %d bytes",
|
||||
meta.SizeBytes, info.Size())
|
||||
}
|
||||
|
||||
|
||||
@@ -21,26 +21,26 @@ type Archiver struct {
|
||||
|
||||
// ArchiveConfig holds WAL archiving configuration
|
||||
type ArchiveConfig struct {
|
||||
ArchiveDir string // Directory to store archived WAL files
|
||||
CompressWAL bool // Compress WAL files with gzip
|
||||
EncryptWAL bool // Encrypt WAL files
|
||||
EncryptionKey []byte // 32-byte key for AES-256-GCM encryption
|
||||
RetentionDays int // Days to keep WAL archives
|
||||
VerifyChecksum bool // Verify WAL file checksums
|
||||
ArchiveDir string // Directory to store archived WAL files
|
||||
CompressWAL bool // Compress WAL files with gzip
|
||||
EncryptWAL bool // Encrypt WAL files
|
||||
EncryptionKey []byte // 32-byte key for AES-256-GCM encryption
|
||||
RetentionDays int // Days to keep WAL archives
|
||||
VerifyChecksum bool // Verify WAL file checksums
|
||||
}
|
||||
|
||||
// WALArchiveInfo contains metadata about an archived WAL file
|
||||
type WALArchiveInfo struct {
|
||||
WALFileName string `json:"wal_filename"`
|
||||
ArchivePath string `json:"archive_path"`
|
||||
OriginalSize int64 `json:"original_size"`
|
||||
ArchivedSize int64 `json:"archived_size"`
|
||||
Checksum string `json:"checksum"`
|
||||
Timeline uint32 `json:"timeline"`
|
||||
Segment uint64 `json:"segment"`
|
||||
ArchivedAt time.Time `json:"archived_at"`
|
||||
Compressed bool `json:"compressed"`
|
||||
Encrypted bool `json:"encrypted"`
|
||||
WALFileName string `json:"wal_filename"`
|
||||
ArchivePath string `json:"archive_path"`
|
||||
OriginalSize int64 `json:"original_size"`
|
||||
ArchivedSize int64 `json:"archived_size"`
|
||||
Checksum string `json:"checksum"`
|
||||
Timeline uint32 `json:"timeline"`
|
||||
Segment uint64 `json:"segment"`
|
||||
ArchivedAt time.Time `json:"archived_at"`
|
||||
Compressed bool `json:"compressed"`
|
||||
Encrypted bool `json:"encrypted"`
|
||||
}
|
||||
|
||||
// NewArchiver creates a new WAL archiver
|
||||
@@ -77,7 +77,7 @@ func (a *Archiver) ArchiveWALFile(ctx context.Context, walFilePath, walFileName
|
||||
// Process WAL file: compression and/or encryption
|
||||
var archivePath string
|
||||
var archivedSize int64
|
||||
|
||||
|
||||
if config.CompressWAL && config.EncryptWAL {
|
||||
// Compress then encrypt
|
||||
archivePath, archivedSize, err = a.compressAndEncryptWAL(walFilePath, walFileName, config)
|
||||
@@ -150,7 +150,7 @@ func (a *Archiver) copyWAL(walFilePath, walFileName string, config ArchiveConfig
|
||||
// compressWAL compresses a WAL file using gzip
|
||||
func (a *Archiver) compressWAL(walFilePath, walFileName string, config ArchiveConfig) (string, int64, error) {
|
||||
archivePath := filepath.Join(config.ArchiveDir, walFileName+".gz")
|
||||
|
||||
|
||||
compressor := NewCompressor(a.log)
|
||||
compressedSize, err := compressor.CompressWALFile(walFilePath, archivePath, 6) // gzip level 6 (balanced)
|
||||
if err != nil {
|
||||
@@ -163,12 +163,12 @@ func (a *Archiver) compressWAL(walFilePath, walFileName string, config ArchiveCo
|
||||
// encryptWAL encrypts a WAL file
|
||||
func (a *Archiver) encryptWAL(walFilePath, walFileName string, config ArchiveConfig) (string, int64, error) {
|
||||
archivePath := filepath.Join(config.ArchiveDir, walFileName+".enc")
|
||||
|
||||
|
||||
encryptor := NewEncryptor(a.log)
|
||||
encOpts := EncryptionOptions{
|
||||
Key: config.EncryptionKey,
|
||||
}
|
||||
|
||||
|
||||
encryptedSize, err := encryptor.EncryptWALFile(walFilePath, archivePath, encOpts)
|
||||
if err != nil {
|
||||
return "", 0, fmt.Errorf("WAL encryption failed: %w", err)
|
||||
@@ -199,7 +199,7 @@ func (a *Archiver) compressAndEncryptWAL(walFilePath, walFileName string, config
|
||||
encOpts := EncryptionOptions{
|
||||
Key: config.EncryptionKey,
|
||||
}
|
||||
|
||||
|
||||
encryptedSize, err := encryptor.EncryptWALFile(tempCompressed, archivePath, encOpts)
|
||||
if err != nil {
|
||||
return "", 0, fmt.Errorf("WAL encryption failed: %w", err)
|
||||
@@ -340,7 +340,7 @@ func (a *Archiver) GetArchiveStats(config ArchiveConfig) (*ArchiveStats, error)
|
||||
|
||||
for _, archive := range archives {
|
||||
stats.TotalSize += archive.ArchivedSize
|
||||
|
||||
|
||||
if archive.Compressed {
|
||||
stats.CompressedFiles++
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"path/filepath"
|
||||
|
||||
"dbbackup/internal/logger"
|
||||
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
)
|
||||
|
||||
|
||||
@@ -23,14 +23,14 @@ type PITRManager struct {
|
||||
|
||||
// PITRConfig holds PITR settings
|
||||
type PITRConfig struct {
|
||||
Enabled bool
|
||||
ArchiveMode string // "on", "off", "always"
|
||||
ArchiveCommand string
|
||||
ArchiveDir string
|
||||
WALLevel string // "minimal", "replica", "logical"
|
||||
MaxWALSenders int
|
||||
WALKeepSize string // e.g., "1GB"
|
||||
RestoreCommand string
|
||||
Enabled bool
|
||||
ArchiveMode string // "on", "off", "always"
|
||||
ArchiveCommand string
|
||||
ArchiveDir string
|
||||
WALLevel string // "minimal", "replica", "logical"
|
||||
MaxWALSenders int
|
||||
WALKeepSize string // e.g., "1GB"
|
||||
RestoreCommand string
|
||||
}
|
||||
|
||||
// RecoveryTarget specifies the point-in-time to recover to
|
||||
@@ -87,11 +87,11 @@ func (pm *PITRManager) EnablePITR(ctx context.Context, archiveDir string) error
|
||||
|
||||
// Settings to enable PITR
|
||||
settings := map[string]string{
|
||||
"wal_level": "replica", // Required for PITR
|
||||
"archive_mode": "on",
|
||||
"archive_command": archiveCommand,
|
||||
"max_wal_senders": "3",
|
||||
"wal_keep_size": "1GB", // Keep at least 1GB of WAL
|
||||
"wal_level": "replica", // Required for PITR
|
||||
"archive_mode": "on",
|
||||
"archive_command": archiveCommand,
|
||||
"max_wal_senders": "3",
|
||||
"wal_keep_size": "1GB", // Keep at least 1GB of WAL
|
||||
}
|
||||
|
||||
// Update postgresql.conf
|
||||
@@ -156,7 +156,7 @@ func (pm *PITRManager) GetCurrentPITRConfig(ctx context.Context) (*PITRConfig, e
|
||||
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
|
||||
|
||||
// Skip comments and empty lines
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
@@ -226,11 +226,11 @@ func (pm *PITRManager) createRecoverySignal(ctx context.Context, dataDir string,
|
||||
|
||||
// Recovery settings go in postgresql.auto.conf (PostgreSQL 12+)
|
||||
autoConfPath := filepath.Join(dataDir, "postgresql.auto.conf")
|
||||
|
||||
|
||||
// Build recovery settings
|
||||
var settings []string
|
||||
settings = append(settings, fmt.Sprintf("restore_command = 'cp %s/%%f %%p'", walArchiveDir))
|
||||
|
||||
|
||||
if target.TargetTime != nil {
|
||||
settings = append(settings, fmt.Sprintf("recovery_target_time = '%s'", target.TargetTime.Format("2006-01-02 15:04:05")))
|
||||
} else if target.TargetXID != "" {
|
||||
@@ -270,11 +270,11 @@ func (pm *PITRManager) createRecoverySignal(ctx context.Context, dataDir string,
|
||||
// createLegacyRecoveryConf creates recovery.conf for PostgreSQL < 12
|
||||
func (pm *PITRManager) createLegacyRecoveryConf(dataDir string, target RecoveryTarget, walArchiveDir string) error {
|
||||
recoveryConfPath := filepath.Join(dataDir, "recovery.conf")
|
||||
|
||||
|
||||
var content strings.Builder
|
||||
content.WriteString("# Recovery Configuration (created by dbbackup)\n")
|
||||
content.WriteString(fmt.Sprintf("restore_command = 'cp %s/%%f %%p'\n", walArchiveDir))
|
||||
|
||||
|
||||
if target.TargetTime != nil {
|
||||
content.WriteString(fmt.Sprintf("recovery_target_time = '%s'\n", target.TargetTime.Format("2006-01-02 15:04:05")))
|
||||
}
|
||||
|
||||
@@ -40,9 +40,9 @@ type TimelineInfo struct {
|
||||
|
||||
// TimelineHistory represents the complete timeline branching structure
|
||||
type TimelineHistory struct {
|
||||
Timelines []*TimelineInfo // All timelines sorted by ID
|
||||
Timelines []*TimelineInfo // All timelines sorted by ID
|
||||
CurrentTimeline uint32 // Current active timeline
|
||||
TimelineMap map[uint32]*TimelineInfo // Quick lookup by timeline ID
|
||||
TimelineMap map[uint32]*TimelineInfo // Quick lookup by timeline ID
|
||||
}
|
||||
|
||||
// ParseTimelineHistory parses timeline history from an archive directory
|
||||
@@ -74,10 +74,10 @@ func (tm *TimelineManager) ParseTimelineHistory(ctx context.Context, archiveDir
|
||||
// Always add timeline 1 (base timeline) if not present
|
||||
if _, exists := history.TimelineMap[1]; !exists {
|
||||
baseTimeline := &TimelineInfo{
|
||||
TimelineID: 1,
|
||||
ParentTimeline: 0,
|
||||
SwitchPoint: "0/0",
|
||||
Reason: "Base timeline",
|
||||
TimelineID: 1,
|
||||
ParentTimeline: 0,
|
||||
SwitchPoint: "0/0",
|
||||
Reason: "Base timeline",
|
||||
FirstWALSegment: 0,
|
||||
}
|
||||
history.Timelines = append(history.Timelines, baseTimeline)
|
||||
@@ -201,7 +201,7 @@ func (tm *TimelineManager) scanWALSegments(archiveDir string, history *TimelineH
|
||||
// Process each WAL file
|
||||
for _, walFile := range walFiles {
|
||||
filename := filepath.Base(walFile)
|
||||
|
||||
|
||||
// Remove extensions
|
||||
filename = strings.TrimSuffix(filename, ".gz.enc")
|
||||
filename = strings.TrimSuffix(filename, ".enc")
|
||||
@@ -255,7 +255,7 @@ func (tm *TimelineManager) ValidateTimelineConsistency(ctx context.Context, hist
|
||||
|
||||
parent, exists := history.TimelineMap[tl.ParentTimeline]
|
||||
if !exists {
|
||||
return fmt.Errorf("timeline %d references non-existent parent timeline %d",
|
||||
return fmt.Errorf("timeline %d references non-existent parent timeline %d",
|
||||
tl.TimelineID, tl.ParentTimeline)
|
||||
}
|
||||
|
||||
@@ -274,29 +274,29 @@ func (tm *TimelineManager) ValidateTimelineConsistency(ctx context.Context, hist
|
||||
// GetTimelinePath returns the path from timeline 1 to the target timeline
|
||||
func (tm *TimelineManager) GetTimelinePath(history *TimelineHistory, targetTimeline uint32) ([]*TimelineInfo, error) {
|
||||
path := make([]*TimelineInfo, 0)
|
||||
|
||||
|
||||
currentTL := targetTimeline
|
||||
for currentTL > 0 {
|
||||
tl, exists := history.TimelineMap[currentTL]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("timeline %d not found in history", currentTL)
|
||||
}
|
||||
|
||||
|
||||
// Prepend to path (we're walking backwards)
|
||||
path = append([]*TimelineInfo{tl}, path...)
|
||||
|
||||
|
||||
// Move to parent
|
||||
if currentTL == 1 {
|
||||
break // Reached base timeline
|
||||
}
|
||||
currentTL = tl.ParentTimeline
|
||||
|
||||
|
||||
// Prevent infinite loops
|
||||
if len(path) > 100 {
|
||||
return nil, fmt.Errorf("timeline path too long (possible cycle)")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return path, nil
|
||||
}
|
||||
|
||||
@@ -305,13 +305,13 @@ func (tm *TimelineManager) FindTimelineAtPoint(history *TimelineHistory, targetL
|
||||
// Start from current timeline and walk backwards
|
||||
for i := len(history.Timelines) - 1; i >= 0; i-- {
|
||||
tl := history.Timelines[i]
|
||||
|
||||
|
||||
// Compare LSNs (simplified - in production would need proper LSN comparison)
|
||||
if tl.SwitchPoint <= targetLSN || tl.SwitchPoint == "0/0" {
|
||||
return tl.TimelineID, nil
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Default to timeline 1
|
||||
return 1, nil
|
||||
}
|
||||
@@ -384,23 +384,23 @@ func (tm *TimelineManager) formatTimelineNode(sb *strings.Builder, history *Time
|
||||
}
|
||||
|
||||
sb.WriteString(fmt.Sprintf("%s%s Timeline %d", indent, marker, tl.TimelineID))
|
||||
|
||||
|
||||
if tl.TimelineID == history.CurrentTimeline {
|
||||
sb.WriteString(" [CURRENT]")
|
||||
}
|
||||
|
||||
|
||||
if tl.SwitchPoint != "" && tl.SwitchPoint != "0/0" {
|
||||
sb.WriteString(fmt.Sprintf(" (switched at %s)", tl.SwitchPoint))
|
||||
}
|
||||
|
||||
|
||||
if tl.FirstWALSegment > 0 {
|
||||
sb.WriteString(fmt.Sprintf("\n%s WAL segments: %d files", indent, tl.LastWALSegment-tl.FirstWALSegment+1))
|
||||
}
|
||||
|
||||
|
||||
if tl.Reason != "" {
|
||||
sb.WriteString(fmt.Sprintf("\n%s Reason: %s", indent, tl.Reason))
|
||||
}
|
||||
|
||||
|
||||
sb.WriteString("\n")
|
||||
|
||||
// Find and format children
|
||||
|
||||
Reference in New Issue
Block a user