Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 63b35414d2 | |||
| db46770e7f | |||
| 51764a677a | |||
| bdbbb59e51 | |||
| 1a6ea13222 | |||
| 598056ffe3 | |||
| 185c8fb0f3 | |||
| d80ac4cae4 |
23
cmd/root.go
23
cmd/root.go
@ -15,11 +15,12 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
cfg *config.Config
|
||||
log logger.Logger
|
||||
auditLogger *security.AuditLogger
|
||||
rateLimiter *security.RateLimiter
|
||||
notifyManager *notify.Manager
|
||||
cfg *config.Config
|
||||
log logger.Logger
|
||||
auditLogger *security.AuditLogger
|
||||
rateLimiter *security.RateLimiter
|
||||
notifyManager *notify.Manager
|
||||
deprecatedPassword string
|
||||
)
|
||||
|
||||
// rootCmd represents the base command when called without any subcommands
|
||||
@ -47,6 +48,11 @@ For help with specific commands, use: dbbackup [command] --help`,
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check for deprecated password flag
|
||||
if deprecatedPassword != "" {
|
||||
return fmt.Errorf("--password flag is not supported for security reasons. Use environment variables instead:\n - MySQL/MariaDB: export MYSQL_PWD='your_password'\n - PostgreSQL: export PGPASSWORD='your_password' or use .pgpass file")
|
||||
}
|
||||
|
||||
// Store which flags were explicitly set by user
|
||||
flagsSet := make(map[string]bool)
|
||||
cmd.Flags().Visit(func(f *pflag.Flag) {
|
||||
@ -171,15 +177,8 @@ func Execute(ctx context.Context, config *config.Config, logger logger.Logger) e
|
||||
rootCmd.PersistentFlags().StringVar(&cfg.Database, "database", cfg.Database, "Database name")
|
||||
// SECURITY: Password flag removed - use PGPASSWORD/MYSQL_PWD environment variable or .pgpass file
|
||||
// Provide helpful error message for users expecting --password flag
|
||||
var deprecatedPassword string
|
||||
rootCmd.PersistentFlags().StringVar(&deprecatedPassword, "password", "", "DEPRECATED: Use MYSQL_PWD or PGPASSWORD environment variable instead")
|
||||
rootCmd.PersistentFlags().MarkHidden("password")
|
||||
rootCmd.PersistentPreRunE = func(cmd *cobra.Command, args []string) error {
|
||||
if deprecatedPassword != "" {
|
||||
return fmt.Errorf("--password flag is not supported for security reasons. Use environment variables instead:\n - MySQL/MariaDB: export MYSQL_PWD='your_password'\n - PostgreSQL: export PGPASSWORD='your_password' or use .pgpass file")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
rootCmd.PersistentFlags().StringVarP(&cfg.DatabaseType, "db-type", "d", cfg.DatabaseType, "Database type (postgres|mysql|mariadb)")
|
||||
rootCmd.PersistentFlags().StringVar(&cfg.BackupDir, "backup-dir", cfg.BackupDir, "Backup directory")
|
||||
rootCmd.PersistentFlags().BoolVar(&cfg.NoColor, "no-color", cfg.NoColor, "Disable colored output")
|
||||
|
||||
@ -6,6 +6,7 @@ import (
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const ConfigFileName = ".dbbackup.conf"
|
||||
@ -159,87 +160,58 @@ func LoadLocalConfigFromPath(configPath string) (*LocalConfig, error) {
|
||||
|
||||
// SaveLocalConfig saves configuration to .dbbackup.conf in current directory
|
||||
func SaveLocalConfig(cfg *LocalConfig) error {
|
||||
return SaveLocalConfigToPath(cfg, filepath.Join(".", ConfigFileName))
|
||||
}
|
||||
|
||||
// SaveLocalConfigToPath saves configuration to a specific path
|
||||
func SaveLocalConfigToPath(cfg *LocalConfig, configPath string) error {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("# dbbackup configuration\n")
|
||||
sb.WriteString("# This file is auto-generated. Edit with care.\n\n")
|
||||
sb.WriteString("# This file is auto-generated. Edit with care.\n")
|
||||
sb.WriteString(fmt.Sprintf("# Saved: %s\n\n", time.Now().Format(time.RFC3339)))
|
||||
|
||||
// Database section
|
||||
// Database section - ALWAYS write all values
|
||||
sb.WriteString("[database]\n")
|
||||
if cfg.DBType != "" {
|
||||
sb.WriteString(fmt.Sprintf("type = %s\n", cfg.DBType))
|
||||
}
|
||||
if cfg.Host != "" {
|
||||
sb.WriteString(fmt.Sprintf("host = %s\n", cfg.Host))
|
||||
}
|
||||
if cfg.Port != 0 {
|
||||
sb.WriteString(fmt.Sprintf("port = %d\n", cfg.Port))
|
||||
}
|
||||
if cfg.User != "" {
|
||||
sb.WriteString(fmt.Sprintf("user = %s\n", cfg.User))
|
||||
}
|
||||
if cfg.Database != "" {
|
||||
sb.WriteString(fmt.Sprintf("database = %s\n", cfg.Database))
|
||||
}
|
||||
if cfg.SSLMode != "" {
|
||||
sb.WriteString(fmt.Sprintf("ssl_mode = %s\n", cfg.SSLMode))
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("type = %s\n", cfg.DBType))
|
||||
sb.WriteString(fmt.Sprintf("host = %s\n", cfg.Host))
|
||||
sb.WriteString(fmt.Sprintf("port = %d\n", cfg.Port))
|
||||
sb.WriteString(fmt.Sprintf("user = %s\n", cfg.User))
|
||||
sb.WriteString(fmt.Sprintf("database = %s\n", cfg.Database))
|
||||
sb.WriteString(fmt.Sprintf("ssl_mode = %s\n", cfg.SSLMode))
|
||||
sb.WriteString("\n")
|
||||
|
||||
// Backup section
|
||||
// Backup section - ALWAYS write all values (including 0)
|
||||
sb.WriteString("[backup]\n")
|
||||
if cfg.BackupDir != "" {
|
||||
sb.WriteString(fmt.Sprintf("backup_dir = %s\n", cfg.BackupDir))
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("backup_dir = %s\n", cfg.BackupDir))
|
||||
if cfg.WorkDir != "" {
|
||||
sb.WriteString(fmt.Sprintf("work_dir = %s\n", cfg.WorkDir))
|
||||
}
|
||||
if cfg.Compression != 0 {
|
||||
sb.WriteString(fmt.Sprintf("compression = %d\n", cfg.Compression))
|
||||
}
|
||||
if cfg.Jobs != 0 {
|
||||
sb.WriteString(fmt.Sprintf("jobs = %d\n", cfg.Jobs))
|
||||
}
|
||||
if cfg.DumpJobs != 0 {
|
||||
sb.WriteString(fmt.Sprintf("dump_jobs = %d\n", cfg.DumpJobs))
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("compression = %d\n", cfg.Compression))
|
||||
sb.WriteString(fmt.Sprintf("jobs = %d\n", cfg.Jobs))
|
||||
sb.WriteString(fmt.Sprintf("dump_jobs = %d\n", cfg.DumpJobs))
|
||||
sb.WriteString("\n")
|
||||
|
||||
// Performance section
|
||||
// Performance section - ALWAYS write all values
|
||||
sb.WriteString("[performance]\n")
|
||||
if cfg.CPUWorkload != "" {
|
||||
sb.WriteString(fmt.Sprintf("cpu_workload = %s\n", cfg.CPUWorkload))
|
||||
}
|
||||
if cfg.MaxCores != 0 {
|
||||
sb.WriteString(fmt.Sprintf("max_cores = %d\n", cfg.MaxCores))
|
||||
}
|
||||
if cfg.ClusterTimeout != 0 {
|
||||
sb.WriteString(fmt.Sprintf("cluster_timeout = %d\n", cfg.ClusterTimeout))
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("cpu_workload = %s\n", cfg.CPUWorkload))
|
||||
sb.WriteString(fmt.Sprintf("max_cores = %d\n", cfg.MaxCores))
|
||||
sb.WriteString(fmt.Sprintf("cluster_timeout = %d\n", cfg.ClusterTimeout))
|
||||
if cfg.ResourceProfile != "" {
|
||||
sb.WriteString(fmt.Sprintf("resource_profile = %s\n", cfg.ResourceProfile))
|
||||
}
|
||||
if cfg.LargeDBMode {
|
||||
sb.WriteString("large_db_mode = true\n")
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("large_db_mode = %t\n", cfg.LargeDBMode))
|
||||
sb.WriteString("\n")
|
||||
|
||||
// Security section
|
||||
// Security section - ALWAYS write all values
|
||||
sb.WriteString("[security]\n")
|
||||
if cfg.RetentionDays != 0 {
|
||||
sb.WriteString(fmt.Sprintf("retention_days = %d\n", cfg.RetentionDays))
|
||||
}
|
||||
if cfg.MinBackups != 0 {
|
||||
sb.WriteString(fmt.Sprintf("min_backups = %d\n", cfg.MinBackups))
|
||||
}
|
||||
if cfg.MaxRetries != 0 {
|
||||
sb.WriteString(fmt.Sprintf("max_retries = %d\n", cfg.MaxRetries))
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("retention_days = %d\n", cfg.RetentionDays))
|
||||
sb.WriteString(fmt.Sprintf("min_backups = %d\n", cfg.MinBackups))
|
||||
sb.WriteString(fmt.Sprintf("max_retries = %d\n", cfg.MaxRetries))
|
||||
|
||||
configPath := filepath.Join(".", ConfigFileName)
|
||||
// Use 0600 permissions for security (readable/writable only by owner)
|
||||
if err := os.WriteFile(configPath, []byte(sb.String()), 0600); err != nil {
|
||||
return fmt.Errorf("failed to write config file: %w", err)
|
||||
// Use 0644 permissions for readability
|
||||
if err := os.WriteFile(configPath, []byte(sb.String()), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write config file %s: %w", configPath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
178
internal/config/persist_test.go
Normal file
178
internal/config/persist_test.go
Normal file
@ -0,0 +1,178 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConfigSaveLoad(t *testing.T) {
|
||||
// Create a temp directory
|
||||
tmpDir, err := os.MkdirTemp("", "dbbackup-config-test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
configPath := filepath.Join(tmpDir, ".dbbackup.conf")
|
||||
|
||||
// Create test config with ALL fields set
|
||||
original := &LocalConfig{
|
||||
DBType: "postgres",
|
||||
Host: "test-host-123",
|
||||
Port: 5432,
|
||||
User: "testuser",
|
||||
Database: "testdb",
|
||||
SSLMode: "require",
|
||||
BackupDir: "/test/backups",
|
||||
WorkDir: "/test/work",
|
||||
Compression: 9,
|
||||
Jobs: 16,
|
||||
DumpJobs: 8,
|
||||
CPUWorkload: "aggressive",
|
||||
MaxCores: 32,
|
||||
ClusterTimeout: 180,
|
||||
ResourceProfile: "high",
|
||||
LargeDBMode: true,
|
||||
RetentionDays: 14,
|
||||
MinBackups: 3,
|
||||
MaxRetries: 5,
|
||||
}
|
||||
|
||||
// Save to specific path
|
||||
err = SaveLocalConfigToPath(original, configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save config: %v", err)
|
||||
}
|
||||
|
||||
// Verify file exists
|
||||
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
||||
t.Fatalf("Config file not created at %s", configPath)
|
||||
}
|
||||
|
||||
// Load it back
|
||||
loaded, err := LoadLocalConfigFromPath(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
if loaded == nil {
|
||||
t.Fatal("Loaded config is nil")
|
||||
}
|
||||
|
||||
// Verify ALL values
|
||||
if loaded.DBType != original.DBType {
|
||||
t.Errorf("DBType mismatch: got %s, want %s", loaded.DBType, original.DBType)
|
||||
}
|
||||
if loaded.Host != original.Host {
|
||||
t.Errorf("Host mismatch: got %s, want %s", loaded.Host, original.Host)
|
||||
}
|
||||
if loaded.Port != original.Port {
|
||||
t.Errorf("Port mismatch: got %d, want %d", loaded.Port, original.Port)
|
||||
}
|
||||
if loaded.User != original.User {
|
||||
t.Errorf("User mismatch: got %s, want %s", loaded.User, original.User)
|
||||
}
|
||||
if loaded.Database != original.Database {
|
||||
t.Errorf("Database mismatch: got %s, want %s", loaded.Database, original.Database)
|
||||
}
|
||||
if loaded.SSLMode != original.SSLMode {
|
||||
t.Errorf("SSLMode mismatch: got %s, want %s", loaded.SSLMode, original.SSLMode)
|
||||
}
|
||||
if loaded.BackupDir != original.BackupDir {
|
||||
t.Errorf("BackupDir mismatch: got %s, want %s", loaded.BackupDir, original.BackupDir)
|
||||
}
|
||||
if loaded.WorkDir != original.WorkDir {
|
||||
t.Errorf("WorkDir mismatch: got %s, want %s", loaded.WorkDir, original.WorkDir)
|
||||
}
|
||||
if loaded.Compression != original.Compression {
|
||||
t.Errorf("Compression mismatch: got %d, want %d", loaded.Compression, original.Compression)
|
||||
}
|
||||
if loaded.Jobs != original.Jobs {
|
||||
t.Errorf("Jobs mismatch: got %d, want %d", loaded.Jobs, original.Jobs)
|
||||
}
|
||||
if loaded.DumpJobs != original.DumpJobs {
|
||||
t.Errorf("DumpJobs mismatch: got %d, want %d", loaded.DumpJobs, original.DumpJobs)
|
||||
}
|
||||
if loaded.CPUWorkload != original.CPUWorkload {
|
||||
t.Errorf("CPUWorkload mismatch: got %s, want %s", loaded.CPUWorkload, original.CPUWorkload)
|
||||
}
|
||||
if loaded.MaxCores != original.MaxCores {
|
||||
t.Errorf("MaxCores mismatch: got %d, want %d", loaded.MaxCores, original.MaxCores)
|
||||
}
|
||||
if loaded.ClusterTimeout != original.ClusterTimeout {
|
||||
t.Errorf("ClusterTimeout mismatch: got %d, want %d", loaded.ClusterTimeout, original.ClusterTimeout)
|
||||
}
|
||||
if loaded.ResourceProfile != original.ResourceProfile {
|
||||
t.Errorf("ResourceProfile mismatch: got %s, want %s", loaded.ResourceProfile, original.ResourceProfile)
|
||||
}
|
||||
if loaded.LargeDBMode != original.LargeDBMode {
|
||||
t.Errorf("LargeDBMode mismatch: got %t, want %t", loaded.LargeDBMode, original.LargeDBMode)
|
||||
}
|
||||
if loaded.RetentionDays != original.RetentionDays {
|
||||
t.Errorf("RetentionDays mismatch: got %d, want %d", loaded.RetentionDays, original.RetentionDays)
|
||||
}
|
||||
if loaded.MinBackups != original.MinBackups {
|
||||
t.Errorf("MinBackups mismatch: got %d, want %d", loaded.MinBackups, original.MinBackups)
|
||||
}
|
||||
if loaded.MaxRetries != original.MaxRetries {
|
||||
t.Errorf("MaxRetries mismatch: got %d, want %d", loaded.MaxRetries, original.MaxRetries)
|
||||
}
|
||||
|
||||
t.Log("✅ All config fields save/load correctly!")
|
||||
}
|
||||
|
||||
func TestConfigSaveZeroValues(t *testing.T) {
|
||||
// This tests that 0 values are saved and loaded correctly
|
||||
tmpDir, err := os.MkdirTemp("", "dbbackup-config-test-zero")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
configPath := filepath.Join(tmpDir, ".dbbackup.conf")
|
||||
|
||||
// Config with 0/false values intentionally
|
||||
original := &LocalConfig{
|
||||
DBType: "postgres",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "postgres",
|
||||
Database: "test",
|
||||
SSLMode: "disable",
|
||||
BackupDir: "/backups",
|
||||
Compression: 0, // Intentionally 0 = no compression
|
||||
Jobs: 1,
|
||||
DumpJobs: 1,
|
||||
CPUWorkload: "conservative",
|
||||
MaxCores: 1,
|
||||
ClusterTimeout: 0, // No timeout
|
||||
LargeDBMode: false,
|
||||
RetentionDays: 0, // Keep forever
|
||||
MinBackups: 0,
|
||||
MaxRetries: 0,
|
||||
}
|
||||
|
||||
// Save
|
||||
err = SaveLocalConfigToPath(original, configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save config: %v", err)
|
||||
}
|
||||
|
||||
// Load
|
||||
loaded, err := LoadLocalConfigFromPath(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// The values that are 0/false should still load correctly
|
||||
// Note: In INI format, 0 values ARE written and loaded
|
||||
if loaded.Compression != 0 {
|
||||
t.Errorf("Compression should be 0, got %d", loaded.Compression)
|
||||
}
|
||||
if loaded.LargeDBMode != false {
|
||||
t.Errorf("LargeDBMode should be false, got %t", loaded.LargeDBMode)
|
||||
}
|
||||
|
||||
t.Log("✅ Zero values handled correctly!")
|
||||
}
|
||||
@ -71,7 +71,14 @@ const (
|
||||
)
|
||||
|
||||
// NewParallelRestoreEngine creates a new parallel restore engine
|
||||
// NOTE: Pass a cancellable context to ensure the pool is properly closed on Ctrl+C
|
||||
func NewParallelRestoreEngine(config *PostgreSQLNativeConfig, log logger.Logger, workers int) (*ParallelRestoreEngine, error) {
|
||||
return NewParallelRestoreEngineWithContext(context.Background(), config, log, workers)
|
||||
}
|
||||
|
||||
// NewParallelRestoreEngineWithContext creates a new parallel restore engine with context support
|
||||
// This ensures the connection pool is properly closed when the context is cancelled
|
||||
func NewParallelRestoreEngineWithContext(ctx context.Context, config *PostgreSQLNativeConfig, log logger.Logger, workers int) (*ParallelRestoreEngine, error) {
|
||||
if workers < 1 {
|
||||
workers = 4 // Default to 4 parallel workers
|
||||
}
|
||||
@ -94,7 +101,8 @@ func NewParallelRestoreEngine(config *PostgreSQLNativeConfig, log logger.Logger,
|
||||
poolConfig.MaxConns = int32(workers + 2)
|
||||
poolConfig.MinConns = int32(workers)
|
||||
|
||||
pool, err := pgxpool.NewWithConfig(context.Background(), poolConfig)
|
||||
// Use the provided context so pool health checks stop when context is cancelled
|
||||
pool, err := pgxpool.NewWithConfig(ctx, poolConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create connection pool: %w", err)
|
||||
}
|
||||
@ -215,17 +223,38 @@ func (e *ParallelRestoreEngine) RestoreFile(ctx context.Context, filePath string
|
||||
semaphore := make(chan struct{}, options.Workers)
|
||||
var completedCopies int64
|
||||
var totalRows int64
|
||||
var cancelled int32 // Atomic flag to signal cancellation
|
||||
|
||||
for _, stmt := range copyStmts {
|
||||
// Check for context cancellation before starting new work
|
||||
if ctx.Err() != nil {
|
||||
break
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
semaphore <- struct{}{} // Acquire worker slot
|
||||
select {
|
||||
case semaphore <- struct{}{}: // Acquire worker slot
|
||||
case <-ctx.Done():
|
||||
wg.Done()
|
||||
atomic.StoreInt32(&cancelled, 1)
|
||||
break
|
||||
}
|
||||
|
||||
go func(s *SQLStatement) {
|
||||
defer wg.Done()
|
||||
defer func() { <-semaphore }() // Release worker slot
|
||||
|
||||
// Check cancellation before executing
|
||||
if ctx.Err() != nil || atomic.LoadInt32(&cancelled) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := e.executeCopy(ctx, s)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
// Context cancelled, don't log as error
|
||||
return
|
||||
}
|
||||
if options.ContinueOnError {
|
||||
e.log.Warn("COPY failed", "table", s.TableName, "error", err)
|
||||
} else {
|
||||
@ -243,6 +272,12 @@ func (e *ParallelRestoreEngine) RestoreFile(ctx context.Context, filePath string
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Check if cancelled
|
||||
if ctx.Err() != nil {
|
||||
return result, ctx.Err()
|
||||
}
|
||||
|
||||
result.TablesRestored = completedCopies
|
||||
result.RowsRestored = totalRows
|
||||
|
||||
@ -264,15 +299,35 @@ func (e *ParallelRestoreEngine) RestoreFile(ctx context.Context, filePath string
|
||||
|
||||
// Execute post-data in parallel
|
||||
var completedPostData int64
|
||||
cancelled = 0 // Reset for phase 4
|
||||
for _, sql := range postDataStmts {
|
||||
// Check for context cancellation before starting new work
|
||||
if ctx.Err() != nil {
|
||||
break
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
semaphore <- struct{}{}
|
||||
select {
|
||||
case semaphore <- struct{}{}:
|
||||
case <-ctx.Done():
|
||||
wg.Done()
|
||||
atomic.StoreInt32(&cancelled, 1)
|
||||
break
|
||||
}
|
||||
|
||||
go func(stmt string) {
|
||||
defer wg.Done()
|
||||
defer func() { <-semaphore }()
|
||||
|
||||
// Check cancellation before executing
|
||||
if ctx.Err() != nil || atomic.LoadInt32(&cancelled) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
if err := e.executeStatement(ctx, stmt); err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return // Context cancelled
|
||||
}
|
||||
if options.ContinueOnError {
|
||||
e.log.Warn("Post-data statement failed", "error", err)
|
||||
}
|
||||
@ -289,6 +344,11 @@ func (e *ParallelRestoreEngine) RestoreFile(ctx context.Context, filePath string
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Check if cancelled
|
||||
if ctx.Err() != nil {
|
||||
return result, ctx.Err()
|
||||
}
|
||||
|
||||
result.Duration = time.Since(startTime)
|
||||
e.log.Info("Parallel restore completed",
|
||||
"duration", result.Duration,
|
||||
|
||||
@ -635,7 +635,8 @@ func (e *Engine) restoreWithNativeEngine(ctx context.Context, archivePath, targe
|
||||
"database", targetDB,
|
||||
"archive", archivePath)
|
||||
|
||||
parallelEngine, err := native.NewParallelRestoreEngine(nativeCfg, e.log, parallelWorkers)
|
||||
// Pass context to ensure pool is properly closed on Ctrl+C cancellation
|
||||
parallelEngine, err := native.NewParallelRestoreEngineWithContext(ctx, nativeCfg, e.log, parallelWorkers)
|
||||
if err != nil {
|
||||
e.log.Warn("Failed to create parallel restore engine, falling back to sequential", "error", err)
|
||||
// Fall back to sequential restore
|
||||
@ -1342,9 +1343,14 @@ func (e *Engine) RestoreCluster(ctx context.Context, archivePath string, preExtr
|
||||
}
|
||||
|
||||
format := DetectArchiveFormat(archivePath)
|
||||
if format != FormatClusterTarGz {
|
||||
if !format.CanBeClusterRestore() {
|
||||
operation.Fail("Invalid cluster archive format")
|
||||
return fmt.Errorf("not a cluster archive: %s (detected format: %s)", archivePath, format)
|
||||
return fmt.Errorf("not a valid cluster restore format: %s (detected format: %s). Supported: .tar.gz, .sql, .sql.gz", archivePath, format)
|
||||
}
|
||||
|
||||
// For SQL-based cluster restores, use a different restore path
|
||||
if format == FormatPostgreSQLSQL || format == FormatPostgreSQLSQLGz {
|
||||
return e.restoreClusterFromSQL(ctx, archivePath, operation)
|
||||
}
|
||||
|
||||
// Check if we have a pre-extracted directory (optimization to avoid double extraction)
|
||||
@ -2177,6 +2183,45 @@ func (e *Engine) RestoreCluster(ctx context.Context, archivePath string, preExtr
|
||||
return nil
|
||||
}
|
||||
|
||||
// restoreClusterFromSQL restores a pg_dumpall SQL file using the native engine
|
||||
// This handles .sql and .sql.gz files containing full cluster dumps
|
||||
func (e *Engine) restoreClusterFromSQL(ctx context.Context, archivePath string, operation logger.OperationLogger) error {
|
||||
e.log.Info("Restoring cluster from SQL file (pg_dumpall format)",
|
||||
"file", filepath.Base(archivePath),
|
||||
"native_engine", true)
|
||||
|
||||
clusterStartTime := time.Now()
|
||||
|
||||
// Determine if compressed
|
||||
compressed := strings.HasSuffix(strings.ToLower(archivePath), ".gz")
|
||||
|
||||
// Use native engine to restore directly to postgres database (globals + all databases)
|
||||
e.log.Info("Restoring SQL dump using native engine...",
|
||||
"compressed", compressed,
|
||||
"size", FormatBytes(getFileSize(archivePath)))
|
||||
|
||||
e.progress.Start("Restoring cluster from SQL dump...")
|
||||
|
||||
// For pg_dumpall, we restore to the 'postgres' database which then creates other databases
|
||||
targetDB := "postgres"
|
||||
|
||||
err := e.restoreWithNativeEngine(ctx, archivePath, targetDB, compressed)
|
||||
if err != nil {
|
||||
operation.Fail(fmt.Sprintf("SQL cluster restore failed: %v", err))
|
||||
e.recordClusterRestoreMetrics(clusterStartTime, archivePath, 0, 0, false, err.Error())
|
||||
return fmt.Errorf("SQL cluster restore failed: %w", err)
|
||||
}
|
||||
|
||||
duration := time.Since(clusterStartTime)
|
||||
e.progress.Complete(fmt.Sprintf("Cluster restored successfully from SQL in %s", duration.Round(time.Second)))
|
||||
operation.Complete("SQL cluster restore completed")
|
||||
|
||||
// Record metrics
|
||||
e.recordClusterRestoreMetrics(clusterStartTime, archivePath, 1, 1, true, "")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// recordClusterRestoreMetrics records metrics for cluster restore operations
|
||||
func (e *Engine) recordClusterRestoreMetrics(startTime time.Time, archivePath string, totalDBs, successCount int, success bool, errorMsg string) {
|
||||
duration := time.Since(startTime)
|
||||
@ -2924,6 +2969,15 @@ func (e *Engine) isIgnorableError(errorMsg string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// getFileSize returns the size of a file, or 0 if it can't be read
|
||||
func getFileSize(path string) int64 {
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return info.Size()
|
||||
}
|
||||
|
||||
// FormatBytes formats bytes to human readable format
|
||||
func FormatBytes(bytes int64) string {
|
||||
const unit = 1024
|
||||
|
||||
@ -47,7 +47,12 @@ func DetectArchiveFormat(filename string) ArchiveFormat {
|
||||
lower := strings.ToLower(filename)
|
||||
|
||||
// Check for cluster archives first (most specific)
|
||||
if strings.Contains(lower, "cluster") && strings.HasSuffix(lower, ".tar.gz") {
|
||||
// A .tar.gz file is considered a cluster backup if:
|
||||
// 1. Contains "cluster" in name, OR
|
||||
// 2. Is a .tar.gz file (likely a cluster backup archive)
|
||||
if strings.HasSuffix(lower, ".tar.gz") {
|
||||
// All .tar.gz files are treated as cluster backups
|
||||
// since that's the format used for cluster archives
|
||||
return FormatClusterTarGz
|
||||
}
|
||||
|
||||
@ -163,11 +168,19 @@ func (f ArchiveFormat) IsCompressed() bool {
|
||||
f == FormatClusterTarGz
|
||||
}
|
||||
|
||||
// IsClusterBackup returns true if the archive is a cluster backup
|
||||
// IsClusterBackup returns true if the archive is a cluster backup (.tar.gz format created by dbbackup)
|
||||
func (f ArchiveFormat) IsClusterBackup() bool {
|
||||
return f == FormatClusterTarGz
|
||||
}
|
||||
|
||||
// CanBeClusterRestore returns true if the format can be used for cluster restore
|
||||
// This includes .tar.gz (dbbackup format) and .sql/.sql.gz (pg_dumpall format for native engine)
|
||||
func (f ArchiveFormat) CanBeClusterRestore() bool {
|
||||
return f == FormatClusterTarGz ||
|
||||
f == FormatPostgreSQLSQL ||
|
||||
f == FormatPostgreSQLSQLGz
|
||||
}
|
||||
|
||||
// IsPostgreSQL returns true if the archive is PostgreSQL format
|
||||
func (f ArchiveFormat) IsPostgreSQL() bool {
|
||||
return f == FormatPostgreSQLDump ||
|
||||
|
||||
@ -220,3 +220,34 @@ func TestDetectArchiveFormatWithRealFiles(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectArchiveFormatAll(t *testing.T) {
|
||||
tests := []struct {
|
||||
filename string
|
||||
want ArchiveFormat
|
||||
isCluster bool
|
||||
}{
|
||||
{"testdb.sql", FormatPostgreSQLSQL, false},
|
||||
{"testdb.sql.gz", FormatPostgreSQLSQLGz, false},
|
||||
{"testdb.dump", FormatPostgreSQLDump, false},
|
||||
{"testdb.dump.gz", FormatPostgreSQLDumpGz, false},
|
||||
{"cluster_backup.tar.gz", FormatClusterTarGz, true},
|
||||
{"mybackup.tar.gz", FormatClusterTarGz, true},
|
||||
{"testdb_20260130_204350_native.sql.gz", FormatPostgreSQLSQLGz, false},
|
||||
{"mysql_backup.sql", FormatMySQLSQL, false},
|
||||
{"mysql_dump.sql.gz", FormatMySQLSQLGz, false}, // Has "mysql" in name = MySQL
|
||||
{"randomfile.txt", FormatUnknown, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.filename, func(t *testing.T) {
|
||||
got := DetectArchiveFormat(tt.filename)
|
||||
if got != tt.want {
|
||||
t.Errorf("DetectArchiveFormat(%q) = %v, want %v", tt.filename, got, tt.want)
|
||||
}
|
||||
if got.IsClusterBackup() != tt.isCluster {
|
||||
t.Errorf("DetectArchiveFormat(%q).IsClusterBackup() = %v, want %v", tt.filename, got.IsClusterBackup(), tt.isCluster)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -205,19 +205,28 @@ func (m ArchiveBrowserModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
return diagnoseView, diagnoseView.Init()
|
||||
}
|
||||
|
||||
// Validate selection based on mode
|
||||
if m.mode == "restore-cluster" && !selected.Format.IsClusterBackup() {
|
||||
m.message = errorStyle.Render("[FAIL] Please select a cluster backup (.tar.gz)")
|
||||
// For restore-cluster mode: check if format can be used for cluster restore
|
||||
// - .tar.gz: dbbackup cluster format (works with pg_restore)
|
||||
// - .sql/.sql.gz: pg_dumpall format (works with native engine or psql)
|
||||
if m.mode == "restore-cluster" && !selected.Format.CanBeClusterRestore() {
|
||||
m.message = errorStyle.Render(fmt.Sprintf("⚠️ %s cannot be used for cluster restore.\n\n Supported formats: .tar.gz (dbbackup), .sql, .sql.gz (pg_dumpall)",
|
||||
selected.Name))
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// For SQL-based cluster restore, enable native engine automatically
|
||||
if m.mode == "restore-cluster" && !selected.Format.IsClusterBackup() {
|
||||
// This is a .sql or .sql.gz file - use native engine
|
||||
m.config.UseNativeEngine = true
|
||||
}
|
||||
|
||||
// For single restore mode with cluster backup selected - offer to select individual database
|
||||
if m.mode == "restore-single" && selected.Format.IsClusterBackup() {
|
||||
// Cluster backup selected in single restore mode - offer to select individual database
|
||||
clusterSelector := NewClusterDatabaseSelector(m.config, m.logger, m, m.ctx, selected, "single", false)
|
||||
return clusterSelector, clusterSelector.Init()
|
||||
}
|
||||
|
||||
// Open restore preview
|
||||
// Open restore preview for valid format
|
||||
preview := NewRestorePreview(m.config, m.logger, m.parent, m.ctx, selected, m.mode)
|
||||
return preview, preview.Init()
|
||||
}
|
||||
@ -382,6 +391,7 @@ func (m ArchiveBrowserModel) filterArchives(archives []ArchiveInfo) []ArchiveInf
|
||||
for _, archive := range archives {
|
||||
switch m.filterType {
|
||||
case "postgres":
|
||||
// Show all PostgreSQL formats (single DB)
|
||||
if archive.Format.IsPostgreSQL() && !archive.Format.IsClusterBackup() {
|
||||
filtered = append(filtered, archive)
|
||||
}
|
||||
@ -390,6 +400,7 @@ func (m ArchiveBrowserModel) filterArchives(archives []ArchiveInfo) []ArchiveInf
|
||||
filtered = append(filtered, archive)
|
||||
}
|
||||
case "cluster":
|
||||
// Show .tar.gz cluster archives
|
||||
if archive.Format.IsClusterBackup() {
|
||||
filtered = append(filtered, archive)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user