Files
dbbackup/internal/engine/native/parallel_restore.go
Renz a101fb81ab
Some checks failed
CI/CD / Test (push) Successful in 3m25s
CI/CD / Lint (push) Successful in 1m33s
CI/CD / Integration Tests (push) Successful in 1m4s
CI/CD / Native Engine Tests (push) Successful in 1m2s
CI/CD / Build Binary (push) Successful in 56s
CI/CD / Test Release Build (push) Successful in 1m41s
CI/CD / Release Binaries (push) Failing after 11m55s
v5.8.22: Defensive fixes for potential restore hang issues
- Add context cancellation check during COPY data parsing loop
  (prevents hangs when parsing large tables with millions of rows)
- Add 5-second timeout for stderr reader in globals restore
  (prevents indefinite hang if psql process doesn't terminate cleanly)
- Reduce database drop timeout from 5 minutes to 60 seconds
  (improves TUI responsiveness during cluster cleanup)
2026-02-05 12:40:26 +00:00

590 lines
16 KiB
Go

package native
import (
"bufio"
"bytes"
"compress/gzip"
"context"
"fmt"
"io"
"os"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/klauspost/pgzip"
"dbbackup/internal/logger"
)
// ParallelRestoreEngine provides high-performance parallel SQL restore
// that can match pg_restore -j8 performance for SQL format dumps
type ParallelRestoreEngine struct {
config *PostgreSQLNativeConfig
pool *pgxpool.Pool
log logger.Logger
// Configuration
parallelWorkers int
// Internal cancel channel to stop the pool cleanup goroutine
closeCh chan struct{}
}
// ParallelRestoreOptions configures parallel restore behavior
type ParallelRestoreOptions struct {
// Number of parallel workers for COPY operations (like pg_restore -j)
Workers int
// Continue on error instead of stopping
ContinueOnError bool
// Progress callback
ProgressCallback func(phase string, current, total int, tableName string)
}
// ParallelRestoreResult contains restore statistics
type ParallelRestoreResult struct {
Duration time.Duration
SchemaStatements int64
TablesRestored int64
RowsRestored int64
IndexesCreated int64
Errors []string
}
// SQLStatement represents a parsed SQL statement with metadata
type SQLStatement struct {
SQL string
Type StatementType
TableName string // For COPY statements
CopyData bytes.Buffer // Data for COPY FROM STDIN
}
// StatementType classifies SQL statements for parallel execution
type StatementType int
const (
StmtSchema StatementType = iota // CREATE TABLE, TYPE, FUNCTION, etc.
StmtCopyData // COPY ... FROM stdin with data
StmtPostData // CREATE INDEX, ADD CONSTRAINT, etc.
StmtOther // SET, COMMENT, etc.
)
// NewParallelRestoreEngine creates a new parallel restore engine
// NOTE: Pass a cancellable context to ensure the pool is properly closed on Ctrl+C
func NewParallelRestoreEngine(config *PostgreSQLNativeConfig, log logger.Logger, workers int) (*ParallelRestoreEngine, error) {
return NewParallelRestoreEngineWithContext(context.Background(), config, log, workers)
}
// NewParallelRestoreEngineWithContext creates a new parallel restore engine with context support
// This ensures the connection pool is properly closed when the context is cancelled
func NewParallelRestoreEngineWithContext(ctx context.Context, config *PostgreSQLNativeConfig, log logger.Logger, workers int) (*ParallelRestoreEngine, error) {
if workers < 1 {
workers = 4 // Default to 4 parallel workers
}
// Build connection string
sslMode := config.SSLMode
if sslMode == "" {
sslMode = "prefer"
}
connString := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
config.Host, config.Port, config.User, config.Password, config.Database, sslMode)
// Create connection pool with enough connections for parallel workers
poolConfig, err := pgxpool.ParseConfig(connString)
if err != nil {
return nil, fmt.Errorf("failed to parse connection config: %w", err)
}
// Pool size = workers + 1 (for schema operations)
poolConfig.MaxConns = int32(workers + 2)
poolConfig.MinConns = int32(workers)
// CRITICAL: Reduce health check period to allow faster shutdown
// Default is 1 minute which causes hangs on Ctrl+C
poolConfig.HealthCheckPeriod = 5 * time.Second
// CRITICAL: Set connection-level timeouts to ensure queries can be cancelled
// This prevents infinite hangs on slow/stuck operations
poolConfig.ConnConfig.RuntimeParams = map[string]string{
"statement_timeout": "3600000", // 1 hour max per statement (in ms)
"lock_timeout": "300000", // 5 min max wait for locks (in ms)
"idle_in_transaction_session_timeout": "600000", // 10 min idle timeout (in ms)
}
// Use the provided context so pool health checks stop when context is cancelled
pool, err := pgxpool.NewWithConfig(ctx, poolConfig)
if err != nil {
return nil, fmt.Errorf("failed to create connection pool: %w", err)
}
closeCh := make(chan struct{})
engine := &ParallelRestoreEngine{
config: config,
pool: pool,
log: log,
parallelWorkers: workers,
closeCh: closeCh,
}
// NOTE: We intentionally do NOT start a goroutine to close the pool on context cancellation.
// The pool is closed via defer parallelEngine.Close() in the caller (restore/engine.go).
// The Close() method properly signals closeCh and closes the pool.
// Starting a goroutine here can cause:
// 1. Race conditions with explicit Close() calls
// 2. Goroutine leaks if neither ctx nor Close() fires
// 3. Deadlocks with BubbleTea's event loop
return engine, nil
}
// RestoreFile restores from a SQL file with parallel execution
func (e *ParallelRestoreEngine) RestoreFile(ctx context.Context, filePath string, options *ParallelRestoreOptions) (*ParallelRestoreResult, error) {
startTime := time.Now()
result := &ParallelRestoreResult{}
if options == nil {
options = &ParallelRestoreOptions{Workers: e.parallelWorkers}
}
if options.Workers < 1 {
options.Workers = e.parallelWorkers
}
e.log.Info("Starting parallel SQL restore",
"file", filePath,
"workers", options.Workers)
// Open file (handle gzip)
file, err := os.Open(filePath)
if err != nil {
return result, fmt.Errorf("failed to open file: %w", err)
}
defer file.Close()
var reader io.Reader = file
if strings.HasSuffix(filePath, ".gz") {
gzReader, err := pgzip.NewReader(file)
if err != nil {
return result, fmt.Errorf("failed to create gzip reader: %w", err)
}
defer gzReader.Close()
reader = gzReader
}
// Phase 1: Parse and classify statements
e.log.Info("Phase 1: Parsing SQL dump...")
if options.ProgressCallback != nil {
options.ProgressCallback("parsing", 0, 0, "")
}
statements, err := e.parseStatementsWithContext(ctx, reader)
if err != nil {
return result, fmt.Errorf("failed to parse SQL: %w", err)
}
// Count by type
var schemaCount, copyCount, postDataCount int
for _, stmt := range statements {
switch stmt.Type {
case StmtSchema:
schemaCount++
case StmtCopyData:
copyCount++
case StmtPostData:
postDataCount++
}
}
e.log.Info("Parsed SQL dump",
"schema_statements", schemaCount,
"copy_operations", copyCount,
"post_data_statements", postDataCount)
// Phase 2: Execute schema statements (sequential - must be in order)
e.log.Info("Phase 2: Creating schema (sequential)...")
if options.ProgressCallback != nil {
options.ProgressCallback("schema", 0, schemaCount, "")
}
schemaStmts := 0
for _, stmt := range statements {
// Check for context cancellation periodically
select {
case <-ctx.Done():
return result, ctx.Err()
default:
}
if stmt.Type == StmtSchema || stmt.Type == StmtOther {
if err := e.executeStatement(ctx, stmt.SQL); err != nil {
if options.ContinueOnError {
result.Errors = append(result.Errors, err.Error())
} else {
return result, fmt.Errorf("schema creation failed: %w", err)
}
}
schemaStmts++
result.SchemaStatements++
if options.ProgressCallback != nil && schemaStmts%100 == 0 {
options.ProgressCallback("schema", schemaStmts, schemaCount, "")
}
}
}
// Phase 3: Execute COPY operations in parallel (THE KEY TO PERFORMANCE!)
e.log.Info("Phase 3: Loading data in parallel...",
"tables", copyCount,
"workers", options.Workers)
if options.ProgressCallback != nil {
options.ProgressCallback("data", 0, copyCount, "")
}
copyStmts := make([]*SQLStatement, 0, copyCount)
for i := range statements {
if statements[i].Type == StmtCopyData {
copyStmts = append(copyStmts, &statements[i])
}
}
// Execute COPY operations in parallel using worker pool
var wg sync.WaitGroup
semaphore := make(chan struct{}, options.Workers)
var completedCopies int64
var totalRows int64
var cancelled int32 // Atomic flag to signal cancellation
copyLoop:
for _, stmt := range copyStmts {
// Check for context cancellation before starting new work
if ctx.Err() != nil {
break
}
wg.Add(1)
select {
case semaphore <- struct{}{}: // Acquire worker slot
case <-ctx.Done():
wg.Done()
atomic.StoreInt32(&cancelled, 1)
break copyLoop // CRITICAL: Use labeled break to exit the for loop, not just the select
}
go func(s *SQLStatement) {
defer wg.Done()
defer func() { <-semaphore }() // Release worker slot
// Check cancellation before executing
if ctx.Err() != nil || atomic.LoadInt32(&cancelled) == 1 {
return
}
rows, err := e.executeCopy(ctx, s)
if err != nil {
if ctx.Err() != nil {
// Context cancelled, don't log as error
return
}
if options.ContinueOnError {
e.log.Warn("COPY failed", "table", s.TableName, "error", err)
} else {
e.log.Error("COPY failed", "table", s.TableName, "error", err)
}
} else {
atomic.AddInt64(&totalRows, rows)
}
completed := atomic.AddInt64(&completedCopies, 1)
if options.ProgressCallback != nil {
options.ProgressCallback("data", int(completed), copyCount, s.TableName)
}
}(stmt)
}
wg.Wait()
// Check if cancelled
if ctx.Err() != nil {
return result, ctx.Err()
}
result.TablesRestored = completedCopies
result.RowsRestored = totalRows
// Phase 4: Execute post-data statements in parallel (indexes, constraints)
e.log.Info("Phase 4: Creating indexes and constraints in parallel...",
"statements", postDataCount,
"workers", options.Workers)
if options.ProgressCallback != nil {
options.ProgressCallback("indexes", 0, postDataCount, "")
}
postDataStmts := make([]string, 0, postDataCount)
for _, stmt := range statements {
if stmt.Type == StmtPostData {
postDataStmts = append(postDataStmts, stmt.SQL)
}
}
// Execute post-data in parallel
var completedPostData int64
cancelled = 0 // Reset for phase 4
postDataLoop:
for _, sql := range postDataStmts {
// Check for context cancellation before starting new work
if ctx.Err() != nil {
break
}
wg.Add(1)
select {
case semaphore <- struct{}{}:
case <-ctx.Done():
wg.Done()
atomic.StoreInt32(&cancelled, 1)
break postDataLoop // CRITICAL: Use labeled break to exit the for loop, not just the select
}
go func(stmt string) {
defer wg.Done()
defer func() { <-semaphore }()
// Check cancellation before executing
if ctx.Err() != nil || atomic.LoadInt32(&cancelled) == 1 {
return
}
if err := e.executeStatement(ctx, stmt); err != nil {
if ctx.Err() != nil {
return // Context cancelled
}
if options.ContinueOnError {
e.log.Warn("Post-data statement failed", "error", err)
}
} else {
atomic.AddInt64(&result.IndexesCreated, 1)
}
completed := atomic.AddInt64(&completedPostData, 1)
if options.ProgressCallback != nil {
options.ProgressCallback("indexes", int(completed), postDataCount, "")
}
}(sql)
}
wg.Wait()
// Check if cancelled
if ctx.Err() != nil {
return result, ctx.Err()
}
result.Duration = time.Since(startTime)
e.log.Info("Parallel restore completed",
"duration", result.Duration,
"tables", result.TablesRestored,
"rows", result.RowsRestored,
"indexes", result.IndexesCreated)
return result, nil
}
// parseStatements reads and classifies all SQL statements
func (e *ParallelRestoreEngine) parseStatements(reader io.Reader) ([]SQLStatement, error) {
return e.parseStatementsWithContext(context.Background(), reader)
}
// parseStatementsWithContext reads and classifies all SQL statements with context support
func (e *ParallelRestoreEngine) parseStatementsWithContext(ctx context.Context, reader io.Reader) ([]SQLStatement, error) {
scanner := bufio.NewScanner(reader)
scanner.Buffer(make([]byte, 1024*1024), 64*1024*1024) // 64MB max for large statements
var statements []SQLStatement
var stmtBuffer bytes.Buffer
var inCopyMode bool
var currentCopyStmt *SQLStatement
lineCount := 0
for scanner.Scan() {
// Check for context cancellation every 10000 lines
lineCount++
if lineCount%10000 == 0 {
select {
case <-ctx.Done():
return statements, ctx.Err()
default:
}
}
line := scanner.Text()
// Handle COPY data mode
if inCopyMode {
if line == "\\." {
// End of COPY data
if currentCopyStmt != nil {
statements = append(statements, *currentCopyStmt)
currentCopyStmt = nil
}
inCopyMode = false
continue
}
if currentCopyStmt != nil {
currentCopyStmt.CopyData.WriteString(line)
currentCopyStmt.CopyData.WriteByte('\n')
}
// Check for context cancellation during COPY data parsing (large tables)
// Check every 10000 lines to avoid overhead
if lineCount%10000 == 0 {
select {
case <-ctx.Done():
return statements, ctx.Err()
default:
}
}
continue
}
// Check for COPY statement start
trimmed := strings.TrimSpace(line)
upperTrimmed := strings.ToUpper(trimmed)
if strings.HasPrefix(upperTrimmed, "COPY ") && strings.HasSuffix(trimmed, "FROM stdin;") {
// Extract table name
parts := strings.Fields(line)
tableName := ""
if len(parts) >= 2 {
tableName = parts[1]
}
currentCopyStmt = &SQLStatement{
SQL: line,
Type: StmtCopyData,
TableName: tableName,
}
inCopyMode = true
continue
}
// Skip comments and empty lines
if trimmed == "" || strings.HasPrefix(trimmed, "--") {
continue
}
// Accumulate statement
stmtBuffer.WriteString(line)
stmtBuffer.WriteByte('\n')
// Check if statement is complete
if strings.HasSuffix(trimmed, ";") {
sql := stmtBuffer.String()
stmtBuffer.Reset()
stmt := SQLStatement{
SQL: sql,
Type: classifyStatement(sql),
}
statements = append(statements, stmt)
}
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("error scanning SQL: %w", err)
}
return statements, nil
}
// classifyStatement determines the type of SQL statement
func classifyStatement(sql string) StatementType {
upper := strings.ToUpper(strings.TrimSpace(sql))
// Post-data statements (can be parallelized)
if strings.HasPrefix(upper, "CREATE INDEX") ||
strings.HasPrefix(upper, "CREATE UNIQUE INDEX") ||
strings.HasPrefix(upper, "ALTER TABLE") && strings.Contains(upper, "ADD CONSTRAINT") ||
strings.HasPrefix(upper, "ALTER TABLE") && strings.Contains(upper, "ADD FOREIGN KEY") ||
strings.HasPrefix(upper, "CREATE TRIGGER") ||
strings.HasPrefix(upper, "ALTER TABLE") && strings.Contains(upper, "ENABLE TRIGGER") {
return StmtPostData
}
// Schema statements (must be sequential)
if strings.HasPrefix(upper, "CREATE ") ||
strings.HasPrefix(upper, "ALTER ") ||
strings.HasPrefix(upper, "DROP ") ||
strings.HasPrefix(upper, "GRANT ") ||
strings.HasPrefix(upper, "REVOKE ") {
return StmtSchema
}
return StmtOther
}
// executeStatement executes a single SQL statement
func (e *ParallelRestoreEngine) executeStatement(ctx context.Context, sql string) error {
conn, err := e.pool.Acquire(ctx)
if err != nil {
return fmt.Errorf("failed to acquire connection: %w", err)
}
defer conn.Release()
_, err = conn.Exec(ctx, sql)
return err
}
// executeCopy executes a COPY FROM STDIN operation with BLOB optimization
func (e *ParallelRestoreEngine) executeCopy(ctx context.Context, stmt *SQLStatement) (int64, error) {
conn, err := e.pool.Acquire(ctx)
if err != nil {
return 0, fmt.Errorf("failed to acquire connection: %w", err)
}
defer conn.Release()
// Apply per-connection BLOB-optimized settings
// PostgreSQL Specialist recommended settings for maximum BLOB throughput
optimizations := []string{
"SET synchronous_commit = 'off'", // Don't wait for WAL sync
"SET session_replication_role = 'replica'", // Disable triggers during load
"SET work_mem = '256MB'", // More memory for sorting
"SET maintenance_work_mem = '512MB'", // For constraint validation
"SET wal_buffers = '64MB'", // Larger WAL buffer
"SET checkpoint_completion_target = '0.9'", // Spread checkpoint I/O
}
for _, opt := range optimizations {
conn.Exec(ctx, opt)
}
// Execute the COPY
copySQL := fmt.Sprintf("COPY %s FROM STDIN", stmt.TableName)
tag, err := conn.Conn().PgConn().CopyFrom(ctx, strings.NewReader(stmt.CopyData.String()), copySQL)
if err != nil {
return 0, err
}
return tag.RowsAffected(), nil
}
// Close closes the connection pool and stops the cleanup goroutine
func (e *ParallelRestoreEngine) Close() error {
// Signal the cleanup goroutine to exit
if e.closeCh != nil {
close(e.closeCh)
}
// Close the pool
if e.pool != nil {
e.pool.Close()
}
return nil
}
// Ensure gzip import is used
var _ = gzip.BestCompression