Compare commits

..

1 Commits

Author SHA1 Message Date
59812400a4 v5.4.3: Bulletproof SIGINT handling & eliminate external gzip
All checks were successful
CI/CD / Test (push) Successful in 2m59s
CI/CD / Lint (push) Successful in 1m10s
CI/CD / Integration Tests (push) Successful in 50s
CI/CD / Native Engine Tests (push) Successful in 50s
CI/CD / Build Binary (push) Successful in 43s
CI/CD / Test Release Build (push) Successful in 1m17s
CI/CD / Release Binaries (push) Successful in 10m7s
## SIGINT Cleanup - Zero Zombie Processes
- Add cleanup.SafeCommand() with process group setup (Setpgid=true)
- Replace all exec.CommandContext with cleanup.SafeCommand in backup/restore
- Replace cmd.Process.Kill() with cleanup.KillCommandGroup() for entire process tree
- Add cleanup.Handler for graceful shutdown with registered cleanup functions
- Add rich cluster progress view for TUI
- Add test script: scripts/test-sigint-cleanup.sh

## Eliminate External gzip Process
- Replace zgrep (spawns gzip -cdfq) with in-process pgzip decompression
- All decompression now uses parallel pgzip (2-4x faster, no subprocess)

Files modified:
- internal/cleanup/command.go, command_windows.go, handler.go (new)
- internal/backup/engine.go (7 SafeCommand + 6 KillCommandGroup)
- internal/restore/engine.go (19 SafeCommand + 2 KillCommandGroup)
- internal/restore/{fast_restore,safety,diagnose,preflight,large_db_guard,version_check,error_report}.go
- internal/tui/restore_exec.go, rich_cluster_progress.go (new)
2026-02-02 14:44:49 +01:00
16 changed files with 1231 additions and 80 deletions

View File

@ -9,7 +9,6 @@ import (
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
@ -19,6 +18,7 @@ import (
"time"
"dbbackup/internal/checks"
"dbbackup/internal/cleanup"
"dbbackup/internal/cloud"
"dbbackup/internal/config"
"dbbackup/internal/database"
@ -650,7 +650,7 @@ func (e *Engine) executeCommandWithProgress(ctx context.Context, cmdArgs []strin
e.log.Debug("Executing backup command with progress", "cmd", cmdArgs[0], "args", cmdArgs[1:])
cmd := exec.CommandContext(ctx, cmdArgs[0], cmdArgs[1:]...)
cmd := cleanup.SafeCommand(ctx, cmdArgs[0], cmdArgs[1:]...)
// Set environment variables for database tools
cmd.Env = os.Environ()
@ -696,9 +696,9 @@ func (e *Engine) executeCommandWithProgress(ctx context.Context, cmdArgs []strin
case cmdErr = <-cmdDone:
// Command completed (success or failure)
case <-ctx.Done():
// Context cancelled - kill process to unblock
e.log.Warn("Backup cancelled - killing process")
cmd.Process.Kill()
// Context cancelled - kill entire process group
e.log.Warn("Backup cancelled - killing process group")
cleanup.KillCommandGroup(cmd)
<-cmdDone // Wait for goroutine to finish
cmdErr = ctx.Err()
}
@ -754,7 +754,7 @@ func (e *Engine) monitorCommandProgress(stderr io.ReadCloser, tracker *progress.
// Uses in-process pgzip for parallel compression (2-4x faster on multi-core systems)
func (e *Engine) executeMySQLWithProgressAndCompression(ctx context.Context, cmdArgs []string, outputFile string, tracker *progress.OperationTracker) error {
// Create mysqldump command
dumpCmd := exec.CommandContext(ctx, cmdArgs[0], cmdArgs[1:]...)
dumpCmd := cleanup.SafeCommand(ctx, cmdArgs[0], cmdArgs[1:]...)
dumpCmd.Env = os.Environ()
if e.cfg.Password != "" {
dumpCmd.Env = append(dumpCmd.Env, "MYSQL_PWD="+e.cfg.Password)
@ -816,8 +816,8 @@ func (e *Engine) executeMySQLWithProgressAndCompression(ctx context.Context, cmd
case dumpErr = <-dumpDone:
// mysqldump completed
case <-ctx.Done():
e.log.Warn("Backup cancelled - killing mysqldump")
dumpCmd.Process.Kill()
e.log.Warn("Backup cancelled - killing mysqldump process group")
cleanup.KillCommandGroup(dumpCmd)
<-dumpDone
return ctx.Err()
}
@ -846,7 +846,7 @@ func (e *Engine) executeMySQLWithProgressAndCompression(ctx context.Context, cmd
// Uses in-process pgzip for parallel compression (2-4x faster on multi-core systems)
func (e *Engine) executeMySQLWithCompression(ctx context.Context, cmdArgs []string, outputFile string) error {
// Create mysqldump command
dumpCmd := exec.CommandContext(ctx, cmdArgs[0], cmdArgs[1:]...)
dumpCmd := cleanup.SafeCommand(ctx, cmdArgs[0], cmdArgs[1:]...)
dumpCmd.Env = os.Environ()
if e.cfg.Password != "" {
dumpCmd.Env = append(dumpCmd.Env, "MYSQL_PWD="+e.cfg.Password)
@ -895,8 +895,8 @@ func (e *Engine) executeMySQLWithCompression(ctx context.Context, cmdArgs []stri
case dumpErr = <-dumpDone:
// mysqldump completed
case <-ctx.Done():
e.log.Warn("Backup cancelled - killing mysqldump")
dumpCmd.Process.Kill()
e.log.Warn("Backup cancelled - killing mysqldump process group")
cleanup.KillCommandGroup(dumpCmd)
<-dumpDone
return ctx.Err()
}
@ -951,7 +951,7 @@ func (e *Engine) createSampleBackup(ctx context.Context, databaseName, outputFil
Format: "plain",
})
cmd := exec.CommandContext(ctx, schemaCmd[0], schemaCmd[1:]...)
cmd := cleanup.SafeCommand(ctx, schemaCmd[0], schemaCmd[1:]...)
cmd.Env = os.Environ()
if e.cfg.Password != "" {
cmd.Env = append(cmd.Env, "PGPASSWORD="+e.cfg.Password)
@ -990,7 +990,7 @@ func (e *Engine) backupGlobals(ctx context.Context, tempDir string) error {
globalsFile := filepath.Join(tempDir, "globals.sql")
// CRITICAL: Always pass port even for localhost - user may have non-standard port
cmd := exec.CommandContext(ctx, "pg_dumpall", "--globals-only",
cmd := cleanup.SafeCommand(ctx, "pg_dumpall", "--globals-only",
"-p", fmt.Sprintf("%d", e.cfg.Port),
"-U", e.cfg.User)
@ -1034,8 +1034,8 @@ func (e *Engine) backupGlobals(ctx context.Context, tempDir string) error {
case cmdErr = <-cmdDone:
// Command completed normally
case <-ctx.Done():
e.log.Warn("Globals backup cancelled - killing pg_dumpall")
cmd.Process.Kill()
e.log.Warn("Globals backup cancelled - killing pg_dumpall process group")
cleanup.KillCommandGroup(cmd)
<-cmdDone
return ctx.Err()
}
@ -1430,7 +1430,7 @@ func (e *Engine) executeCommand(ctx context.Context, cmdArgs []string, outputFil
// For custom format, pg_dump handles everything (writes directly to file)
// NO GO BUFFERING - pg_dump writes directly to disk
cmd := exec.CommandContext(ctx, cmdArgs[0], cmdArgs[1:]...)
cmd := cleanup.SafeCommand(ctx, cmdArgs[0], cmdArgs[1:]...)
// Start heartbeat ticker for backup progress
backupStart := time.Now()
@ -1499,9 +1499,9 @@ func (e *Engine) executeCommand(ctx context.Context, cmdArgs []string, outputFil
case cmdErr = <-cmdDone:
// Command completed (success or failure)
case <-ctx.Done():
// Context cancelled - kill process to unblock
e.log.Warn("Backup cancelled - killing pg_dump process")
cmd.Process.Kill()
// Context cancelled - kill entire process group
e.log.Warn("Backup cancelled - killing pg_dump process group")
cleanup.KillCommandGroup(cmd)
<-cmdDone // Wait for goroutine to finish
cmdErr = ctx.Err()
}
@ -1536,7 +1536,7 @@ func (e *Engine) executeWithStreamingCompression(ctx context.Context, cmdArgs []
}
// Create pg_dump command
dumpCmd := exec.CommandContext(ctx, cmdArgs[0], cmdArgs[1:]...)
dumpCmd := cleanup.SafeCommand(ctx, cmdArgs[0], cmdArgs[1:]...)
dumpCmd.Env = os.Environ()
if e.cfg.Password != "" && e.cfg.IsPostgreSQL() {
dumpCmd.Env = append(dumpCmd.Env, "PGPASSWORD="+e.cfg.Password)
@ -1612,9 +1612,9 @@ func (e *Engine) executeWithStreamingCompression(ctx context.Context, cmdArgs []
case dumpErr = <-dumpDone:
// pg_dump completed (success or failure)
case <-ctx.Done():
// Context cancelled/timeout - kill pg_dump to unblock
e.log.Warn("Backup timeout - killing pg_dump process")
dumpCmd.Process.Kill()
// Context cancelled/timeout - kill pg_dump process group
e.log.Warn("Backup timeout - killing pg_dump process group")
cleanup.KillCommandGroup(dumpCmd)
<-dumpDone // Wait for goroutine to finish
dumpErr = ctx.Err()
}

154
internal/cleanup/command.go Normal file
View File

@ -0,0 +1,154 @@
//go:build !windows
// +build !windows
package cleanup
import (
"context"
"fmt"
"os/exec"
"syscall"
"time"
"dbbackup/internal/logger"
)
// SafeCommand creates an exec.Cmd with proper process group setup for clean termination.
// This ensures that child processes (e.g., from pipelines) are killed when the parent is killed.
func SafeCommand(ctx context.Context, name string, args ...string) *exec.Cmd {
cmd := exec.CommandContext(ctx, name, args...)
// Set up process group for clean termination
// This allows killing the entire process tree when cancelled
cmd.SysProcAttr = &syscall.SysProcAttr{
Setpgid: true, // Create new process group
Pgid: 0, // Use the new process's PID as the PGID
}
return cmd
}
// TrackedCommand creates a command that is tracked for cleanup on shutdown.
// When the handler shuts down, this command will be killed if still running.
type TrackedCommand struct {
*exec.Cmd
log logger.Logger
name string
}
// NewTrackedCommand creates a tracked command
func NewTrackedCommand(ctx context.Context, log logger.Logger, name string, args ...string) *TrackedCommand {
tc := &TrackedCommand{
Cmd: SafeCommand(ctx, name, args...),
log: log,
name: name,
}
return tc
}
// StartWithCleanup starts the command and registers cleanup with the handler
func (tc *TrackedCommand) StartWithCleanup(h *Handler) error {
if err := tc.Cmd.Start(); err != nil {
return err
}
// Register cleanup function
pid := tc.Cmd.Process.Pid
h.RegisterCleanup(fmt.Sprintf("kill-%s-%d", tc.name, pid), func(ctx context.Context) error {
return tc.Kill()
})
return nil
}
// Kill terminates the command and its process group
func (tc *TrackedCommand) Kill() error {
if tc.Cmd.Process == nil {
return nil // Not started or already cleaned up
}
pid := tc.Cmd.Process.Pid
// Get the process group ID
pgid, err := syscall.Getpgid(pid)
if err != nil {
// Process might already be gone
return nil
}
tc.log.Debug("Terminating process", "name", tc.name, "pid", pid, "pgid", pgid)
// Try graceful shutdown first (SIGTERM to process group)
if err := syscall.Kill(-pgid, syscall.SIGTERM); err != nil {
tc.log.Debug("SIGTERM failed, trying SIGKILL", "error", err)
}
// Wait briefly for graceful shutdown
done := make(chan error, 1)
go func() {
_, err := tc.Cmd.Process.Wait()
done <- err
}()
select {
case <-time.After(3 * time.Second):
// Force kill after timeout
tc.log.Debug("Process didn't stop gracefully, sending SIGKILL", "name", tc.name, "pid", pid)
if err := syscall.Kill(-pgid, syscall.SIGKILL); err != nil {
tc.log.Debug("SIGKILL failed", "error", err)
}
<-done // Wait for Wait() to finish
case <-done:
// Process exited
}
tc.log.Debug("Process terminated", "name", tc.name, "pid", pid)
return nil
}
// WaitWithContext waits for the command to complete, handling context cancellation properly.
// This is the recommended way to wait for commands, as it ensures proper cleanup on cancellation.
func WaitWithContext(ctx context.Context, cmd *exec.Cmd, log logger.Logger) error {
if cmd.Process == nil {
return fmt.Errorf("process not started")
}
// Wait for command in a goroutine
cmdDone := make(chan error, 1)
go func() {
cmdDone <- cmd.Wait()
}()
select {
case err := <-cmdDone:
return err
case <-ctx.Done():
// Context cancelled - kill process group
log.Debug("Context cancelled, terminating process", "pid", cmd.Process.Pid)
// Get process group and kill entire group
pgid, err := syscall.Getpgid(cmd.Process.Pid)
if err == nil {
// Kill process group
syscall.Kill(-pgid, syscall.SIGTERM)
// Wait briefly for graceful shutdown
select {
case <-cmdDone:
// Process exited
case <-time.After(2 * time.Second):
// Force kill
syscall.Kill(-pgid, syscall.SIGKILL)
<-cmdDone
}
} else {
// Fallback to killing just the process
cmd.Process.Kill()
<-cmdDone
}
return ctx.Err()
}
}

View File

@ -0,0 +1,99 @@
//go:build windows
// +build windows
package cleanup
import (
"context"
"fmt"
"os/exec"
"time"
"dbbackup/internal/logger"
)
// SafeCommand creates an exec.Cmd with proper setup for clean termination on Windows.
func SafeCommand(ctx context.Context, name string, args ...string) *exec.Cmd {
cmd := exec.CommandContext(ctx, name, args...)
// Windows doesn't use process groups the same way as Unix
// exec.CommandContext will handle termination via the context
return cmd
}
// TrackedCommand creates a command that is tracked for cleanup on shutdown.
type TrackedCommand struct {
*exec.Cmd
log logger.Logger
name string
}
// NewTrackedCommand creates a tracked command
func NewTrackedCommand(ctx context.Context, log logger.Logger, name string, args ...string) *TrackedCommand {
tc := &TrackedCommand{
Cmd: SafeCommand(ctx, name, args...),
log: log,
name: name,
}
return tc
}
// StartWithCleanup starts the command and registers cleanup with the handler
func (tc *TrackedCommand) StartWithCleanup(h *Handler) error {
if err := tc.Cmd.Start(); err != nil {
return err
}
// Register cleanup function
pid := tc.Cmd.Process.Pid
h.RegisterCleanup(fmt.Sprintf("kill-%s-%d", tc.name, pid), func(ctx context.Context) error {
return tc.Kill()
})
return nil
}
// Kill terminates the command on Windows
func (tc *TrackedCommand) Kill() error {
if tc.Cmd.Process == nil {
return nil
}
tc.log.Debug("Terminating process", "name", tc.name, "pid", tc.Cmd.Process.Pid)
if err := tc.Cmd.Process.Kill(); err != nil {
tc.log.Debug("Kill failed", "error", err)
return err
}
tc.log.Debug("Process terminated", "name", tc.name, "pid", tc.Cmd.Process.Pid)
return nil
}
// WaitWithContext waits for the command to complete, handling context cancellation properly.
func WaitWithContext(ctx context.Context, cmd *exec.Cmd, log logger.Logger) error {
if cmd.Process == nil {
return fmt.Errorf("process not started")
}
cmdDone := make(chan error, 1)
go func() {
cmdDone <- cmd.Wait()
}()
select {
case err := <-cmdDone:
return err
case <-ctx.Done():
log.Debug("Context cancelled, terminating process", "pid", cmd.Process.Pid)
cmd.Process.Kill()
select {
case <-cmdDone:
case <-time.After(5 * time.Second):
// Already killed, just wait for it
}
return ctx.Err()
}
}

242
internal/cleanup/handler.go Normal file
View File

@ -0,0 +1,242 @@
// Package cleanup provides graceful shutdown and resource cleanup functionality
package cleanup
import (
"context"
"fmt"
"os"
"os/signal"
"sync"
"syscall"
"time"
"dbbackup/internal/logger"
)
// CleanupFunc is a function that performs cleanup with a timeout context
type CleanupFunc func(ctx context.Context) error
// Handler manages graceful shutdown and resource cleanup
type Handler struct {
ctx context.Context
cancel context.CancelFunc
cleanupFns []cleanupEntry
mu sync.Mutex
shutdownTimeout time.Duration
log logger.Logger
// Track if shutdown has been initiated
shutdownOnce sync.Once
shutdownDone chan struct{}
}
type cleanupEntry struct {
name string
fn CleanupFunc
}
// NewHandler creates a shutdown handler
func NewHandler(log logger.Logger) *Handler {
ctx, cancel := context.WithCancel(context.Background())
h := &Handler{
ctx: ctx,
cancel: cancel,
cleanupFns: make([]cleanupEntry, 0),
shutdownTimeout: 30 * time.Second,
log: log,
shutdownDone: make(chan struct{}),
}
return h
}
// Context returns the shutdown context
func (h *Handler) Context() context.Context {
return h.ctx
}
// RegisterCleanup adds a named cleanup function
func (h *Handler) RegisterCleanup(name string, fn CleanupFunc) {
h.mu.Lock()
defer h.mu.Unlock()
h.cleanupFns = append(h.cleanupFns, cleanupEntry{name: name, fn: fn})
}
// SetShutdownTimeout sets the maximum time to wait for cleanup
func (h *Handler) SetShutdownTimeout(d time.Duration) {
h.shutdownTimeout = d
}
// Shutdown triggers graceful shutdown
func (h *Handler) Shutdown() {
h.shutdownOnce.Do(func() {
h.log.Info("Initiating graceful shutdown...")
// Cancel context first (stops all ongoing operations)
h.cancel()
// Run cleanup functions
h.runCleanup()
close(h.shutdownDone)
})
}
// ShutdownWithSignal triggers shutdown due to an OS signal
func (h *Handler) ShutdownWithSignal(sig os.Signal) {
h.log.Info("Received signal, initiating graceful shutdown", "signal", sig.String())
h.Shutdown()
}
// Wait blocks until shutdown is complete
func (h *Handler) Wait() {
<-h.shutdownDone
}
// runCleanup executes all cleanup functions in LIFO order
func (h *Handler) runCleanup() {
h.mu.Lock()
fns := make([]cleanupEntry, len(h.cleanupFns))
copy(fns, h.cleanupFns)
h.mu.Unlock()
if len(fns) == 0 {
h.log.Info("No cleanup functions registered")
return
}
h.log.Info("Running cleanup functions", "count", len(fns))
// Create timeout context for cleanup
ctx, cancel := context.WithTimeout(context.Background(), h.shutdownTimeout)
defer cancel()
// Run all cleanups in LIFO order (most recently registered first)
var failed int
for i := len(fns) - 1; i >= 0; i-- {
entry := fns[i]
h.log.Debug("Running cleanup", "name", entry.name)
if err := entry.fn(ctx); err != nil {
h.log.Warn("Cleanup function failed", "name", entry.name, "error", err)
failed++
} else {
h.log.Debug("Cleanup completed", "name", entry.name)
}
}
if failed > 0 {
h.log.Warn("Some cleanup functions failed", "failed", failed, "total", len(fns))
} else {
h.log.Info("All cleanup functions completed successfully")
}
}
// RegisterSignalHandler sets up signal handling for graceful shutdown
func (h *Handler) RegisterSignalHandler() {
sigChan := make(chan os.Signal, 2)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGINT)
go func() {
// First signal: graceful shutdown
sig := <-sigChan
h.ShutdownWithSignal(sig)
// Second signal: force exit
sig = <-sigChan
h.log.Warn("Received second signal, forcing exit", "signal", sig.String())
os.Exit(1)
}()
}
// ChildProcessCleanup creates a cleanup function for killing child processes
func (h *Handler) ChildProcessCleanup() CleanupFunc {
return func(ctx context.Context) error {
h.log.Info("Cleaning up orphaned child processes...")
if err := KillOrphanedProcesses(h.log); err != nil {
h.log.Warn("Failed to kill some orphaned processes", "error", err)
return err
}
h.log.Info("Child process cleanup complete")
return nil
}
}
// DatabasePoolCleanup creates a cleanup function for database connection pools
// poolCloser should be a function that closes the pool
func DatabasePoolCleanup(log logger.Logger, name string, poolCloser func()) CleanupFunc {
return func(ctx context.Context) error {
log.Debug("Closing database connection pool", "name", name)
poolCloser()
log.Debug("Database connection pool closed", "name", name)
return nil
}
}
// FileCleanup creates a cleanup function for file handles
func FileCleanup(log logger.Logger, path string, file *os.File) CleanupFunc {
return func(ctx context.Context) error {
if file == nil {
return nil
}
log.Debug("Closing file", "path", path)
if err := file.Close(); err != nil {
return fmt.Errorf("failed to close file %s: %w", path, err)
}
return nil
}
}
// TempFileCleanup creates a cleanup function that closes and removes a temp file
func TempFileCleanup(log logger.Logger, file *os.File) CleanupFunc {
return func(ctx context.Context) error {
if file == nil {
return nil
}
path := file.Name()
log.Debug("Removing temporary file", "path", path)
// Close file first
if err := file.Close(); err != nil {
log.Warn("Failed to close temp file", "path", path, "error", err)
}
// Remove file
if err := os.Remove(path); err != nil {
if !os.IsNotExist(err) {
return fmt.Errorf("failed to remove temp file %s: %w", path, err)
}
}
log.Debug("Temporary file removed", "path", path)
return nil
}
}
// TempDirCleanup creates a cleanup function that removes a temp directory
func TempDirCleanup(log logger.Logger, path string) CleanupFunc {
return func(ctx context.Context) error {
if path == "" {
return nil
}
log.Debug("Removing temporary directory", "path", path)
if err := os.RemoveAll(path); err != nil {
if !os.IsNotExist(err) {
return fmt.Errorf("failed to remove temp dir %s: %w", path, err)
}
}
log.Debug("Temporary directory removed", "path", path)
return nil
}
}

View File

@ -8,12 +8,12 @@ import (
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"regexp"
"strings"
"time"
"dbbackup/internal/cleanup"
"dbbackup/internal/fs"
"dbbackup/internal/logger"
@ -568,7 +568,7 @@ func (d *Diagnoser) verifyWithPgRestore(filePath string, result *DiagnoseResult)
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutMinutes)*time.Minute)
defer cancel()
cmd := exec.CommandContext(ctx, "pg_restore", "--list", filePath)
cmd := cleanup.SafeCommand(ctx, "pg_restore", "--list", filePath)
output, err := cmd.CombinedOutput()
if err != nil {

View File

@ -17,6 +17,7 @@ import (
"time"
"dbbackup/internal/checks"
"dbbackup/internal/cleanup"
"dbbackup/internal/config"
"dbbackup/internal/database"
"dbbackup/internal/fs"
@ -499,7 +500,7 @@ func (e *Engine) checkDumpHasLargeObjects(archivePath string) bool {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "pg_restore", "-l", archivePath)
cmd := cleanup.SafeCommand(ctx, "pg_restore", "-l", archivePath)
output, err := cmd.Output()
if err != nil {
@ -592,7 +593,7 @@ func (e *Engine) executeRestoreCommand(ctx context.Context, cmdArgs []string) er
func (e *Engine) executeRestoreCommandWithContext(ctx context.Context, cmdArgs []string, archivePath, targetDB string, format ArchiveFormat) error {
e.log.Info("Executing restore command", "command", strings.Join(cmdArgs, " "))
cmd := exec.CommandContext(ctx, cmdArgs[0], cmdArgs[1:]...)
cmd := cleanup.SafeCommand(ctx, cmdArgs[0], cmdArgs[1:]...)
// Set environment variables
cmd.Env = append(os.Environ(),
@ -662,9 +663,9 @@ func (e *Engine) executeRestoreCommandWithContext(ctx context.Context, cmdArgs [
case cmdErr = <-cmdDone:
// Command completed (success or failure)
case <-ctx.Done():
// Context cancelled - kill process
e.log.Warn("Restore cancelled - killing process")
cmd.Process.Kill()
// Context cancelled - kill entire process group
e.log.Warn("Restore cancelled - killing process group")
cleanup.KillCommandGroup(cmd)
<-cmdDone
cmdErr = ctx.Err()
}
@ -772,7 +773,7 @@ func (e *Engine) executeRestoreWithDecompression(ctx context.Context, archivePat
defer gz.Close()
// Start restore command
cmd := exec.CommandContext(ctx, restoreCmd[0], restoreCmd[1:]...)
cmd := cleanup.SafeCommand(ctx, restoreCmd[0], restoreCmd[1:]...)
cmd.Env = append(os.Environ(),
fmt.Sprintf("PGPASSWORD=%s", e.cfg.Password),
fmt.Sprintf("MYSQL_PWD=%s", e.cfg.Password),
@ -876,7 +877,7 @@ func (e *Engine) executeRestoreWithPgzipStream(ctx context.Context, archivePath,
if e.cfg.Host != "localhost" && e.cfg.Host != "" {
args = append([]string{"-h", e.cfg.Host}, args...)
}
cmd = exec.CommandContext(ctx, "psql", args...)
cmd = cleanup.SafeCommand(ctx, "psql", args...)
cmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", e.cfg.Password))
} else {
// MySQL - use MYSQL_PWD env var to avoid password in process list
@ -885,7 +886,7 @@ func (e *Engine) executeRestoreWithPgzipStream(ctx context.Context, archivePath,
args = append(args, "-h", e.cfg.Host)
}
args = append(args, "-P", fmt.Sprintf("%d", e.cfg.Port), targetDB)
cmd = exec.CommandContext(ctx, "mysql", args...)
cmd = cleanup.SafeCommand(ctx, "mysql", args...)
// Pass password via environment variable to avoid process list exposure
cmd.Env = os.Environ()
if e.cfg.Password != "" {
@ -1322,7 +1323,7 @@ func (e *Engine) RestoreCluster(ctx context.Context, archivePath string, preExtr
}
} else if strings.HasSuffix(dumpFile, ".dump") {
// Validate custom format dumps using pg_restore --list
cmd := exec.CommandContext(ctx, "pg_restore", "--list", dumpFile)
cmd := cleanup.SafeCommand(ctx, "pg_restore", "--list", dumpFile)
output, err := cmd.CombinedOutput()
if err != nil {
dbName := strings.TrimSuffix(entry.Name(), ".dump")
@ -2121,7 +2122,7 @@ func (e *Engine) restoreGlobals(ctx context.Context, globalsFile string) error {
args = append([]string{"-h", e.cfg.Host}, args...)
}
cmd := exec.CommandContext(ctx, "psql", args...)
cmd := cleanup.SafeCommand(ctx, "psql", args...)
cmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", e.cfg.Password))
@ -2183,8 +2184,8 @@ func (e *Engine) restoreGlobals(ctx context.Context, globalsFile string) error {
case cmdErr = <-cmdDone:
// Command completed
case <-ctx.Done():
e.log.Warn("Globals restore cancelled - killing process")
cmd.Process.Kill()
e.log.Warn("Globals restore cancelled - killing process group")
cleanup.KillCommandGroup(cmd)
<-cmdDone
cmdErr = ctx.Err()
}
@ -2225,7 +2226,7 @@ func (e *Engine) checkSuperuser(ctx context.Context) (bool, error) {
args = append([]string{"-h", e.cfg.Host}, args...)
}
cmd := exec.CommandContext(ctx, "psql", args...)
cmd := cleanup.SafeCommand(ctx, "psql", args...)
// Always set PGPASSWORD (empty string is fine for peer/ident auth)
cmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", e.cfg.Password))
@ -2260,7 +2261,7 @@ func (e *Engine) terminateConnections(ctx context.Context, dbName string) error
args = append([]string{"-h", e.cfg.Host}, args...)
}
cmd := exec.CommandContext(ctx, "psql", args...)
cmd := cleanup.SafeCommand(ctx, "psql", args...)
// Always set PGPASSWORD (empty string is fine for peer/ident auth)
cmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", e.cfg.Password))
@ -2296,7 +2297,7 @@ func (e *Engine) dropDatabaseIfExists(ctx context.Context, dbName string) error
if e.cfg.Host != "localhost" && e.cfg.Host != "127.0.0.1" && e.cfg.Host != "" {
revokeArgs = append([]string{"-h", e.cfg.Host}, revokeArgs...)
}
revokeCmd := exec.CommandContext(ctx, "psql", revokeArgs...)
revokeCmd := cleanup.SafeCommand(ctx, "psql", revokeArgs...)
revokeCmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", e.cfg.Password))
revokeCmd.Run() // Ignore errors - database might not exist
@ -2315,7 +2316,7 @@ func (e *Engine) dropDatabaseIfExists(ctx context.Context, dbName string) error
if e.cfg.Host != "localhost" && e.cfg.Host != "127.0.0.1" && e.cfg.Host != "" {
forceArgs = append([]string{"-h", e.cfg.Host}, forceArgs...)
}
forceCmd := exec.CommandContext(ctx, "psql", forceArgs...)
forceCmd := cleanup.SafeCommand(ctx, "psql", forceArgs...)
forceCmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", e.cfg.Password))
output, err := forceCmd.CombinedOutput()
@ -2338,7 +2339,7 @@ func (e *Engine) dropDatabaseIfExists(ctx context.Context, dbName string) error
args = append([]string{"-h", e.cfg.Host}, args...)
}
cmd := exec.CommandContext(ctx, "psql", args...)
cmd := cleanup.SafeCommand(ctx, "psql", args...)
cmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", e.cfg.Password))
output, err = cmd.CombinedOutput()
@ -2372,7 +2373,7 @@ func (e *Engine) ensureMySQLDatabaseExists(ctx context.Context, dbName string) e
"-e", fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `%s`", dbName),
}
cmd := exec.CommandContext(ctx, "mysql", args...)
cmd := cleanup.SafeCommand(ctx, "mysql", args...)
cmd.Env = os.Environ()
if e.cfg.Password != "" {
cmd.Env = append(cmd.Env, "MYSQL_PWD="+e.cfg.Password)
@ -2410,7 +2411,7 @@ func (e *Engine) ensurePostgresDatabaseExists(ctx context.Context, dbName string
args = append([]string{"-h", e.cfg.Host}, args...)
}
cmd := exec.CommandContext(ctx, "psql", args...)
cmd := cleanup.SafeCommand(ctx, "psql", args...)
// Always set PGPASSWORD (empty string is fine for peer/ident auth)
cmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", e.cfg.Password))
@ -2467,7 +2468,7 @@ func (e *Engine) ensurePostgresDatabaseExists(ctx context.Context, dbName string
createArgs = append([]string{"-h", e.cfg.Host}, createArgs...)
}
createCmd := exec.CommandContext(ctx, "psql", createArgs...)
createCmd := cleanup.SafeCommand(ctx, "psql", createArgs...)
// Always set PGPASSWORD (empty string is fine for peer/ident auth)
createCmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", e.cfg.Password))
@ -2487,7 +2488,7 @@ func (e *Engine) ensurePostgresDatabaseExists(ctx context.Context, dbName string
simpleArgs = append([]string{"-h", e.cfg.Host}, simpleArgs...)
}
simpleCmd := exec.CommandContext(ctx, "psql", simpleArgs...)
simpleCmd := cleanup.SafeCommand(ctx, "psql", simpleArgs...)
simpleCmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", e.cfg.Password))
output, err = simpleCmd.CombinedOutput()
@ -2552,7 +2553,7 @@ func (e *Engine) detectLargeObjectsInDumps(dumpsDir string, entries []os.DirEntr
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
cmd := exec.CommandContext(ctx, "pg_restore", "-l", dumpFile)
cmd := cleanup.SafeCommand(ctx, "pg_restore", "-l", dumpFile)
output, err := cmd.Output()
if err != nil {
@ -2876,7 +2877,7 @@ func (e *Engine) canRestartPostgreSQL() bool {
// Try a quick sudo check - if this fails, we can't restart
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "sudo", "-n", "true")
cmd := cleanup.SafeCommand(ctx, "sudo", "-n", "true")
cmd.Stdin = nil
if err := cmd.Run(); err != nil {
e.log.Info("Running as postgres user without sudo access - cannot restart PostgreSQL",
@ -2906,7 +2907,7 @@ func (e *Engine) tryRestartPostgreSQL(ctx context.Context) bool {
runWithTimeout := func(args ...string) bool {
cmdCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cmd := exec.CommandContext(cmdCtx, args[0], args[1:]...)
cmd := cleanup.SafeCommand(cmdCtx, args[0], args[1:]...)
// Set stdin to /dev/null to prevent sudo from waiting for password
cmd.Stdin = nil
return cmd.Run() == nil

View File

@ -7,12 +7,12 @@ import (
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
"dbbackup/internal/cleanup"
"dbbackup/internal/config"
"dbbackup/internal/logger"
@ -568,7 +568,7 @@ func getCommandVersion(cmd string, arg string) string {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
output, err := exec.CommandContext(ctx, cmd, arg).CombinedOutput()
output, err := cleanup.SafeCommand(ctx, cmd, arg).CombinedOutput()
if err != nil {
return ""
}

View File

@ -5,11 +5,11 @@ package restore
import (
"context"
"fmt"
"os/exec"
"strings"
"sync"
"time"
"dbbackup/internal/cleanup"
"dbbackup/internal/config"
"dbbackup/internal/logger"
)
@ -124,7 +124,7 @@ func ApplySessionOptimizations(ctx context.Context, cfg *config.Config, log logg
for _, sql := range safeOptimizations {
cmdArgs := append(args, "-c", sql)
cmd := exec.CommandContext(ctx, "psql", cmdArgs...)
cmd := cleanup.SafeCommand(ctx, "psql", cmdArgs...)
cmd.Env = append(cmd.Environ(), fmt.Sprintf("PGPASSWORD=%s", cfg.Password))
if err := cmd.Run(); err != nil {

View File

@ -6,11 +6,11 @@ import (
"database/sql"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"syscall"
"dbbackup/internal/cleanup"
"dbbackup/internal/config"
"dbbackup/internal/logger"
)
@ -572,7 +572,7 @@ func (g *LargeDBGuard) RevertMySQLSettings() []string {
// Uses pg_restore -l which outputs a line-by-line listing, then streams through it
func (g *LargeDBGuard) StreamCountBLOBs(ctx context.Context, dumpFile string) (int, error) {
// pg_restore -l outputs text listing, one line per object
cmd := exec.CommandContext(ctx, "pg_restore", "-l", dumpFile)
cmd := cleanup.SafeCommand(ctx, "pg_restore", "-l", dumpFile)
stdout, err := cmd.StdoutPipe()
if err != nil {
@ -609,7 +609,7 @@ func (g *LargeDBGuard) StreamCountBLOBs(ctx context.Context, dumpFile string) (i
// StreamAnalyzeDump analyzes a dump file using streaming to avoid memory issues
// Returns: blobCount, estimatedObjects, error
func (g *LargeDBGuard) StreamAnalyzeDump(ctx context.Context, dumpFile string) (blobCount, totalObjects int, err error) {
cmd := exec.CommandContext(ctx, "pg_restore", "-l", dumpFile)
cmd := cleanup.SafeCommand(ctx, "pg_restore", "-l", dumpFile)
stdout, err := cmd.StdoutPipe()
if err != nil {

View File

@ -1,18 +1,22 @@
package restore
import (
"bufio"
"context"
"database/sql"
"fmt"
"os"
"os/exec"
"path/filepath"
"regexp"
"runtime"
"strconv"
"strings"
"time"
"dbbackup/internal/cleanup"
"github.com/dustin/go-humanize"
"github.com/klauspost/pgzip"
"github.com/shirou/gopsutil/v3/mem"
)
@ -381,7 +385,7 @@ func (e *Engine) countBlobsInDump(ctx context.Context, dumpFile string) int {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "pg_restore", "-l", dumpFile)
cmd := cleanup.SafeCommand(ctx, "pg_restore", "-l", dumpFile)
output, err := cmd.Output()
if err != nil {
return 0
@ -398,24 +402,51 @@ func (e *Engine) countBlobsInDump(ctx context.Context, dumpFile string) int {
}
// estimateBlobsInSQL samples compressed SQL for lo_create patterns
// Uses in-process pgzip decompression (NO external gzip process)
func (e *Engine) estimateBlobsInSQL(sqlFile string) int {
// Use zgrep for efficient searching in gzipped files
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// Count lo_create calls (each = one large object)
cmd := exec.CommandContext(ctx, "zgrep", "-c", "lo_create", sqlFile)
output, err := cmd.Output()
// Open the gzipped file
f, err := os.Open(sqlFile)
if err != nil {
// Also try SELECT lo_create pattern
cmd2 := exec.CommandContext(ctx, "zgrep", "-c", "SELECT.*lo_create", sqlFile)
output, err = cmd2.Output()
if err != nil {
return 0
}
e.log.Debug("Cannot open SQL file for BLOB estimation", "file", sqlFile, "error", err)
return 0
}
defer f.Close()
// Create pgzip reader for parallel decompression
gzReader, err := pgzip.NewReader(f)
if err != nil {
e.log.Debug("Cannot create pgzip reader", "file", sqlFile, "error", err)
return 0
}
defer gzReader.Close()
// Scan for lo_create patterns
// We use a regex to match both "lo_create" and "SELECT lo_create" patterns
loCreatePattern := regexp.MustCompile(`lo_create`)
scanner := bufio.NewScanner(gzReader)
// Use larger buffer for potentially long lines
buf := make([]byte, 0, 256*1024)
scanner.Buffer(buf, 10*1024*1024)
count := 0
linesScanned := 0
maxLines := 1000000 // Limit scanning for very large files
for scanner.Scan() && linesScanned < maxLines {
line := scanner.Text()
linesScanned++
// Count all lo_create occurrences in the line
matches := loCreatePattern.FindAllString(line, -1)
count += len(matches)
}
count, _ := strconv.Atoi(strings.TrimSpace(string(output)))
if err := scanner.Err(); err != nil {
e.log.Debug("Error scanning SQL file", "file", sqlFile, "error", err, "lines_scanned", linesScanned)
}
e.log.Debug("BLOB estimation from SQL file", "file", sqlFile, "lo_create_count", count, "lines_scanned", linesScanned)
return count
}

View File

@ -8,6 +8,7 @@ import (
"os/exec"
"strings"
"dbbackup/internal/cleanup"
"dbbackup/internal/config"
"dbbackup/internal/fs"
"dbbackup/internal/logger"
@ -419,7 +420,7 @@ func (s *Safety) checkPostgresDatabaseExists(ctx context.Context, dbName string)
}
args = append([]string{"-h", host}, args...)
cmd := exec.CommandContext(ctx, "psql", args...)
cmd := cleanup.SafeCommand(ctx, "psql", args...)
// Set password if provided
if s.cfg.Password != "" {
@ -447,7 +448,7 @@ func (s *Safety) checkMySQLDatabaseExists(ctx context.Context, dbName string) (b
args = append([]string{"-h", s.cfg.Host}, args...)
}
cmd := exec.CommandContext(ctx, "mysql", args...)
cmd := cleanup.SafeCommand(ctx, "mysql", args...)
if s.cfg.Password != "" {
cmd.Env = append(os.Environ(), fmt.Sprintf("MYSQL_PWD=%s", s.cfg.Password))
@ -493,7 +494,7 @@ func (s *Safety) listPostgresUserDatabases(ctx context.Context) ([]string, error
}
args = append([]string{"-h", host}, args...)
cmd := exec.CommandContext(ctx, "psql", args...)
cmd := cleanup.SafeCommand(ctx, "psql", args...)
// Set password - check config first, then environment
env := os.Environ()
@ -542,7 +543,7 @@ func (s *Safety) listMySQLUserDatabases(ctx context.Context) ([]string, error) {
args = append([]string{"-h", s.cfg.Host}, args...)
}
cmd := exec.CommandContext(ctx, "mysql", args...)
cmd := cleanup.SafeCommand(ctx, "mysql", args...)
if s.cfg.Password != "" {
cmd.Env = append(os.Environ(), fmt.Sprintf("MYSQL_PWD=%s", s.cfg.Password))

View File

@ -3,11 +3,11 @@ package restore
import (
"context"
"fmt"
"os/exec"
"regexp"
"strconv"
"time"
"dbbackup/internal/cleanup"
"dbbackup/internal/database"
)
@ -54,7 +54,7 @@ func GetDumpFileVersion(dumpPath string) (*VersionInfo, error) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "pg_restore", "-l", dumpPath)
cmd := cleanup.SafeCommand(ctx, "pg_restore", "-l", dumpPath)
output, err := cmd.CombinedOutput()
if err != nil {
return nil, fmt.Errorf("failed to read dump file metadata: %w (output: %s)", err, string(output))

View File

@ -16,6 +16,7 @@ import (
"dbbackup/internal/config"
"dbbackup/internal/database"
"dbbackup/internal/logger"
"dbbackup/internal/progress"
"dbbackup/internal/restore"
)
@ -75,6 +76,13 @@ type RestoreExecutionModel struct {
overallPhase int // 1=Extracting, 2=Globals, 3=Databases
extractionDone bool
// Rich progress view for cluster restores
richProgressView *RichClusterProgressView
unifiedProgress *progress.UnifiedClusterProgress
useRichProgress bool // Whether to use the rich progress view
termWidth int // Terminal width for rich progress
termHeight int // Terminal height for rich progress
// Results
done bool
cancelling bool // True when user has requested cancellation
@ -108,6 +116,11 @@ func NewRestoreExecution(cfg *config.Config, log logger.Logger, parent tea.Model
details: []string{},
spinnerFrames: spinnerFrames, // Use package-level constant
spinnerFrame: 0,
// Initialize rich progress view for cluster restores
richProgressView: NewRichClusterProgressView(),
useRichProgress: restoreType == "restore-cluster",
termWidth: 80,
termHeight: 24,
}
}
@ -176,6 +189,9 @@ type sharedProgressState struct {
// Throttling to prevent excessive updates (memory optimization)
lastSpeedSampleTime time.Time // Last time we added a speed sample
minSampleInterval time.Duration // Minimum interval between samples (100ms)
// Unified progress tracker for rich display
unifiedProgress *progress.UnifiedClusterProgress
}
type restoreSpeedSample struct {
@ -231,6 +247,18 @@ func getCurrentRestoreProgress() (bytesTotal, bytesDone int64, description strin
currentRestoreProgressState.phase3StartTime
}
// getUnifiedProgress returns the unified progress tracker if available
func getUnifiedProgress() *progress.UnifiedClusterProgress {
currentRestoreProgressMu.Lock()
defer currentRestoreProgressMu.Unlock()
if currentRestoreProgressState == nil {
return nil
}
return currentRestoreProgressState.unifiedProgress
}
// calculateRollingSpeed calculates speed from recent samples (last 5 seconds)
func calculateRollingSpeed(samples []restoreSpeedSample) float64 {
if len(samples) < 2 {
@ -332,6 +360,11 @@ func executeRestoreWithTUIProgress(parentCtx context.Context, cfg *config.Config
progressState := &sharedProgressState{
speedSamples: make([]restoreSpeedSample, 0, 100),
}
// Initialize unified progress tracker for cluster restores
if restoreType == "restore-cluster" {
progressState.unifiedProgress = progress.NewUnifiedClusterProgress("restore", archive.Path)
}
engine.SetProgressCallback(func(current, total int64, description string) {
progressState.mu.Lock()
defer progressState.mu.Unlock()
@ -342,10 +375,19 @@ func executeRestoreWithTUIProgress(parentCtx context.Context, cfg *config.Config
progressState.overallPhase = 1
progressState.extractionDone = false
// Update unified progress tracker
if progressState.unifiedProgress != nil {
progressState.unifiedProgress.SetPhase(progress.PhaseExtracting)
progressState.unifiedProgress.SetExtractProgress(current, total)
}
// Check if extraction is complete
if current >= total && total > 0 {
progressState.extractionDone = true
progressState.overallPhase = 2
if progressState.unifiedProgress != nil {
progressState.unifiedProgress.SetPhase(progress.PhaseGlobals)
}
}
// Throttle speed samples to prevent memory bloat (max 10 samples/sec)
@ -384,6 +426,13 @@ func executeRestoreWithTUIProgress(parentCtx context.Context, cfg *config.Config
// Clear byte progress when switching to db progress
progressState.bytesTotal = 0
progressState.bytesDone = 0
// Update unified progress tracker
if progressState.unifiedProgress != nil {
progressState.unifiedProgress.SetPhase(progress.PhaseDatabases)
progressState.unifiedProgress.SetDatabasesTotal(total, nil)
progressState.unifiedProgress.StartDatabase(dbName, 0)
}
})
// Set up timing-aware database progress callback for cluster restore ETA
@ -406,6 +455,13 @@ func executeRestoreWithTUIProgress(parentCtx context.Context, cfg *config.Config
// Clear byte progress when switching to db progress
progressState.bytesTotal = 0
progressState.bytesDone = 0
// Update unified progress tracker
if progressState.unifiedProgress != nil {
progressState.unifiedProgress.SetPhase(progress.PhaseDatabases)
progressState.unifiedProgress.SetDatabasesTotal(total, nil)
progressState.unifiedProgress.StartDatabase(dbName, 0)
}
})
// Set up weighted (bytes-based) progress callback for accurate cluster restore progress
@ -424,6 +480,14 @@ func executeRestoreWithTUIProgress(parentCtx context.Context, cfg *config.Config
if progressState.phase3StartTime.IsZero() {
progressState.phase3StartTime = time.Now()
}
// Update unified progress tracker
if progressState.unifiedProgress != nil {
progressState.unifiedProgress.SetPhase(progress.PhaseDatabases)
progressState.unifiedProgress.SetDatabasesTotal(dbTotal, nil)
progressState.unifiedProgress.StartDatabase(dbName, bytesTotal)
progressState.unifiedProgress.UpdateDatabaseProgress(bytesDone)
}
})
// Store progress state in a package-level variable for the ticker to access
@ -489,11 +553,30 @@ func executeRestoreWithTUIProgress(parentCtx context.Context, cfg *config.Config
func (m RestoreExecutionModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.WindowSizeMsg:
// Update terminal dimensions for rich progress view
m.termWidth = msg.Width
m.termHeight = msg.Height
if m.richProgressView != nil {
m.richProgressView.SetSize(msg.Width, msg.Height)
}
return m, nil
case restoreTickMsg:
if !m.done {
m.spinnerFrame = (m.spinnerFrame + 1) % len(m.spinnerFrames)
m.elapsed = time.Since(m.startTime)
// Advance spinner for rich progress view
if m.richProgressView != nil {
m.richProgressView.AdvanceSpinner()
}
// Update unified progress reference
if m.useRichProgress && m.unifiedProgress == nil {
m.unifiedProgress = getUnifiedProgress()
}
// Poll shared progress state for real-time updates
// Note: dbPhaseElapsed is now calculated in realtime inside getCurrentRestoreProgress()
bytesTotal, bytesDone, description, hasUpdate, dbTotal, dbDone, speed, dbPhaseElapsed, dbAvgPerDB, currentDB, overallPhase, extractionDone, dbBytesTotal, dbBytesDone, _ := getCurrentRestoreProgress()
@ -782,7 +865,16 @@ func (m RestoreExecutionModel) View() string {
} else {
// Show unified progress for cluster restore
if m.restoreType == "restore-cluster" {
// Calculate overall progress across all phases
// Use rich progress view when we have unified progress data
if m.useRichProgress && m.unifiedProgress != nil {
// Render using the rich cluster progress view
s.WriteString(m.richProgressView.RenderUnified(m.unifiedProgress))
s.WriteString("\n")
s.WriteString(infoStyle.Render("[KEYS] Press Ctrl+C to cancel"))
return s.String()
}
// Fallback: Calculate overall progress across all phases
// Phase 1: Extraction (0-60%)
// Phase 2: Globals (60-65%)
// Phase 3: Databases (65-100%)

View File

@ -0,0 +1,339 @@
package tui
import (
"fmt"
"strings"
"time"
"dbbackup/internal/progress"
)
// RichClusterProgressView renders detailed cluster restore progress
type RichClusterProgressView struct {
width int
height int
spinnerFrames []string
spinnerFrame int
}
// NewRichClusterProgressView creates a new rich progress view
func NewRichClusterProgressView() *RichClusterProgressView {
return &RichClusterProgressView{
width: 80,
height: 24,
spinnerFrames: []string{
"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏",
},
}
}
// SetSize updates the terminal size
func (v *RichClusterProgressView) SetSize(width, height int) {
v.width = width
v.height = height
}
// AdvanceSpinner moves to the next spinner frame
func (v *RichClusterProgressView) AdvanceSpinner() {
v.spinnerFrame = (v.spinnerFrame + 1) % len(v.spinnerFrames)
}
// RenderUnified renders progress from UnifiedClusterProgress
func (v *RichClusterProgressView) RenderUnified(p *progress.UnifiedClusterProgress) string {
if p == nil {
return ""
}
snapshot := p.GetSnapshot()
return v.RenderSnapshot(&snapshot)
}
// RenderSnapshot renders progress from a ProgressSnapshot
func (v *RichClusterProgressView) RenderSnapshot(snapshot *progress.ProgressSnapshot) string {
if snapshot == nil {
return ""
}
var b strings.Builder
b.Grow(2048)
// Header with overall progress
b.WriteString(v.renderHeader(snapshot))
b.WriteString("\n\n")
// Overall progress bar
b.WriteString(v.renderOverallProgress(snapshot))
b.WriteString("\n\n")
// Phase-specific details
b.WriteString(v.renderPhaseDetails(snapshot))
// Performance metrics
if v.height > 15 {
b.WriteString("\n")
b.WriteString(v.renderMetricsFromSnapshot(snapshot))
}
return b.String()
}
func (v *RichClusterProgressView) renderHeader(snapshot *progress.ProgressSnapshot) string {
elapsed := time.Since(snapshot.StartTime)
// Calculate ETA based on progress
overall := v.calculateOverallPercent(snapshot)
var etaStr string
if overall > 0 && overall < 100 {
eta := time.Duration(float64(elapsed) / float64(overall) * float64(100-overall))
etaStr = fmt.Sprintf("ETA: %s", formatDuration(eta))
} else if overall >= 100 {
etaStr = "Complete!"
} else {
etaStr = "ETA: calculating..."
}
title := "Cluster Restore Progress"
separator := strings.Repeat("━", maxInt(0, v.width-len(title)-4))
return fmt.Sprintf("%s %s\n Elapsed: %s | %s",
title, separator,
formatDuration(elapsed), etaStr)
}
func (v *RichClusterProgressView) renderOverallProgress(snapshot *progress.ProgressSnapshot) string {
overall := v.calculateOverallPercent(snapshot)
// Phase indicator
phaseLabel := v.getPhaseLabel(snapshot)
// Progress bar
barWidth := v.width - 20
if barWidth < 20 {
barWidth = 20
}
bar := v.renderProgressBarWidth(overall, barWidth)
return fmt.Sprintf(" Overall: %s %3d%%\n Phase: %s", bar, overall, phaseLabel)
}
func (v *RichClusterProgressView) getPhaseLabel(snapshot *progress.ProgressSnapshot) string {
switch snapshot.Phase {
case progress.PhaseExtracting:
return fmt.Sprintf("📦 Extracting archive (%s / %s)",
FormatBytes(snapshot.ExtractBytes), FormatBytes(snapshot.ExtractTotal))
case progress.PhaseGlobals:
return "🔧 Restoring globals (roles, tablespaces)"
case progress.PhaseDatabases:
return fmt.Sprintf("🗄️ Databases (%d/%d) %s",
snapshot.DatabasesDone, snapshot.DatabasesTotal, snapshot.CurrentDB)
case progress.PhaseVerifying:
return fmt.Sprintf("✅ Verifying (%d/%d)", snapshot.VerifyDone, snapshot.VerifyTotal)
case progress.PhaseComplete:
return "🎉 Complete!"
case progress.PhaseFailed:
return "❌ Failed"
default:
return string(snapshot.Phase)
}
}
func (v *RichClusterProgressView) calculateOverallPercent(snapshot *progress.ProgressSnapshot) int {
// Use the same logic as UnifiedClusterProgress
phaseWeights := map[progress.Phase]int{
progress.PhaseExtracting: 20,
progress.PhaseGlobals: 5,
progress.PhaseDatabases: 70,
progress.PhaseVerifying: 5,
}
switch snapshot.Phase {
case progress.PhaseIdle:
return 0
case progress.PhaseExtracting:
if snapshot.ExtractTotal > 0 {
return int(float64(snapshot.ExtractBytes) / float64(snapshot.ExtractTotal) * float64(phaseWeights[progress.PhaseExtracting]))
}
return 0
case progress.PhaseGlobals:
return phaseWeights[progress.PhaseExtracting] + phaseWeights[progress.PhaseGlobals]
case progress.PhaseDatabases:
basePercent := phaseWeights[progress.PhaseExtracting] + phaseWeights[progress.PhaseGlobals]
if snapshot.DatabasesTotal == 0 {
return basePercent
}
dbProgress := float64(snapshot.DatabasesDone) / float64(snapshot.DatabasesTotal)
if snapshot.CurrentDBTotal > 0 {
currentProgress := float64(snapshot.CurrentDBBytes) / float64(snapshot.CurrentDBTotal)
dbProgress += currentProgress / float64(snapshot.DatabasesTotal)
}
return basePercent + int(dbProgress*float64(phaseWeights[progress.PhaseDatabases]))
case progress.PhaseVerifying:
basePercent := phaseWeights[progress.PhaseExtracting] + phaseWeights[progress.PhaseGlobals] + phaseWeights[progress.PhaseDatabases]
if snapshot.VerifyTotal > 0 {
verifyProgress := float64(snapshot.VerifyDone) / float64(snapshot.VerifyTotal)
return basePercent + int(verifyProgress*float64(phaseWeights[progress.PhaseVerifying]))
}
return basePercent
case progress.PhaseComplete:
return 100
default:
return 0
}
}
func (v *RichClusterProgressView) renderPhaseDetails(snapshot *progress.ProgressSnapshot) string {
var b strings.Builder
switch snapshot.Phase {
case progress.PhaseExtracting:
pct := 0
if snapshot.ExtractTotal > 0 {
pct = int(float64(snapshot.ExtractBytes) / float64(snapshot.ExtractTotal) * 100)
}
bar := v.renderMiniProgressBar(pct)
b.WriteString(fmt.Sprintf(" 📦 Extraction: %s %d%%\n", bar, pct))
b.WriteString(fmt.Sprintf(" %s / %s\n",
FormatBytes(snapshot.ExtractBytes), FormatBytes(snapshot.ExtractTotal)))
case progress.PhaseDatabases:
b.WriteString(" 📊 Databases:\n\n")
// Show completed databases if any
if snapshot.DatabasesDone > 0 {
avgTime := time.Duration(0)
if len(snapshot.DatabaseTimes) > 0 {
var total time.Duration
for _, t := range snapshot.DatabaseTimes {
total += t
}
avgTime = total / time.Duration(len(snapshot.DatabaseTimes))
}
b.WriteString(fmt.Sprintf(" ✓ %d completed (avg: %s)\n",
snapshot.DatabasesDone, formatDuration(avgTime)))
}
// Show current database
if snapshot.CurrentDB != "" {
spinner := v.spinnerFrames[v.spinnerFrame]
pct := 0
if snapshot.CurrentDBTotal > 0 {
pct = int(float64(snapshot.CurrentDBBytes) / float64(snapshot.CurrentDBTotal) * 100)
}
bar := v.renderMiniProgressBar(pct)
phaseElapsed := time.Since(snapshot.PhaseStartTime)
b.WriteString(fmt.Sprintf(" %s %-20s %s %3d%%\n",
spinner, truncateString(snapshot.CurrentDB, 20), bar, pct))
b.WriteString(fmt.Sprintf(" └─ %s / %s (running %s)\n",
FormatBytes(snapshot.CurrentDBBytes), FormatBytes(snapshot.CurrentDBTotal),
formatDuration(phaseElapsed)))
}
// Show remaining count
remaining := snapshot.DatabasesTotal - snapshot.DatabasesDone
if snapshot.CurrentDB != "" {
remaining--
}
if remaining > 0 {
b.WriteString(fmt.Sprintf(" ⏳ %d remaining\n", remaining))
}
case progress.PhaseVerifying:
pct := 0
if snapshot.VerifyTotal > 0 {
pct = snapshot.VerifyDone * 100 / snapshot.VerifyTotal
}
bar := v.renderMiniProgressBar(pct)
b.WriteString(fmt.Sprintf(" ✅ Verification: %s %d%%\n", bar, pct))
b.WriteString(fmt.Sprintf(" %d / %d databases verified\n",
snapshot.VerifyDone, snapshot.VerifyTotal))
case progress.PhaseComplete:
elapsed := time.Since(snapshot.StartTime)
b.WriteString(fmt.Sprintf(" 🎉 Restore complete!\n"))
b.WriteString(fmt.Sprintf(" %d databases restored in %s\n",
snapshot.DatabasesDone, formatDuration(elapsed)))
case progress.PhaseFailed:
b.WriteString(" ❌ Restore failed:\n")
for _, err := range snapshot.Errors {
b.WriteString(fmt.Sprintf(" • %s\n", truncateString(err, v.width-10)))
}
}
return b.String()
}
func (v *RichClusterProgressView) renderMetricsFromSnapshot(snapshot *progress.ProgressSnapshot) string {
var b strings.Builder
b.WriteString(" 📈 Performance:\n")
elapsed := time.Since(snapshot.StartTime)
if elapsed > 0 {
// Calculate throughput from extraction phase if we have data
if snapshot.ExtractBytes > 0 && elapsed.Seconds() > 0 {
throughput := float64(snapshot.ExtractBytes) / elapsed.Seconds()
b.WriteString(fmt.Sprintf(" Throughput: %s/s\n", FormatBytes(int64(throughput))))
}
// Database timing info
if len(snapshot.DatabaseTimes) > 0 {
var total time.Duration
for _, t := range snapshot.DatabaseTimes {
total += t
}
avg := total / time.Duration(len(snapshot.DatabaseTimes))
b.WriteString(fmt.Sprintf(" Avg DB time: %s\n", formatDuration(avg)))
}
}
return b.String()
}
// Helper functions
func (v *RichClusterProgressView) renderProgressBarWidth(pct, width int) string {
if width < 10 {
width = 10
}
filled := (pct * width) / 100
empty := width - filled
bar := strings.Repeat("█", filled) + strings.Repeat("░", empty)
return "[" + bar + "]"
}
func (v *RichClusterProgressView) renderMiniProgressBar(pct int) string {
width := 20
filled := (pct * width) / 100
empty := width - filled
return strings.Repeat("█", filled) + strings.Repeat("░", empty)
}
func truncateString(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
if maxLen < 4 {
return s[:maxLen]
}
return s[:maxLen-3] + "..."
}
func maxInt(a, b int) int {
if a > b {
return a
}
return b
}
func formatNumShort(n int64) string {
if n >= 1e9 {
return fmt.Sprintf("%.1fB", float64(n)/1e9)
} else if n >= 1e6 {
return fmt.Sprintf("%.1fM", float64(n)/1e6)
} else if n >= 1e3 {
return fmt.Sprintf("%.1fK", float64(n)/1e3)
}
return fmt.Sprintf("%d", n)
}

View File

@ -16,7 +16,7 @@ import (
// Build information (set by ldflags)
var (
version = "5.4.2"
version = "5.4.3"
buildTime = "unknown"
gitCommit = "unknown"
)

192
scripts/test-sigint-cleanup.sh Executable file
View File

@ -0,0 +1,192 @@
#!/bin/bash
# scripts/test-sigint-cleanup.sh
# Test script to verify clean shutdown on SIGINT (Ctrl+C)
set -e
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_DIR="$(dirname "$SCRIPT_DIR")"
BINARY="$PROJECT_DIR/dbbackup"
# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m' # No Color
echo "=== SIGINT Cleanup Test ==="
echo ""
echo "Project: $PROJECT_DIR"
echo "Binary: $BINARY"
echo ""
# Check if binary exists
if [ ! -f "$BINARY" ]; then
echo -e "${YELLOW}Binary not found, building...${NC}"
cd "$PROJECT_DIR"
go build -o dbbackup .
fi
# Create a test backup file if it doesn't exist
TEST_BACKUP="/tmp/test-sigint-backup.sql.gz"
if [ ! -f "$TEST_BACKUP" ]; then
echo -e "${YELLOW}Creating test backup file...${NC}"
echo "-- Test SQL file for SIGINT testing" | gzip > "$TEST_BACKUP"
fi
echo "=== Phase 1: Pre-test Cleanup ==="
echo "Killing any existing dbbackup processes..."
pkill -f "dbbackup" 2>/dev/null || true
sleep 1
echo ""
echo "=== Phase 2: Check Initial State ==="
echo "Checking for orphaned processes..."
INITIAL_PROCS=$(pgrep -f "pg_dump|pg_restore|dbbackup" 2>/dev/null | wc -l)
echo "Initial related processes: $INITIAL_PROCS"
echo ""
echo "Checking for temp files..."
INITIAL_TEMPS=$(ls /tmp/dbbackup-* 2>/dev/null | wc -l || echo "0")
echo "Initial temp files: $INITIAL_TEMPS"
echo ""
echo "=== Phase 3: Start Test Operation ==="
# Start a TUI operation that will hang (version is fast, but menu would wait)
echo "Starting dbbackup TUI (will be interrupted)..."
# Run in background with PTY simulation (needed for TUI)
cd "$PROJECT_DIR"
timeout 30 script -q -c "$BINARY" /dev/null &
PID=$!
echo "Process started: PID=$PID"
sleep 2
# Check if process is running
if ! kill -0 $PID 2>/dev/null; then
echo -e "${YELLOW}Process exited quickly (expected for non-interactive test)${NC}"
echo "This is normal - the TUI requires a real TTY"
PID=""
else
echo "Process is running"
echo ""
echo "=== Phase 4: Check Running State ==="
echo "Child processes of $PID:"
pgrep -P $PID 2>/dev/null | while read child; do
ps -p $child -o pid,ppid,cmd 2>/dev/null || true
done
echo ""
echo "=== Phase 5: Send SIGINT ==="
echo "Sending SIGINT to process $PID..."
kill -SIGINT $PID 2>/dev/null || true
echo "Waiting for cleanup (max 10 seconds)..."
for i in {1..10}; do
if ! kill -0 $PID 2>/dev/null; then
echo ""
echo -e "${GREEN}Process exited after ${i} seconds${NC}"
break
fi
sleep 1
echo -n "."
done
echo ""
# Check if still running
if kill -0 $PID 2>/dev/null; then
echo -e "${RED}Process still running after 10 seconds!${NC}"
echo "Force killing..."
kill -9 $PID 2>/dev/null || true
fi
fi
sleep 2 # Give OS time to clean up
echo ""
echo "=== Phase 6: Post-Shutdown Verification ==="
# Check for zombie processes
ZOMBIES=$(ps aux 2>/dev/null | grep -E "dbbackup|pg_dump|pg_restore" | grep -v grep | grep defunct | wc -l)
echo "Zombie processes: $ZOMBIES"
# Check for orphaned children
if [ -n "$PID" ]; then
ORPHANS=$(pgrep -P $PID 2>/dev/null | wc -l || echo "0")
echo "Orphaned children of original process: $ORPHANS"
else
ORPHANS=0
fi
# Check for leftover related processes
LEFTOVER_PROCS=$(pgrep -f "pg_dump|pg_restore" 2>/dev/null | wc -l || echo "0")
echo "Leftover pg_dump/pg_restore processes: $LEFTOVER_PROCS"
# Check for temp files
TEMP_FILES=$(ls /tmp/dbbackup-* 2>/dev/null | wc -l || echo "0")
echo "Temporary files: $TEMP_FILES"
# Database connections check (if psql available and configured)
if command -v psql &> /dev/null; then
echo ""
echo "Checking database connections..."
DB_CONNS=$(psql -t -c "SELECT count(*) FROM pg_stat_activity WHERE application_name LIKE '%dbbackup%';" 2>/dev/null | tr -d ' ' || echo "N/A")
echo "Database connections with 'dbbackup' in name: $DB_CONNS"
else
echo "psql not available - skipping database connection check"
DB_CONNS="N/A"
fi
echo ""
echo "=== Test Results ==="
PASSED=true
if [ "$ZOMBIES" -gt 0 ]; then
echo -e "${RED}❌ FAIL: $ZOMBIES zombie process(es) found${NC}"
PASSED=false
else
echo -e "${GREEN}✓ No zombie processes${NC}"
fi
if [ "$ORPHANS" -gt 0 ]; then
echo -e "${RED}❌ FAIL: $ORPHANS orphaned child process(es) found${NC}"
PASSED=false
else
echo -e "${GREEN}✓ No orphaned children${NC}"
fi
if [ "$LEFTOVER_PROCS" -gt 0 ]; then
echo -e "${YELLOW}⚠ WARNING: $LEFTOVER_PROCS leftover pg_dump/pg_restore process(es)${NC}"
echo " These may be from other operations"
fi
if [ "$TEMP_FILES" -gt "$INITIAL_TEMPS" ]; then
NEW_TEMPS=$((TEMP_FILES - INITIAL_TEMPS))
echo -e "${RED}❌ FAIL: $NEW_TEMPS new temporary file(s) left behind${NC}"
ls -la /tmp/dbbackup-* 2>/dev/null || true
PASSED=false
else
echo -e "${GREEN}✓ No new temporary files left behind${NC}"
fi
if [ "$DB_CONNS" != "N/A" ] && [ "$DB_CONNS" -gt 0 ]; then
echo -e "${RED}❌ FAIL: $DB_CONNS database connection(s) still active${NC}"
PASSED=false
elif [ "$DB_CONNS" != "N/A" ]; then
echo -e "${GREEN}✓ No lingering database connections${NC}"
fi
echo ""
if [ "$PASSED" = true ]; then
echo -e "${GREEN}=== ✓ ALL TESTS PASSED ===${NC}"
exit 0
else
echo -e "${RED}=== ✗ SOME TESTS FAILED ===${NC}"
exit 1
fi