Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2e7aa9fcdf | |||
| 59812400a4 | |||
| 48f922ef6c |
@ -1,7 +1,6 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -12,6 +11,8 @@ import (
|
||||
"dbbackup/internal/database"
|
||||
"dbbackup/internal/engine/native"
|
||||
"dbbackup/internal/notify"
|
||||
|
||||
"github.com/klauspost/pgzip"
|
||||
)
|
||||
|
||||
// runNativeBackup executes backup using native Go engines
|
||||
@ -58,10 +59,13 @@ func runNativeBackup(ctx context.Context, db database.Database, databaseName, ba
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Wrap with compression if enabled
|
||||
// Wrap with compression if enabled (use pgzip for parallel compression)
|
||||
var writer io.Writer = file
|
||||
if cfg.CompressionLevel > 0 {
|
||||
gzWriter := gzip.NewWriter(file)
|
||||
gzWriter, err := pgzip.NewWriterLevel(file, cfg.CompressionLevel)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create gzip writer: %w", err)
|
||||
}
|
||||
defer gzWriter.Close()
|
||||
writer = gzWriter
|
||||
}
|
||||
|
||||
93
cmd/native_restore.go
Normal file
93
cmd/native_restore.go
Normal file
@ -0,0 +1,93 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"dbbackup/internal/database"
|
||||
"dbbackup/internal/engine/native"
|
||||
"dbbackup/internal/notify"
|
||||
|
||||
"github.com/klauspost/pgzip"
|
||||
)
|
||||
|
||||
// runNativeRestore executes restore using native Go engines
|
||||
func runNativeRestore(ctx context.Context, db database.Database, archivePath, targetDB string, cleanFirst, createIfMissing bool, startTime time.Time, user string) error {
|
||||
// Initialize native engine manager
|
||||
engineManager := native.NewEngineManager(cfg, log)
|
||||
|
||||
if err := engineManager.InitializeEngines(ctx); err != nil {
|
||||
return fmt.Errorf("failed to initialize native engines: %w", err)
|
||||
}
|
||||
defer engineManager.Close()
|
||||
|
||||
// Check if native engine is available for this database type
|
||||
dbType := detectDatabaseTypeFromConfig()
|
||||
if !engineManager.IsNativeEngineAvailable(dbType) {
|
||||
return fmt.Errorf("native restore engine not available for database type: %s", dbType)
|
||||
}
|
||||
|
||||
// Open archive file
|
||||
file, err := os.Open(archivePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open archive: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Detect if file is gzip compressed
|
||||
var reader io.Reader = file
|
||||
if isGzipFile(archivePath) {
|
||||
gzReader, err := pgzip.NewReader(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create gzip reader: %w", err)
|
||||
}
|
||||
defer gzReader.Close()
|
||||
reader = gzReader
|
||||
}
|
||||
|
||||
log.Info("Starting native restore",
|
||||
"archive", archivePath,
|
||||
"database", targetDB,
|
||||
"engine", dbType,
|
||||
"clean_first", cleanFirst,
|
||||
"create_if_missing", createIfMissing)
|
||||
|
||||
// Perform restore using native engine
|
||||
if err := engineManager.RestoreWithNativeEngine(ctx, reader, targetDB); err != nil {
|
||||
auditLogger.LogRestoreFailed(user, targetDB, err)
|
||||
if notifyManager != nil {
|
||||
notifyManager.Notify(notify.NewEvent(notify.EventRestoreFailed, notify.SeverityError, "Native restore failed").
|
||||
WithDatabase(targetDB).
|
||||
WithError(err))
|
||||
}
|
||||
return fmt.Errorf("native restore failed: %w", err)
|
||||
}
|
||||
|
||||
restoreDuration := time.Since(startTime)
|
||||
|
||||
log.Info("Native restore completed successfully",
|
||||
"database", targetDB,
|
||||
"duration", restoreDuration,
|
||||
"engine", dbType)
|
||||
|
||||
// Audit log: restore completed
|
||||
auditLogger.LogRestoreComplete(user, targetDB, restoreDuration)
|
||||
|
||||
// Notify: restore completed
|
||||
if notifyManager != nil {
|
||||
notifyManager.Notify(notify.NewEvent(notify.EventRestoreCompleted, notify.SeverityInfo, "Native restore completed").
|
||||
WithDatabase(targetDB).
|
||||
WithDuration(restoreDuration).
|
||||
WithDetail("engine", dbType))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isGzipFile checks if file has gzip extension
|
||||
func isGzipFile(path string) bool {
|
||||
return len(path) > 3 && path[len(path)-3:] == ".gz"
|
||||
}
|
||||
@ -720,6 +720,23 @@ func runRestoreSingle(cmd *cobra.Command, args []string) error {
|
||||
WithDetail("archive", filepath.Base(archivePath)))
|
||||
}
|
||||
|
||||
// Check if native engine should be used for restore
|
||||
if cfg.UseNativeEngine {
|
||||
log.Info("Using native engine for restore", "database", targetDB)
|
||||
err = runNativeRestore(ctx, db, archivePath, targetDB, restoreClean, restoreCreate, startTime, user)
|
||||
|
||||
if err != nil && cfg.FallbackToTools {
|
||||
log.Warn("Native engine restore failed, falling back to external tools", "error", err)
|
||||
// Continue with tool-based restore below
|
||||
} else {
|
||||
// Native engine succeeded or no fallback configured
|
||||
if err == nil {
|
||||
log.Info("[OK] Restore completed successfully (native engine)", "database", targetDB)
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := engine.RestoreSingle(ctx, archivePath, targetDB, restoreClean, restoreCreate); err != nil {
|
||||
auditLogger.LogRestoreFailed(user, targetDB, err)
|
||||
// Notify: restore failed
|
||||
|
||||
@ -9,7 +9,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
@ -19,6 +18,7 @@ import (
|
||||
"time"
|
||||
|
||||
"dbbackup/internal/checks"
|
||||
"dbbackup/internal/cleanup"
|
||||
"dbbackup/internal/cloud"
|
||||
"dbbackup/internal/config"
|
||||
"dbbackup/internal/database"
|
||||
@ -650,7 +650,7 @@ func (e *Engine) executeCommandWithProgress(ctx context.Context, cmdArgs []strin
|
||||
|
||||
e.log.Debug("Executing backup command with progress", "cmd", cmdArgs[0], "args", cmdArgs[1:])
|
||||
|
||||
cmd := exec.CommandContext(ctx, cmdArgs[0], cmdArgs[1:]...)
|
||||
cmd := cleanup.SafeCommand(ctx, cmdArgs[0], cmdArgs[1:]...)
|
||||
|
||||
// Set environment variables for database tools
|
||||
cmd.Env = os.Environ()
|
||||
@ -696,9 +696,9 @@ func (e *Engine) executeCommandWithProgress(ctx context.Context, cmdArgs []strin
|
||||
case cmdErr = <-cmdDone:
|
||||
// Command completed (success or failure)
|
||||
case <-ctx.Done():
|
||||
// Context cancelled - kill process to unblock
|
||||
e.log.Warn("Backup cancelled - killing process")
|
||||
cmd.Process.Kill()
|
||||
// Context cancelled - kill entire process group
|
||||
e.log.Warn("Backup cancelled - killing process group")
|
||||
cleanup.KillCommandGroup(cmd)
|
||||
<-cmdDone // Wait for goroutine to finish
|
||||
cmdErr = ctx.Err()
|
||||
}
|
||||
@ -754,7 +754,7 @@ func (e *Engine) monitorCommandProgress(stderr io.ReadCloser, tracker *progress.
|
||||
// Uses in-process pgzip for parallel compression (2-4x faster on multi-core systems)
|
||||
func (e *Engine) executeMySQLWithProgressAndCompression(ctx context.Context, cmdArgs []string, outputFile string, tracker *progress.OperationTracker) error {
|
||||
// Create mysqldump command
|
||||
dumpCmd := exec.CommandContext(ctx, cmdArgs[0], cmdArgs[1:]...)
|
||||
dumpCmd := cleanup.SafeCommand(ctx, cmdArgs[0], cmdArgs[1:]...)
|
||||
dumpCmd.Env = os.Environ()
|
||||
if e.cfg.Password != "" {
|
||||
dumpCmd.Env = append(dumpCmd.Env, "MYSQL_PWD="+e.cfg.Password)
|
||||
@ -816,8 +816,8 @@ func (e *Engine) executeMySQLWithProgressAndCompression(ctx context.Context, cmd
|
||||
case dumpErr = <-dumpDone:
|
||||
// mysqldump completed
|
||||
case <-ctx.Done():
|
||||
e.log.Warn("Backup cancelled - killing mysqldump")
|
||||
dumpCmd.Process.Kill()
|
||||
e.log.Warn("Backup cancelled - killing mysqldump process group")
|
||||
cleanup.KillCommandGroup(dumpCmd)
|
||||
<-dumpDone
|
||||
return ctx.Err()
|
||||
}
|
||||
@ -846,7 +846,7 @@ func (e *Engine) executeMySQLWithProgressAndCompression(ctx context.Context, cmd
|
||||
// Uses in-process pgzip for parallel compression (2-4x faster on multi-core systems)
|
||||
func (e *Engine) executeMySQLWithCompression(ctx context.Context, cmdArgs []string, outputFile string) error {
|
||||
// Create mysqldump command
|
||||
dumpCmd := exec.CommandContext(ctx, cmdArgs[0], cmdArgs[1:]...)
|
||||
dumpCmd := cleanup.SafeCommand(ctx, cmdArgs[0], cmdArgs[1:]...)
|
||||
dumpCmd.Env = os.Environ()
|
||||
if e.cfg.Password != "" {
|
||||
dumpCmd.Env = append(dumpCmd.Env, "MYSQL_PWD="+e.cfg.Password)
|
||||
@ -895,8 +895,8 @@ func (e *Engine) executeMySQLWithCompression(ctx context.Context, cmdArgs []stri
|
||||
case dumpErr = <-dumpDone:
|
||||
// mysqldump completed
|
||||
case <-ctx.Done():
|
||||
e.log.Warn("Backup cancelled - killing mysqldump")
|
||||
dumpCmd.Process.Kill()
|
||||
e.log.Warn("Backup cancelled - killing mysqldump process group")
|
||||
cleanup.KillCommandGroup(dumpCmd)
|
||||
<-dumpDone
|
||||
return ctx.Err()
|
||||
}
|
||||
@ -951,7 +951,7 @@ func (e *Engine) createSampleBackup(ctx context.Context, databaseName, outputFil
|
||||
Format: "plain",
|
||||
})
|
||||
|
||||
cmd := exec.CommandContext(ctx, schemaCmd[0], schemaCmd[1:]...)
|
||||
cmd := cleanup.SafeCommand(ctx, schemaCmd[0], schemaCmd[1:]...)
|
||||
cmd.Env = os.Environ()
|
||||
if e.cfg.Password != "" {
|
||||
cmd.Env = append(cmd.Env, "PGPASSWORD="+e.cfg.Password)
|
||||
@ -990,7 +990,7 @@ func (e *Engine) backupGlobals(ctx context.Context, tempDir string) error {
|
||||
globalsFile := filepath.Join(tempDir, "globals.sql")
|
||||
|
||||
// CRITICAL: Always pass port even for localhost - user may have non-standard port
|
||||
cmd := exec.CommandContext(ctx, "pg_dumpall", "--globals-only",
|
||||
cmd := cleanup.SafeCommand(ctx, "pg_dumpall", "--globals-only",
|
||||
"-p", fmt.Sprintf("%d", e.cfg.Port),
|
||||
"-U", e.cfg.User)
|
||||
|
||||
@ -1034,8 +1034,8 @@ func (e *Engine) backupGlobals(ctx context.Context, tempDir string) error {
|
||||
case cmdErr = <-cmdDone:
|
||||
// Command completed normally
|
||||
case <-ctx.Done():
|
||||
e.log.Warn("Globals backup cancelled - killing pg_dumpall")
|
||||
cmd.Process.Kill()
|
||||
e.log.Warn("Globals backup cancelled - killing pg_dumpall process group")
|
||||
cleanup.KillCommandGroup(cmd)
|
||||
<-cmdDone
|
||||
return ctx.Err()
|
||||
}
|
||||
@ -1430,7 +1430,7 @@ func (e *Engine) executeCommand(ctx context.Context, cmdArgs []string, outputFil
|
||||
|
||||
// For custom format, pg_dump handles everything (writes directly to file)
|
||||
// NO GO BUFFERING - pg_dump writes directly to disk
|
||||
cmd := exec.CommandContext(ctx, cmdArgs[0], cmdArgs[1:]...)
|
||||
cmd := cleanup.SafeCommand(ctx, cmdArgs[0], cmdArgs[1:]...)
|
||||
|
||||
// Start heartbeat ticker for backup progress
|
||||
backupStart := time.Now()
|
||||
@ -1499,9 +1499,9 @@ func (e *Engine) executeCommand(ctx context.Context, cmdArgs []string, outputFil
|
||||
case cmdErr = <-cmdDone:
|
||||
// Command completed (success or failure)
|
||||
case <-ctx.Done():
|
||||
// Context cancelled - kill process to unblock
|
||||
e.log.Warn("Backup cancelled - killing pg_dump process")
|
||||
cmd.Process.Kill()
|
||||
// Context cancelled - kill entire process group
|
||||
e.log.Warn("Backup cancelled - killing pg_dump process group")
|
||||
cleanup.KillCommandGroup(cmd)
|
||||
<-cmdDone // Wait for goroutine to finish
|
||||
cmdErr = ctx.Err()
|
||||
}
|
||||
@ -1536,7 +1536,7 @@ func (e *Engine) executeWithStreamingCompression(ctx context.Context, cmdArgs []
|
||||
}
|
||||
|
||||
// Create pg_dump command
|
||||
dumpCmd := exec.CommandContext(ctx, cmdArgs[0], cmdArgs[1:]...)
|
||||
dumpCmd := cleanup.SafeCommand(ctx, cmdArgs[0], cmdArgs[1:]...)
|
||||
dumpCmd.Env = os.Environ()
|
||||
if e.cfg.Password != "" && e.cfg.IsPostgreSQL() {
|
||||
dumpCmd.Env = append(dumpCmd.Env, "PGPASSWORD="+e.cfg.Password)
|
||||
@ -1612,9 +1612,9 @@ func (e *Engine) executeWithStreamingCompression(ctx context.Context, cmdArgs []
|
||||
case dumpErr = <-dumpDone:
|
||||
// pg_dump completed (success or failure)
|
||||
case <-ctx.Done():
|
||||
// Context cancelled/timeout - kill pg_dump to unblock
|
||||
e.log.Warn("Backup timeout - killing pg_dump process")
|
||||
dumpCmd.Process.Kill()
|
||||
// Context cancelled/timeout - kill pg_dump process group
|
||||
e.log.Warn("Backup timeout - killing pg_dump process group")
|
||||
cleanup.KillCommandGroup(dumpCmd)
|
||||
<-dumpDone // Wait for goroutine to finish
|
||||
dumpErr = ctx.Err()
|
||||
}
|
||||
|
||||
154
internal/cleanup/command.go
Normal file
154
internal/cleanup/command.go
Normal file
@ -0,0 +1,154 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package cleanup
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"dbbackup/internal/logger"
|
||||
)
|
||||
|
||||
// SafeCommand creates an exec.Cmd with proper process group setup for clean termination.
|
||||
// This ensures that child processes (e.g., from pipelines) are killed when the parent is killed.
|
||||
func SafeCommand(ctx context.Context, name string, args ...string) *exec.Cmd {
|
||||
cmd := exec.CommandContext(ctx, name, args...)
|
||||
|
||||
// Set up process group for clean termination
|
||||
// This allows killing the entire process tree when cancelled
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
Setpgid: true, // Create new process group
|
||||
Pgid: 0, // Use the new process's PID as the PGID
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
// TrackedCommand creates a command that is tracked for cleanup on shutdown.
|
||||
// When the handler shuts down, this command will be killed if still running.
|
||||
type TrackedCommand struct {
|
||||
*exec.Cmd
|
||||
log logger.Logger
|
||||
name string
|
||||
}
|
||||
|
||||
// NewTrackedCommand creates a tracked command
|
||||
func NewTrackedCommand(ctx context.Context, log logger.Logger, name string, args ...string) *TrackedCommand {
|
||||
tc := &TrackedCommand{
|
||||
Cmd: SafeCommand(ctx, name, args...),
|
||||
log: log,
|
||||
name: name,
|
||||
}
|
||||
return tc
|
||||
}
|
||||
|
||||
// StartWithCleanup starts the command and registers cleanup with the handler
|
||||
func (tc *TrackedCommand) StartWithCleanup(h *Handler) error {
|
||||
if err := tc.Cmd.Start(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Register cleanup function
|
||||
pid := tc.Cmd.Process.Pid
|
||||
h.RegisterCleanup(fmt.Sprintf("kill-%s-%d", tc.name, pid), func(ctx context.Context) error {
|
||||
return tc.Kill()
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Kill terminates the command and its process group
|
||||
func (tc *TrackedCommand) Kill() error {
|
||||
if tc.Cmd.Process == nil {
|
||||
return nil // Not started or already cleaned up
|
||||
}
|
||||
|
||||
pid := tc.Cmd.Process.Pid
|
||||
|
||||
// Get the process group ID
|
||||
pgid, err := syscall.Getpgid(pid)
|
||||
if err != nil {
|
||||
// Process might already be gone
|
||||
return nil
|
||||
}
|
||||
|
||||
tc.log.Debug("Terminating process", "name", tc.name, "pid", pid, "pgid", pgid)
|
||||
|
||||
// Try graceful shutdown first (SIGTERM to process group)
|
||||
if err := syscall.Kill(-pgid, syscall.SIGTERM); err != nil {
|
||||
tc.log.Debug("SIGTERM failed, trying SIGKILL", "error", err)
|
||||
}
|
||||
|
||||
// Wait briefly for graceful shutdown
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := tc.Cmd.Process.Wait()
|
||||
done <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-time.After(3 * time.Second):
|
||||
// Force kill after timeout
|
||||
tc.log.Debug("Process didn't stop gracefully, sending SIGKILL", "name", tc.name, "pid", pid)
|
||||
if err := syscall.Kill(-pgid, syscall.SIGKILL); err != nil {
|
||||
tc.log.Debug("SIGKILL failed", "error", err)
|
||||
}
|
||||
<-done // Wait for Wait() to finish
|
||||
|
||||
case <-done:
|
||||
// Process exited
|
||||
}
|
||||
|
||||
tc.log.Debug("Process terminated", "name", tc.name, "pid", pid)
|
||||
return nil
|
||||
}
|
||||
|
||||
// WaitWithContext waits for the command to complete, handling context cancellation properly.
|
||||
// This is the recommended way to wait for commands, as it ensures proper cleanup on cancellation.
|
||||
func WaitWithContext(ctx context.Context, cmd *exec.Cmd, log logger.Logger) error {
|
||||
if cmd.Process == nil {
|
||||
return fmt.Errorf("process not started")
|
||||
}
|
||||
|
||||
// Wait for command in a goroutine
|
||||
cmdDone := make(chan error, 1)
|
||||
go func() {
|
||||
cmdDone <- cmd.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-cmdDone:
|
||||
return err
|
||||
|
||||
case <-ctx.Done():
|
||||
// Context cancelled - kill process group
|
||||
log.Debug("Context cancelled, terminating process", "pid", cmd.Process.Pid)
|
||||
|
||||
// Get process group and kill entire group
|
||||
pgid, err := syscall.Getpgid(cmd.Process.Pid)
|
||||
if err == nil {
|
||||
// Kill process group
|
||||
syscall.Kill(-pgid, syscall.SIGTERM)
|
||||
|
||||
// Wait briefly for graceful shutdown
|
||||
select {
|
||||
case <-cmdDone:
|
||||
// Process exited
|
||||
case <-time.After(2 * time.Second):
|
||||
// Force kill
|
||||
syscall.Kill(-pgid, syscall.SIGKILL)
|
||||
<-cmdDone
|
||||
}
|
||||
} else {
|
||||
// Fallback to killing just the process
|
||||
cmd.Process.Kill()
|
||||
<-cmdDone
|
||||
}
|
||||
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
99
internal/cleanup/command_windows.go
Normal file
99
internal/cleanup/command_windows.go
Normal file
@ -0,0 +1,99 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package cleanup
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"time"
|
||||
|
||||
"dbbackup/internal/logger"
|
||||
)
|
||||
|
||||
// SafeCommand creates an exec.Cmd with proper setup for clean termination on Windows.
|
||||
func SafeCommand(ctx context.Context, name string, args ...string) *exec.Cmd {
|
||||
cmd := exec.CommandContext(ctx, name, args...)
|
||||
// Windows doesn't use process groups the same way as Unix
|
||||
// exec.CommandContext will handle termination via the context
|
||||
return cmd
|
||||
}
|
||||
|
||||
// TrackedCommand creates a command that is tracked for cleanup on shutdown.
|
||||
type TrackedCommand struct {
|
||||
*exec.Cmd
|
||||
log logger.Logger
|
||||
name string
|
||||
}
|
||||
|
||||
// NewTrackedCommand creates a tracked command
|
||||
func NewTrackedCommand(ctx context.Context, log logger.Logger, name string, args ...string) *TrackedCommand {
|
||||
tc := &TrackedCommand{
|
||||
Cmd: SafeCommand(ctx, name, args...),
|
||||
log: log,
|
||||
name: name,
|
||||
}
|
||||
return tc
|
||||
}
|
||||
|
||||
// StartWithCleanup starts the command and registers cleanup with the handler
|
||||
func (tc *TrackedCommand) StartWithCleanup(h *Handler) error {
|
||||
if err := tc.Cmd.Start(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Register cleanup function
|
||||
pid := tc.Cmd.Process.Pid
|
||||
h.RegisterCleanup(fmt.Sprintf("kill-%s-%d", tc.name, pid), func(ctx context.Context) error {
|
||||
return tc.Kill()
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Kill terminates the command on Windows
|
||||
func (tc *TrackedCommand) Kill() error {
|
||||
if tc.Cmd.Process == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
tc.log.Debug("Terminating process", "name", tc.name, "pid", tc.Cmd.Process.Pid)
|
||||
|
||||
if err := tc.Cmd.Process.Kill(); err != nil {
|
||||
tc.log.Debug("Kill failed", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
tc.log.Debug("Process terminated", "name", tc.name, "pid", tc.Cmd.Process.Pid)
|
||||
return nil
|
||||
}
|
||||
|
||||
// WaitWithContext waits for the command to complete, handling context cancellation properly.
|
||||
func WaitWithContext(ctx context.Context, cmd *exec.Cmd, log logger.Logger) error {
|
||||
if cmd.Process == nil {
|
||||
return fmt.Errorf("process not started")
|
||||
}
|
||||
|
||||
cmdDone := make(chan error, 1)
|
||||
go func() {
|
||||
cmdDone <- cmd.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-cmdDone:
|
||||
return err
|
||||
|
||||
case <-ctx.Done():
|
||||
log.Debug("Context cancelled, terminating process", "pid", cmd.Process.Pid)
|
||||
cmd.Process.Kill()
|
||||
|
||||
select {
|
||||
case <-cmdDone:
|
||||
case <-time.After(5 * time.Second):
|
||||
// Already killed, just wait for it
|
||||
}
|
||||
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
242
internal/cleanup/handler.go
Normal file
242
internal/cleanup/handler.go
Normal file
@ -0,0 +1,242 @@
|
||||
// Package cleanup provides graceful shutdown and resource cleanup functionality
|
||||
package cleanup
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"dbbackup/internal/logger"
|
||||
)
|
||||
|
||||
// CleanupFunc is a function that performs cleanup with a timeout context
|
||||
type CleanupFunc func(ctx context.Context) error
|
||||
|
||||
// Handler manages graceful shutdown and resource cleanup
|
||||
type Handler struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
cleanupFns []cleanupEntry
|
||||
mu sync.Mutex
|
||||
|
||||
shutdownTimeout time.Duration
|
||||
log logger.Logger
|
||||
|
||||
// Track if shutdown has been initiated
|
||||
shutdownOnce sync.Once
|
||||
shutdownDone chan struct{}
|
||||
}
|
||||
|
||||
type cleanupEntry struct {
|
||||
name string
|
||||
fn CleanupFunc
|
||||
}
|
||||
|
||||
// NewHandler creates a shutdown handler
|
||||
func NewHandler(log logger.Logger) *Handler {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
h := &Handler{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
cleanupFns: make([]cleanupEntry, 0),
|
||||
shutdownTimeout: 30 * time.Second,
|
||||
log: log,
|
||||
shutdownDone: make(chan struct{}),
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// Context returns the shutdown context
|
||||
func (h *Handler) Context() context.Context {
|
||||
return h.ctx
|
||||
}
|
||||
|
||||
// RegisterCleanup adds a named cleanup function
|
||||
func (h *Handler) RegisterCleanup(name string, fn CleanupFunc) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.cleanupFns = append(h.cleanupFns, cleanupEntry{name: name, fn: fn})
|
||||
}
|
||||
|
||||
// SetShutdownTimeout sets the maximum time to wait for cleanup
|
||||
func (h *Handler) SetShutdownTimeout(d time.Duration) {
|
||||
h.shutdownTimeout = d
|
||||
}
|
||||
|
||||
// Shutdown triggers graceful shutdown
|
||||
func (h *Handler) Shutdown() {
|
||||
h.shutdownOnce.Do(func() {
|
||||
h.log.Info("Initiating graceful shutdown...")
|
||||
|
||||
// Cancel context first (stops all ongoing operations)
|
||||
h.cancel()
|
||||
|
||||
// Run cleanup functions
|
||||
h.runCleanup()
|
||||
|
||||
close(h.shutdownDone)
|
||||
})
|
||||
}
|
||||
|
||||
// ShutdownWithSignal triggers shutdown due to an OS signal
|
||||
func (h *Handler) ShutdownWithSignal(sig os.Signal) {
|
||||
h.log.Info("Received signal, initiating graceful shutdown", "signal", sig.String())
|
||||
h.Shutdown()
|
||||
}
|
||||
|
||||
// Wait blocks until shutdown is complete
|
||||
func (h *Handler) Wait() {
|
||||
<-h.shutdownDone
|
||||
}
|
||||
|
||||
// runCleanup executes all cleanup functions in LIFO order
|
||||
func (h *Handler) runCleanup() {
|
||||
h.mu.Lock()
|
||||
fns := make([]cleanupEntry, len(h.cleanupFns))
|
||||
copy(fns, h.cleanupFns)
|
||||
h.mu.Unlock()
|
||||
|
||||
if len(fns) == 0 {
|
||||
h.log.Info("No cleanup functions registered")
|
||||
return
|
||||
}
|
||||
|
||||
h.log.Info("Running cleanup functions", "count", len(fns))
|
||||
|
||||
// Create timeout context for cleanup
|
||||
ctx, cancel := context.WithTimeout(context.Background(), h.shutdownTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Run all cleanups in LIFO order (most recently registered first)
|
||||
var failed int
|
||||
for i := len(fns) - 1; i >= 0; i-- {
|
||||
entry := fns[i]
|
||||
|
||||
h.log.Debug("Running cleanup", "name", entry.name)
|
||||
|
||||
if err := entry.fn(ctx); err != nil {
|
||||
h.log.Warn("Cleanup function failed", "name", entry.name, "error", err)
|
||||
failed++
|
||||
} else {
|
||||
h.log.Debug("Cleanup completed", "name", entry.name)
|
||||
}
|
||||
}
|
||||
|
||||
if failed > 0 {
|
||||
h.log.Warn("Some cleanup functions failed", "failed", failed, "total", len(fns))
|
||||
} else {
|
||||
h.log.Info("All cleanup functions completed successfully")
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterSignalHandler sets up signal handling for graceful shutdown
|
||||
func (h *Handler) RegisterSignalHandler() {
|
||||
sigChan := make(chan os.Signal, 2)
|
||||
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGINT)
|
||||
|
||||
go func() {
|
||||
// First signal: graceful shutdown
|
||||
sig := <-sigChan
|
||||
h.ShutdownWithSignal(sig)
|
||||
|
||||
// Second signal: force exit
|
||||
sig = <-sigChan
|
||||
h.log.Warn("Received second signal, forcing exit", "signal", sig.String())
|
||||
os.Exit(1)
|
||||
}()
|
||||
}
|
||||
|
||||
// ChildProcessCleanup creates a cleanup function for killing child processes
|
||||
func (h *Handler) ChildProcessCleanup() CleanupFunc {
|
||||
return func(ctx context.Context) error {
|
||||
h.log.Info("Cleaning up orphaned child processes...")
|
||||
|
||||
if err := KillOrphanedProcesses(h.log); err != nil {
|
||||
h.log.Warn("Failed to kill some orphaned processes", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
h.log.Info("Child process cleanup complete")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// DatabasePoolCleanup creates a cleanup function for database connection pools
|
||||
// poolCloser should be a function that closes the pool
|
||||
func DatabasePoolCleanup(log logger.Logger, name string, poolCloser func()) CleanupFunc {
|
||||
return func(ctx context.Context) error {
|
||||
log.Debug("Closing database connection pool", "name", name)
|
||||
poolCloser()
|
||||
log.Debug("Database connection pool closed", "name", name)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// FileCleanup creates a cleanup function for file handles
|
||||
func FileCleanup(log logger.Logger, path string, file *os.File) CleanupFunc {
|
||||
return func(ctx context.Context) error {
|
||||
if file == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debug("Closing file", "path", path)
|
||||
if err := file.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close file %s: %w", path, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// TempFileCleanup creates a cleanup function that closes and removes a temp file
|
||||
func TempFileCleanup(log logger.Logger, file *os.File) CleanupFunc {
|
||||
return func(ctx context.Context) error {
|
||||
if file == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
path := file.Name()
|
||||
log.Debug("Removing temporary file", "path", path)
|
||||
|
||||
// Close file first
|
||||
if err := file.Close(); err != nil {
|
||||
log.Warn("Failed to close temp file", "path", path, "error", err)
|
||||
}
|
||||
|
||||
// Remove file
|
||||
if err := os.Remove(path); err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to remove temp file %s: %w", path, err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("Temporary file removed", "path", path)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// TempDirCleanup creates a cleanup function that removes a temp directory
|
||||
func TempDirCleanup(log logger.Logger, path string) CleanupFunc {
|
||||
return func(ctx context.Context) error {
|
||||
if path == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debug("Removing temporary directory", "path", path)
|
||||
|
||||
if err := os.RemoveAll(path); err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to remove temp dir %s: %w", path, err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("Temporary directory removed", "path", path)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@ -8,12 +8,12 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"dbbackup/internal/cleanup"
|
||||
"dbbackup/internal/fs"
|
||||
"dbbackup/internal/logger"
|
||||
|
||||
@ -568,7 +568,7 @@ func (d *Diagnoser) verifyWithPgRestore(filePath string, result *DiagnoseResult)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutMinutes)*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "pg_restore", "--list", filePath)
|
||||
cmd := cleanup.SafeCommand(ctx, "pg_restore", "--list", filePath)
|
||||
output, err := cmd.CombinedOutput()
|
||||
|
||||
if err != nil {
|
||||
|
||||
@ -17,6 +17,7 @@ import (
|
||||
"time"
|
||||
|
||||
"dbbackup/internal/checks"
|
||||
"dbbackup/internal/cleanup"
|
||||
"dbbackup/internal/config"
|
||||
"dbbackup/internal/database"
|
||||
"dbbackup/internal/fs"
|
||||
@ -499,7 +500,7 @@ func (e *Engine) checkDumpHasLargeObjects(archivePath string) bool {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "pg_restore", "-l", archivePath)
|
||||
cmd := cleanup.SafeCommand(ctx, "pg_restore", "-l", archivePath)
|
||||
output, err := cmd.Output()
|
||||
|
||||
if err != nil {
|
||||
@ -592,7 +593,7 @@ func (e *Engine) executeRestoreCommand(ctx context.Context, cmdArgs []string) er
|
||||
func (e *Engine) executeRestoreCommandWithContext(ctx context.Context, cmdArgs []string, archivePath, targetDB string, format ArchiveFormat) error {
|
||||
e.log.Info("Executing restore command", "command", strings.Join(cmdArgs, " "))
|
||||
|
||||
cmd := exec.CommandContext(ctx, cmdArgs[0], cmdArgs[1:]...)
|
||||
cmd := cleanup.SafeCommand(ctx, cmdArgs[0], cmdArgs[1:]...)
|
||||
|
||||
// Set environment variables
|
||||
cmd.Env = append(os.Environ(),
|
||||
@ -662,9 +663,9 @@ func (e *Engine) executeRestoreCommandWithContext(ctx context.Context, cmdArgs [
|
||||
case cmdErr = <-cmdDone:
|
||||
// Command completed (success or failure)
|
||||
case <-ctx.Done():
|
||||
// Context cancelled - kill process
|
||||
e.log.Warn("Restore cancelled - killing process")
|
||||
cmd.Process.Kill()
|
||||
// Context cancelled - kill entire process group
|
||||
e.log.Warn("Restore cancelled - killing process group")
|
||||
cleanup.KillCommandGroup(cmd)
|
||||
<-cmdDone
|
||||
cmdErr = ctx.Err()
|
||||
}
|
||||
@ -772,7 +773,7 @@ func (e *Engine) executeRestoreWithDecompression(ctx context.Context, archivePat
|
||||
defer gz.Close()
|
||||
|
||||
// Start restore command
|
||||
cmd := exec.CommandContext(ctx, restoreCmd[0], restoreCmd[1:]...)
|
||||
cmd := cleanup.SafeCommand(ctx, restoreCmd[0], restoreCmd[1:]...)
|
||||
cmd.Env = append(os.Environ(),
|
||||
fmt.Sprintf("PGPASSWORD=%s", e.cfg.Password),
|
||||
fmt.Sprintf("MYSQL_PWD=%s", e.cfg.Password),
|
||||
@ -876,7 +877,7 @@ func (e *Engine) executeRestoreWithPgzipStream(ctx context.Context, archivePath,
|
||||
if e.cfg.Host != "localhost" && e.cfg.Host != "" {
|
||||
args = append([]string{"-h", e.cfg.Host}, args...)
|
||||
}
|
||||
cmd = exec.CommandContext(ctx, "psql", args...)
|
||||
cmd = cleanup.SafeCommand(ctx, "psql", args...)
|
||||
cmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", e.cfg.Password))
|
||||
} else {
|
||||
// MySQL - use MYSQL_PWD env var to avoid password in process list
|
||||
@ -885,7 +886,7 @@ func (e *Engine) executeRestoreWithPgzipStream(ctx context.Context, archivePath,
|
||||
args = append(args, "-h", e.cfg.Host)
|
||||
}
|
||||
args = append(args, "-P", fmt.Sprintf("%d", e.cfg.Port), targetDB)
|
||||
cmd = exec.CommandContext(ctx, "mysql", args...)
|
||||
cmd = cleanup.SafeCommand(ctx, "mysql", args...)
|
||||
// Pass password via environment variable to avoid process list exposure
|
||||
cmd.Env = os.Environ()
|
||||
if e.cfg.Password != "" {
|
||||
@ -1322,7 +1323,7 @@ func (e *Engine) RestoreCluster(ctx context.Context, archivePath string, preExtr
|
||||
}
|
||||
} else if strings.HasSuffix(dumpFile, ".dump") {
|
||||
// Validate custom format dumps using pg_restore --list
|
||||
cmd := exec.CommandContext(ctx, "pg_restore", "--list", dumpFile)
|
||||
cmd := cleanup.SafeCommand(ctx, "pg_restore", "--list", dumpFile)
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
dbName := strings.TrimSuffix(entry.Name(), ".dump")
|
||||
@ -2121,7 +2122,7 @@ func (e *Engine) restoreGlobals(ctx context.Context, globalsFile string) error {
|
||||
args = append([]string{"-h", e.cfg.Host}, args...)
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "psql", args...)
|
||||
cmd := cleanup.SafeCommand(ctx, "psql", args...)
|
||||
|
||||
cmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", e.cfg.Password))
|
||||
|
||||
@ -2183,8 +2184,8 @@ func (e *Engine) restoreGlobals(ctx context.Context, globalsFile string) error {
|
||||
case cmdErr = <-cmdDone:
|
||||
// Command completed
|
||||
case <-ctx.Done():
|
||||
e.log.Warn("Globals restore cancelled - killing process")
|
||||
cmd.Process.Kill()
|
||||
e.log.Warn("Globals restore cancelled - killing process group")
|
||||
cleanup.KillCommandGroup(cmd)
|
||||
<-cmdDone
|
||||
cmdErr = ctx.Err()
|
||||
}
|
||||
@ -2225,7 +2226,7 @@ func (e *Engine) checkSuperuser(ctx context.Context) (bool, error) {
|
||||
args = append([]string{"-h", e.cfg.Host}, args...)
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "psql", args...)
|
||||
cmd := cleanup.SafeCommand(ctx, "psql", args...)
|
||||
|
||||
// Always set PGPASSWORD (empty string is fine for peer/ident auth)
|
||||
cmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", e.cfg.Password))
|
||||
@ -2260,7 +2261,7 @@ func (e *Engine) terminateConnections(ctx context.Context, dbName string) error
|
||||
args = append([]string{"-h", e.cfg.Host}, args...)
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "psql", args...)
|
||||
cmd := cleanup.SafeCommand(ctx, "psql", args...)
|
||||
|
||||
// Always set PGPASSWORD (empty string is fine for peer/ident auth)
|
||||
cmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", e.cfg.Password))
|
||||
@ -2296,7 +2297,7 @@ func (e *Engine) dropDatabaseIfExists(ctx context.Context, dbName string) error
|
||||
if e.cfg.Host != "localhost" && e.cfg.Host != "127.0.0.1" && e.cfg.Host != "" {
|
||||
revokeArgs = append([]string{"-h", e.cfg.Host}, revokeArgs...)
|
||||
}
|
||||
revokeCmd := exec.CommandContext(ctx, "psql", revokeArgs...)
|
||||
revokeCmd := cleanup.SafeCommand(ctx, "psql", revokeArgs...)
|
||||
revokeCmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", e.cfg.Password))
|
||||
revokeCmd.Run() // Ignore errors - database might not exist
|
||||
|
||||
@ -2315,7 +2316,7 @@ func (e *Engine) dropDatabaseIfExists(ctx context.Context, dbName string) error
|
||||
if e.cfg.Host != "localhost" && e.cfg.Host != "127.0.0.1" && e.cfg.Host != "" {
|
||||
forceArgs = append([]string{"-h", e.cfg.Host}, forceArgs...)
|
||||
}
|
||||
forceCmd := exec.CommandContext(ctx, "psql", forceArgs...)
|
||||
forceCmd := cleanup.SafeCommand(ctx, "psql", forceArgs...)
|
||||
forceCmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", e.cfg.Password))
|
||||
|
||||
output, err := forceCmd.CombinedOutput()
|
||||
@ -2338,7 +2339,7 @@ func (e *Engine) dropDatabaseIfExists(ctx context.Context, dbName string) error
|
||||
args = append([]string{"-h", e.cfg.Host}, args...)
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "psql", args...)
|
||||
cmd := cleanup.SafeCommand(ctx, "psql", args...)
|
||||
cmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", e.cfg.Password))
|
||||
|
||||
output, err = cmd.CombinedOutput()
|
||||
@ -2372,7 +2373,7 @@ func (e *Engine) ensureMySQLDatabaseExists(ctx context.Context, dbName string) e
|
||||
"-e", fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `%s`", dbName),
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "mysql", args...)
|
||||
cmd := cleanup.SafeCommand(ctx, "mysql", args...)
|
||||
cmd.Env = os.Environ()
|
||||
if e.cfg.Password != "" {
|
||||
cmd.Env = append(cmd.Env, "MYSQL_PWD="+e.cfg.Password)
|
||||
@ -2410,7 +2411,7 @@ func (e *Engine) ensurePostgresDatabaseExists(ctx context.Context, dbName string
|
||||
args = append([]string{"-h", e.cfg.Host}, args...)
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "psql", args...)
|
||||
cmd := cleanup.SafeCommand(ctx, "psql", args...)
|
||||
|
||||
// Always set PGPASSWORD (empty string is fine for peer/ident auth)
|
||||
cmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", e.cfg.Password))
|
||||
@ -2467,7 +2468,7 @@ func (e *Engine) ensurePostgresDatabaseExists(ctx context.Context, dbName string
|
||||
createArgs = append([]string{"-h", e.cfg.Host}, createArgs...)
|
||||
}
|
||||
|
||||
createCmd := exec.CommandContext(ctx, "psql", createArgs...)
|
||||
createCmd := cleanup.SafeCommand(ctx, "psql", createArgs...)
|
||||
|
||||
// Always set PGPASSWORD (empty string is fine for peer/ident auth)
|
||||
createCmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", e.cfg.Password))
|
||||
@ -2487,7 +2488,7 @@ func (e *Engine) ensurePostgresDatabaseExists(ctx context.Context, dbName string
|
||||
simpleArgs = append([]string{"-h", e.cfg.Host}, simpleArgs...)
|
||||
}
|
||||
|
||||
simpleCmd := exec.CommandContext(ctx, "psql", simpleArgs...)
|
||||
simpleCmd := cleanup.SafeCommand(ctx, "psql", simpleArgs...)
|
||||
simpleCmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", e.cfg.Password))
|
||||
|
||||
output, err = simpleCmd.CombinedOutput()
|
||||
@ -2552,7 +2553,7 @@ func (e *Engine) detectLargeObjectsInDumps(dumpsDir string, entries []os.DirEntr
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "pg_restore", "-l", dumpFile)
|
||||
cmd := cleanup.SafeCommand(ctx, "pg_restore", "-l", dumpFile)
|
||||
output, err := cmd.Output()
|
||||
|
||||
if err != nil {
|
||||
@ -2876,7 +2877,7 @@ func (e *Engine) canRestartPostgreSQL() bool {
|
||||
// Try a quick sudo check - if this fails, we can't restart
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
cmd := exec.CommandContext(ctx, "sudo", "-n", "true")
|
||||
cmd := cleanup.SafeCommand(ctx, "sudo", "-n", "true")
|
||||
cmd.Stdin = nil
|
||||
if err := cmd.Run(); err != nil {
|
||||
e.log.Info("Running as postgres user without sudo access - cannot restart PostgreSQL",
|
||||
@ -2906,7 +2907,7 @@ func (e *Engine) tryRestartPostgreSQL(ctx context.Context) bool {
|
||||
runWithTimeout := func(args ...string) bool {
|
||||
cmdCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
cmd := exec.CommandContext(cmdCtx, args[0], args[1:]...)
|
||||
cmd := cleanup.SafeCommand(cmdCtx, args[0], args[1:]...)
|
||||
// Set stdin to /dev/null to prevent sudo from waiting for password
|
||||
cmd.Stdin = nil
|
||||
return cmd.Run() == nil
|
||||
|
||||
@ -7,12 +7,12 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"dbbackup/internal/cleanup"
|
||||
"dbbackup/internal/config"
|
||||
"dbbackup/internal/logger"
|
||||
|
||||
@ -568,7 +568,7 @@ func getCommandVersion(cmd string, arg string) string {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
output, err := exec.CommandContext(ctx, cmd, arg).CombinedOutput()
|
||||
output, err := cleanup.SafeCommand(ctx, cmd, arg).CombinedOutput()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
@ -5,11 +5,11 @@ package restore
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"dbbackup/internal/cleanup"
|
||||
"dbbackup/internal/config"
|
||||
"dbbackup/internal/logger"
|
||||
)
|
||||
@ -124,7 +124,7 @@ func ApplySessionOptimizations(ctx context.Context, cfg *config.Config, log logg
|
||||
|
||||
for _, sql := range safeOptimizations {
|
||||
cmdArgs := append(args, "-c", sql)
|
||||
cmd := exec.CommandContext(ctx, "psql", cmdArgs...)
|
||||
cmd := cleanup.SafeCommand(ctx, "psql", cmdArgs...)
|
||||
cmd.Env = append(cmd.Environ(), fmt.Sprintf("PGPASSWORD=%s", cfg.Password))
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
|
||||
@ -6,11 +6,11 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"dbbackup/internal/cleanup"
|
||||
"dbbackup/internal/config"
|
||||
"dbbackup/internal/logger"
|
||||
)
|
||||
@ -572,7 +572,7 @@ func (g *LargeDBGuard) RevertMySQLSettings() []string {
|
||||
// Uses pg_restore -l which outputs a line-by-line listing, then streams through it
|
||||
func (g *LargeDBGuard) StreamCountBLOBs(ctx context.Context, dumpFile string) (int, error) {
|
||||
// pg_restore -l outputs text listing, one line per object
|
||||
cmd := exec.CommandContext(ctx, "pg_restore", "-l", dumpFile)
|
||||
cmd := cleanup.SafeCommand(ctx, "pg_restore", "-l", dumpFile)
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
@ -609,7 +609,7 @@ func (g *LargeDBGuard) StreamCountBLOBs(ctx context.Context, dumpFile string) (i
|
||||
// StreamAnalyzeDump analyzes a dump file using streaming to avoid memory issues
|
||||
// Returns: blobCount, estimatedObjects, error
|
||||
func (g *LargeDBGuard) StreamAnalyzeDump(ctx context.Context, dumpFile string) (blobCount, totalObjects int, err error) {
|
||||
cmd := exec.CommandContext(ctx, "pg_restore", "-l", dumpFile)
|
||||
cmd := cleanup.SafeCommand(ctx, "pg_restore", "-l", dumpFile)
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
|
||||
@ -1,18 +1,22 @@
|
||||
package restore
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"dbbackup/internal/cleanup"
|
||||
|
||||
"github.com/dustin/go-humanize"
|
||||
"github.com/klauspost/pgzip"
|
||||
"github.com/shirou/gopsutil/v3/mem"
|
||||
)
|
||||
|
||||
@ -381,7 +385,7 @@ func (e *Engine) countBlobsInDump(ctx context.Context, dumpFile string) int {
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "pg_restore", "-l", dumpFile)
|
||||
cmd := cleanup.SafeCommand(ctx, "pg_restore", "-l", dumpFile)
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return 0
|
||||
@ -398,24 +402,51 @@ func (e *Engine) countBlobsInDump(ctx context.Context, dumpFile string) int {
|
||||
}
|
||||
|
||||
// estimateBlobsInSQL samples compressed SQL for lo_create patterns
|
||||
// Uses in-process pgzip decompression (NO external gzip process)
|
||||
func (e *Engine) estimateBlobsInSQL(sqlFile string) int {
|
||||
// Use zgrep for efficient searching in gzipped files
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Count lo_create calls (each = one large object)
|
||||
cmd := exec.CommandContext(ctx, "zgrep", "-c", "lo_create", sqlFile)
|
||||
output, err := cmd.Output()
|
||||
// Open the gzipped file
|
||||
f, err := os.Open(sqlFile)
|
||||
if err != nil {
|
||||
// Also try SELECT lo_create pattern
|
||||
cmd2 := exec.CommandContext(ctx, "zgrep", "-c", "SELECT.*lo_create", sqlFile)
|
||||
output, err = cmd2.Output()
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
e.log.Debug("Cannot open SQL file for BLOB estimation", "file", sqlFile, "error", err)
|
||||
return 0
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// Create pgzip reader for parallel decompression
|
||||
gzReader, err := pgzip.NewReader(f)
|
||||
if err != nil {
|
||||
e.log.Debug("Cannot create pgzip reader", "file", sqlFile, "error", err)
|
||||
return 0
|
||||
}
|
||||
defer gzReader.Close()
|
||||
|
||||
// Scan for lo_create patterns
|
||||
// We use a regex to match both "lo_create" and "SELECT lo_create" patterns
|
||||
loCreatePattern := regexp.MustCompile(`lo_create`)
|
||||
|
||||
scanner := bufio.NewScanner(gzReader)
|
||||
// Use larger buffer for potentially long lines
|
||||
buf := make([]byte, 0, 256*1024)
|
||||
scanner.Buffer(buf, 10*1024*1024)
|
||||
|
||||
count := 0
|
||||
linesScanned := 0
|
||||
maxLines := 1000000 // Limit scanning for very large files
|
||||
|
||||
for scanner.Scan() && linesScanned < maxLines {
|
||||
line := scanner.Text()
|
||||
linesScanned++
|
||||
|
||||
// Count all lo_create occurrences in the line
|
||||
matches := loCreatePattern.FindAllString(line, -1)
|
||||
count += len(matches)
|
||||
}
|
||||
|
||||
count, _ := strconv.Atoi(strings.TrimSpace(string(output)))
|
||||
if err := scanner.Err(); err != nil {
|
||||
e.log.Debug("Error scanning SQL file", "file", sqlFile, "error", err, "lines_scanned", linesScanned)
|
||||
}
|
||||
|
||||
e.log.Debug("BLOB estimation from SQL file", "file", sqlFile, "lo_create_count", count, "lines_scanned", linesScanned)
|
||||
return count
|
||||
}
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ import (
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"dbbackup/internal/cleanup"
|
||||
"dbbackup/internal/config"
|
||||
"dbbackup/internal/fs"
|
||||
"dbbackup/internal/logger"
|
||||
@ -419,7 +420,7 @@ func (s *Safety) checkPostgresDatabaseExists(ctx context.Context, dbName string)
|
||||
}
|
||||
args = append([]string{"-h", host}, args...)
|
||||
|
||||
cmd := exec.CommandContext(ctx, "psql", args...)
|
||||
cmd := cleanup.SafeCommand(ctx, "psql", args...)
|
||||
|
||||
// Set password if provided
|
||||
if s.cfg.Password != "" {
|
||||
@ -447,7 +448,7 @@ func (s *Safety) checkMySQLDatabaseExists(ctx context.Context, dbName string) (b
|
||||
args = append([]string{"-h", s.cfg.Host}, args...)
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "mysql", args...)
|
||||
cmd := cleanup.SafeCommand(ctx, "mysql", args...)
|
||||
|
||||
if s.cfg.Password != "" {
|
||||
cmd.Env = append(os.Environ(), fmt.Sprintf("MYSQL_PWD=%s", s.cfg.Password))
|
||||
@ -493,7 +494,7 @@ func (s *Safety) listPostgresUserDatabases(ctx context.Context) ([]string, error
|
||||
}
|
||||
args = append([]string{"-h", host}, args...)
|
||||
|
||||
cmd := exec.CommandContext(ctx, "psql", args...)
|
||||
cmd := cleanup.SafeCommand(ctx, "psql", args...)
|
||||
|
||||
// Set password - check config first, then environment
|
||||
env := os.Environ()
|
||||
@ -542,7 +543,7 @@ func (s *Safety) listMySQLUserDatabases(ctx context.Context) ([]string, error) {
|
||||
args = append([]string{"-h", s.cfg.Host}, args...)
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "mysql", args...)
|
||||
cmd := cleanup.SafeCommand(ctx, "mysql", args...)
|
||||
|
||||
if s.cfg.Password != "" {
|
||||
cmd.Env = append(os.Environ(), fmt.Sprintf("MYSQL_PWD=%s", s.cfg.Password))
|
||||
|
||||
@ -3,11 +3,11 @@ package restore
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"dbbackup/internal/cleanup"
|
||||
"dbbackup/internal/database"
|
||||
)
|
||||
|
||||
@ -54,7 +54,7 @@ func GetDumpFileVersion(dumpPath string) (*VersionInfo, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "pg_restore", "-l", dumpPath)
|
||||
cmd := cleanup.SafeCommand(ctx, "pg_restore", "-l", dumpPath)
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read dump file metadata: %w (output: %s)", err, string(output))
|
||||
|
||||
@ -16,6 +16,7 @@ import (
|
||||
"dbbackup/internal/config"
|
||||
"dbbackup/internal/database"
|
||||
"dbbackup/internal/logger"
|
||||
"dbbackup/internal/progress"
|
||||
"dbbackup/internal/restore"
|
||||
)
|
||||
|
||||
@ -75,6 +76,13 @@ type RestoreExecutionModel struct {
|
||||
overallPhase int // 1=Extracting, 2=Globals, 3=Databases
|
||||
extractionDone bool
|
||||
|
||||
// Rich progress view for cluster restores
|
||||
richProgressView *RichClusterProgressView
|
||||
unifiedProgress *progress.UnifiedClusterProgress
|
||||
useRichProgress bool // Whether to use the rich progress view
|
||||
termWidth int // Terminal width for rich progress
|
||||
termHeight int // Terminal height for rich progress
|
||||
|
||||
// Results
|
||||
done bool
|
||||
cancelling bool // True when user has requested cancellation
|
||||
@ -108,6 +116,11 @@ func NewRestoreExecution(cfg *config.Config, log logger.Logger, parent tea.Model
|
||||
details: []string{},
|
||||
spinnerFrames: spinnerFrames, // Use package-level constant
|
||||
spinnerFrame: 0,
|
||||
// Initialize rich progress view for cluster restores
|
||||
richProgressView: NewRichClusterProgressView(),
|
||||
useRichProgress: restoreType == "restore-cluster",
|
||||
termWidth: 80,
|
||||
termHeight: 24,
|
||||
}
|
||||
}
|
||||
|
||||
@ -176,6 +189,9 @@ type sharedProgressState struct {
|
||||
// Throttling to prevent excessive updates (memory optimization)
|
||||
lastSpeedSampleTime time.Time // Last time we added a speed sample
|
||||
minSampleInterval time.Duration // Minimum interval between samples (100ms)
|
||||
|
||||
// Unified progress tracker for rich display
|
||||
unifiedProgress *progress.UnifiedClusterProgress
|
||||
}
|
||||
|
||||
type restoreSpeedSample struct {
|
||||
@ -231,6 +247,18 @@ func getCurrentRestoreProgress() (bytesTotal, bytesDone int64, description strin
|
||||
currentRestoreProgressState.phase3StartTime
|
||||
}
|
||||
|
||||
// getUnifiedProgress returns the unified progress tracker if available
|
||||
func getUnifiedProgress() *progress.UnifiedClusterProgress {
|
||||
currentRestoreProgressMu.Lock()
|
||||
defer currentRestoreProgressMu.Unlock()
|
||||
|
||||
if currentRestoreProgressState == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return currentRestoreProgressState.unifiedProgress
|
||||
}
|
||||
|
||||
// calculateRollingSpeed calculates speed from recent samples (last 5 seconds)
|
||||
func calculateRollingSpeed(samples []restoreSpeedSample) float64 {
|
||||
if len(samples) < 2 {
|
||||
@ -332,6 +360,11 @@ func executeRestoreWithTUIProgress(parentCtx context.Context, cfg *config.Config
|
||||
progressState := &sharedProgressState{
|
||||
speedSamples: make([]restoreSpeedSample, 0, 100),
|
||||
}
|
||||
|
||||
// Initialize unified progress tracker for cluster restores
|
||||
if restoreType == "restore-cluster" {
|
||||
progressState.unifiedProgress = progress.NewUnifiedClusterProgress("restore", archive.Path)
|
||||
}
|
||||
engine.SetProgressCallback(func(current, total int64, description string) {
|
||||
progressState.mu.Lock()
|
||||
defer progressState.mu.Unlock()
|
||||
@ -342,10 +375,19 @@ func executeRestoreWithTUIProgress(parentCtx context.Context, cfg *config.Config
|
||||
progressState.overallPhase = 1
|
||||
progressState.extractionDone = false
|
||||
|
||||
// Update unified progress tracker
|
||||
if progressState.unifiedProgress != nil {
|
||||
progressState.unifiedProgress.SetPhase(progress.PhaseExtracting)
|
||||
progressState.unifiedProgress.SetExtractProgress(current, total)
|
||||
}
|
||||
|
||||
// Check if extraction is complete
|
||||
if current >= total && total > 0 {
|
||||
progressState.extractionDone = true
|
||||
progressState.overallPhase = 2
|
||||
if progressState.unifiedProgress != nil {
|
||||
progressState.unifiedProgress.SetPhase(progress.PhaseGlobals)
|
||||
}
|
||||
}
|
||||
|
||||
// Throttle speed samples to prevent memory bloat (max 10 samples/sec)
|
||||
@ -384,6 +426,13 @@ func executeRestoreWithTUIProgress(parentCtx context.Context, cfg *config.Config
|
||||
// Clear byte progress when switching to db progress
|
||||
progressState.bytesTotal = 0
|
||||
progressState.bytesDone = 0
|
||||
|
||||
// Update unified progress tracker
|
||||
if progressState.unifiedProgress != nil {
|
||||
progressState.unifiedProgress.SetPhase(progress.PhaseDatabases)
|
||||
progressState.unifiedProgress.SetDatabasesTotal(total, nil)
|
||||
progressState.unifiedProgress.StartDatabase(dbName, 0)
|
||||
}
|
||||
})
|
||||
|
||||
// Set up timing-aware database progress callback for cluster restore ETA
|
||||
@ -406,6 +455,13 @@ func executeRestoreWithTUIProgress(parentCtx context.Context, cfg *config.Config
|
||||
// Clear byte progress when switching to db progress
|
||||
progressState.bytesTotal = 0
|
||||
progressState.bytesDone = 0
|
||||
|
||||
// Update unified progress tracker
|
||||
if progressState.unifiedProgress != nil {
|
||||
progressState.unifiedProgress.SetPhase(progress.PhaseDatabases)
|
||||
progressState.unifiedProgress.SetDatabasesTotal(total, nil)
|
||||
progressState.unifiedProgress.StartDatabase(dbName, 0)
|
||||
}
|
||||
})
|
||||
|
||||
// Set up weighted (bytes-based) progress callback for accurate cluster restore progress
|
||||
@ -424,6 +480,14 @@ func executeRestoreWithTUIProgress(parentCtx context.Context, cfg *config.Config
|
||||
if progressState.phase3StartTime.IsZero() {
|
||||
progressState.phase3StartTime = time.Now()
|
||||
}
|
||||
|
||||
// Update unified progress tracker
|
||||
if progressState.unifiedProgress != nil {
|
||||
progressState.unifiedProgress.SetPhase(progress.PhaseDatabases)
|
||||
progressState.unifiedProgress.SetDatabasesTotal(dbTotal, nil)
|
||||
progressState.unifiedProgress.StartDatabase(dbName, bytesTotal)
|
||||
progressState.unifiedProgress.UpdateDatabaseProgress(bytesDone)
|
||||
}
|
||||
})
|
||||
|
||||
// Store progress state in a package-level variable for the ticker to access
|
||||
@ -489,11 +553,30 @@ func executeRestoreWithTUIProgress(parentCtx context.Context, cfg *config.Config
|
||||
|
||||
func (m RestoreExecutionModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
// Update terminal dimensions for rich progress view
|
||||
m.termWidth = msg.Width
|
||||
m.termHeight = msg.Height
|
||||
if m.richProgressView != nil {
|
||||
m.richProgressView.SetSize(msg.Width, msg.Height)
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case restoreTickMsg:
|
||||
if !m.done {
|
||||
m.spinnerFrame = (m.spinnerFrame + 1) % len(m.spinnerFrames)
|
||||
m.elapsed = time.Since(m.startTime)
|
||||
|
||||
// Advance spinner for rich progress view
|
||||
if m.richProgressView != nil {
|
||||
m.richProgressView.AdvanceSpinner()
|
||||
}
|
||||
|
||||
// Update unified progress reference
|
||||
if m.useRichProgress && m.unifiedProgress == nil {
|
||||
m.unifiedProgress = getUnifiedProgress()
|
||||
}
|
||||
|
||||
// Poll shared progress state for real-time updates
|
||||
// Note: dbPhaseElapsed is now calculated in realtime inside getCurrentRestoreProgress()
|
||||
bytesTotal, bytesDone, description, hasUpdate, dbTotal, dbDone, speed, dbPhaseElapsed, dbAvgPerDB, currentDB, overallPhase, extractionDone, dbBytesTotal, dbBytesDone, _ := getCurrentRestoreProgress()
|
||||
@ -782,7 +865,16 @@ func (m RestoreExecutionModel) View() string {
|
||||
} else {
|
||||
// Show unified progress for cluster restore
|
||||
if m.restoreType == "restore-cluster" {
|
||||
// Calculate overall progress across all phases
|
||||
// Use rich progress view when we have unified progress data
|
||||
if m.useRichProgress && m.unifiedProgress != nil {
|
||||
// Render using the rich cluster progress view
|
||||
s.WriteString(m.richProgressView.RenderUnified(m.unifiedProgress))
|
||||
s.WriteString("\n")
|
||||
s.WriteString(infoStyle.Render("[KEYS] Press Ctrl+C to cancel"))
|
||||
return s.String()
|
||||
}
|
||||
|
||||
// Fallback: Calculate overall progress across all phases
|
||||
// Phase 1: Extraction (0-60%)
|
||||
// Phase 2: Globals (60-65%)
|
||||
// Phase 3: Databases (65-100%)
|
||||
|
||||
344
internal/tui/rich_cluster_progress.go
Normal file
344
internal/tui/rich_cluster_progress.go
Normal file
@ -0,0 +1,344 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"dbbackup/internal/progress"
|
||||
)
|
||||
|
||||
// RichClusterProgressView renders detailed cluster restore progress
|
||||
type RichClusterProgressView struct {
|
||||
width int
|
||||
height int
|
||||
spinnerFrames []string
|
||||
spinnerFrame int
|
||||
}
|
||||
|
||||
// NewRichClusterProgressView creates a new rich progress view
|
||||
func NewRichClusterProgressView() *RichClusterProgressView {
|
||||
return &RichClusterProgressView{
|
||||
width: 80,
|
||||
height: 24,
|
||||
spinnerFrames: []string{
|
||||
"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// SetSize updates the terminal size
|
||||
func (v *RichClusterProgressView) SetSize(width, height int) {
|
||||
v.width = width
|
||||
v.height = height
|
||||
}
|
||||
|
||||
// AdvanceSpinner moves to the next spinner frame
|
||||
func (v *RichClusterProgressView) AdvanceSpinner() {
|
||||
v.spinnerFrame = (v.spinnerFrame + 1) % len(v.spinnerFrames)
|
||||
}
|
||||
|
||||
// RenderUnified renders progress from UnifiedClusterProgress
|
||||
func (v *RichClusterProgressView) RenderUnified(p *progress.UnifiedClusterProgress) string {
|
||||
if p == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
snapshot := p.GetSnapshot()
|
||||
return v.RenderSnapshot(&snapshot)
|
||||
}
|
||||
|
||||
// RenderSnapshot renders progress from a ProgressSnapshot
|
||||
func (v *RichClusterProgressView) RenderSnapshot(snapshot *progress.ProgressSnapshot) string {
|
||||
if snapshot == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(2048)
|
||||
|
||||
// Header with overall progress
|
||||
b.WriteString(v.renderHeader(snapshot))
|
||||
b.WriteString("\n\n")
|
||||
|
||||
// Overall progress bar
|
||||
b.WriteString(v.renderOverallProgress(snapshot))
|
||||
b.WriteString("\n\n")
|
||||
|
||||
// Phase-specific details
|
||||
b.WriteString(v.renderPhaseDetails(snapshot))
|
||||
|
||||
// Performance metrics
|
||||
if v.height > 15 {
|
||||
b.WriteString("\n")
|
||||
b.WriteString(v.renderMetricsFromSnapshot(snapshot))
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (v *RichClusterProgressView) renderHeader(snapshot *progress.ProgressSnapshot) string {
|
||||
elapsed := time.Since(snapshot.StartTime)
|
||||
|
||||
// Calculate ETA based on progress
|
||||
overall := v.calculateOverallPercent(snapshot)
|
||||
var etaStr string
|
||||
if overall > 0 && overall < 100 {
|
||||
eta := time.Duration(float64(elapsed) / float64(overall) * float64(100-overall))
|
||||
etaStr = fmt.Sprintf("ETA: %s", formatDuration(eta))
|
||||
} else if overall >= 100 {
|
||||
etaStr = "Complete!"
|
||||
} else {
|
||||
etaStr = "ETA: calculating..."
|
||||
}
|
||||
|
||||
title := "Cluster Restore Progress"
|
||||
// Cap separator at 40 chars to avoid long lines on wide terminals
|
||||
sepLen := maxInt(0, v.width-len(title)-4)
|
||||
if sepLen > 40 {
|
||||
sepLen = 40
|
||||
}
|
||||
separator := strings.Repeat("━", sepLen)
|
||||
|
||||
return fmt.Sprintf("%s %s\n Elapsed: %s | %s",
|
||||
title, separator,
|
||||
formatDuration(elapsed), etaStr)
|
||||
}
|
||||
|
||||
func (v *RichClusterProgressView) renderOverallProgress(snapshot *progress.ProgressSnapshot) string {
|
||||
overall := v.calculateOverallPercent(snapshot)
|
||||
|
||||
// Phase indicator
|
||||
phaseLabel := v.getPhaseLabel(snapshot)
|
||||
|
||||
// Progress bar
|
||||
barWidth := v.width - 20
|
||||
if barWidth < 20 {
|
||||
barWidth = 20
|
||||
}
|
||||
bar := v.renderProgressBarWidth(overall, barWidth)
|
||||
|
||||
return fmt.Sprintf(" Overall: %s %3d%%\n Phase: %s", bar, overall, phaseLabel)
|
||||
}
|
||||
|
||||
func (v *RichClusterProgressView) getPhaseLabel(snapshot *progress.ProgressSnapshot) string {
|
||||
switch snapshot.Phase {
|
||||
case progress.PhaseExtracting:
|
||||
return fmt.Sprintf("📦 Extracting archive (%s / %s)",
|
||||
FormatBytes(snapshot.ExtractBytes), FormatBytes(snapshot.ExtractTotal))
|
||||
case progress.PhaseGlobals:
|
||||
return "🔧 Restoring globals (roles, tablespaces)"
|
||||
case progress.PhaseDatabases:
|
||||
return fmt.Sprintf("🗄️ Databases (%d/%d) %s",
|
||||
snapshot.DatabasesDone, snapshot.DatabasesTotal, snapshot.CurrentDB)
|
||||
case progress.PhaseVerifying:
|
||||
return fmt.Sprintf("✅ Verifying (%d/%d)", snapshot.VerifyDone, snapshot.VerifyTotal)
|
||||
case progress.PhaseComplete:
|
||||
return "🎉 Complete!"
|
||||
case progress.PhaseFailed:
|
||||
return "❌ Failed"
|
||||
default:
|
||||
return string(snapshot.Phase)
|
||||
}
|
||||
}
|
||||
|
||||
func (v *RichClusterProgressView) calculateOverallPercent(snapshot *progress.ProgressSnapshot) int {
|
||||
// Use the same logic as UnifiedClusterProgress
|
||||
phaseWeights := map[progress.Phase]int{
|
||||
progress.PhaseExtracting: 20,
|
||||
progress.PhaseGlobals: 5,
|
||||
progress.PhaseDatabases: 70,
|
||||
progress.PhaseVerifying: 5,
|
||||
}
|
||||
|
||||
switch snapshot.Phase {
|
||||
case progress.PhaseIdle:
|
||||
return 0
|
||||
case progress.PhaseExtracting:
|
||||
if snapshot.ExtractTotal > 0 {
|
||||
return int(float64(snapshot.ExtractBytes) / float64(snapshot.ExtractTotal) * float64(phaseWeights[progress.PhaseExtracting]))
|
||||
}
|
||||
return 0
|
||||
case progress.PhaseGlobals:
|
||||
return phaseWeights[progress.PhaseExtracting] + phaseWeights[progress.PhaseGlobals]
|
||||
case progress.PhaseDatabases:
|
||||
basePercent := phaseWeights[progress.PhaseExtracting] + phaseWeights[progress.PhaseGlobals]
|
||||
if snapshot.DatabasesTotal == 0 {
|
||||
return basePercent
|
||||
}
|
||||
dbProgress := float64(snapshot.DatabasesDone) / float64(snapshot.DatabasesTotal)
|
||||
if snapshot.CurrentDBTotal > 0 {
|
||||
currentProgress := float64(snapshot.CurrentDBBytes) / float64(snapshot.CurrentDBTotal)
|
||||
dbProgress += currentProgress / float64(snapshot.DatabasesTotal)
|
||||
}
|
||||
return basePercent + int(dbProgress*float64(phaseWeights[progress.PhaseDatabases]))
|
||||
case progress.PhaseVerifying:
|
||||
basePercent := phaseWeights[progress.PhaseExtracting] + phaseWeights[progress.PhaseGlobals] + phaseWeights[progress.PhaseDatabases]
|
||||
if snapshot.VerifyTotal > 0 {
|
||||
verifyProgress := float64(snapshot.VerifyDone) / float64(snapshot.VerifyTotal)
|
||||
return basePercent + int(verifyProgress*float64(phaseWeights[progress.PhaseVerifying]))
|
||||
}
|
||||
return basePercent
|
||||
case progress.PhaseComplete:
|
||||
return 100
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func (v *RichClusterProgressView) renderPhaseDetails(snapshot *progress.ProgressSnapshot) string {
|
||||
var b strings.Builder
|
||||
|
||||
switch snapshot.Phase {
|
||||
case progress.PhaseExtracting:
|
||||
pct := 0
|
||||
if snapshot.ExtractTotal > 0 {
|
||||
pct = int(float64(snapshot.ExtractBytes) / float64(snapshot.ExtractTotal) * 100)
|
||||
}
|
||||
bar := v.renderMiniProgressBar(pct)
|
||||
b.WriteString(fmt.Sprintf(" 📦 Extraction: %s %d%%\n", bar, pct))
|
||||
b.WriteString(fmt.Sprintf(" %s / %s\n",
|
||||
FormatBytes(snapshot.ExtractBytes), FormatBytes(snapshot.ExtractTotal)))
|
||||
|
||||
case progress.PhaseDatabases:
|
||||
b.WriteString(" 📊 Databases:\n\n")
|
||||
|
||||
// Show completed databases if any
|
||||
if snapshot.DatabasesDone > 0 {
|
||||
avgTime := time.Duration(0)
|
||||
if len(snapshot.DatabaseTimes) > 0 {
|
||||
var total time.Duration
|
||||
for _, t := range snapshot.DatabaseTimes {
|
||||
total += t
|
||||
}
|
||||
avgTime = total / time.Duration(len(snapshot.DatabaseTimes))
|
||||
}
|
||||
b.WriteString(fmt.Sprintf(" ✓ %d completed (avg: %s)\n",
|
||||
snapshot.DatabasesDone, formatDuration(avgTime)))
|
||||
}
|
||||
|
||||
// Show current database
|
||||
if snapshot.CurrentDB != "" {
|
||||
spinner := v.spinnerFrames[v.spinnerFrame]
|
||||
pct := 0
|
||||
if snapshot.CurrentDBTotal > 0 {
|
||||
pct = int(float64(snapshot.CurrentDBBytes) / float64(snapshot.CurrentDBTotal) * 100)
|
||||
}
|
||||
bar := v.renderMiniProgressBar(pct)
|
||||
|
||||
phaseElapsed := time.Since(snapshot.PhaseStartTime)
|
||||
b.WriteString(fmt.Sprintf(" %s %-20s %s %3d%%\n",
|
||||
spinner, truncateString(snapshot.CurrentDB, 20), bar, pct))
|
||||
b.WriteString(fmt.Sprintf(" └─ %s / %s (running %s)\n",
|
||||
FormatBytes(snapshot.CurrentDBBytes), FormatBytes(snapshot.CurrentDBTotal),
|
||||
formatDuration(phaseElapsed)))
|
||||
}
|
||||
|
||||
// Show remaining count
|
||||
remaining := snapshot.DatabasesTotal - snapshot.DatabasesDone
|
||||
if snapshot.CurrentDB != "" {
|
||||
remaining--
|
||||
}
|
||||
if remaining > 0 {
|
||||
b.WriteString(fmt.Sprintf(" ⏳ %d remaining\n", remaining))
|
||||
}
|
||||
|
||||
case progress.PhaseVerifying:
|
||||
pct := 0
|
||||
if snapshot.VerifyTotal > 0 {
|
||||
pct = snapshot.VerifyDone * 100 / snapshot.VerifyTotal
|
||||
}
|
||||
bar := v.renderMiniProgressBar(pct)
|
||||
b.WriteString(fmt.Sprintf(" ✅ Verification: %s %d%%\n", bar, pct))
|
||||
b.WriteString(fmt.Sprintf(" %d / %d databases verified\n",
|
||||
snapshot.VerifyDone, snapshot.VerifyTotal))
|
||||
|
||||
case progress.PhaseComplete:
|
||||
elapsed := time.Since(snapshot.StartTime)
|
||||
b.WriteString(fmt.Sprintf(" 🎉 Restore complete!\n"))
|
||||
b.WriteString(fmt.Sprintf(" %d databases restored in %s\n",
|
||||
snapshot.DatabasesDone, formatDuration(elapsed)))
|
||||
|
||||
case progress.PhaseFailed:
|
||||
b.WriteString(" ❌ Restore failed:\n")
|
||||
for _, err := range snapshot.Errors {
|
||||
b.WriteString(fmt.Sprintf(" • %s\n", truncateString(err, v.width-10)))
|
||||
}
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (v *RichClusterProgressView) renderMetricsFromSnapshot(snapshot *progress.ProgressSnapshot) string {
|
||||
var b strings.Builder
|
||||
b.WriteString(" 📈 Performance:\n")
|
||||
|
||||
elapsed := time.Since(snapshot.StartTime)
|
||||
if elapsed > 0 {
|
||||
// Calculate throughput from extraction phase if we have data
|
||||
if snapshot.ExtractBytes > 0 && elapsed.Seconds() > 0 {
|
||||
throughput := float64(snapshot.ExtractBytes) / elapsed.Seconds()
|
||||
b.WriteString(fmt.Sprintf(" Throughput: %s/s\n", FormatBytes(int64(throughput))))
|
||||
}
|
||||
|
||||
// Database timing info
|
||||
if len(snapshot.DatabaseTimes) > 0 {
|
||||
var total time.Duration
|
||||
for _, t := range snapshot.DatabaseTimes {
|
||||
total += t
|
||||
}
|
||||
avg := total / time.Duration(len(snapshot.DatabaseTimes))
|
||||
b.WriteString(fmt.Sprintf(" Avg DB time: %s\n", formatDuration(avg)))
|
||||
}
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func (v *RichClusterProgressView) renderProgressBarWidth(pct, width int) string {
|
||||
if width < 10 {
|
||||
width = 10
|
||||
}
|
||||
filled := (pct * width) / 100
|
||||
empty := width - filled
|
||||
|
||||
bar := strings.Repeat("█", filled) + strings.Repeat("░", empty)
|
||||
return "[" + bar + "]"
|
||||
}
|
||||
|
||||
func (v *RichClusterProgressView) renderMiniProgressBar(pct int) string {
|
||||
width := 20
|
||||
filled := (pct * width) / 100
|
||||
empty := width - filled
|
||||
return strings.Repeat("█", filled) + strings.Repeat("░", empty)
|
||||
}
|
||||
|
||||
func truncateString(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
if maxLen < 4 {
|
||||
return s[:maxLen]
|
||||
}
|
||||
return s[:maxLen-3] + "..."
|
||||
}
|
||||
|
||||
func maxInt(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func formatNumShort(n int64) string {
|
||||
if n >= 1e9 {
|
||||
return fmt.Sprintf("%.1fB", float64(n)/1e9)
|
||||
} else if n >= 1e6 {
|
||||
return fmt.Sprintf("%.1fM", float64(n)/1e6)
|
||||
} else if n >= 1e3 {
|
||||
return fmt.Sprintf("%.1fK", float64(n)/1e3)
|
||||
}
|
||||
return fmt.Sprintf("%d", n)
|
||||
}
|
||||
@ -94,6 +94,11 @@ func NewSettingsModel(cfg *config.Config, log logger.Logger, parent tea.Model) S
|
||||
c.CPUWorkloadType = workloads[nextIdx]
|
||||
|
||||
// Recalculate Jobs and DumpJobs based on workload type
|
||||
// If CPUInfo is nil, try to detect it first
|
||||
if c.CPUInfo == nil && c.AutoDetectCores {
|
||||
_ = c.OptimizeForCPU() // This will detect CPU and set CPUInfo
|
||||
}
|
||||
|
||||
if c.CPUInfo != nil && c.AutoDetectCores {
|
||||
switch c.CPUWorkloadType {
|
||||
case "cpu-intensive":
|
||||
|
||||
2
main.go
2
main.go
@ -16,7 +16,7 @@ import (
|
||||
|
||||
// Build information (set by ldflags)
|
||||
var (
|
||||
version = "5.4.1"
|
||||
version = "5.4.4"
|
||||
buildTime = "unknown"
|
||||
gitCommit = "unknown"
|
||||
)
|
||||
|
||||
192
scripts/test-sigint-cleanup.sh
Executable file
192
scripts/test-sigint-cleanup.sh
Executable file
@ -0,0 +1,192 @@
|
||||
#!/bin/bash
|
||||
# scripts/test-sigint-cleanup.sh
|
||||
# Test script to verify clean shutdown on SIGINT (Ctrl+C)
|
||||
|
||||
set -e
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
PROJECT_DIR="$(dirname "$SCRIPT_DIR")"
|
||||
BINARY="$PROJECT_DIR/dbbackup"
|
||||
|
||||
# Colors for output
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
echo "=== SIGINT Cleanup Test ==="
|
||||
echo ""
|
||||
echo "Project: $PROJECT_DIR"
|
||||
echo "Binary: $BINARY"
|
||||
echo ""
|
||||
|
||||
# Check if binary exists
|
||||
if [ ! -f "$BINARY" ]; then
|
||||
echo -e "${YELLOW}Binary not found, building...${NC}"
|
||||
cd "$PROJECT_DIR"
|
||||
go build -o dbbackup .
|
||||
fi
|
||||
|
||||
# Create a test backup file if it doesn't exist
|
||||
TEST_BACKUP="/tmp/test-sigint-backup.sql.gz"
|
||||
if [ ! -f "$TEST_BACKUP" ]; then
|
||||
echo -e "${YELLOW}Creating test backup file...${NC}"
|
||||
echo "-- Test SQL file for SIGINT testing" | gzip > "$TEST_BACKUP"
|
||||
fi
|
||||
|
||||
echo "=== Phase 1: Pre-test Cleanup ==="
|
||||
echo "Killing any existing dbbackup processes..."
|
||||
pkill -f "dbbackup" 2>/dev/null || true
|
||||
sleep 1
|
||||
|
||||
echo ""
|
||||
echo "=== Phase 2: Check Initial State ==="
|
||||
|
||||
echo "Checking for orphaned processes..."
|
||||
INITIAL_PROCS=$(pgrep -f "pg_dump|pg_restore|dbbackup" 2>/dev/null | wc -l)
|
||||
echo "Initial related processes: $INITIAL_PROCS"
|
||||
|
||||
echo ""
|
||||
echo "Checking for temp files..."
|
||||
INITIAL_TEMPS=$(ls /tmp/dbbackup-* 2>/dev/null | wc -l || echo "0")
|
||||
echo "Initial temp files: $INITIAL_TEMPS"
|
||||
|
||||
echo ""
|
||||
echo "=== Phase 3: Start Test Operation ==="
|
||||
|
||||
# Start a TUI operation that will hang (version is fast, but menu would wait)
|
||||
echo "Starting dbbackup TUI (will be interrupted)..."
|
||||
|
||||
# Run in background with PTY simulation (needed for TUI)
|
||||
cd "$PROJECT_DIR"
|
||||
timeout 30 script -q -c "$BINARY" /dev/null &
|
||||
PID=$!
|
||||
|
||||
echo "Process started: PID=$PID"
|
||||
sleep 2
|
||||
|
||||
# Check if process is running
|
||||
if ! kill -0 $PID 2>/dev/null; then
|
||||
echo -e "${YELLOW}Process exited quickly (expected for non-interactive test)${NC}"
|
||||
echo "This is normal - the TUI requires a real TTY"
|
||||
PID=""
|
||||
else
|
||||
echo "Process is running"
|
||||
|
||||
echo ""
|
||||
echo "=== Phase 4: Check Running State ==="
|
||||
|
||||
echo "Child processes of $PID:"
|
||||
pgrep -P $PID 2>/dev/null | while read child; do
|
||||
ps -p $child -o pid,ppid,cmd 2>/dev/null || true
|
||||
done
|
||||
|
||||
echo ""
|
||||
echo "=== Phase 5: Send SIGINT ==="
|
||||
echo "Sending SIGINT to process $PID..."
|
||||
kill -SIGINT $PID 2>/dev/null || true
|
||||
|
||||
echo "Waiting for cleanup (max 10 seconds)..."
|
||||
for i in {1..10}; do
|
||||
if ! kill -0 $PID 2>/dev/null; then
|
||||
echo ""
|
||||
echo -e "${GREEN}Process exited after ${i} seconds${NC}"
|
||||
break
|
||||
fi
|
||||
sleep 1
|
||||
echo -n "."
|
||||
done
|
||||
echo ""
|
||||
|
||||
# Check if still running
|
||||
if kill -0 $PID 2>/dev/null; then
|
||||
echo -e "${RED}Process still running after 10 seconds!${NC}"
|
||||
echo "Force killing..."
|
||||
kill -9 $PID 2>/dev/null || true
|
||||
fi
|
||||
fi
|
||||
|
||||
sleep 2 # Give OS time to clean up
|
||||
|
||||
echo ""
|
||||
echo "=== Phase 6: Post-Shutdown Verification ==="
|
||||
|
||||
# Check for zombie processes
|
||||
ZOMBIES=$(ps aux 2>/dev/null | grep -E "dbbackup|pg_dump|pg_restore" | grep -v grep | grep defunct | wc -l)
|
||||
echo "Zombie processes: $ZOMBIES"
|
||||
|
||||
# Check for orphaned children
|
||||
if [ -n "$PID" ]; then
|
||||
ORPHANS=$(pgrep -P $PID 2>/dev/null | wc -l || echo "0")
|
||||
echo "Orphaned children of original process: $ORPHANS"
|
||||
else
|
||||
ORPHANS=0
|
||||
fi
|
||||
|
||||
# Check for leftover related processes
|
||||
LEFTOVER_PROCS=$(pgrep -f "pg_dump|pg_restore" 2>/dev/null | wc -l || echo "0")
|
||||
echo "Leftover pg_dump/pg_restore processes: $LEFTOVER_PROCS"
|
||||
|
||||
# Check for temp files
|
||||
TEMP_FILES=$(ls /tmp/dbbackup-* 2>/dev/null | wc -l || echo "0")
|
||||
echo "Temporary files: $TEMP_FILES"
|
||||
|
||||
# Database connections check (if psql available and configured)
|
||||
if command -v psql &> /dev/null; then
|
||||
echo ""
|
||||
echo "Checking database connections..."
|
||||
DB_CONNS=$(psql -t -c "SELECT count(*) FROM pg_stat_activity WHERE application_name LIKE '%dbbackup%';" 2>/dev/null | tr -d ' ' || echo "N/A")
|
||||
echo "Database connections with 'dbbackup' in name: $DB_CONNS"
|
||||
else
|
||||
echo "psql not available - skipping database connection check"
|
||||
DB_CONNS="N/A"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=== Test Results ==="
|
||||
|
||||
PASSED=true
|
||||
|
||||
if [ "$ZOMBIES" -gt 0 ]; then
|
||||
echo -e "${RED}❌ FAIL: $ZOMBIES zombie process(es) found${NC}"
|
||||
PASSED=false
|
||||
else
|
||||
echo -e "${GREEN}✓ No zombie processes${NC}"
|
||||
fi
|
||||
|
||||
if [ "$ORPHANS" -gt 0 ]; then
|
||||
echo -e "${RED}❌ FAIL: $ORPHANS orphaned child process(es) found${NC}"
|
||||
PASSED=false
|
||||
else
|
||||
echo -e "${GREEN}✓ No orphaned children${NC}"
|
||||
fi
|
||||
|
||||
if [ "$LEFTOVER_PROCS" -gt 0 ]; then
|
||||
echo -e "${YELLOW}⚠ WARNING: $LEFTOVER_PROCS leftover pg_dump/pg_restore process(es)${NC}"
|
||||
echo " These may be from other operations"
|
||||
fi
|
||||
|
||||
if [ "$TEMP_FILES" -gt "$INITIAL_TEMPS" ]; then
|
||||
NEW_TEMPS=$((TEMP_FILES - INITIAL_TEMPS))
|
||||
echo -e "${RED}❌ FAIL: $NEW_TEMPS new temporary file(s) left behind${NC}"
|
||||
ls -la /tmp/dbbackup-* 2>/dev/null || true
|
||||
PASSED=false
|
||||
else
|
||||
echo -e "${GREEN}✓ No new temporary files left behind${NC}"
|
||||
fi
|
||||
|
||||
if [ "$DB_CONNS" != "N/A" ] && [ "$DB_CONNS" -gt 0 ]; then
|
||||
echo -e "${RED}❌ FAIL: $DB_CONNS database connection(s) still active${NC}"
|
||||
PASSED=false
|
||||
elif [ "$DB_CONNS" != "N/A" ]; then
|
||||
echo -e "${GREEN}✓ No lingering database connections${NC}"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
if [ "$PASSED" = true ]; then
|
||||
echo -e "${GREEN}=== ✓ ALL TESTS PASSED ===${NC}"
|
||||
exit 0
|
||||
else
|
||||
echo -e "${RED}=== ✗ SOME TESTS FAILED ===${NC}"
|
||||
exit 1
|
||||
fi
|
||||
Reference in New Issue
Block a user