Files
dbbackup/internal/tui/restore_exec.go
Alexander Renz eeacbfa007
All checks were successful
CI/CD / Test (push) Successful in 1m28s
CI/CD / Lint (push) Successful in 1m41s
CI/CD / Build & Release (push) Successful in 4m9s
TUI: Add detailed progress tracking with rolling speed and database count
- Add TUI-native detailed progress component (detailed_progress.go)
- Hide spinner when progress bar is shown for cleaner display
- Implement rolling window speed calculation (5-sec window, 100 samples)
- Add database count tracking (X/Y) for cluster restore operations
- Wire DatabaseProgressCallback to restore engine for multi-db progress
2026-01-15 15:16:21 +01:00

725 lines
22 KiB
Go
Executable File

package tui
import (
"context"
"fmt"
"os/exec"
"path/filepath"
"strings"
"sync"
"time"
tea "github.com/charmbracelet/bubbletea"
"dbbackup/internal/config"
"dbbackup/internal/database"
"dbbackup/internal/logger"
"dbbackup/internal/restore"
)
// Shared spinner frames for consistent animation across all TUI operations
var spinnerFrames = []string{"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"}
// RestoreExecutionModel handles restore execution with progress
type RestoreExecutionModel struct {
config *config.Config
logger logger.Logger
parent tea.Model
ctx context.Context
cancel context.CancelFunc // Cancel function to stop the operation
archive ArchiveInfo
targetDB string
cleanFirst bool
createIfMissing bool
restoreType string
cleanClusterFirst bool // Drop all user databases before cluster restore
existingDBs []string // List of databases to drop
saveDebugLog bool // Save detailed error report on failure
workDir string // Custom work directory for extraction
// Progress tracking
status string
phase string
progress int
details []string
startTime time.Time
spinnerFrame int
spinnerFrames []string
// Detailed byte progress for schollz-style display
bytesTotal int64
bytesDone int64
description string
showBytes bool // True when we have real byte progress to show
speed float64 // Rolling window speed in bytes/sec
// Database count progress (for cluster restore)
dbTotal int
dbDone int
// Results
done bool
cancelling bool // True when user has requested cancellation
err error
result string
elapsed time.Duration
}
// NewRestoreExecution creates a new restore execution model
func NewRestoreExecution(cfg *config.Config, log logger.Logger, parent tea.Model, ctx context.Context, archive ArchiveInfo, targetDB string, cleanFirst, createIfMissing bool, restoreType string, cleanClusterFirst bool, existingDBs []string, saveDebugLog bool, workDir string) RestoreExecutionModel {
// Create a cancellable context derived from parent
childCtx, cancel := context.WithCancel(ctx)
return RestoreExecutionModel{
config: cfg,
logger: log,
parent: parent,
ctx: childCtx,
cancel: cancel,
archive: archive,
targetDB: targetDB,
cleanFirst: cleanFirst,
createIfMissing: createIfMissing,
restoreType: restoreType,
cleanClusterFirst: cleanClusterFirst,
existingDBs: existingDBs,
saveDebugLog: saveDebugLog,
workDir: workDir,
status: "Initializing...",
phase: "Starting",
startTime: time.Now(),
details: []string{},
spinnerFrames: spinnerFrames, // Use package-level constant
spinnerFrame: 0,
}
}
func (m RestoreExecutionModel) Init() tea.Cmd {
return tea.Batch(
executeRestoreWithTUIProgress(m.ctx, m.config, m.logger, m.archive, m.targetDB, m.cleanFirst, m.createIfMissing, m.restoreType, m.cleanClusterFirst, m.existingDBs, m.saveDebugLog),
restoreTickCmd(),
)
}
type restoreTickMsg time.Time
func restoreTickCmd() tea.Cmd {
return tea.Tick(time.Millisecond*100, func(t time.Time) tea.Msg {
return restoreTickMsg(t)
})
}
type restoreProgressMsg struct {
status string
phase string
progress int
detail string
bytesTotal int64
bytesDone int64
description string
}
type restoreCompleteMsg struct {
result string
err error
elapsed time.Duration
}
// sharedProgressState holds progress state that can be safely accessed from callbacks
type sharedProgressState struct {
mu sync.Mutex
bytesTotal int64
bytesDone int64
description string
hasUpdate bool
// Database count progress (for cluster restore)
dbTotal int
dbDone int
// Rolling window for speed calculation
speedSamples []restoreSpeedSample
}
type restoreSpeedSample struct {
timestamp time.Time
bytes int64
}
// Package-level shared progress state for restore operations
var (
currentRestoreProgressMu sync.Mutex
currentRestoreProgressState *sharedProgressState
)
func setCurrentRestoreProgress(state *sharedProgressState) {
currentRestoreProgressMu.Lock()
defer currentRestoreProgressMu.Unlock()
currentRestoreProgressState = state
}
func clearCurrentRestoreProgress() {
currentRestoreProgressMu.Lock()
defer currentRestoreProgressMu.Unlock()
currentRestoreProgressState = nil
}
func getCurrentRestoreProgress() (bytesTotal, bytesDone int64, description string, hasUpdate bool, dbTotal, dbDone int, speed float64) {
currentRestoreProgressMu.Lock()
defer currentRestoreProgressMu.Unlock()
if currentRestoreProgressState == nil {
return 0, 0, "", false, 0, 0, 0
}
currentRestoreProgressState.mu.Lock()
defer currentRestoreProgressState.mu.Unlock()
// Calculate rolling window speed
speed = calculateRollingSpeed(currentRestoreProgressState.speedSamples)
return currentRestoreProgressState.bytesTotal, currentRestoreProgressState.bytesDone,
currentRestoreProgressState.description, currentRestoreProgressState.hasUpdate,
currentRestoreProgressState.dbTotal, currentRestoreProgressState.dbDone, speed
}
// calculateRollingSpeed calculates speed from recent samples (last 5 seconds)
func calculateRollingSpeed(samples []restoreSpeedSample) float64 {
if len(samples) < 2 {
return 0
}
// Use samples from last 5 seconds for smoothed speed
now := time.Now()
cutoff := now.Add(-5 * time.Second)
var firstInWindow, lastInWindow *restoreSpeedSample
for i := range samples {
if samples[i].timestamp.After(cutoff) {
if firstInWindow == nil {
firstInWindow = &samples[i]
}
lastInWindow = &samples[i]
}
}
// Fall back to first and last if window is empty
if firstInWindow == nil || lastInWindow == nil || firstInWindow == lastInWindow {
firstInWindow = &samples[0]
lastInWindow = &samples[len(samples)-1]
}
elapsed := lastInWindow.timestamp.Sub(firstInWindow.timestamp).Seconds()
if elapsed <= 0 {
return 0
}
bytesTransferred := lastInWindow.bytes - firstInWindow.bytes
return float64(bytesTransferred) / elapsed
}
// restoreProgressChannel allows sending progress updates from the restore goroutine
type restoreProgressChannel chan restoreProgressMsg
func executeRestoreWithTUIProgress(parentCtx context.Context, cfg *config.Config, log logger.Logger, archive ArchiveInfo, targetDB string, cleanFirst, createIfMissing bool, restoreType string, cleanClusterFirst bool, existingDBs []string, saveDebugLog bool) tea.Cmd {
return func() tea.Msg {
// NO TIMEOUT for restore operations - a restore takes as long as it takes
// Large databases with large objects can take many hours
// Only manual cancellation (Ctrl+C) should stop the restore
ctx, cancel := context.WithCancel(parentCtx)
defer cancel()
start := time.Now()
// Create database instance
dbClient, err := database.New(cfg, log)
if err != nil {
return restoreCompleteMsg{
result: "",
err: fmt.Errorf("failed to create database client: %w", err),
elapsed: time.Since(start),
}
}
defer dbClient.Close()
// STEP 1: Clean cluster if requested (drop all existing user databases)
if restoreType == "restore-cluster" && cleanClusterFirst && len(existingDBs) > 0 {
log.Info("Dropping existing user databases before cluster restore", "count", len(existingDBs))
// Drop databases using command-line psql (no connection required)
// This matches how cluster restore works - uses CLI tools, not database connections
droppedCount := 0
for _, dbName := range existingDBs {
// Create timeout context for each database drop (5 minutes per DB - large DBs take time)
dropCtx, dropCancel := context.WithTimeout(ctx, 5*time.Minute)
if err := dropDatabaseCLI(dropCtx, cfg, dbName); err != nil {
log.Warn("Failed to drop database", "name", dbName, "error", err)
// Continue with other databases
} else {
droppedCount++
log.Info("Dropped database", "name", dbName)
}
dropCancel() // Clean up context
}
log.Info("Cluster cleanup completed", "dropped", droppedCount, "total", len(existingDBs))
}
// STEP 2: Create restore engine with silent progress (no stdout interference with TUI)
engine := restore.NewSilent(cfg, log, dbClient)
// Set up progress callback for detailed progress reporting
// We use a shared pointer that can be queried by the TUI ticker
progressState := &sharedProgressState{
speedSamples: make([]restoreSpeedSample, 0, 100),
}
engine.SetProgressCallback(func(current, total int64, description string) {
progressState.mu.Lock()
defer progressState.mu.Unlock()
progressState.bytesDone = current
progressState.bytesTotal = total
progressState.description = description
progressState.hasUpdate = true
// Add speed sample for rolling window calculation
progressState.speedSamples = append(progressState.speedSamples, restoreSpeedSample{
timestamp: time.Now(),
bytes: current,
})
// Keep only last 100 samples
if len(progressState.speedSamples) > 100 {
progressState.speedSamples = progressState.speedSamples[len(progressState.speedSamples)-100:]
}
})
// Set up database progress callback for cluster restore
engine.SetDatabaseProgressCallback(func(done, total int, dbName string) {
progressState.mu.Lock()
defer progressState.mu.Unlock()
progressState.dbDone = done
progressState.dbTotal = total
progressState.description = fmt.Sprintf("Restoring %s", dbName)
progressState.hasUpdate = true
// Clear byte progress when switching to db progress
progressState.bytesTotal = 0
progressState.bytesDone = 0
})
// Store progress state in a package-level variable for the ticker to access
// This is a workaround because tea messages can't be sent from callbacks
setCurrentRestoreProgress(progressState)
defer clearCurrentRestoreProgress()
// Enable debug logging if requested
if saveDebugLog {
// Generate debug log path using configured WorkDir
workDir := cfg.GetEffectiveWorkDir()
debugLogPath := filepath.Join(workDir, fmt.Sprintf("dbbackup-restore-debug-%s.json", time.Now().Format("20060102-150405")))
engine.SetDebugLogPath(debugLogPath)
log.Info("Debug logging enabled", "path", debugLogPath)
}
// STEP 3: Execute restore based on type
var restoreErr error
if restoreType == "restore-cluster" {
restoreErr = engine.RestoreCluster(ctx, archive.Path)
} else {
restoreErr = engine.RestoreSingle(ctx, archive.Path, targetDB, cleanFirst, createIfMissing)
}
if restoreErr != nil {
return restoreCompleteMsg{
result: "",
err: restoreErr,
elapsed: time.Since(start),
}
}
result := fmt.Sprintf("Successfully restored from %s", archive.Name)
if restoreType == "restore-single" {
result = fmt.Sprintf("Successfully restored '%s' from %s", targetDB, archive.Name)
} else if restoreType == "restore-cluster" && cleanClusterFirst {
result = fmt.Sprintf("Successfully restored cluster from %s (cleaned %d existing database(s) first)", archive.Name, len(existingDBs))
}
return restoreCompleteMsg{
result: result,
err: nil,
elapsed: time.Since(start),
}
}
}
func (m RestoreExecutionModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case restoreTickMsg:
if !m.done {
m.spinnerFrame = (m.spinnerFrame + 1) % len(m.spinnerFrames)
m.elapsed = time.Since(m.startTime)
// Poll shared progress state for real-time updates
bytesTotal, bytesDone, description, hasUpdate, dbTotal, dbDone, speed := getCurrentRestoreProgress()
if hasUpdate && bytesTotal > 0 {
m.bytesTotal = bytesTotal
m.bytesDone = bytesDone
m.description = description
m.showBytes = true
m.speed = speed
// Update status to reflect actual progress
m.status = description
m.phase = "Extracting"
m.progress = int((bytesDone * 100) / bytesTotal)
} else if hasUpdate && dbTotal > 0 {
// Database count progress for cluster restore
m.dbTotal = dbTotal
m.dbDone = dbDone
m.showBytes = false
m.status = fmt.Sprintf("Restoring database %d of %d...", dbDone+1, dbTotal)
m.phase = "Restore"
m.progress = int((dbDone * 100) / dbTotal)
} else {
// Fallback: Update status based on elapsed time to show progress
// This provides visual feedback even though we don't have real-time progress
elapsedSec := int(m.elapsed.Seconds())
if elapsedSec < 2 {
m.status = "Initializing restore..."
m.phase = "Starting"
} else if elapsedSec < 5 {
if m.cleanClusterFirst && len(m.existingDBs) > 0 {
m.status = fmt.Sprintf("Cleaning %d existing database(s)...", len(m.existingDBs))
m.phase = "Cleanup"
} else if m.restoreType == "restore-cluster" {
m.status = "Extracting cluster archive..."
m.phase = "Extraction"
} else {
m.status = "Preparing restore..."
m.phase = "Preparation"
}
} else if elapsedSec < 10 {
if m.restoreType == "restore-cluster" {
m.status = "Restoring global objects..."
m.phase = "Globals"
} else {
m.status = fmt.Sprintf("Restoring database '%s'...", m.targetDB)
m.phase = "Restore"
}
} else {
if m.restoreType == "restore-cluster" {
m.status = "Restoring cluster databases..."
m.phase = "Restore"
} else {
m.status = fmt.Sprintf("Restoring database '%s'...", m.targetDB)
m.phase = "Restore"
}
}
}
return m, restoreTickCmd()
}
return m, nil
case restoreProgressMsg:
m.status = msg.status
m.phase = msg.phase
m.progress = msg.progress
// Update byte-level progress if available
if msg.bytesTotal > 0 {
m.bytesTotal = msg.bytesTotal
m.bytesDone = msg.bytesDone
m.description = msg.description
m.showBytes = true
}
if msg.detail != "" {
m.details = append(m.details, msg.detail)
// Keep only last 5 details
if len(m.details) > 5 {
m.details = m.details[len(m.details)-5:]
}
}
return m, nil
case restoreCompleteMsg:
m.done = true
m.err = msg.err
m.result = msg.result
m.elapsed = msg.elapsed
if m.err == nil {
m.status = "Restore completed successfully"
m.phase = "Done"
m.progress = 100
} else {
m.status = "Failed"
m.phase = "Error"
}
// Auto-forward in auto-confirm mode when done
if m.config.TUIAutoConfirm && m.done {
return m.parent, tea.Quit
}
return m, nil
case tea.KeyMsg:
switch msg.String() {
case "ctrl+c", "esc":
if !m.done && !m.cancelling {
// User requested cancellation - cancel the context
m.cancelling = true
m.status = "[STOP] Cancelling restore... (please wait)"
m.phase = "Cancelling"
if m.cancel != nil {
m.cancel()
}
return m, nil
} else if m.done {
return m.parent, nil
}
case "q":
if !m.done && !m.cancelling {
m.cancelling = true
m.status = "[STOP] Cancelling restore... (please wait)"
m.phase = "Cancelling"
if m.cancel != nil {
m.cancel()
}
return m, nil
} else if m.done {
return m.parent, tea.Quit
}
case "enter", " ":
if m.done {
return m.parent, nil
}
}
}
return m, nil
}
func (m RestoreExecutionModel) View() string {
var s strings.Builder
s.Grow(512) // Pre-allocate estimated capacity for better performance
// Title
title := "[EXEC] Restoring Database"
if m.restoreType == "restore-cluster" {
title = "[EXEC] Restoring Cluster"
}
s.WriteString(titleStyle.Render(title))
s.WriteString("\n\n")
// Archive info
s.WriteString(fmt.Sprintf("Archive: %s\n", m.archive.Name))
if m.restoreType == "restore-single" {
s.WriteString(fmt.Sprintf("Target: %s\n", m.targetDB))
}
s.WriteString("\n")
if m.done {
// Show result
if m.err != nil {
s.WriteString(errorStyle.Render("[FAIL] Restore Failed"))
s.WriteString("\n\n")
s.WriteString(errorStyle.Render(fmt.Sprintf("Error: %v", m.err)))
s.WriteString("\n")
} else {
s.WriteString(successStyle.Render("[OK] Restore Completed Successfully"))
s.WriteString("\n\n")
s.WriteString(successStyle.Render(m.result))
s.WriteString("\n")
}
s.WriteString(fmt.Sprintf("\nElapsed Time: %s\n", formatDuration(m.elapsed)))
s.WriteString("\n")
s.WriteString(infoStyle.Render("[KEYS] Press Enter to continue"))
} else {
// Show progress
s.WriteString(fmt.Sprintf("Phase: %s\n", m.phase))
// Show detailed progress bar when we have byte-level information
// In this case, hide the spinner for cleaner display
if m.showBytes && m.bytesTotal > 0 {
// Status line without spinner (progress bar provides activity indication)
s.WriteString(fmt.Sprintf("Status: %s\n", m.status))
s.WriteString("\n")
// Render schollz-style progress bar with bytes, rolling speed, ETA
s.WriteString(renderDetailedProgressBarWithSpeed(m.bytesDone, m.bytesTotal, m.speed))
s.WriteString("\n\n")
} else if m.dbTotal > 0 {
// Database count progress for cluster restore
spinner := m.spinnerFrames[m.spinnerFrame]
s.WriteString(fmt.Sprintf("Status: %s %s\n", spinner, m.status))
s.WriteString("\n")
// Show database progress bar
s.WriteString(renderDatabaseProgressBar(m.dbDone, m.dbTotal))
s.WriteString("\n\n")
} else {
// Show status with rotating spinner (for phases without detailed progress)
spinner := m.spinnerFrames[m.spinnerFrame]
s.WriteString(fmt.Sprintf("Status: %s %s\n", spinner, m.status))
s.WriteString("\n")
if m.restoreType == "restore-single" {
// Fallback to simple progress bar for single database restore
progressBar := renderProgressBar(m.progress)
s.WriteString(progressBar)
s.WriteString(fmt.Sprintf(" %d%%\n", m.progress))
s.WriteString("\n")
}
}
// Elapsed time
s.WriteString(fmt.Sprintf("Elapsed: %s\n", formatDuration(m.elapsed)))
s.WriteString("\n")
s.WriteString(infoStyle.Render("[KEYS] Press Ctrl+C to cancel"))
}
return s.String()
}
// renderProgressBar renders a text progress bar
func renderProgressBar(percent int) string {
width := 40
filled := (percent * width) / 100
bar := strings.Repeat("█", filled)
empty := strings.Repeat("░", width-filled)
return successStyle.Render(bar) + infoStyle.Render(empty)
}
// renderDetailedProgressBar renders a schollz-style progress bar with bytes, speed, and ETA
// Uses elapsed time for speed calculation (fallback)
func renderDetailedProgressBar(done, total int64, elapsed time.Duration) string {
speed := 0.0
if elapsed.Seconds() > 0 {
speed = float64(done) / elapsed.Seconds()
}
return renderDetailedProgressBarWithSpeed(done, total, speed)
}
// renderDetailedProgressBarWithSpeed renders a schollz-style progress bar with pre-calculated rolling speed
func renderDetailedProgressBarWithSpeed(done, total int64, speed float64) string {
var s strings.Builder
// Calculate percentage
percent := 0
if total > 0 {
percent = int((done * 100) / total)
if percent > 100 {
percent = 100
}
}
// Render progress bar
width := 30
filled := (percent * width) / 100
barFilled := strings.Repeat("█", filled)
barEmpty := strings.Repeat("░", width-filled)
s.WriteString(successStyle.Render("["))
s.WriteString(successStyle.Render(barFilled))
s.WriteString(infoStyle.Render(barEmpty))
s.WriteString(successStyle.Render("]"))
// Percentage
s.WriteString(fmt.Sprintf(" %3d%%", percent))
// Bytes progress
s.WriteString(fmt.Sprintf(" %s / %s", FormatBytes(done), FormatBytes(total)))
// Speed display (using rolling window speed)
if speed > 0 {
s.WriteString(fmt.Sprintf(" %s/s", FormatBytes(int64(speed))))
// ETA calculation based on rolling speed
if done < total {
remaining := total - done
etaSeconds := float64(remaining) / speed
eta := time.Duration(etaSeconds) * time.Second
s.WriteString(fmt.Sprintf(" ETA: %s", FormatDurationShort(eta)))
}
}
return s.String()
}
// renderDatabaseProgressBar renders a progress bar for database count (cluster restore)
func renderDatabaseProgressBar(done, total int) string {
var s strings.Builder
// Calculate percentage
percent := 0
if total > 0 {
percent = (done * 100) / total
if percent > 100 {
percent = 100
}
}
// Render progress bar
width := 30
filled := (percent * width) / 100
barFilled := strings.Repeat("█", filled)
barEmpty := strings.Repeat("░", width-filled)
s.WriteString(successStyle.Render("["))
s.WriteString(successStyle.Render(barFilled))
s.WriteString(infoStyle.Render(barEmpty))
s.WriteString(successStyle.Render("]"))
// Count and percentage
s.WriteString(fmt.Sprintf(" %3d%% %d / %d databases", percent, done, total))
return s.String()
}
// formatDuration formats duration in human readable format
func formatDuration(d time.Duration) string {
if d < time.Minute {
return fmt.Sprintf("%.1fs", d.Seconds())
}
if d < time.Hour {
minutes := int(d.Minutes())
seconds := int(d.Seconds()) % 60
return fmt.Sprintf("%dm %ds", minutes, seconds)
}
hours := int(d.Hours())
minutes := int(d.Minutes()) % 60
return fmt.Sprintf("%dh %dm", hours, minutes)
}
// dropDatabaseCLI drops a database using command-line psql
// This avoids needing an active database connection
func dropDatabaseCLI(ctx context.Context, cfg *config.Config, dbName string) error {
args := []string{
"-p", fmt.Sprintf("%d", cfg.Port),
"-U", cfg.User,
"-d", "postgres", // Connect to postgres maintenance DB
"-c", fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbName),
}
// Only add -h flag if host is not localhost (to use Unix socket for peer auth)
if cfg.Host != "localhost" && cfg.Host != "127.0.0.1" && cfg.Host != "" {
args = append([]string{"-h", cfg.Host}, args...)
}
cmd := exec.CommandContext(ctx, "psql", args...)
// Set password if provided
if cfg.Password != "" {
cmd.Env = append(cmd.Environ(), fmt.Sprintf("PGPASSWORD=%s", cfg.Password))
}
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("failed to drop database %s: %w\nOutput: %s", dbName, err, string(output))
}
return nil
}