Added explicit context checks at critical points: 1. After extraction completes - logs error if context was cancelled 2. Before database restore loop starts - catches premature cancellation This helps diagnose issues where all database restores fail with 'context cancelled' even though extraction completed successfully. The user reported this happening after 4h20m extraction - all 6 DBs showed 'restore skipped (context cancelled)'. These checks will log exactly when/where the context becomes invalid.
2428 lines
83 KiB
Go
Executable File
2428 lines
83 KiB
Go
Executable File
package restore
|
||
|
||
import (
|
||
"archive/tar"
|
||
"compress/gzip"
|
||
"context"
|
||
"database/sql"
|
||
"fmt"
|
||
"io"
|
||
"os"
|
||
"os/exec"
|
||
"path/filepath"
|
||
"strconv"
|
||
"strings"
|
||
"sync"
|
||
"sync/atomic"
|
||
"time"
|
||
|
||
"dbbackup/internal/checks"
|
||
"dbbackup/internal/config"
|
||
"dbbackup/internal/database"
|
||
"dbbackup/internal/logger"
|
||
"dbbackup/internal/progress"
|
||
"dbbackup/internal/security"
|
||
|
||
"github.com/hashicorp/go-multierror"
|
||
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
|
||
)
|
||
|
||
// ProgressCallback is called with progress updates during long operations
|
||
// Parameters: current bytes/items done, total bytes/items, description
|
||
type ProgressCallback func(current, total int64, description string)
|
||
|
||
// DatabaseProgressCallback is called with database count progress during cluster restore
|
||
type DatabaseProgressCallback func(done, total int, dbName string)
|
||
|
||
// DatabaseProgressWithTimingCallback is called with database progress including timing info
|
||
// Parameters: done count, total count, database name, elapsed time for current restore phase, avg duration per DB
|
||
type DatabaseProgressWithTimingCallback func(done, total int, dbName string, phaseElapsed, avgPerDB time.Duration)
|
||
|
||
// DatabaseProgressByBytesCallback is called with progress weighted by database sizes (bytes)
|
||
// Parameters: bytes completed, total bytes, current database name, databases done count, total database count
|
||
type DatabaseProgressByBytesCallback func(bytesDone, bytesTotal int64, dbName string, dbDone, dbTotal int)
|
||
|
||
// Engine handles database restore operations
|
||
type Engine struct {
|
||
cfg *config.Config
|
||
log logger.Logger
|
||
db database.Database
|
||
progress progress.Indicator
|
||
detailedReporter *progress.DetailedReporter
|
||
dryRun bool
|
||
debugLogPath string // Path to save debug log on error
|
||
|
||
// TUI progress callback for detailed progress reporting
|
||
progressCallback ProgressCallback
|
||
dbProgressCallback DatabaseProgressCallback
|
||
dbProgressTimingCallback DatabaseProgressWithTimingCallback
|
||
dbProgressByBytesCallback DatabaseProgressByBytesCallback
|
||
}
|
||
|
||
// New creates a new restore engine
|
||
func New(cfg *config.Config, log logger.Logger, db database.Database) *Engine {
|
||
progressIndicator := progress.NewIndicator(true, "line")
|
||
detailedReporter := progress.NewDetailedReporter(progressIndicator, &loggerAdapter{logger: log})
|
||
|
||
return &Engine{
|
||
cfg: cfg,
|
||
log: log,
|
||
db: db,
|
||
progress: progressIndicator,
|
||
detailedReporter: detailedReporter,
|
||
dryRun: false,
|
||
}
|
||
}
|
||
|
||
// NewSilent creates a new restore engine with no stdout progress (for TUI mode)
|
||
func NewSilent(cfg *config.Config, log logger.Logger, db database.Database) *Engine {
|
||
progressIndicator := progress.NewNullIndicator()
|
||
detailedReporter := progress.NewDetailedReporter(progressIndicator, &loggerAdapter{logger: log})
|
||
|
||
return &Engine{
|
||
cfg: cfg,
|
||
log: log,
|
||
db: db,
|
||
progress: progressIndicator,
|
||
detailedReporter: detailedReporter,
|
||
dryRun: false,
|
||
}
|
||
}
|
||
|
||
// NewWithProgress creates a restore engine with custom progress indicator
|
||
func NewWithProgress(cfg *config.Config, log logger.Logger, db database.Database, progressIndicator progress.Indicator, dryRun bool) *Engine {
|
||
if progressIndicator == nil {
|
||
progressIndicator = progress.NewNullIndicator()
|
||
}
|
||
|
||
detailedReporter := progress.NewDetailedReporter(progressIndicator, &loggerAdapter{logger: log})
|
||
|
||
return &Engine{
|
||
cfg: cfg,
|
||
log: log,
|
||
db: db,
|
||
progress: progressIndicator,
|
||
detailedReporter: detailedReporter,
|
||
dryRun: dryRun,
|
||
}
|
||
}
|
||
|
||
// SetDebugLogPath enables saving detailed error reports on failure
|
||
func (e *Engine) SetDebugLogPath(path string) {
|
||
e.debugLogPath = path
|
||
}
|
||
|
||
// SetProgressCallback sets a callback for detailed progress reporting (for TUI mode)
|
||
func (e *Engine) SetProgressCallback(cb ProgressCallback) {
|
||
e.progressCallback = cb
|
||
}
|
||
|
||
// SetDatabaseProgressCallback sets a callback for database count progress during cluster restore
|
||
func (e *Engine) SetDatabaseProgressCallback(cb DatabaseProgressCallback) {
|
||
e.dbProgressCallback = cb
|
||
}
|
||
|
||
// SetDatabaseProgressWithTimingCallback sets a callback for database progress with timing info
|
||
func (e *Engine) SetDatabaseProgressWithTimingCallback(cb DatabaseProgressWithTimingCallback) {
|
||
e.dbProgressTimingCallback = cb
|
||
}
|
||
|
||
// SetDatabaseProgressByBytesCallback sets a callback for progress weighted by database sizes
|
||
func (e *Engine) SetDatabaseProgressByBytesCallback(cb DatabaseProgressByBytesCallback) {
|
||
e.dbProgressByBytesCallback = cb
|
||
}
|
||
|
||
// reportProgress safely calls the progress callback if set
|
||
func (e *Engine) reportProgress(current, total int64, description string) {
|
||
if e.progressCallback != nil {
|
||
e.progressCallback(current, total, description)
|
||
}
|
||
}
|
||
|
||
// reportDatabaseProgress safely calls the database progress callback if set
|
||
func (e *Engine) reportDatabaseProgress(done, total int, dbName string) {
|
||
if e.dbProgressCallback != nil {
|
||
e.dbProgressCallback(done, total, dbName)
|
||
}
|
||
}
|
||
|
||
// reportDatabaseProgressWithTiming safely calls the timing-aware callback if set
|
||
func (e *Engine) reportDatabaseProgressWithTiming(done, total int, dbName string, phaseElapsed, avgPerDB time.Duration) {
|
||
if e.dbProgressTimingCallback != nil {
|
||
e.dbProgressTimingCallback(done, total, dbName, phaseElapsed, avgPerDB)
|
||
}
|
||
}
|
||
|
||
// reportDatabaseProgressByBytes safely calls the bytes-weighted callback if set
|
||
func (e *Engine) reportDatabaseProgressByBytes(bytesDone, bytesTotal int64, dbName string, dbDone, dbTotal int) {
|
||
if e.dbProgressByBytesCallback != nil {
|
||
e.dbProgressByBytesCallback(bytesDone, bytesTotal, dbName, dbDone, dbTotal)
|
||
}
|
||
}
|
||
|
||
// loggerAdapter adapts our logger to the progress.Logger interface
|
||
type loggerAdapter struct {
|
||
logger logger.Logger
|
||
}
|
||
|
||
func (la *loggerAdapter) Info(msg string, args ...any) {
|
||
la.logger.Info(msg, args...)
|
||
}
|
||
|
||
func (la *loggerAdapter) Warn(msg string, args ...any) {
|
||
la.logger.Warn(msg, args...)
|
||
}
|
||
|
||
func (la *loggerAdapter) Error(msg string, args ...any) {
|
||
la.logger.Error(msg, args...)
|
||
}
|
||
|
||
func (la *loggerAdapter) Debug(msg string, args ...any) {
|
||
la.logger.Debug(msg, args...)
|
||
}
|
||
|
||
// RestoreSingle restores a single database from an archive
|
||
func (e *Engine) RestoreSingle(ctx context.Context, archivePath, targetDB string, cleanFirst, createIfMissing bool) error {
|
||
operation := e.log.StartOperation("Single Database Restore")
|
||
|
||
// Validate and sanitize archive path
|
||
validArchivePath, pathErr := security.ValidateArchivePath(archivePath)
|
||
if pathErr != nil {
|
||
operation.Fail(fmt.Sprintf("Invalid archive path: %v", pathErr))
|
||
return fmt.Errorf("invalid archive path: %w", pathErr)
|
||
}
|
||
archivePath = validArchivePath
|
||
|
||
// Validate archive exists
|
||
if _, err := os.Stat(archivePath); os.IsNotExist(err) {
|
||
operation.Fail("Archive not found")
|
||
return fmt.Errorf("archive not found: %s", archivePath)
|
||
}
|
||
|
||
// Verify checksum if .sha256 file exists
|
||
if checksumErr := security.LoadAndVerifyChecksum(archivePath); checksumErr != nil {
|
||
e.log.Warn("Checksum verification failed", "error", checksumErr)
|
||
e.log.Warn("Continuing restore without checksum verification (use with caution)")
|
||
} else {
|
||
e.log.Info("[OK] Archive checksum verified successfully")
|
||
}
|
||
|
||
// Detect archive format
|
||
format := DetectArchiveFormat(archivePath)
|
||
e.log.Info("Detected archive format", "format", format, "path", archivePath)
|
||
|
||
// Check version compatibility for PostgreSQL dumps
|
||
if format == FormatPostgreSQLDump || format == FormatPostgreSQLDumpGz {
|
||
if compatResult, err := e.CheckRestoreVersionCompatibility(ctx, archivePath); err == nil && compatResult != nil {
|
||
e.log.Info(compatResult.Message,
|
||
"source_version", compatResult.SourceVersion.Full,
|
||
"target_version", compatResult.TargetVersion.Full,
|
||
"compatibility", compatResult.Level.String())
|
||
|
||
// Block unsupported downgrades
|
||
if !compatResult.Compatible {
|
||
operation.Fail(compatResult.Message)
|
||
return fmt.Errorf("version compatibility error: %s", compatResult.Message)
|
||
}
|
||
|
||
// Show warnings for risky upgrades
|
||
if compatResult.Level == CompatibilityLevelRisky || compatResult.Level == CompatibilityLevelWarning {
|
||
for _, warning := range compatResult.Warnings {
|
||
e.log.Warn(warning)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
if e.dryRun {
|
||
e.log.Info("DRY RUN: Would restore single database", "archive", archivePath, "target", targetDB)
|
||
return e.previewRestore(archivePath, targetDB, format)
|
||
}
|
||
|
||
// Start progress tracking
|
||
e.progress.Start(fmt.Sprintf("Restoring database '%s' from %s", targetDB, filepath.Base(archivePath)))
|
||
|
||
// Create database if requested and it doesn't exist
|
||
if createIfMissing {
|
||
e.log.Info("Checking if target database exists", "database", targetDB)
|
||
if err := e.ensureDatabaseExists(ctx, targetDB); err != nil {
|
||
operation.Fail(fmt.Sprintf("Failed to create database: %v", err))
|
||
return fmt.Errorf("failed to create database '%s': %w", targetDB, err)
|
||
}
|
||
}
|
||
|
||
// Handle different archive formats
|
||
var err error
|
||
switch format {
|
||
case FormatPostgreSQLDump, FormatPostgreSQLDumpGz:
|
||
err = e.restorePostgreSQLDump(ctx, archivePath, targetDB, format == FormatPostgreSQLDumpGz, cleanFirst)
|
||
case FormatPostgreSQLSQL, FormatPostgreSQLSQLGz:
|
||
err = e.restorePostgreSQLSQL(ctx, archivePath, targetDB, format == FormatPostgreSQLSQLGz)
|
||
case FormatMySQLSQL, FormatMySQLSQLGz:
|
||
err = e.restoreMySQLSQL(ctx, archivePath, targetDB, format == FormatMySQLSQLGz)
|
||
default:
|
||
operation.Fail("Unsupported archive format")
|
||
return fmt.Errorf("unsupported archive format: %s", format)
|
||
}
|
||
|
||
if err != nil {
|
||
e.progress.Fail(fmt.Sprintf("Restore failed: %v", err))
|
||
operation.Fail(fmt.Sprintf("Restore failed: %v", err))
|
||
return err
|
||
}
|
||
|
||
e.progress.Complete(fmt.Sprintf("Database '%s' restored successfully", targetDB))
|
||
operation.Complete(fmt.Sprintf("Restored database '%s' from %s", targetDB, filepath.Base(archivePath)))
|
||
return nil
|
||
}
|
||
|
||
// restorePostgreSQLDump restores from PostgreSQL custom dump format
|
||
func (e *Engine) restorePostgreSQLDump(ctx context.Context, archivePath, targetDB string, compressed bool, cleanFirst bool) error {
|
||
// Build restore command
|
||
opts := database.RestoreOptions{
|
||
Parallel: 1,
|
||
Clean: cleanFirst,
|
||
NoOwner: true,
|
||
NoPrivileges: true,
|
||
SingleTransaction: false, // CRITICAL: Disabled to prevent lock exhaustion with large objects
|
||
Verbose: true, // Enable verbose for single database restores (not cluster)
|
||
}
|
||
|
||
cmd := e.db.BuildRestoreCommand(targetDB, archivePath, opts)
|
||
|
||
if compressed {
|
||
// For compressed dumps, decompress first
|
||
return e.executeRestoreWithDecompression(ctx, archivePath, cmd)
|
||
}
|
||
|
||
return e.executeRestoreCommand(ctx, cmd)
|
||
}
|
||
|
||
// restorePostgreSQLDumpWithOwnership restores from PostgreSQL custom dump with ownership control
|
||
func (e *Engine) restorePostgreSQLDumpWithOwnership(ctx context.Context, archivePath, targetDB string, compressed bool, preserveOwnership bool) error {
|
||
// Check if dump contains large objects (BLOBs) - if so, use phased restore
|
||
// to prevent lock table exhaustion (max_locks_per_transaction OOM)
|
||
hasLargeObjects := e.checkDumpHasLargeObjects(archivePath)
|
||
|
||
if hasLargeObjects {
|
||
e.log.Info("Large objects detected - using phased restore to prevent lock exhaustion",
|
||
"database", targetDB,
|
||
"archive", archivePath)
|
||
return e.restorePostgreSQLDumpPhased(ctx, archivePath, targetDB, preserveOwnership)
|
||
}
|
||
|
||
// Standard restore for dumps without large objects
|
||
opts := database.RestoreOptions{
|
||
Parallel: 1,
|
||
Clean: false, // We already dropped the database
|
||
NoOwner: !preserveOwnership, // Preserve ownership if we're superuser
|
||
NoPrivileges: !preserveOwnership, // Preserve privileges if we're superuser
|
||
SingleTransaction: false, // CRITICAL: Disabled to prevent lock exhaustion with large objects
|
||
Verbose: false, // CRITICAL: disable verbose to prevent OOM on large restores
|
||
}
|
||
|
||
e.log.Info("Restoring database",
|
||
"database", targetDB,
|
||
"preserveOwnership", preserveOwnership,
|
||
"noOwner", opts.NoOwner,
|
||
"noPrivileges", opts.NoPrivileges)
|
||
|
||
cmd := e.db.BuildRestoreCommand(targetDB, archivePath, opts)
|
||
|
||
if compressed {
|
||
// For compressed dumps, decompress first
|
||
return e.executeRestoreWithDecompression(ctx, archivePath, cmd)
|
||
}
|
||
|
||
return e.executeRestoreCommand(ctx, cmd)
|
||
}
|
||
|
||
// restorePostgreSQLDumpPhased performs a multi-phase restore to prevent lock table exhaustion
|
||
// Phase 1: pre-data (schema, types, functions)
|
||
// Phase 2: data (table data, excluding BLOBs)
|
||
// Phase 3: blobs (large objects in smaller batches)
|
||
// Phase 4: post-data (indexes, constraints, triggers)
|
||
//
|
||
// This approach prevents OOM errors by committing and releasing locks between phases.
|
||
func (e *Engine) restorePostgreSQLDumpPhased(ctx context.Context, archivePath, targetDB string, preserveOwnership bool) error {
|
||
e.log.Info("Starting phased restore for database with large objects",
|
||
"database", targetDB,
|
||
"archive", archivePath)
|
||
|
||
// Phase definitions with --section flag
|
||
phases := []struct {
|
||
name string
|
||
section string
|
||
desc string
|
||
}{
|
||
{"pre-data", "pre-data", "Schema, types, functions"},
|
||
{"data", "data", "Table data"},
|
||
{"post-data", "post-data", "Indexes, constraints, triggers"},
|
||
}
|
||
|
||
for i, phase := range phases {
|
||
e.log.Info(fmt.Sprintf("Phase %d/%d: Restoring %s", i+1, len(phases), phase.name),
|
||
"database", targetDB,
|
||
"section", phase.section,
|
||
"description", phase.desc)
|
||
|
||
if err := e.restoreSection(ctx, archivePath, targetDB, phase.section, preserveOwnership); err != nil {
|
||
// Check if it's an ignorable error
|
||
if e.isIgnorableError(err.Error()) {
|
||
e.log.Warn(fmt.Sprintf("Phase %d completed with ignorable errors", i+1),
|
||
"section", phase.section,
|
||
"error", err)
|
||
continue
|
||
}
|
||
return fmt.Errorf("phase %d (%s) failed: %w", i+1, phase.name, err)
|
||
}
|
||
|
||
e.log.Info(fmt.Sprintf("Phase %d/%d completed successfully", i+1, len(phases)),
|
||
"section", phase.section)
|
||
}
|
||
|
||
e.log.Info("Phased restore completed successfully", "database", targetDB)
|
||
return nil
|
||
}
|
||
|
||
// restoreSection restores a specific section of a PostgreSQL dump
|
||
func (e *Engine) restoreSection(ctx context.Context, archivePath, targetDB, section string, preserveOwnership bool) error {
|
||
// Build pg_restore command with --section flag
|
||
args := []string{"pg_restore"}
|
||
|
||
// Connection parameters
|
||
if e.cfg.Host != "localhost" {
|
||
args = append(args, "-h", e.cfg.Host)
|
||
args = append(args, "-p", fmt.Sprintf("%d", e.cfg.Port))
|
||
args = append(args, "--no-password")
|
||
}
|
||
args = append(args, "-U", e.cfg.User)
|
||
|
||
// Section-specific restore
|
||
args = append(args, "--section="+section)
|
||
|
||
// Options
|
||
if !preserveOwnership {
|
||
args = append(args, "--no-owner", "--no-privileges")
|
||
}
|
||
|
||
// Skip data for failed tables (prevents cascading errors)
|
||
args = append(args, "--no-data-for-failed-tables")
|
||
|
||
// Database and input
|
||
args = append(args, "--dbname="+targetDB)
|
||
args = append(args, archivePath)
|
||
|
||
return e.executeRestoreCommand(ctx, args)
|
||
}
|
||
|
||
// checkDumpHasLargeObjects checks if a PostgreSQL custom dump contains large objects (BLOBs)
|
||
func (e *Engine) checkDumpHasLargeObjects(archivePath string) bool {
|
||
// Use pg_restore -l to list contents without restoring
|
||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||
defer cancel()
|
||
|
||
cmd := exec.CommandContext(ctx, "pg_restore", "-l", archivePath)
|
||
output, err := cmd.Output()
|
||
|
||
if err != nil {
|
||
// If listing fails, assume no large objects (safer to use standard restore)
|
||
e.log.Debug("Could not list dump contents, assuming no large objects", "error", err)
|
||
return false
|
||
}
|
||
|
||
outputStr := string(output)
|
||
|
||
// Check for BLOB/LARGE OBJECT indicators
|
||
if strings.Contains(outputStr, "BLOB") ||
|
||
strings.Contains(outputStr, "LARGE OBJECT") ||
|
||
strings.Contains(outputStr, " BLOBS ") ||
|
||
strings.Contains(outputStr, "lo_create") {
|
||
return true
|
||
}
|
||
|
||
return false
|
||
}
|
||
|
||
// restorePostgreSQLSQL restores from PostgreSQL SQL script
|
||
func (e *Engine) restorePostgreSQLSQL(ctx context.Context, archivePath, targetDB string, compressed bool) error {
|
||
// Pre-validate SQL dump to detect truncation BEFORE attempting restore
|
||
// This saves time by catching corrupted files early (vs 49min failures)
|
||
if err := e.quickValidateSQLDump(archivePath, compressed); err != nil {
|
||
e.log.Error("Pre-restore validation failed - dump file appears corrupted",
|
||
"file", archivePath,
|
||
"error", err)
|
||
return fmt.Errorf("dump validation failed: %w - the backup file may be truncated or corrupted", err)
|
||
}
|
||
|
||
// Use psql for SQL scripts
|
||
var cmd []string
|
||
|
||
// For localhost, omit -h to use Unix socket (avoids Ident auth issues)
|
||
// But always include -p for port (in case of non-standard port)
|
||
hostArg := ""
|
||
portArg := fmt.Sprintf("-p %d", e.cfg.Port)
|
||
if e.cfg.Host != "localhost" && e.cfg.Host != "" {
|
||
hostArg = fmt.Sprintf("-h %s", e.cfg.Host)
|
||
}
|
||
|
||
if compressed {
|
||
// Use ON_ERROR_STOP=1 to fail fast on first error (prevents millions of errors on truncated dumps)
|
||
psqlCmd := fmt.Sprintf("psql %s -U %s -d %s -v ON_ERROR_STOP=1", portArg, e.cfg.User, targetDB)
|
||
if hostArg != "" {
|
||
psqlCmd = fmt.Sprintf("psql %s %s -U %s -d %s -v ON_ERROR_STOP=1", hostArg, portArg, e.cfg.User, targetDB)
|
||
}
|
||
// Set PGPASSWORD in the bash command for password-less auth
|
||
cmd = []string{
|
||
"bash", "-c",
|
||
fmt.Sprintf("PGPASSWORD='%s' gunzip -c %s | %s", e.cfg.Password, archivePath, psqlCmd),
|
||
}
|
||
} else {
|
||
if hostArg != "" {
|
||
cmd = []string{
|
||
"psql",
|
||
"-h", e.cfg.Host,
|
||
"-p", fmt.Sprintf("%d", e.cfg.Port),
|
||
"-U", e.cfg.User,
|
||
"-d", targetDB,
|
||
"-v", "ON_ERROR_STOP=1",
|
||
"-f", archivePath,
|
||
}
|
||
} else {
|
||
cmd = []string{
|
||
"psql",
|
||
"-p", fmt.Sprintf("%d", e.cfg.Port),
|
||
"-U", e.cfg.User,
|
||
"-d", targetDB,
|
||
"-v", "ON_ERROR_STOP=1",
|
||
"-f", archivePath,
|
||
}
|
||
}
|
||
}
|
||
|
||
return e.executeRestoreCommand(ctx, cmd)
|
||
}
|
||
|
||
// restoreMySQLSQL restores from MySQL SQL script
|
||
func (e *Engine) restoreMySQLSQL(ctx context.Context, archivePath, targetDB string, compressed bool) error {
|
||
options := database.RestoreOptions{}
|
||
|
||
cmd := e.db.BuildRestoreCommand(targetDB, archivePath, options)
|
||
|
||
if compressed {
|
||
// For compressed SQL, decompress on the fly
|
||
cmd = []string{
|
||
"bash", "-c",
|
||
fmt.Sprintf("gunzip -c %s | %s", archivePath, strings.Join(cmd, " ")),
|
||
}
|
||
}
|
||
|
||
return e.executeRestoreCommand(ctx, cmd)
|
||
}
|
||
|
||
// executeRestoreCommand executes a restore command
|
||
func (e *Engine) executeRestoreCommand(ctx context.Context, cmdArgs []string) error {
|
||
return e.executeRestoreCommandWithContext(ctx, cmdArgs, "", "", FormatUnknown)
|
||
}
|
||
|
||
// executeRestoreCommandWithContext executes a restore command with error collection context
|
||
func (e *Engine) executeRestoreCommandWithContext(ctx context.Context, cmdArgs []string, archivePath, targetDB string, format ArchiveFormat) error {
|
||
e.log.Info("Executing restore command", "command", strings.Join(cmdArgs, " "))
|
||
|
||
cmd := exec.CommandContext(ctx, cmdArgs[0], cmdArgs[1:]...)
|
||
|
||
// Set environment variables
|
||
cmd.Env = append(os.Environ(),
|
||
fmt.Sprintf("PGPASSWORD=%s", e.cfg.Password),
|
||
fmt.Sprintf("MYSQL_PWD=%s", e.cfg.Password),
|
||
)
|
||
|
||
// Create error collector if debug log path is set
|
||
var collector *ErrorCollector
|
||
if e.debugLogPath != "" {
|
||
collector = NewErrorCollector(e.cfg, e.log, archivePath, targetDB, format, true)
|
||
}
|
||
|
||
// Stream stderr to avoid memory issues with large output
|
||
// Don't use CombinedOutput() as it loads everything into memory
|
||
stderr, err := cmd.StderrPipe()
|
||
if err != nil {
|
||
return fmt.Errorf("failed to create stderr pipe: %w", err)
|
||
}
|
||
|
||
if err := cmd.Start(); err != nil {
|
||
return fmt.Errorf("failed to start restore command: %w", err)
|
||
}
|
||
|
||
// Read stderr in goroutine to avoid blocking
|
||
var lastError string
|
||
var errorCount int
|
||
stderrDone := make(chan struct{})
|
||
go func() {
|
||
defer close(stderrDone)
|
||
buf := make([]byte, 4096)
|
||
const maxErrors = 10 // Limit captured errors to prevent OOM
|
||
for {
|
||
n, err := stderr.Read(buf)
|
||
if n > 0 {
|
||
chunk := string(buf[:n])
|
||
|
||
// Feed to error collector if enabled
|
||
if collector != nil {
|
||
collector.CaptureStderr(chunk)
|
||
}
|
||
|
||
// Only capture REAL errors, not verbose output
|
||
if strings.Contains(chunk, "ERROR:") || strings.Contains(chunk, "FATAL:") || strings.Contains(chunk, "error:") {
|
||
lastError = strings.TrimSpace(chunk)
|
||
errorCount++
|
||
if errorCount <= maxErrors {
|
||
e.log.Warn("Restore stderr", "output", chunk)
|
||
}
|
||
}
|
||
// Note: --verbose output is discarded to prevent OOM
|
||
}
|
||
if err != nil {
|
||
break
|
||
}
|
||
}
|
||
}()
|
||
|
||
// Wait for command with proper context handling
|
||
cmdDone := make(chan error, 1)
|
||
go func() {
|
||
cmdDone <- cmd.Wait()
|
||
}()
|
||
|
||
var cmdErr error
|
||
select {
|
||
case cmdErr = <-cmdDone:
|
||
// Command completed (success or failure)
|
||
case <-ctx.Done():
|
||
// Context cancelled - kill process
|
||
e.log.Warn("Restore cancelled - killing process")
|
||
cmd.Process.Kill()
|
||
<-cmdDone
|
||
cmdErr = ctx.Err()
|
||
}
|
||
|
||
// Wait for stderr reader to finish
|
||
<-stderrDone
|
||
|
||
if cmdErr != nil {
|
||
// Get exit code
|
||
exitCode := 1
|
||
if exitErr, ok := cmdErr.(*exec.ExitError); ok {
|
||
exitCode = exitErr.ExitCode()
|
||
}
|
||
|
||
// PostgreSQL pg_restore returns exit code 1 even for ignorable errors
|
||
// Check if errors are ignorable (already exists, duplicate, etc.)
|
||
if lastError != "" && e.isIgnorableError(lastError) {
|
||
e.log.Warn("Restore completed with ignorable errors", "error_count", errorCount, "last_error", lastError)
|
||
return nil // Success despite ignorable errors
|
||
}
|
||
|
||
// Classify error and provide helpful hints
|
||
var classification *checks.ErrorClassification
|
||
var errType, errHint string
|
||
if lastError != "" {
|
||
classification = checks.ClassifyError(lastError)
|
||
errType = classification.Type
|
||
errHint = classification.Hint
|
||
e.log.Error("Restore command failed",
|
||
"error", err,
|
||
"last_stderr", lastError,
|
||
"error_count", errorCount,
|
||
"error_type", classification.Type,
|
||
"hint", classification.Hint,
|
||
"action", classification.Action)
|
||
} else {
|
||
e.log.Error("Restore command failed", "error", err, "error_count", errorCount)
|
||
}
|
||
|
||
// Generate and save error report if collector is enabled
|
||
if collector != nil {
|
||
collector.SetExitCode(exitCode)
|
||
report := collector.GenerateReport(
|
||
lastError,
|
||
errType,
|
||
errHint,
|
||
)
|
||
|
||
// Print report to console
|
||
collector.PrintReport(report)
|
||
|
||
// Save to file
|
||
if e.debugLogPath != "" {
|
||
if saveErr := collector.SaveReport(report, e.debugLogPath); saveErr != nil {
|
||
e.log.Warn("Failed to save debug log", "error", saveErr)
|
||
} else {
|
||
e.log.Info("Debug log saved", "path", e.debugLogPath)
|
||
fmt.Printf("\n[LOG] Detailed error report saved to: %s\n", e.debugLogPath)
|
||
}
|
||
}
|
||
}
|
||
|
||
if lastError != "" {
|
||
return fmt.Errorf("restore failed: %w (last error: %s, total errors: %d) - %s",
|
||
err, lastError, errorCount, errHint)
|
||
}
|
||
return fmt.Errorf("restore failed: %w", err)
|
||
}
|
||
|
||
e.log.Info("Restore command completed successfully")
|
||
return nil
|
||
}
|
||
|
||
// executeRestoreWithDecompression handles decompression during restore
|
||
func (e *Engine) executeRestoreWithDecompression(ctx context.Context, archivePath string, restoreCmd []string) error {
|
||
// Check if pigz is available for faster decompression
|
||
decompressCmd := "gunzip"
|
||
if _, err := exec.LookPath("pigz"); err == nil {
|
||
decompressCmd = "pigz"
|
||
e.log.Info("Using pigz for parallel decompression")
|
||
}
|
||
|
||
// Build pipeline: decompress | restore
|
||
pipeline := fmt.Sprintf("%s -dc %s | %s", decompressCmd, archivePath, strings.Join(restoreCmd, " "))
|
||
cmd := exec.CommandContext(ctx, "bash", "-c", pipeline)
|
||
|
||
cmd.Env = append(os.Environ(),
|
||
fmt.Sprintf("PGPASSWORD=%s", e.cfg.Password),
|
||
fmt.Sprintf("MYSQL_PWD=%s", e.cfg.Password),
|
||
)
|
||
|
||
// Stream stderr to avoid memory issues with large output
|
||
stderr, err := cmd.StderrPipe()
|
||
if err != nil {
|
||
return fmt.Errorf("failed to create stderr pipe: %w", err)
|
||
}
|
||
|
||
if err := cmd.Start(); err != nil {
|
||
return fmt.Errorf("failed to start restore command: %w", err)
|
||
}
|
||
|
||
// Read stderr in goroutine to avoid blocking
|
||
var lastError string
|
||
var errorCount int
|
||
stderrDone := make(chan struct{})
|
||
go func() {
|
||
defer close(stderrDone)
|
||
buf := make([]byte, 4096)
|
||
const maxErrors = 10 // Limit captured errors to prevent OOM
|
||
for {
|
||
n, err := stderr.Read(buf)
|
||
if n > 0 {
|
||
chunk := string(buf[:n])
|
||
// Only capture REAL errors, not verbose output
|
||
if strings.Contains(chunk, "ERROR:") || strings.Contains(chunk, "FATAL:") || strings.Contains(chunk, "error:") {
|
||
lastError = strings.TrimSpace(chunk)
|
||
errorCount++
|
||
if errorCount <= maxErrors {
|
||
e.log.Warn("Restore stderr", "output", chunk)
|
||
}
|
||
}
|
||
// Note: --verbose output is discarded to prevent OOM
|
||
}
|
||
if err != nil {
|
||
break
|
||
}
|
||
}
|
||
}()
|
||
|
||
// Wait for command with proper context handling
|
||
cmdDone := make(chan error, 1)
|
||
go func() {
|
||
cmdDone <- cmd.Wait()
|
||
}()
|
||
|
||
var cmdErr error
|
||
select {
|
||
case cmdErr = <-cmdDone:
|
||
// Command completed (success or failure)
|
||
case <-ctx.Done():
|
||
// Context cancelled - kill process
|
||
e.log.Warn("Restore with decompression cancelled - killing process")
|
||
cmd.Process.Kill()
|
||
<-cmdDone
|
||
cmdErr = ctx.Err()
|
||
}
|
||
|
||
// Wait for stderr reader to finish
|
||
<-stderrDone
|
||
|
||
if cmdErr != nil {
|
||
// PostgreSQL pg_restore returns exit code 1 even for ignorable errors
|
||
// Check if errors are ignorable (already exists, duplicate, etc.)
|
||
if lastError != "" && e.isIgnorableError(lastError) {
|
||
e.log.Warn("Restore with decompression completed with ignorable errors", "error_count", errorCount, "last_error", lastError)
|
||
return nil // Success despite ignorable errors
|
||
}
|
||
|
||
// Classify error and provide helpful hints
|
||
if lastError != "" {
|
||
classification := checks.ClassifyError(lastError)
|
||
e.log.Error("Restore with decompression failed",
|
||
"error", cmdErr,
|
||
"last_stderr", lastError,
|
||
"error_count", errorCount,
|
||
"error_type", classification.Type,
|
||
"hint", classification.Hint,
|
||
"action", classification.Action)
|
||
return fmt.Errorf("restore failed: %w (last error: %s, total errors: %d) - %s",
|
||
cmdErr, lastError, errorCount, classification.Hint)
|
||
}
|
||
|
||
e.log.Error("Restore with decompression failed", "error", cmdErr, "last_stderr", lastError, "error_count", errorCount)
|
||
return fmt.Errorf("restore failed: %w", cmdErr)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// previewRestore shows what would be done without executing
|
||
func (e *Engine) previewRestore(archivePath, targetDB string, format ArchiveFormat) error {
|
||
fmt.Println("\n" + strings.Repeat("=", 60))
|
||
fmt.Println(" RESTORE PREVIEW (DRY RUN)")
|
||
fmt.Println(strings.Repeat("=", 60))
|
||
|
||
stat, _ := os.Stat(archivePath)
|
||
fmt.Printf("\nArchive: %s\n", filepath.Base(archivePath))
|
||
fmt.Printf("Format: %s\n", format)
|
||
if stat != nil {
|
||
fmt.Printf("Size: %s\n", FormatBytes(stat.Size()))
|
||
fmt.Printf("Modified: %s\n", stat.ModTime().Format("2006-01-02 15:04:05"))
|
||
}
|
||
fmt.Printf("Target Database: %s\n", targetDB)
|
||
fmt.Printf("Target Host: %s:%d\n", e.cfg.Host, e.cfg.Port)
|
||
|
||
fmt.Println("\nOperations that would be performed:")
|
||
switch format {
|
||
case FormatPostgreSQLDump:
|
||
fmt.Printf(" 1. Execute: pg_restore -d %s %s\n", targetDB, archivePath)
|
||
case FormatPostgreSQLDumpGz:
|
||
fmt.Printf(" 1. Decompress: %s\n", archivePath)
|
||
fmt.Printf(" 2. Execute: pg_restore -d %s\n", targetDB)
|
||
case FormatPostgreSQLSQL, FormatPostgreSQLSQLGz:
|
||
fmt.Printf(" 1. Execute: psql -d %s -f %s\n", targetDB, archivePath)
|
||
case FormatMySQLSQL, FormatMySQLSQLGz:
|
||
fmt.Printf(" 1. Execute: mysql %s < %s\n", targetDB, archivePath)
|
||
}
|
||
|
||
fmt.Println("\n[WARN] WARNING: This will restore data to the target database.")
|
||
fmt.Println(" Existing data may be overwritten or merged.")
|
||
fmt.Println("\nTo execute this restore, add the --confirm flag.")
|
||
fmt.Println(strings.Repeat("=", 60) + "\n")
|
||
|
||
return nil
|
||
}
|
||
|
||
// RestoreCluster restores a full cluster from a tar.gz archive
|
||
func (e *Engine) RestoreCluster(ctx context.Context, archivePath string) error {
|
||
operation := e.log.StartOperation("Cluster Restore")
|
||
|
||
// Validate and sanitize archive path
|
||
validArchivePath, pathErr := security.ValidateArchivePath(archivePath)
|
||
if pathErr != nil {
|
||
operation.Fail(fmt.Sprintf("Invalid archive path: %v", pathErr))
|
||
return fmt.Errorf("invalid archive path: %w", pathErr)
|
||
}
|
||
archivePath = validArchivePath
|
||
|
||
// Validate archive exists
|
||
if _, err := os.Stat(archivePath); os.IsNotExist(err) {
|
||
operation.Fail("Archive not found")
|
||
return fmt.Errorf("archive not found: %s", archivePath)
|
||
}
|
||
|
||
// Verify checksum if .sha256 file exists
|
||
if checksumErr := security.LoadAndVerifyChecksum(archivePath); checksumErr != nil {
|
||
e.log.Warn("Checksum verification failed", "error", checksumErr)
|
||
e.log.Warn("Continuing restore without checksum verification (use with caution)")
|
||
} else {
|
||
e.log.Info("[OK] Cluster archive checksum verified successfully")
|
||
}
|
||
|
||
format := DetectArchiveFormat(archivePath)
|
||
if format != FormatClusterTarGz {
|
||
operation.Fail("Invalid cluster archive format")
|
||
return fmt.Errorf("not a cluster archive: %s (detected format: %s)", archivePath, format)
|
||
}
|
||
|
||
// Check disk space before starting restore
|
||
e.log.Info("Checking disk space for restore")
|
||
archiveInfo, err := os.Stat(archivePath)
|
||
if err == nil {
|
||
spaceCheck := checks.CheckDiskSpaceForRestore(e.cfg.BackupDir, archiveInfo.Size())
|
||
|
||
if spaceCheck.Critical {
|
||
operation.Fail("Insufficient disk space")
|
||
return fmt.Errorf("insufficient disk space for restore: %.1f%% used - need at least 4x archive size", spaceCheck.UsedPercent)
|
||
}
|
||
|
||
if spaceCheck.Warning {
|
||
e.log.Warn("Low disk space - restore may fail",
|
||
"available_gb", float64(spaceCheck.AvailableBytes)/(1024*1024*1024),
|
||
"used_percent", spaceCheck.UsedPercent)
|
||
}
|
||
}
|
||
|
||
if e.dryRun {
|
||
e.log.Info("DRY RUN: Would restore cluster", "archive", archivePath)
|
||
return e.previewClusterRestore(archivePath)
|
||
}
|
||
|
||
e.progress.Start(fmt.Sprintf("Restoring cluster from %s", filepath.Base(archivePath)))
|
||
|
||
// Create temporary extraction directory in configured WorkDir
|
||
workDir := e.cfg.GetEffectiveWorkDir()
|
||
tempDir := filepath.Join(workDir, fmt.Sprintf(".restore_%d", time.Now().Unix()))
|
||
|
||
// Check disk space for extraction (need ~3x archive size: compressed + extracted + working space)
|
||
if archiveInfo != nil {
|
||
requiredBytes := uint64(archiveInfo.Size()) * 3
|
||
extractionCheck := checks.CheckDiskSpace(workDir)
|
||
if extractionCheck.AvailableBytes < requiredBytes {
|
||
operation.Fail("Insufficient disk space for extraction")
|
||
return fmt.Errorf("insufficient disk space for extraction in %s: need %.1f GB, have %.1f GB (archive size: %.1f GB × 3)",
|
||
workDir,
|
||
float64(requiredBytes)/(1024*1024*1024),
|
||
float64(extractionCheck.AvailableBytes)/(1024*1024*1024),
|
||
float64(archiveInfo.Size())/(1024*1024*1024))
|
||
}
|
||
e.log.Info("Disk space check for extraction passed",
|
||
"workdir", workDir,
|
||
"required_gb", float64(requiredBytes)/(1024*1024*1024),
|
||
"available_gb", float64(extractionCheck.AvailableBytes)/(1024*1024*1024))
|
||
}
|
||
|
||
if err := os.MkdirAll(tempDir, 0755); err != nil {
|
||
operation.Fail("Failed to create temporary directory")
|
||
return fmt.Errorf("failed to create temp directory in %s: %w", workDir, err)
|
||
}
|
||
defer os.RemoveAll(tempDir)
|
||
|
||
// Extract archive
|
||
e.log.Info("Extracting cluster archive", "archive", archivePath, "tempDir", tempDir)
|
||
if err := e.extractArchive(ctx, archivePath, tempDir); err != nil {
|
||
operation.Fail("Archive extraction failed")
|
||
return fmt.Errorf("failed to extract archive: %w", err)
|
||
}
|
||
|
||
// Check context validity after extraction (debugging context cancellation issues)
|
||
if ctx.Err() != nil {
|
||
e.log.Error("Context cancelled after extraction - this should not happen",
|
||
"context_error", ctx.Err(),
|
||
"extraction_completed", true)
|
||
operation.Fail("Context cancelled unexpectedly")
|
||
return fmt.Errorf("context cancelled after extraction completed: %w", ctx.Err())
|
||
}
|
||
e.log.Info("Extraction completed, context still valid")
|
||
|
||
// Check if user has superuser privileges (required for ownership restoration)
|
||
e.progress.Update("Checking privileges...")
|
||
isSuperuser, err := e.checkSuperuser(ctx)
|
||
if err != nil {
|
||
e.log.Warn("Could not verify superuser status", "error", err)
|
||
isSuperuser = false // Assume not superuser if check fails
|
||
}
|
||
|
||
if !isSuperuser {
|
||
e.log.Warn("Current user is not a superuser - database ownership may not be fully restored")
|
||
e.progress.Update("[WARN] Warning: Non-superuser - ownership restoration limited")
|
||
time.Sleep(2 * time.Second) // Give user time to see warning
|
||
} else {
|
||
e.log.Info("Superuser privileges confirmed - full ownership restoration enabled")
|
||
}
|
||
|
||
// Restore global objects FIRST (roles, tablespaces) - CRITICAL for ownership
|
||
globalsFile := filepath.Join(tempDir, "globals.sql")
|
||
if _, err := os.Stat(globalsFile); err == nil {
|
||
e.log.Info("Restoring global objects (roles, tablespaces)")
|
||
e.progress.Update("Restoring global objects (roles, tablespaces)...")
|
||
if err := e.restoreGlobals(ctx, globalsFile); err != nil {
|
||
e.log.Error("Failed to restore global objects", "error", err)
|
||
if isSuperuser {
|
||
// If we're superuser and can't restore globals, this is a problem
|
||
e.progress.Fail("Failed to restore global objects")
|
||
operation.Fail("Global objects restoration failed")
|
||
return fmt.Errorf("failed to restore global objects: %w", err)
|
||
} else {
|
||
e.log.Warn("Continuing without global objects (may cause ownership issues)")
|
||
}
|
||
} else {
|
||
e.log.Info("Successfully restored global objects")
|
||
}
|
||
} else {
|
||
e.log.Warn("No globals.sql file found in backup - roles and tablespaces will not be restored")
|
||
}
|
||
|
||
// Restore individual databases
|
||
dumpsDir := filepath.Join(tempDir, "dumps")
|
||
if _, err := os.Stat(dumpsDir); err != nil {
|
||
operation.Fail("No database dumps found in archive")
|
||
return fmt.Errorf("no database dumps found in archive")
|
||
}
|
||
|
||
entries, err := os.ReadDir(dumpsDir)
|
||
if err != nil {
|
||
operation.Fail("Failed to read dumps directory")
|
||
return fmt.Errorf("failed to read dumps directory: %w", err)
|
||
}
|
||
|
||
// PRE-VALIDATE all SQL dumps BEFORE starting restore
|
||
// This catches truncated files early instead of failing after hours of work
|
||
e.log.Info("Pre-validating dump files before restore...")
|
||
e.progress.Update("Pre-validating dump files...")
|
||
var corruptedDumps []string
|
||
diagnoser := NewDiagnoser(e.log, false)
|
||
for _, entry := range entries {
|
||
if entry.IsDir() {
|
||
continue
|
||
}
|
||
dumpFile := filepath.Join(dumpsDir, entry.Name())
|
||
if strings.HasSuffix(dumpFile, ".sql.gz") {
|
||
result, err := diagnoser.DiagnoseFile(dumpFile)
|
||
if err != nil {
|
||
e.log.Warn("Could not validate dump file", "file", entry.Name(), "error", err)
|
||
continue
|
||
}
|
||
if result.IsTruncated || result.IsCorrupted || !result.IsValid {
|
||
dbName := strings.TrimSuffix(entry.Name(), ".sql.gz")
|
||
errDetail := "unknown issue"
|
||
if len(result.Errors) > 0 {
|
||
errDetail = result.Errors[0]
|
||
}
|
||
corruptedDumps = append(corruptedDumps, fmt.Sprintf("%s: %s", dbName, errDetail))
|
||
e.log.Error("CORRUPTED dump file detected",
|
||
"database", dbName,
|
||
"file", entry.Name(),
|
||
"truncated", result.IsTruncated,
|
||
"errors", result.Errors)
|
||
}
|
||
} else if strings.HasSuffix(dumpFile, ".dump") {
|
||
// Validate custom format dumps using pg_restore --list
|
||
cmd := exec.CommandContext(ctx, "pg_restore", "--list", dumpFile)
|
||
output, err := cmd.CombinedOutput()
|
||
if err != nil {
|
||
dbName := strings.TrimSuffix(entry.Name(), ".dump")
|
||
errDetail := strings.TrimSpace(string(output))
|
||
if len(errDetail) > 100 {
|
||
errDetail = errDetail[:100] + "..."
|
||
}
|
||
// Check for truncation indicators
|
||
if strings.Contains(errDetail, "unexpected end") || strings.Contains(errDetail, "invalid") {
|
||
corruptedDumps = append(corruptedDumps, fmt.Sprintf("%s: %s", dbName, errDetail))
|
||
e.log.Error("CORRUPTED custom dump file detected",
|
||
"database", dbName,
|
||
"file", entry.Name(),
|
||
"error", errDetail)
|
||
} else {
|
||
e.log.Warn("pg_restore --list warning (may be recoverable)",
|
||
"file", entry.Name(),
|
||
"error", errDetail)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
if len(corruptedDumps) > 0 {
|
||
operation.Fail("Corrupted dump files detected")
|
||
e.progress.Fail(fmt.Sprintf("Found %d corrupted dump files - restore aborted", len(corruptedDumps)))
|
||
return fmt.Errorf("pre-validation failed: %d corrupted dump files detected: %s - the backup archive appears to be damaged, restore from a different backup",
|
||
len(corruptedDumps), strings.Join(corruptedDumps, ", "))
|
||
}
|
||
e.log.Info("All dump files passed validation")
|
||
|
||
// Run comprehensive preflight checks (Linux system + PostgreSQL + Archive analysis)
|
||
preflight, preflightErr := e.RunPreflightChecks(ctx, dumpsDir, entries)
|
||
if preflightErr != nil {
|
||
e.log.Warn("Preflight checks failed", "error", preflightErr)
|
||
}
|
||
|
||
// Calculate optimal lock boost based on BLOB count
|
||
lockBoostValue := 2048 // Default
|
||
if preflight != nil && preflight.Archive.RecommendedLockBoost > 0 {
|
||
lockBoostValue = preflight.Archive.RecommendedLockBoost
|
||
}
|
||
|
||
// AUTO-TUNE: Boost PostgreSQL settings for large restores
|
||
e.progress.Update("Tuning PostgreSQL for large restore...")
|
||
originalSettings, tuneErr := e.boostPostgreSQLSettings(ctx, lockBoostValue)
|
||
if tuneErr != nil {
|
||
e.log.Warn("Could not boost PostgreSQL settings - restore may fail on BLOB-heavy databases",
|
||
"error", tuneErr)
|
||
} else {
|
||
e.log.Info("Boosted PostgreSQL settings for restore",
|
||
"max_locks_per_transaction", fmt.Sprintf("%d → %d", originalSettings.MaxLocks, lockBoostValue),
|
||
"maintenance_work_mem", fmt.Sprintf("%s → 2GB", originalSettings.MaintenanceWorkMem))
|
||
// Ensure we reset settings when done (even on failure)
|
||
defer func() {
|
||
if resetErr := e.resetPostgreSQLSettings(ctx, originalSettings); resetErr != nil {
|
||
e.log.Warn("Could not reset PostgreSQL settings", "error", resetErr)
|
||
} else {
|
||
e.log.Info("Reset PostgreSQL settings to original values")
|
||
}
|
||
}()
|
||
}
|
||
|
||
var restoreErrors *multierror.Error
|
||
var restoreErrorsMu sync.Mutex
|
||
totalDBs := 0
|
||
|
||
// Count total databases and calculate total bytes for weighted progress
|
||
var totalBytes int64
|
||
dbSizes := make(map[string]int64) // Map database name to dump file size
|
||
for _, entry := range entries {
|
||
if !entry.IsDir() {
|
||
totalDBs++
|
||
dumpFile := filepath.Join(dumpsDir, entry.Name())
|
||
if info, err := os.Stat(dumpFile); err == nil {
|
||
dbName := entry.Name()
|
||
dbName = strings.TrimSuffix(dbName, ".dump")
|
||
dbName = strings.TrimSuffix(dbName, ".sql.gz")
|
||
dbSizes[dbName] = info.Size()
|
||
totalBytes += info.Size()
|
||
}
|
||
}
|
||
}
|
||
e.log.Info("Calculated total restore size", "databases", totalDBs, "total_bytes", totalBytes)
|
||
|
||
// Track bytes completed for weighted progress
|
||
var bytesCompleted int64
|
||
var bytesCompletedMu sync.Mutex
|
||
|
||
// Create ETA estimator for database restores
|
||
estimator := progress.NewETAEstimator("Restoring cluster", totalDBs)
|
||
e.progress.SetEstimator(estimator)
|
||
|
||
// Check for large objects in dump files and adjust parallelism
|
||
hasLargeObjects := e.detectLargeObjectsInDumps(dumpsDir, entries)
|
||
|
||
// Use worker pool for parallel restore
|
||
parallelism := e.cfg.ClusterParallelism
|
||
if parallelism < 1 {
|
||
parallelism = 1 // Ensure at least sequential
|
||
}
|
||
|
||
// Automatically reduce parallelism if large objects detected
|
||
if hasLargeObjects && parallelism > 1 {
|
||
e.log.Warn("Large objects detected in dump files - reducing parallelism to avoid lock contention",
|
||
"original_parallelism", parallelism,
|
||
"adjusted_parallelism", 1)
|
||
e.progress.Update("[WARN] Large objects detected - using sequential restore to avoid lock conflicts")
|
||
time.Sleep(2 * time.Second) // Give user time to see warning
|
||
parallelism = 1
|
||
}
|
||
|
||
var successCount, failCount int32
|
||
var mu sync.Mutex // Protect shared resources (progress, logger)
|
||
|
||
// CRITICAL: Check context before starting database restore loop
|
||
// This helps debug issues where context gets cancelled between extraction and restore
|
||
if ctx.Err() != nil {
|
||
e.log.Error("Context cancelled before database restore loop started",
|
||
"context_error", ctx.Err(),
|
||
"total_databases", totalDBs,
|
||
"parallelism", parallelism)
|
||
operation.Fail("Context cancelled before database restores could start")
|
||
return fmt.Errorf("context cancelled before database restore: %w", ctx.Err())
|
||
}
|
||
e.log.Info("Starting database restore loop", "databases", totalDBs, "parallelism", parallelism)
|
||
|
||
// Timing tracking for restore phase progress
|
||
restorePhaseStart := time.Now()
|
||
var completedDBTimes []time.Duration // Track duration for each completed DB restore
|
||
var completedDBTimesMu sync.Mutex
|
||
|
||
// Create semaphore to limit concurrency
|
||
semaphore := make(chan struct{}, parallelism)
|
||
var wg sync.WaitGroup
|
||
|
||
dbIndex := 0
|
||
for _, entry := range entries {
|
||
if entry.IsDir() {
|
||
continue
|
||
}
|
||
|
||
wg.Add(1)
|
||
semaphore <- struct{}{} // Acquire
|
||
|
||
go func(idx int, filename string) {
|
||
defer wg.Done()
|
||
defer func() { <-semaphore }() // Release
|
||
|
||
// Panic recovery - prevent one database failure from crashing entire cluster restore
|
||
defer func() {
|
||
if r := recover(); r != nil {
|
||
e.log.Error("Panic in database restore goroutine", "file", filename, "panic", r)
|
||
atomic.AddInt32(&failCount, 1)
|
||
}
|
||
}()
|
||
|
||
// Check for context cancellation before starting
|
||
if ctx.Err() != nil {
|
||
e.log.Warn("Context cancelled - skipping database restore", "file", filename)
|
||
atomic.AddInt32(&failCount, 1)
|
||
restoreErrorsMu.Lock()
|
||
restoreErrors = multierror.Append(restoreErrors, fmt.Errorf("%s: restore skipped (context cancelled)", strings.TrimSuffix(strings.TrimSuffix(filename, ".dump"), ".sql.gz")))
|
||
restoreErrorsMu.Unlock()
|
||
return
|
||
}
|
||
|
||
// Track timing for this database restore
|
||
dbRestoreStart := time.Now()
|
||
|
||
// Update estimator progress (thread-safe)
|
||
mu.Lock()
|
||
estimator.UpdateProgress(idx)
|
||
mu.Unlock()
|
||
|
||
dumpFile := filepath.Join(dumpsDir, filename)
|
||
dbName := filename
|
||
dbName = strings.TrimSuffix(dbName, ".dump")
|
||
dbName = strings.TrimSuffix(dbName, ".sql.gz")
|
||
|
||
dbProgress := 15 + int(float64(idx)/float64(totalDBs)*85.0)
|
||
|
||
// Calculate average time per DB and report progress with timing
|
||
completedDBTimesMu.Lock()
|
||
var avgPerDB time.Duration
|
||
if len(completedDBTimes) > 0 {
|
||
var totalDuration time.Duration
|
||
for _, d := range completedDBTimes {
|
||
totalDuration += d
|
||
}
|
||
avgPerDB = totalDuration / time.Duration(len(completedDBTimes))
|
||
}
|
||
phaseElapsed := time.Since(restorePhaseStart)
|
||
completedDBTimesMu.Unlock()
|
||
|
||
mu.Lock()
|
||
statusMsg := fmt.Sprintf("Restoring database %s (%d/%d)", dbName, idx+1, totalDBs)
|
||
e.progress.Update(statusMsg)
|
||
e.log.Info("Restoring database", "name", dbName, "file", dumpFile, "progress", dbProgress)
|
||
// Report database progress for TUI (both callbacks)
|
||
e.reportDatabaseProgress(idx, totalDBs, dbName)
|
||
e.reportDatabaseProgressWithTiming(idx, totalDBs, dbName, phaseElapsed, avgPerDB)
|
||
mu.Unlock()
|
||
|
||
// STEP 1: Drop existing database completely (clean slate)
|
||
e.log.Info("Dropping existing database for clean restore", "name", dbName)
|
||
if err := e.dropDatabaseIfExists(ctx, dbName); err != nil {
|
||
e.log.Warn("Could not drop existing database", "name", dbName, "error", err)
|
||
}
|
||
|
||
// STEP 2: Create fresh database
|
||
if err := e.ensureDatabaseExists(ctx, dbName); err != nil {
|
||
e.log.Error("Failed to create database", "name", dbName, "error", err)
|
||
restoreErrorsMu.Lock()
|
||
restoreErrors = multierror.Append(restoreErrors, fmt.Errorf("%s: failed to create database: %w", dbName, err))
|
||
restoreErrorsMu.Unlock()
|
||
atomic.AddInt32(&failCount, 1)
|
||
return
|
||
}
|
||
|
||
// STEP 3: Restore with ownership preservation if superuser
|
||
preserveOwnership := isSuperuser
|
||
isCompressedSQL := strings.HasSuffix(dumpFile, ".sql.gz")
|
||
|
||
var restoreErr error
|
||
if isCompressedSQL {
|
||
mu.Lock()
|
||
e.log.Info("Detected compressed SQL format, using psql + gunzip", "file", dumpFile, "database", dbName)
|
||
mu.Unlock()
|
||
restoreErr = e.restorePostgreSQLSQL(ctx, dumpFile, dbName, true)
|
||
} else {
|
||
mu.Lock()
|
||
e.log.Info("Detected custom dump format, using pg_restore", "file", dumpFile, "database", dbName)
|
||
mu.Unlock()
|
||
restoreErr = e.restorePostgreSQLDumpWithOwnership(ctx, dumpFile, dbName, false, preserveOwnership)
|
||
}
|
||
|
||
if restoreErr != nil {
|
||
mu.Lock()
|
||
e.log.Error("Failed to restore database", "name", dbName, "file", dumpFile, "error", restoreErr)
|
||
mu.Unlock()
|
||
|
||
// Check for specific recoverable errors
|
||
errMsg := restoreErr.Error()
|
||
if strings.Contains(errMsg, "max_locks_per_transaction") {
|
||
mu.Lock()
|
||
e.log.Warn("Database restore failed due to insufficient locks - this is a PostgreSQL configuration issue",
|
||
"database", dbName,
|
||
"solution", "increase max_locks_per_transaction in postgresql.conf")
|
||
mu.Unlock()
|
||
} else if strings.Contains(errMsg, "total errors:") && strings.Contains(errMsg, "2562426") {
|
||
mu.Lock()
|
||
e.log.Warn("Database has massive error count - likely data corruption or incompatible dump format",
|
||
"database", dbName,
|
||
"errors", "2562426")
|
||
mu.Unlock()
|
||
}
|
||
|
||
restoreErrorsMu.Lock()
|
||
// Include more context in the error message
|
||
restoreErrors = multierror.Append(restoreErrors, fmt.Errorf("%s: restore failed: %w", dbName, restoreErr))
|
||
restoreErrorsMu.Unlock()
|
||
atomic.AddInt32(&failCount, 1)
|
||
return
|
||
}
|
||
|
||
// Track completed database restore duration for ETA calculation
|
||
dbRestoreDuration := time.Since(dbRestoreStart)
|
||
completedDBTimesMu.Lock()
|
||
completedDBTimes = append(completedDBTimes, dbRestoreDuration)
|
||
completedDBTimesMu.Unlock()
|
||
|
||
// Update bytes completed for weighted progress
|
||
dbSize := dbSizes[dbName]
|
||
bytesCompletedMu.Lock()
|
||
bytesCompleted += dbSize
|
||
currentBytesCompleted := bytesCompleted
|
||
currentSuccessCount := int(atomic.LoadInt32(&successCount)) + 1 // +1 because we're about to increment
|
||
bytesCompletedMu.Unlock()
|
||
|
||
// Report weighted progress (bytes-based)
|
||
e.reportDatabaseProgressByBytes(currentBytesCompleted, totalBytes, dbName, currentSuccessCount, totalDBs)
|
||
|
||
atomic.AddInt32(&successCount, 1)
|
||
|
||
// Small delay to ensure PostgreSQL fully closes connections before next restore
|
||
time.Sleep(100 * time.Millisecond)
|
||
}(dbIndex, entry.Name())
|
||
|
||
dbIndex++
|
||
}
|
||
|
||
// Wait for all restores to complete
|
||
wg.Wait()
|
||
|
||
successCountFinal := int(atomic.LoadInt32(&successCount))
|
||
failCountFinal := int(atomic.LoadInt32(&failCount))
|
||
|
||
// SANITY CHECK: Verify all databases were accounted for
|
||
// This catches any goroutine that exited without updating counters
|
||
accountedFor := successCountFinal + failCountFinal
|
||
if accountedFor != totalDBs {
|
||
missingCount := totalDBs - accountedFor
|
||
e.log.Error("INTERNAL ERROR: Some database restore goroutines did not report status",
|
||
"expected", totalDBs,
|
||
"success", successCountFinal,
|
||
"failed", failCountFinal,
|
||
"unaccounted", missingCount)
|
||
|
||
// Treat unaccounted databases as failures
|
||
failCountFinal += missingCount
|
||
restoreErrorsMu.Lock()
|
||
restoreErrors = multierror.Append(restoreErrors, fmt.Errorf("%d database(s) did not complete (possible goroutine crash or deadlock)", missingCount))
|
||
restoreErrorsMu.Unlock()
|
||
}
|
||
|
||
// CRITICAL: Check if no databases were restored at all
|
||
if successCountFinal == 0 {
|
||
e.progress.Fail(fmt.Sprintf("Cluster restore FAILED: 0 of %d databases restored", totalDBs))
|
||
operation.Fail("No databases were restored")
|
||
|
||
if failCountFinal > 0 && restoreErrors != nil {
|
||
return fmt.Errorf("cluster restore failed: all %d database(s) failed:\n%s", failCountFinal, restoreErrors.Error())
|
||
}
|
||
return fmt.Errorf("cluster restore failed: no databases were restored (0 of %d total). Check PostgreSQL logs for details", totalDBs)
|
||
}
|
||
|
||
if failCountFinal > 0 {
|
||
// Format multi-error with detailed output
|
||
restoreErrors.ErrorFormat = func(errs []error) string {
|
||
if len(errs) == 1 {
|
||
return errs[0].Error()
|
||
}
|
||
points := make([]string, len(errs))
|
||
for i, err := range errs {
|
||
points[i] = fmt.Sprintf(" • %s", err.Error())
|
||
}
|
||
return fmt.Sprintf("%d database(s) failed:\n%s", len(errs), strings.Join(points, "\n"))
|
||
}
|
||
|
||
// Log summary
|
||
e.log.Info("Cluster restore completed with failures",
|
||
"succeeded", successCountFinal,
|
||
"failed", failCountFinal,
|
||
"total", totalDBs)
|
||
|
||
e.progress.Fail(fmt.Sprintf("Cluster restore: %d succeeded, %d failed out of %d total", successCountFinal, failCountFinal, totalDBs))
|
||
operation.Complete(fmt.Sprintf("Partial restore: %d/%d databases succeeded", successCountFinal, totalDBs))
|
||
|
||
return fmt.Errorf("cluster restore completed with %d failures:\n%s", failCountFinal, restoreErrors.Error())
|
||
}
|
||
|
||
e.progress.Complete(fmt.Sprintf("Cluster restored successfully: %d databases", successCountFinal))
|
||
operation.Complete(fmt.Sprintf("Restored %d databases from cluster archive", successCountFinal))
|
||
return nil
|
||
}
|
||
|
||
// extractArchive extracts a tar.gz archive with progress reporting
|
||
func (e *Engine) extractArchive(ctx context.Context, archivePath, destDir string) error {
|
||
// If progress callback is set, use Go's archive/tar for progress tracking
|
||
if e.progressCallback != nil {
|
||
return e.extractArchiveWithProgress(ctx, archivePath, destDir)
|
||
}
|
||
|
||
// Otherwise use fast shell tar (no progress)
|
||
return e.extractArchiveShell(ctx, archivePath, destDir)
|
||
}
|
||
|
||
// extractArchiveWithProgress extracts using Go's archive/tar with detailed progress reporting
|
||
func (e *Engine) extractArchiveWithProgress(ctx context.Context, archivePath, destDir string) error {
|
||
// Get archive size for progress calculation
|
||
archiveInfo, err := os.Stat(archivePath)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to stat archive: %w", err)
|
||
}
|
||
totalSize := archiveInfo.Size()
|
||
|
||
// Open the archive file
|
||
file, err := os.Open(archivePath)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to open archive: %w", err)
|
||
}
|
||
defer file.Close()
|
||
|
||
// Wrap with progress reader
|
||
progressReader := &progressReader{
|
||
reader: file,
|
||
totalSize: totalSize,
|
||
callback: e.progressCallback,
|
||
desc: "Extracting archive",
|
||
}
|
||
|
||
// Create gzip reader
|
||
gzReader, err := gzip.NewReader(progressReader)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to create gzip reader: %w", err)
|
||
}
|
||
defer gzReader.Close()
|
||
|
||
// Create tar reader
|
||
tarReader := tar.NewReader(gzReader)
|
||
|
||
// Extract files
|
||
for {
|
||
select {
|
||
case <-ctx.Done():
|
||
return ctx.Err()
|
||
default:
|
||
}
|
||
|
||
header, err := tarReader.Next()
|
||
if err == io.EOF {
|
||
break // End of archive
|
||
}
|
||
if err != nil {
|
||
return fmt.Errorf("failed to read tar header: %w", err)
|
||
}
|
||
|
||
// Sanitize and validate path
|
||
targetPath := filepath.Join(destDir, header.Name)
|
||
|
||
// Security check: ensure path is within destDir (prevent path traversal)
|
||
if !strings.HasPrefix(filepath.Clean(targetPath), filepath.Clean(destDir)) {
|
||
e.log.Warn("Skipping potentially malicious path in archive", "path", header.Name)
|
||
continue
|
||
}
|
||
|
||
switch header.Typeflag {
|
||
case tar.TypeDir:
|
||
if err := os.MkdirAll(targetPath, 0755); err != nil {
|
||
return fmt.Errorf("failed to create directory %s: %w", targetPath, err)
|
||
}
|
||
case tar.TypeReg:
|
||
// Ensure parent directory exists
|
||
if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil {
|
||
return fmt.Errorf("failed to create parent directory: %w", err)
|
||
}
|
||
|
||
// Create the file
|
||
outFile, err := os.OpenFile(targetPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(header.Mode))
|
||
if err != nil {
|
||
return fmt.Errorf("failed to create file %s: %w", targetPath, err)
|
||
}
|
||
|
||
// Copy file contents
|
||
if _, err := io.Copy(outFile, tarReader); err != nil {
|
||
outFile.Close()
|
||
return fmt.Errorf("failed to write file %s: %w", targetPath, err)
|
||
}
|
||
outFile.Close()
|
||
case tar.TypeSymlink:
|
||
// Handle symlinks (common in some archives)
|
||
if err := os.Symlink(header.Linkname, targetPath); err != nil {
|
||
// Ignore symlink errors (may already exist or not supported)
|
||
e.log.Debug("Could not create symlink", "path", targetPath, "target", header.Linkname)
|
||
}
|
||
}
|
||
}
|
||
|
||
// Final progress update
|
||
e.reportProgress(totalSize, totalSize, "Extraction complete")
|
||
return nil
|
||
}
|
||
|
||
// progressReader wraps an io.Reader to report read progress
|
||
type progressReader struct {
|
||
reader io.Reader
|
||
totalSize int64
|
||
bytesRead int64
|
||
callback ProgressCallback
|
||
desc string
|
||
lastReport time.Time
|
||
reportEvery time.Duration
|
||
}
|
||
|
||
func (pr *progressReader) Read(p []byte) (n int, err error) {
|
||
n, err = pr.reader.Read(p)
|
||
pr.bytesRead += int64(n)
|
||
|
||
// Throttle progress reporting to every 100ms
|
||
if pr.reportEvery == 0 {
|
||
pr.reportEvery = 100 * time.Millisecond
|
||
}
|
||
if time.Since(pr.lastReport) > pr.reportEvery {
|
||
if pr.callback != nil {
|
||
pr.callback(pr.bytesRead, pr.totalSize, pr.desc)
|
||
}
|
||
pr.lastReport = time.Now()
|
||
}
|
||
|
||
return n, err
|
||
}
|
||
|
||
// extractArchiveShell extracts using shell tar command (faster but no progress)
|
||
func (e *Engine) extractArchiveShell(ctx context.Context, archivePath, destDir string) error {
|
||
cmd := exec.CommandContext(ctx, "tar", "-xzf", archivePath, "-C", destDir)
|
||
|
||
// Stream stderr to avoid memory issues - tar can produce lots of output for large archives
|
||
stderr, err := cmd.StderrPipe()
|
||
if err != nil {
|
||
return fmt.Errorf("failed to create stderr pipe: %w", err)
|
||
}
|
||
|
||
if err := cmd.Start(); err != nil {
|
||
return fmt.Errorf("failed to start tar: %w", err)
|
||
}
|
||
|
||
// Discard stderr output in chunks to prevent memory buildup
|
||
stderrDone := make(chan struct{})
|
||
go func() {
|
||
defer close(stderrDone)
|
||
buf := make([]byte, 4096)
|
||
for {
|
||
_, err := stderr.Read(buf)
|
||
if err != nil {
|
||
break
|
||
}
|
||
}
|
||
}()
|
||
|
||
// Wait for command with proper context handling
|
||
cmdDone := make(chan error, 1)
|
||
go func() {
|
||
cmdDone <- cmd.Wait()
|
||
}()
|
||
|
||
var cmdErr error
|
||
select {
|
||
case cmdErr = <-cmdDone:
|
||
// Command completed
|
||
case <-ctx.Done():
|
||
e.log.Warn("Archive extraction cancelled - killing process")
|
||
cmd.Process.Kill()
|
||
<-cmdDone
|
||
cmdErr = ctx.Err()
|
||
}
|
||
|
||
<-stderrDone
|
||
|
||
if cmdErr != nil {
|
||
return fmt.Errorf("tar extraction failed: %w", cmdErr)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// restoreGlobals restores global objects (roles, tablespaces)
|
||
// Note: psql returns 0 even when some statements fail (e.g., role already exists)
|
||
// We track errors but only fail on FATAL errors that would prevent restore
|
||
func (e *Engine) restoreGlobals(ctx context.Context, globalsFile string) error {
|
||
args := []string{
|
||
"-p", fmt.Sprintf("%d", e.cfg.Port),
|
||
"-U", e.cfg.User,
|
||
"-d", "postgres",
|
||
"-f", globalsFile,
|
||
}
|
||
|
||
// Only add -h flag if host is not localhost (to use Unix socket for peer auth)
|
||
if e.cfg.Host != "localhost" && e.cfg.Host != "127.0.0.1" && e.cfg.Host != "" {
|
||
args = append([]string{"-h", e.cfg.Host}, args...)
|
||
}
|
||
|
||
cmd := exec.CommandContext(ctx, "psql", args...)
|
||
|
||
cmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", e.cfg.Password))
|
||
|
||
// Stream output to avoid memory issues with large globals.sql files
|
||
stderr, err := cmd.StderrPipe()
|
||
if err != nil {
|
||
return fmt.Errorf("failed to create stderr pipe: %w", err)
|
||
}
|
||
|
||
if err := cmd.Start(); err != nil {
|
||
return fmt.Errorf("failed to start psql: %w", err)
|
||
}
|
||
|
||
// Read stderr in chunks in goroutine
|
||
var lastError string
|
||
var errorCount int
|
||
var fatalError bool
|
||
stderrDone := make(chan struct{})
|
||
go func() {
|
||
defer close(stderrDone)
|
||
buf := make([]byte, 4096)
|
||
for {
|
||
n, err := stderr.Read(buf)
|
||
if n > 0 {
|
||
chunk := string(buf[:n])
|
||
// Track different error types
|
||
if strings.Contains(chunk, "FATAL") {
|
||
fatalError = true
|
||
lastError = chunk
|
||
e.log.Error("Globals restore FATAL error", "output", chunk)
|
||
} else if strings.Contains(chunk, "ERROR") {
|
||
errorCount++
|
||
lastError = chunk
|
||
// Only log first few errors to avoid spam
|
||
if errorCount <= 5 {
|
||
// Check if it's an ignorable "already exists" error
|
||
if strings.Contains(chunk, "already exists") {
|
||
e.log.Debug("Globals restore: object already exists (expected)", "output", chunk)
|
||
} else {
|
||
e.log.Warn("Globals restore error", "output", chunk)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
if err != nil {
|
||
break
|
||
}
|
||
}
|
||
}()
|
||
|
||
// Wait for command with proper context handling
|
||
cmdDone := make(chan error, 1)
|
||
go func() {
|
||
cmdDone <- cmd.Wait()
|
||
}()
|
||
|
||
var cmdErr error
|
||
select {
|
||
case cmdErr = <-cmdDone:
|
||
// Command completed
|
||
case <-ctx.Done():
|
||
e.log.Warn("Globals restore cancelled - killing process")
|
||
cmd.Process.Kill()
|
||
<-cmdDone
|
||
cmdErr = ctx.Err()
|
||
}
|
||
|
||
<-stderrDone
|
||
|
||
// Only fail on actual command errors or FATAL PostgreSQL errors
|
||
// Regular ERROR messages (like "role already exists") are expected
|
||
if cmdErr != nil {
|
||
return fmt.Errorf("failed to restore globals: %w (last error: %s)", cmdErr, lastError)
|
||
}
|
||
|
||
// If we had FATAL errors, those are real problems
|
||
if fatalError {
|
||
return fmt.Errorf("globals restore had FATAL error: %s", lastError)
|
||
}
|
||
|
||
// Log summary if there were errors (but don't fail)
|
||
if errorCount > 0 {
|
||
e.log.Info("Globals restore completed with some errors (usually 'already exists' - expected)",
|
||
"error_count", errorCount)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// checkSuperuser verifies if the current user has superuser privileges
|
||
func (e *Engine) checkSuperuser(ctx context.Context) (bool, error) {
|
||
args := []string{
|
||
"-p", fmt.Sprintf("%d", e.cfg.Port),
|
||
"-U", e.cfg.User,
|
||
"-d", "postgres",
|
||
"-tAc", "SELECT usesuper FROM pg_user WHERE usename = current_user",
|
||
}
|
||
|
||
// Only add -h flag if host is not localhost (to use Unix socket for peer auth)
|
||
if e.cfg.Host != "localhost" && e.cfg.Host != "127.0.0.1" && e.cfg.Host != "" {
|
||
args = append([]string{"-h", e.cfg.Host}, args...)
|
||
}
|
||
|
||
cmd := exec.CommandContext(ctx, "psql", args...)
|
||
|
||
// Always set PGPASSWORD (empty string is fine for peer/ident auth)
|
||
cmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", e.cfg.Password))
|
||
|
||
output, err := cmd.CombinedOutput()
|
||
if err != nil {
|
||
return false, fmt.Errorf("failed to check superuser status: %w", err)
|
||
}
|
||
|
||
isSuperuser := strings.TrimSpace(string(output)) == "t"
|
||
return isSuperuser, nil
|
||
}
|
||
|
||
// terminateConnections kills all active connections to a database
|
||
func (e *Engine) terminateConnections(ctx context.Context, dbName string) error {
|
||
query := fmt.Sprintf(`
|
||
SELECT pg_terminate_backend(pid)
|
||
FROM pg_stat_activity
|
||
WHERE datname = '%s'
|
||
AND pid <> pg_backend_pid()
|
||
`, dbName)
|
||
|
||
args := []string{
|
||
"-p", fmt.Sprintf("%d", e.cfg.Port),
|
||
"-U", e.cfg.User,
|
||
"-d", "postgres",
|
||
"-tAc", query,
|
||
}
|
||
|
||
// Only add -h flag if host is not localhost (to use Unix socket for peer auth)
|
||
if e.cfg.Host != "localhost" && e.cfg.Host != "127.0.0.1" && e.cfg.Host != "" {
|
||
args = append([]string{"-h", e.cfg.Host}, args...)
|
||
}
|
||
|
||
cmd := exec.CommandContext(ctx, "psql", args...)
|
||
|
||
// Always set PGPASSWORD (empty string is fine for peer/ident auth)
|
||
cmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", e.cfg.Password))
|
||
|
||
output, err := cmd.CombinedOutput()
|
||
if err != nil {
|
||
e.log.Warn("Failed to terminate connections", "database", dbName, "error", err, "output", string(output))
|
||
// Don't fail - database might not exist or have no connections
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// dropDatabaseIfExists drops a database completely (clean slate)
|
||
// Uses PostgreSQL 13+ WITH (FORCE) option to forcefully drop even with active connections
|
||
func (e *Engine) dropDatabaseIfExists(ctx context.Context, dbName string) error {
|
||
// First terminate all connections
|
||
if err := e.terminateConnections(ctx, dbName); err != nil {
|
||
e.log.Warn("Could not terminate connections", "database", dbName, "error", err)
|
||
}
|
||
|
||
// Wait a moment for connections to terminate
|
||
time.Sleep(500 * time.Millisecond)
|
||
|
||
// Try to revoke new connections (prevents race condition)
|
||
// This only works if we have the privilege to do so
|
||
revokeArgs := []string{
|
||
"-p", fmt.Sprintf("%d", e.cfg.Port),
|
||
"-U", e.cfg.User,
|
||
"-d", "postgres",
|
||
"-c", fmt.Sprintf("REVOKE CONNECT ON DATABASE \"%s\" FROM PUBLIC", dbName),
|
||
}
|
||
if e.cfg.Host != "localhost" && e.cfg.Host != "127.0.0.1" && e.cfg.Host != "" {
|
||
revokeArgs = append([]string{"-h", e.cfg.Host}, revokeArgs...)
|
||
}
|
||
revokeCmd := exec.CommandContext(ctx, "psql", revokeArgs...)
|
||
revokeCmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", e.cfg.Password))
|
||
revokeCmd.Run() // Ignore errors - database might not exist
|
||
|
||
// Terminate connections again after revoking connect privilege
|
||
e.terminateConnections(ctx, dbName)
|
||
time.Sleep(200 * time.Millisecond)
|
||
|
||
// Try DROP DATABASE WITH (FORCE) first (PostgreSQL 13+)
|
||
// This forcefully terminates connections and drops the database atomically
|
||
forceArgs := []string{
|
||
"-p", fmt.Sprintf("%d", e.cfg.Port),
|
||
"-U", e.cfg.User,
|
||
"-d", "postgres",
|
||
"-c", fmt.Sprintf("DROP DATABASE IF EXISTS \"%s\" WITH (FORCE)", dbName),
|
||
}
|
||
if e.cfg.Host != "localhost" && e.cfg.Host != "127.0.0.1" && e.cfg.Host != "" {
|
||
forceArgs = append([]string{"-h", e.cfg.Host}, forceArgs...)
|
||
}
|
||
forceCmd := exec.CommandContext(ctx, "psql", forceArgs...)
|
||
forceCmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", e.cfg.Password))
|
||
|
||
output, err := forceCmd.CombinedOutput()
|
||
if err == nil {
|
||
e.log.Info("Dropped existing database (with FORCE)", "name", dbName)
|
||
return nil
|
||
}
|
||
|
||
// If FORCE option failed (PostgreSQL < 13), try regular drop
|
||
if strings.Contains(string(output), "syntax error") || strings.Contains(string(output), "WITH (FORCE)") {
|
||
e.log.Debug("WITH (FORCE) not supported, using standard DROP", "name", dbName)
|
||
|
||
args := []string{
|
||
"-p", fmt.Sprintf("%d", e.cfg.Port),
|
||
"-U", e.cfg.User,
|
||
"-d", "postgres",
|
||
"-c", fmt.Sprintf("DROP DATABASE IF EXISTS \"%s\"", dbName),
|
||
}
|
||
if e.cfg.Host != "localhost" && e.cfg.Host != "127.0.0.1" && e.cfg.Host != "" {
|
||
args = append([]string{"-h", e.cfg.Host}, args...)
|
||
}
|
||
|
||
cmd := exec.CommandContext(ctx, "psql", args...)
|
||
cmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", e.cfg.Password))
|
||
|
||
output, err = cmd.CombinedOutput()
|
||
if err != nil {
|
||
return fmt.Errorf("failed to drop database '%s': %w\nOutput: %s", dbName, err, string(output))
|
||
}
|
||
} else if err != nil {
|
||
return fmt.Errorf("failed to drop database '%s': %w\nOutput: %s", dbName, err, string(output))
|
||
}
|
||
|
||
e.log.Info("Dropped existing database", "name", dbName)
|
||
return nil
|
||
}
|
||
|
||
// ensureDatabaseExists checks if a database exists and creates it if not
|
||
func (e *Engine) ensureDatabaseExists(ctx context.Context, dbName string) error {
|
||
// Route to appropriate implementation based on database type
|
||
if e.cfg.DatabaseType == "mysql" || e.cfg.DatabaseType == "mariadb" {
|
||
return e.ensureMySQLDatabaseExists(ctx, dbName)
|
||
}
|
||
return e.ensurePostgresDatabaseExists(ctx, dbName)
|
||
}
|
||
|
||
// ensureMySQLDatabaseExists checks if a MySQL database exists and creates it if not
|
||
func (e *Engine) ensureMySQLDatabaseExists(ctx context.Context, dbName string) error {
|
||
// Build mysql command
|
||
args := []string{
|
||
"-h", e.cfg.Host,
|
||
"-P", fmt.Sprintf("%d", e.cfg.Port),
|
||
"-u", e.cfg.User,
|
||
"-e", fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `%s`", dbName),
|
||
}
|
||
|
||
if e.cfg.Password != "" {
|
||
args = append(args, fmt.Sprintf("-p%s", e.cfg.Password))
|
||
}
|
||
|
||
cmd := exec.CommandContext(ctx, "mysql", args...)
|
||
output, err := cmd.CombinedOutput()
|
||
if err != nil {
|
||
e.log.Warn("MySQL database creation failed", "name", dbName, "error", err, "output", string(output))
|
||
return fmt.Errorf("failed to create database '%s': %w (output: %s)", dbName, err, strings.TrimSpace(string(output)))
|
||
}
|
||
|
||
e.log.Info("Successfully ensured MySQL database exists", "name", dbName)
|
||
return nil
|
||
}
|
||
|
||
// ensurePostgresDatabaseExists checks if a PostgreSQL database exists and creates it if not
|
||
// It attempts to extract encoding/locale from the dump file to preserve original settings
|
||
func (e *Engine) ensurePostgresDatabaseExists(ctx context.Context, dbName string) error {
|
||
// Skip creation for postgres and template databases - they should already exist
|
||
if dbName == "postgres" || dbName == "template0" || dbName == "template1" {
|
||
e.log.Info("Skipping create for system database (assume exists)", "name", dbName)
|
||
return nil
|
||
}
|
||
|
||
// Build psql command with authentication
|
||
buildPsqlCmd := func(ctx context.Context, database, query string) *exec.Cmd {
|
||
args := []string{
|
||
"-p", fmt.Sprintf("%d", e.cfg.Port),
|
||
"-U", e.cfg.User,
|
||
"-d", database,
|
||
"-tAc", query,
|
||
}
|
||
|
||
// Only add -h flag if host is not localhost (to use Unix socket for peer auth)
|
||
if e.cfg.Host != "localhost" && e.cfg.Host != "127.0.0.1" && e.cfg.Host != "" {
|
||
args = append([]string{"-h", e.cfg.Host}, args...)
|
||
}
|
||
|
||
cmd := exec.CommandContext(ctx, "psql", args...)
|
||
|
||
// Always set PGPASSWORD (empty string is fine for peer/ident auth)
|
||
cmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", e.cfg.Password))
|
||
|
||
return cmd
|
||
}
|
||
|
||
// Check if database exists
|
||
checkCmd := buildPsqlCmd(ctx, "postgres", fmt.Sprintf("SELECT 1 FROM pg_database WHERE datname = '%s'", dbName))
|
||
|
||
output, err := checkCmd.CombinedOutput()
|
||
if err != nil {
|
||
e.log.Warn("Database existence check failed", "name", dbName, "error", err, "output", string(output))
|
||
// Continue anyway - maybe we can create it
|
||
}
|
||
|
||
// If database exists, we're done
|
||
if strings.TrimSpace(string(output)) == "1" {
|
||
e.log.Info("Database already exists", "name", dbName)
|
||
return nil
|
||
}
|
||
|
||
// Database doesn't exist, create it
|
||
// IMPORTANT: Use template0 to avoid duplicate definition errors from local additions to template1
|
||
// Also use UTF8 encoding explicitly as it's the most common and safest choice
|
||
// See PostgreSQL docs: https://www.postgresql.org/docs/current/app-pgrestore.html#APP-PGRESTORE-NOTES
|
||
e.log.Info("Creating database from template0 with UTF8 encoding", "name", dbName)
|
||
|
||
// Get server's default locale for LC_COLLATE and LC_CTYPE
|
||
// This ensures compatibility while using the correct encoding
|
||
localeCmd := buildPsqlCmd(ctx, "postgres", "SHOW lc_collate")
|
||
localeOutput, _ := localeCmd.CombinedOutput()
|
||
serverLocale := strings.TrimSpace(string(localeOutput))
|
||
if serverLocale == "" {
|
||
serverLocale = "en_US.UTF-8" // Fallback to common default
|
||
}
|
||
|
||
// Build CREATE DATABASE command with encoding and locale
|
||
// Using ENCODING 'UTF8' explicitly ensures the dump can be restored
|
||
createSQL := fmt.Sprintf(
|
||
"CREATE DATABASE \"%s\" WITH TEMPLATE template0 ENCODING 'UTF8' LC_COLLATE '%s' LC_CTYPE '%s'",
|
||
dbName, serverLocale, serverLocale,
|
||
)
|
||
|
||
createArgs := []string{
|
||
"-p", fmt.Sprintf("%d", e.cfg.Port),
|
||
"-U", e.cfg.User,
|
||
"-d", "postgres",
|
||
"-c", createSQL,
|
||
}
|
||
|
||
// Only add -h flag if host is not localhost (to use Unix socket for peer auth)
|
||
if e.cfg.Host != "localhost" && e.cfg.Host != "127.0.0.1" && e.cfg.Host != "" {
|
||
createArgs = append([]string{"-h", e.cfg.Host}, createArgs...)
|
||
}
|
||
|
||
createCmd := exec.CommandContext(ctx, "psql", createArgs...)
|
||
|
||
// Always set PGPASSWORD (empty string is fine for peer/ident auth)
|
||
createCmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", e.cfg.Password))
|
||
|
||
output, err = createCmd.CombinedOutput()
|
||
if err != nil {
|
||
// If encoding/locale fails, try simpler CREATE DATABASE
|
||
e.log.Warn("Database creation with encoding failed, trying simple create", "name", dbName, "error", err)
|
||
|
||
simpleArgs := []string{
|
||
"-p", fmt.Sprintf("%d", e.cfg.Port),
|
||
"-U", e.cfg.User,
|
||
"-d", "postgres",
|
||
"-c", fmt.Sprintf("CREATE DATABASE \"%s\" WITH TEMPLATE template0", dbName),
|
||
}
|
||
if e.cfg.Host != "localhost" && e.cfg.Host != "127.0.0.1" && e.cfg.Host != "" {
|
||
simpleArgs = append([]string{"-h", e.cfg.Host}, simpleArgs...)
|
||
}
|
||
|
||
simpleCmd := exec.CommandContext(ctx, "psql", simpleArgs...)
|
||
simpleCmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", e.cfg.Password))
|
||
|
||
output, err = simpleCmd.CombinedOutput()
|
||
if err != nil {
|
||
e.log.Warn("Database creation failed", "name", dbName, "error", err, "output", string(output))
|
||
return fmt.Errorf("failed to create database '%s': %w (output: %s)", dbName, err, strings.TrimSpace(string(output)))
|
||
}
|
||
}
|
||
|
||
e.log.Info("Successfully created database from template0", "name", dbName)
|
||
return nil
|
||
}
|
||
|
||
// previewClusterRestore shows cluster restore preview
|
||
func (e *Engine) previewClusterRestore(archivePath string) error {
|
||
fmt.Println("\n" + strings.Repeat("=", 60))
|
||
fmt.Println(" CLUSTER RESTORE PREVIEW (DRY RUN)")
|
||
fmt.Println(strings.Repeat("=", 60))
|
||
|
||
stat, _ := os.Stat(archivePath)
|
||
fmt.Printf("\nArchive: %s\n", filepath.Base(archivePath))
|
||
if stat != nil {
|
||
fmt.Printf("Size: %s\n", FormatBytes(stat.Size()))
|
||
fmt.Printf("Modified: %s\n", stat.ModTime().Format("2006-01-02 15:04:05"))
|
||
}
|
||
fmt.Printf("Target Host: %s:%d\n", e.cfg.Host, e.cfg.Port)
|
||
|
||
fmt.Println("\nOperations that would be performed:")
|
||
fmt.Println(" 1. Extract cluster archive to temporary directory")
|
||
fmt.Println(" 2. Restore global objects (roles, tablespaces)")
|
||
fmt.Println(" 3. Restore all databases found in archive")
|
||
fmt.Println(" 4. Cleanup temporary files")
|
||
|
||
fmt.Println("\n[WARN] WARNING: This will restore multiple databases.")
|
||
fmt.Println(" Existing databases may be overwritten or merged.")
|
||
fmt.Println("\nTo execute this restore, add the --confirm flag.")
|
||
fmt.Println(strings.Repeat("=", 60) + "\n")
|
||
|
||
return nil
|
||
}
|
||
|
||
// detectLargeObjectsInDumps checks if any dump files contain large objects
|
||
func (e *Engine) detectLargeObjectsInDumps(dumpsDir string, entries []os.DirEntry) bool {
|
||
hasLargeObjects := false
|
||
checkedCount := 0
|
||
maxChecks := 5 // Only check first 5 dumps to avoid slowdown
|
||
|
||
for _, entry := range entries {
|
||
if entry.IsDir() || checkedCount >= maxChecks {
|
||
continue
|
||
}
|
||
|
||
dumpFile := filepath.Join(dumpsDir, entry.Name())
|
||
|
||
// Skip compressed SQL files (can't easily check without decompressing)
|
||
if strings.HasSuffix(dumpFile, ".sql.gz") {
|
||
continue
|
||
}
|
||
|
||
// Use pg_restore -l to list contents (fast, doesn't restore data)
|
||
// 2 minutes for large dumps with many objects
|
||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||
defer cancel()
|
||
|
||
cmd := exec.CommandContext(ctx, "pg_restore", "-l", dumpFile)
|
||
output, err := cmd.Output()
|
||
|
||
if err != nil {
|
||
// If pg_restore -l fails, it might not be custom format - skip
|
||
continue
|
||
}
|
||
|
||
checkedCount++
|
||
|
||
// Check if output contains "BLOB" or "LARGE OBJECT" entries
|
||
outputStr := string(output)
|
||
if strings.Contains(outputStr, "BLOB") ||
|
||
strings.Contains(outputStr, "LARGE OBJECT") ||
|
||
strings.Contains(outputStr, " BLOBS ") {
|
||
e.log.Info("Large objects detected in dump file", "file", entry.Name())
|
||
hasLargeObjects = true
|
||
// Don't break - log all files with large objects
|
||
}
|
||
}
|
||
|
||
if hasLargeObjects {
|
||
e.log.Warn("Cluster contains databases with large objects - parallel restore may cause lock contention")
|
||
}
|
||
|
||
return hasLargeObjects
|
||
}
|
||
|
||
// isIgnorableError checks if an error message represents an ignorable PostgreSQL restore error
|
||
func (e *Engine) isIgnorableError(errorMsg string) bool {
|
||
// Convert to lowercase for case-insensitive matching
|
||
lowerMsg := strings.ToLower(errorMsg)
|
||
|
||
// CRITICAL: Syntax errors are NOT ignorable - indicates corrupted dump
|
||
if strings.Contains(lowerMsg, "syntax error") {
|
||
e.log.Error("CRITICAL: Syntax error in dump file - dump may be corrupted", "error", errorMsg)
|
||
return false
|
||
}
|
||
|
||
// CRITICAL: If error count is extremely high (>100k), dump is likely corrupted
|
||
if strings.Contains(errorMsg, "total errors:") {
|
||
// Extract error count if present in message
|
||
parts := strings.Split(errorMsg, "total errors:")
|
||
if len(parts) > 1 {
|
||
errorCountStr := strings.TrimSpace(strings.Split(parts[1], ")")[0])
|
||
// Try to parse as number
|
||
var count int
|
||
if _, err := fmt.Sscanf(errorCountStr, "%d", &count); err == nil && count > 100000 {
|
||
e.log.Error("CRITICAL: Excessive errors indicate corrupted dump", "error_count", count)
|
||
return false
|
||
}
|
||
}
|
||
}
|
||
|
||
// List of ignorable error patterns (objects that already exist)
|
||
ignorablePatterns := []string{
|
||
"already exists",
|
||
"duplicate key",
|
||
"does not exist, skipping", // For DROP IF EXISTS
|
||
"no pg_hba.conf entry", // Permission warnings (not fatal)
|
||
}
|
||
|
||
for _, pattern := range ignorablePatterns {
|
||
if strings.Contains(lowerMsg, pattern) {
|
||
return true
|
||
}
|
||
}
|
||
|
||
return false
|
||
}
|
||
|
||
// FormatBytes formats bytes to human readable format
|
||
func FormatBytes(bytes int64) string {
|
||
const unit = 1024
|
||
if bytes < unit {
|
||
return fmt.Sprintf("%d B", bytes)
|
||
}
|
||
div, exp := int64(unit), 0
|
||
for n := bytes / unit; n >= unit; n /= unit {
|
||
div *= unit
|
||
exp++
|
||
}
|
||
return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(div), "KMGTPE"[exp])
|
||
}
|
||
|
||
// quickValidateSQLDump performs a fast validation of SQL dump files
|
||
// by checking for truncated COPY blocks. This catches corrupted dumps
|
||
// BEFORE attempting a full restore (which could waste 49+ minutes).
|
||
func (e *Engine) quickValidateSQLDump(archivePath string, compressed bool) error {
|
||
e.log.Debug("Pre-validating SQL dump file", "path", archivePath, "compressed", compressed)
|
||
|
||
diagnoser := NewDiagnoser(e.log, false) // non-verbose for speed
|
||
result, err := diagnoser.DiagnoseFile(archivePath)
|
||
if err != nil {
|
||
return fmt.Errorf("diagnosis error: %w", err)
|
||
}
|
||
|
||
// Check for critical issues that would cause restore failure
|
||
if result.IsTruncated {
|
||
errMsg := "SQL dump file is TRUNCATED"
|
||
if result.Details != nil && result.Details.UnterminatedCopy {
|
||
errMsg = fmt.Sprintf("%s - unterminated COPY block for table '%s' at line %d",
|
||
errMsg, result.Details.LastCopyTable, result.Details.LastCopyLineNumber)
|
||
if len(result.Details.SampleCopyData) > 0 {
|
||
errMsg = fmt.Sprintf("%s (sample orphaned data: %s)", errMsg, result.Details.SampleCopyData[0])
|
||
}
|
||
}
|
||
return fmt.Errorf("%s", errMsg)
|
||
}
|
||
|
||
if result.IsCorrupted {
|
||
return fmt.Errorf("SQL dump file is corrupted: %v", result.Errors)
|
||
}
|
||
|
||
if !result.IsValid {
|
||
if len(result.Errors) > 0 {
|
||
return fmt.Errorf("dump validation failed: %s", result.Errors[0])
|
||
}
|
||
return fmt.Errorf("dump file is invalid (unknown reason)")
|
||
}
|
||
|
||
// Log any warnings but don't fail
|
||
for _, warning := range result.Warnings {
|
||
e.log.Warn("Dump validation warning", "warning", warning)
|
||
}
|
||
|
||
e.log.Debug("SQL dump validation passed", "path", archivePath)
|
||
return nil
|
||
}
|
||
|
||
// boostLockCapacity temporarily increases max_locks_per_transaction to prevent OOM
|
||
// during large restores with many BLOBs. Returns the original value for later reset.
|
||
// Uses ALTER SYSTEM + pg_reload_conf() so no restart is needed.
|
||
func (e *Engine) boostLockCapacity(ctx context.Context) (int, error) {
|
||
// Connect to PostgreSQL to run system commands
|
||
connStr := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=postgres sslmode=disable",
|
||
e.cfg.Host, e.cfg.Port, e.cfg.User, e.cfg.Password)
|
||
|
||
// For localhost, use Unix socket
|
||
if e.cfg.Host == "localhost" || e.cfg.Host == "" {
|
||
connStr = fmt.Sprintf("user=%s password=%s dbname=postgres sslmode=disable",
|
||
e.cfg.User, e.cfg.Password)
|
||
}
|
||
|
||
db, err := sql.Open("pgx", connStr)
|
||
if err != nil {
|
||
return 0, fmt.Errorf("failed to connect: %w", err)
|
||
}
|
||
defer db.Close()
|
||
|
||
// Get current value
|
||
var currentValue int
|
||
err = db.QueryRowContext(ctx, "SHOW max_locks_per_transaction").Scan(¤tValue)
|
||
if err != nil {
|
||
// Try parsing as string (some versions return string)
|
||
var currentValueStr string
|
||
err = db.QueryRowContext(ctx, "SHOW max_locks_per_transaction").Scan(¤tValueStr)
|
||
if err != nil {
|
||
return 0, fmt.Errorf("failed to get current max_locks_per_transaction: %w", err)
|
||
}
|
||
fmt.Sscanf(currentValueStr, "%d", ¤tValue)
|
||
}
|
||
|
||
// Skip if already high enough
|
||
if currentValue >= 2048 {
|
||
e.log.Info("max_locks_per_transaction already sufficient", "value", currentValue)
|
||
return currentValue, nil
|
||
}
|
||
|
||
// Boost to 2048 (enough for most BLOB-heavy databases)
|
||
_, err = db.ExecContext(ctx, "ALTER SYSTEM SET max_locks_per_transaction = 2048")
|
||
if err != nil {
|
||
return currentValue, fmt.Errorf("failed to set max_locks_per_transaction: %w", err)
|
||
}
|
||
|
||
// Reload config without restart
|
||
_, err = db.ExecContext(ctx, "SELECT pg_reload_conf()")
|
||
if err != nil {
|
||
return currentValue, fmt.Errorf("failed to reload config: %w", err)
|
||
}
|
||
|
||
return currentValue, nil
|
||
}
|
||
|
||
// resetLockCapacity restores the original max_locks_per_transaction value
|
||
func (e *Engine) resetLockCapacity(ctx context.Context, originalValue int) error {
|
||
connStr := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=postgres sslmode=disable",
|
||
e.cfg.Host, e.cfg.Port, e.cfg.User, e.cfg.Password)
|
||
|
||
if e.cfg.Host == "localhost" || e.cfg.Host == "" {
|
||
connStr = fmt.Sprintf("user=%s password=%s dbname=postgres sslmode=disable",
|
||
e.cfg.User, e.cfg.Password)
|
||
}
|
||
|
||
db, err := sql.Open("pgx", connStr)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to connect: %w", err)
|
||
}
|
||
defer db.Close()
|
||
|
||
// Reset to original value (or use RESET to go back to default)
|
||
if originalValue == 64 { // Default value
|
||
_, err = db.ExecContext(ctx, "ALTER SYSTEM RESET max_locks_per_transaction")
|
||
} else {
|
||
_, err = db.ExecContext(ctx, fmt.Sprintf("ALTER SYSTEM SET max_locks_per_transaction = %d", originalValue))
|
||
}
|
||
if err != nil {
|
||
return fmt.Errorf("failed to reset max_locks_per_transaction: %w", err)
|
||
}
|
||
|
||
// Reload config
|
||
_, err = db.ExecContext(ctx, "SELECT pg_reload_conf()")
|
||
if err != nil {
|
||
return fmt.Errorf("failed to reload config: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// OriginalSettings stores PostgreSQL settings to restore after operation
|
||
type OriginalSettings struct {
|
||
MaxLocks int
|
||
MaintenanceWorkMem string
|
||
}
|
||
|
||
// boostPostgreSQLSettings boosts multiple PostgreSQL settings for large restores
|
||
// NOTE: max_locks_per_transaction requires a PostgreSQL RESTART to take effect!
|
||
// maintenance_work_mem can be changed with pg_reload_conf().
|
||
func (e *Engine) boostPostgreSQLSettings(ctx context.Context, lockBoostValue int) (*OriginalSettings, error) {
|
||
connStr := e.buildConnString()
|
||
db, err := sql.Open("pgx", connStr)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to connect: %w", err)
|
||
}
|
||
defer db.Close()
|
||
|
||
original := &OriginalSettings{}
|
||
|
||
// Get current max_locks_per_transaction
|
||
var maxLocksStr string
|
||
if err := db.QueryRowContext(ctx, "SHOW max_locks_per_transaction").Scan(&maxLocksStr); err == nil {
|
||
original.MaxLocks, _ = strconv.Atoi(maxLocksStr)
|
||
}
|
||
|
||
// Get current maintenance_work_mem
|
||
db.QueryRowContext(ctx, "SHOW maintenance_work_mem").Scan(&original.MaintenanceWorkMem)
|
||
|
||
// CRITICAL: max_locks_per_transaction requires a PostgreSQL RESTART!
|
||
// pg_reload_conf() is NOT sufficient for this parameter.
|
||
needsRestart := false
|
||
if original.MaxLocks < lockBoostValue {
|
||
_, err = db.ExecContext(ctx, fmt.Sprintf("ALTER SYSTEM SET max_locks_per_transaction = %d", lockBoostValue))
|
||
if err != nil {
|
||
e.log.Warn("Could not set max_locks_per_transaction", "error", err)
|
||
} else {
|
||
needsRestart = true
|
||
e.log.Warn("max_locks_per_transaction requires PostgreSQL restart to take effect",
|
||
"current", original.MaxLocks,
|
||
"target", lockBoostValue)
|
||
}
|
||
}
|
||
|
||
// Boost maintenance_work_mem to 2GB for faster index creation
|
||
// (this one CAN be applied via pg_reload_conf)
|
||
_, err = db.ExecContext(ctx, "ALTER SYSTEM SET maintenance_work_mem = '2GB'")
|
||
if err != nil {
|
||
e.log.Warn("Could not boost maintenance_work_mem", "error", err)
|
||
}
|
||
|
||
// Reload config to apply maintenance_work_mem
|
||
_, err = db.ExecContext(ctx, "SELECT pg_reload_conf()")
|
||
if err != nil {
|
||
return original, fmt.Errorf("failed to reload config: %w", err)
|
||
}
|
||
|
||
// If max_locks_per_transaction needs a restart, try to do it
|
||
if needsRestart {
|
||
if restarted := e.tryRestartPostgreSQL(ctx); restarted {
|
||
e.log.Info("PostgreSQL restarted successfully - max_locks_per_transaction now active")
|
||
// Wait for PostgreSQL to be ready
|
||
time.Sleep(3 * time.Second)
|
||
} else {
|
||
// Cannot restart - warn user but continue
|
||
// The setting is written to postgresql.auto.conf and will take effect on next restart
|
||
e.log.Warn("=" + strings.Repeat("=", 70))
|
||
e.log.Warn("NOTE: max_locks_per_transaction change requires PostgreSQL restart")
|
||
e.log.Warn("Current value: " + strconv.Itoa(original.MaxLocks) + ", target: " + strconv.Itoa(lockBoostValue))
|
||
e.log.Warn("")
|
||
e.log.Warn("The setting has been saved to postgresql.auto.conf and will take")
|
||
e.log.Warn("effect on the next PostgreSQL restart. If restore fails with")
|
||
e.log.Warn("'out of shared memory' errors, ask your DBA to restart PostgreSQL.")
|
||
e.log.Warn("")
|
||
e.log.Warn("Continuing with restore - this may succeed if your databases")
|
||
e.log.Warn("don't have many large objects (BLOBs).")
|
||
e.log.Warn("=" + strings.Repeat("=", 70))
|
||
// Continue anyway - might work for small restores or DBs without BLOBs
|
||
}
|
||
}
|
||
|
||
return original, nil
|
||
}
|
||
|
||
// canRestartPostgreSQL checks if we have the ability to restart PostgreSQL
|
||
// Returns false if running in a restricted environment (e.g., su postgres on enterprise systems)
|
||
func (e *Engine) canRestartPostgreSQL() bool {
|
||
// Check if we're running as postgres user - if so, we likely can't restart
|
||
// because PostgreSQL is managed by init/systemd, not directly by pg_ctl
|
||
currentUser := os.Getenv("USER")
|
||
if currentUser == "" {
|
||
currentUser = os.Getenv("LOGNAME")
|
||
}
|
||
|
||
// If we're the postgres user, check if we have sudo access
|
||
if currentUser == "postgres" {
|
||
// Try a quick sudo check - if this fails, we can't restart
|
||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||
defer cancel()
|
||
cmd := exec.CommandContext(ctx, "sudo", "-n", "true")
|
||
cmd.Stdin = nil
|
||
if err := cmd.Run(); err != nil {
|
||
e.log.Info("Running as postgres user without sudo access - cannot restart PostgreSQL",
|
||
"user", currentUser,
|
||
"hint", "Ask system administrator to restart PostgreSQL if needed")
|
||
return false
|
||
}
|
||
}
|
||
|
||
return true
|
||
}
|
||
|
||
// tryRestartPostgreSQL attempts to restart PostgreSQL using various methods
|
||
// Returns true if restart was successful
|
||
// IMPORTANT: Uses short timeouts and non-interactive sudo to avoid blocking on password prompts
|
||
// NOTE: This function will return false immediately if running as postgres without sudo
|
||
func (e *Engine) tryRestartPostgreSQL(ctx context.Context) bool {
|
||
// First check if we can even attempt a restart
|
||
if !e.canRestartPostgreSQL() {
|
||
e.log.Info("Skipping PostgreSQL restart attempt (no privileges)")
|
||
return false
|
||
}
|
||
|
||
e.progress.Update("Attempting PostgreSQL restart for lock settings...")
|
||
|
||
// Use short timeout for each restart attempt (don't block on sudo password prompts)
|
||
runWithTimeout := func(args ...string) bool {
|
||
cmdCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||
defer cancel()
|
||
cmd := exec.CommandContext(cmdCtx, args[0], args[1:]...)
|
||
// Set stdin to /dev/null to prevent sudo from waiting for password
|
||
cmd.Stdin = nil
|
||
return cmd.Run() == nil
|
||
}
|
||
|
||
// Method 1: systemctl (most common on modern Linux) - use sudo -n for non-interactive
|
||
if runWithTimeout("sudo", "-n", "systemctl", "restart", "postgresql") {
|
||
return true
|
||
}
|
||
|
||
// Method 2: systemctl with version suffix (e.g., postgresql-15)
|
||
for _, ver := range []string{"17", "16", "15", "14", "13", "12"} {
|
||
if runWithTimeout("sudo", "-n", "systemctl", "restart", "postgresql-"+ver) {
|
||
return true
|
||
}
|
||
}
|
||
|
||
// Method 3: service command (older systems)
|
||
if runWithTimeout("sudo", "-n", "service", "postgresql", "restart") {
|
||
return true
|
||
}
|
||
|
||
// Method 4: pg_ctl as postgres user (if we ARE postgres user, no sudo needed)
|
||
if runWithTimeout("pg_ctl", "restart", "-D", "/var/lib/postgresql/data", "-m", "fast") {
|
||
return true
|
||
}
|
||
|
||
// Method 5: Try common PGDATA paths with pg_ctl directly (for postgres user)
|
||
pgdataPaths := []string{
|
||
"/var/lib/pgsql/data",
|
||
"/var/lib/pgsql/17/data",
|
||
"/var/lib/pgsql/16/data",
|
||
"/var/lib/pgsql/15/data",
|
||
"/var/lib/postgresql/17/main",
|
||
"/var/lib/postgresql/16/main",
|
||
"/var/lib/postgresql/15/main",
|
||
}
|
||
for _, pgdata := range pgdataPaths {
|
||
if runWithTimeout("pg_ctl", "restart", "-D", pgdata, "-m", "fast") {
|
||
return true
|
||
}
|
||
}
|
||
|
||
return false
|
||
}
|
||
|
||
// resetPostgreSQLSettings restores original PostgreSQL settings
|
||
// NOTE: max_locks_per_transaction changes are written but require restart to take effect.
|
||
// We don't restart here since we're done with the restore.
|
||
func (e *Engine) resetPostgreSQLSettings(ctx context.Context, original *OriginalSettings) error {
|
||
connStr := e.buildConnString()
|
||
db, err := sql.Open("pgx", connStr)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to connect: %w", err)
|
||
}
|
||
defer db.Close()
|
||
|
||
// Reset max_locks_per_transaction (will take effect on next restart)
|
||
if original.MaxLocks == 64 { // Default
|
||
db.ExecContext(ctx, "ALTER SYSTEM RESET max_locks_per_transaction")
|
||
} else if original.MaxLocks > 0 {
|
||
db.ExecContext(ctx, fmt.Sprintf("ALTER SYSTEM SET max_locks_per_transaction = %d", original.MaxLocks))
|
||
}
|
||
|
||
// Reset maintenance_work_mem (takes effect immediately with reload)
|
||
if original.MaintenanceWorkMem == "64MB" { // Default
|
||
db.ExecContext(ctx, "ALTER SYSTEM RESET maintenance_work_mem")
|
||
} else if original.MaintenanceWorkMem != "" {
|
||
db.ExecContext(ctx, fmt.Sprintf("ALTER SYSTEM SET maintenance_work_mem = '%s'", original.MaintenanceWorkMem))
|
||
}
|
||
|
||
// Reload config (only maintenance_work_mem will take effect immediately)
|
||
_, err = db.ExecContext(ctx, "SELECT pg_reload_conf()")
|
||
if err != nil {
|
||
return fmt.Errorf("failed to reload config: %w", err)
|
||
}
|
||
|
||
e.log.Info("PostgreSQL settings reset queued",
|
||
"note", "max_locks_per_transaction will revert on next PostgreSQL restart")
|
||
|
||
return nil
|
||
}
|