Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 79dc604eb6 | |||
| de88e38f93 | |||
| 97c52ab9e5 | |||
| 3c9e5f04ca | |||
| 86a28b6ec5 | |||
| 63b35414d2 | |||
| db46770e7f | |||
| 51764a677a | |||
| bdbbb59e51 | |||
| 1a6ea13222 |
30
CHANGELOG.md
30
CHANGELOG.md
@ -5,6 +5,36 @@ All notable changes to dbbackup will be documented in this file.
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [5.8.12] - 2026-02-04
|
||||
|
||||
### Fixed
|
||||
- **Config Loading**: Fixed config not loading for users without standard home directories
|
||||
- Now searches: current dir → home dir → /etc/dbbackup.conf → /etc/dbbackup/dbbackup.conf
|
||||
- Works for postgres user with home at /var/lib/postgresql
|
||||
- Added `ConfigSearchPaths()` and `LoadLocalConfigWithPath()` functions
|
||||
- Log now shows which config path was actually loaded
|
||||
|
||||
## [5.8.11] - 2026-02-04
|
||||
|
||||
### Fixed
|
||||
- **TUI Deadlock**: Fixed goroutine leaks in pgxpool connection handling
|
||||
- Removed redundant goroutines waiting on ctx.Done() in postgresql.go and parallel_restore.go
|
||||
- These were causing WaitGroup deadlocks when BubbleTea tried to shutdown
|
||||
|
||||
### Added
|
||||
- **systemd-run Resource Isolation**: New `internal/cleanup/cgroups.go` for long-running jobs
|
||||
- `RunWithResourceLimits()` wraps commands in systemd-run scopes
|
||||
- Configurable: MemoryHigh, MemoryMax, CPUQuota, IOWeight, Nice, Slice
|
||||
- Automatic cleanup on context cancellation
|
||||
- **Restore Dry-Run Checks**: New `internal/restore/dryrun.go` with 10 pre-restore validations
|
||||
- Archive access, format, connectivity, permissions, target conflicts
|
||||
- Disk space, work directory, required tools, lock settings, memory estimation
|
||||
- Returns pass/warning/fail status with detailed messages
|
||||
- **Audit Log Signing**: Enhanced `internal/security/audit.go` with Ed25519 cryptographic signing
|
||||
- `SignedAuditEntry` with sequence numbers, hash chains, and signatures
|
||||
- `GenerateSigningKeys()`, `SavePrivateKey()`, `LoadPublicKey()`
|
||||
- `EnableSigning()`, `ExportSignedLog()`, `VerifyAuditLog()` for tamper detection
|
||||
|
||||
## [5.7.10] - 2026-02-03
|
||||
|
||||
### Fixed
|
||||
|
||||
@ -11,6 +11,7 @@ import (
|
||||
|
||||
"dbbackup/internal/database"
|
||||
"dbbackup/internal/engine/native"
|
||||
"dbbackup/internal/metadata"
|
||||
"dbbackup/internal/notify"
|
||||
|
||||
"github.com/klauspost/pgzip"
|
||||
@ -163,6 +164,54 @@ func runNativeBackup(ctx context.Context, db database.Database, databaseName, ba
|
||||
"duration", backupDuration,
|
||||
"engine", result.EngineUsed)
|
||||
|
||||
// Get actual file size from disk
|
||||
fileInfo, err := os.Stat(outputFile)
|
||||
var actualSize int64
|
||||
if err == nil {
|
||||
actualSize = fileInfo.Size()
|
||||
} else {
|
||||
actualSize = result.BytesProcessed
|
||||
}
|
||||
|
||||
// Calculate SHA256 checksum
|
||||
sha256sum, err := metadata.CalculateSHA256(outputFile)
|
||||
if err != nil {
|
||||
log.Warn("Failed to calculate SHA256", "error", err)
|
||||
sha256sum = ""
|
||||
}
|
||||
|
||||
// Create and save metadata file
|
||||
meta := &metadata.BackupMetadata{
|
||||
Version: "1.0",
|
||||
Timestamp: backupStartTime,
|
||||
Database: databaseName,
|
||||
DatabaseType: dbType,
|
||||
Host: cfg.Host,
|
||||
Port: cfg.Port,
|
||||
User: cfg.User,
|
||||
BackupFile: filepath.Base(outputFile),
|
||||
SizeBytes: actualSize,
|
||||
SHA256: sha256sum,
|
||||
Compression: "gzip",
|
||||
BackupType: backupType,
|
||||
Duration: backupDuration.Seconds(),
|
||||
ExtraInfo: map[string]string{
|
||||
"engine": result.EngineUsed,
|
||||
"objects_processed": fmt.Sprintf("%d", result.ObjectsProcessed),
|
||||
},
|
||||
}
|
||||
|
||||
if cfg.CompressionLevel == 0 {
|
||||
meta.Compression = "none"
|
||||
}
|
||||
|
||||
metaPath := outputFile + ".meta.json"
|
||||
if err := metadata.Save(metaPath, meta); err != nil {
|
||||
log.Warn("Failed to save metadata", "error", err)
|
||||
} else {
|
||||
log.Debug("Metadata saved", "path", metaPath)
|
||||
}
|
||||
|
||||
// Audit log: backup completed
|
||||
auditLogger.LogBackupComplete(user, databaseName, cfg.BackupDir, result.BytesProcessed)
|
||||
|
||||
|
||||
33
cmd/root.go
33
cmd/root.go
@ -15,11 +15,12 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
cfg *config.Config
|
||||
log logger.Logger
|
||||
auditLogger *security.AuditLogger
|
||||
rateLimiter *security.RateLimiter
|
||||
notifyManager *notify.Manager
|
||||
cfg *config.Config
|
||||
log logger.Logger
|
||||
auditLogger *security.AuditLogger
|
||||
rateLimiter *security.RateLimiter
|
||||
notifyManager *notify.Manager
|
||||
deprecatedPassword string
|
||||
)
|
||||
|
||||
// rootCmd represents the base command when called without any subcommands
|
||||
@ -47,6 +48,11 @@ For help with specific commands, use: dbbackup [command] --help`,
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check for deprecated password flag
|
||||
if deprecatedPassword != "" {
|
||||
return fmt.Errorf("--password flag is not supported for security reasons. Use environment variables instead:\n - MySQL/MariaDB: export MYSQL_PWD='your_password'\n - PostgreSQL: export PGPASSWORD='your_password' or use .pgpass file")
|
||||
}
|
||||
|
||||
// Store which flags were explicitly set by user
|
||||
flagsSet := make(map[string]bool)
|
||||
cmd.Flags().Visit(func(f *pflag.Flag) {
|
||||
@ -55,22 +61,24 @@ For help with specific commands, use: dbbackup [command] --help`,
|
||||
|
||||
// Load local config if not disabled
|
||||
if !cfg.NoLoadConfig {
|
||||
// Use custom config path if specified, otherwise default to current directory
|
||||
// Use custom config path if specified, otherwise search standard locations
|
||||
var localCfg *config.LocalConfig
|
||||
var configPath string
|
||||
var err error
|
||||
if cfg.ConfigPath != "" {
|
||||
localCfg, err = config.LoadLocalConfigFromPath(cfg.ConfigPath)
|
||||
configPath = cfg.ConfigPath
|
||||
if err != nil {
|
||||
log.Warn("Failed to load config from specified path", "path", cfg.ConfigPath, "error", err)
|
||||
} else if localCfg != nil {
|
||||
log.Info("Loaded configuration", "path", cfg.ConfigPath)
|
||||
}
|
||||
} else {
|
||||
localCfg, err = config.LoadLocalConfig()
|
||||
localCfg, configPath, err = config.LoadLocalConfigWithPath()
|
||||
if err != nil {
|
||||
log.Warn("Failed to load local config", "error", err)
|
||||
log.Warn("Failed to load config", "error", err)
|
||||
} else if localCfg != nil {
|
||||
log.Info("Loaded configuration from .dbbackup.conf")
|
||||
log.Info("Loaded configuration", "path", configPath)
|
||||
}
|
||||
}
|
||||
|
||||
@ -171,15 +179,8 @@ func Execute(ctx context.Context, config *config.Config, logger logger.Logger) e
|
||||
rootCmd.PersistentFlags().StringVar(&cfg.Database, "database", cfg.Database, "Database name")
|
||||
// SECURITY: Password flag removed - use PGPASSWORD/MYSQL_PWD environment variable or .pgpass file
|
||||
// Provide helpful error message for users expecting --password flag
|
||||
var deprecatedPassword string
|
||||
rootCmd.PersistentFlags().StringVar(&deprecatedPassword, "password", "", "DEPRECATED: Use MYSQL_PWD or PGPASSWORD environment variable instead")
|
||||
rootCmd.PersistentFlags().MarkHidden("password")
|
||||
rootCmd.PersistentPreRunE = func(cmd *cobra.Command, args []string) error {
|
||||
if deprecatedPassword != "" {
|
||||
return fmt.Errorf("--password flag is not supported for security reasons. Use environment variables instead:\n - MySQL/MariaDB: export MYSQL_PWD='your_password'\n - PostgreSQL: export PGPASSWORD='your_password' or use .pgpass file")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
rootCmd.PersistentFlags().StringVarP(&cfg.DatabaseType, "db-type", "d", cfg.DatabaseType, "Database type (postgres|mysql|mariadb)")
|
||||
rootCmd.PersistentFlags().StringVar(&cfg.BackupDir, "backup-dir", cfg.BackupDir, "Backup directory")
|
||||
rootCmd.PersistentFlags().BoolVar(&cfg.NoColor, "no-color", cfg.NoColor, "Disable colored output")
|
||||
|
||||
236
internal/cleanup/cgroups.go
Normal file
236
internal/cleanup/cgroups.go
Normal file
@ -0,0 +1,236 @@
|
||||
package cleanup
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"dbbackup/internal/logger"
|
||||
)
|
||||
|
||||
// ResourceLimits defines resource constraints for long-running operations
|
||||
type ResourceLimits struct {
|
||||
// MemoryHigh is the high memory limit (e.g., "4G", "2048M")
|
||||
// When exceeded, kernel will throttle and reclaim memory aggressively
|
||||
MemoryHigh string
|
||||
|
||||
// MemoryMax is the hard memory limit (e.g., "6G")
|
||||
// Process is killed if exceeded
|
||||
MemoryMax string
|
||||
|
||||
// CPUQuota limits CPU usage (e.g., "70%" for 70% of one CPU)
|
||||
CPUQuota string
|
||||
|
||||
// IOWeight sets I/O priority (1-10000, default 100)
|
||||
IOWeight int
|
||||
|
||||
// Nice sets process priority (-20 to 19)
|
||||
Nice int
|
||||
|
||||
// Slice is the systemd slice to run under (e.g., "dbbackup.slice")
|
||||
Slice string
|
||||
}
|
||||
|
||||
// DefaultResourceLimits returns sensible defaults for backup/restore operations
|
||||
func DefaultResourceLimits() *ResourceLimits {
|
||||
return &ResourceLimits{
|
||||
MemoryHigh: "4G",
|
||||
MemoryMax: "6G",
|
||||
CPUQuota: "80%",
|
||||
IOWeight: 100, // Default priority
|
||||
Nice: 10, // Slightly lower priority than interactive processes
|
||||
Slice: "dbbackup.slice",
|
||||
}
|
||||
}
|
||||
|
||||
// SystemdRunAvailable checks if systemd-run is available on this system
|
||||
func SystemdRunAvailable() bool {
|
||||
if runtime.GOOS != "linux" {
|
||||
return false
|
||||
}
|
||||
_, err := exec.LookPath("systemd-run")
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// RunWithResourceLimits executes a command with resource limits via systemd-run
|
||||
// Falls back to direct execution if systemd-run is not available
|
||||
func RunWithResourceLimits(ctx context.Context, log logger.Logger, limits *ResourceLimits, name string, args ...string) error {
|
||||
if limits == nil {
|
||||
limits = DefaultResourceLimits()
|
||||
}
|
||||
|
||||
// If systemd-run not available, fall back to direct execution
|
||||
if !SystemdRunAvailable() {
|
||||
log.Debug("systemd-run not available, running without resource limits")
|
||||
cmd := exec.CommandContext(ctx, name, args...)
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// Build systemd-run command
|
||||
systemdArgs := buildSystemdArgs(limits, name, args)
|
||||
|
||||
log.Info("Running with systemd resource limits",
|
||||
"command", name,
|
||||
"memory_high", limits.MemoryHigh,
|
||||
"cpu_quota", limits.CPUQuota)
|
||||
|
||||
cmd := exec.CommandContext(ctx, "systemd-run", systemdArgs...)
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// RunWithResourceLimitsOutput executes with limits and returns combined output
|
||||
func RunWithResourceLimitsOutput(ctx context.Context, log logger.Logger, limits *ResourceLimits, name string, args ...string) ([]byte, error) {
|
||||
if limits == nil {
|
||||
limits = DefaultResourceLimits()
|
||||
}
|
||||
|
||||
// If systemd-run not available, fall back to direct execution
|
||||
if !SystemdRunAvailable() {
|
||||
log.Debug("systemd-run not available, running without resource limits")
|
||||
cmd := exec.CommandContext(ctx, name, args...)
|
||||
return cmd.CombinedOutput()
|
||||
}
|
||||
|
||||
// Build systemd-run command
|
||||
systemdArgs := buildSystemdArgs(limits, name, args)
|
||||
|
||||
log.Debug("Running with systemd resource limits",
|
||||
"command", name,
|
||||
"memory_high", limits.MemoryHigh)
|
||||
|
||||
cmd := exec.CommandContext(ctx, "systemd-run", systemdArgs...)
|
||||
return cmd.CombinedOutput()
|
||||
}
|
||||
|
||||
// buildSystemdArgs constructs the systemd-run argument list
|
||||
func buildSystemdArgs(limits *ResourceLimits, name string, args []string) []string {
|
||||
systemdArgs := []string{
|
||||
"--scope", // Run as transient scope (not service)
|
||||
"--user", // Run in user session (no root required)
|
||||
"--quiet", // Reduce systemd noise
|
||||
"--collect", // Automatically clean up after exit
|
||||
}
|
||||
|
||||
// Add description for easier identification
|
||||
systemdArgs = append(systemdArgs, fmt.Sprintf("--description=dbbackup: %s", name))
|
||||
|
||||
// Add resource properties
|
||||
if limits.MemoryHigh != "" {
|
||||
systemdArgs = append(systemdArgs, fmt.Sprintf("--property=MemoryHigh=%s", limits.MemoryHigh))
|
||||
}
|
||||
|
||||
if limits.MemoryMax != "" {
|
||||
systemdArgs = append(systemdArgs, fmt.Sprintf("--property=MemoryMax=%s", limits.MemoryMax))
|
||||
}
|
||||
|
||||
if limits.CPUQuota != "" {
|
||||
systemdArgs = append(systemdArgs, fmt.Sprintf("--property=CPUQuota=%s", limits.CPUQuota))
|
||||
}
|
||||
|
||||
if limits.IOWeight > 0 {
|
||||
systemdArgs = append(systemdArgs, fmt.Sprintf("--property=IOWeight=%d", limits.IOWeight))
|
||||
}
|
||||
|
||||
if limits.Nice != 0 {
|
||||
systemdArgs = append(systemdArgs, fmt.Sprintf("--property=Nice=%d", limits.Nice))
|
||||
}
|
||||
|
||||
if limits.Slice != "" {
|
||||
systemdArgs = append(systemdArgs, fmt.Sprintf("--slice=%s", limits.Slice))
|
||||
}
|
||||
|
||||
// Add separator and command
|
||||
systemdArgs = append(systemdArgs, "--")
|
||||
systemdArgs = append(systemdArgs, name)
|
||||
systemdArgs = append(systemdArgs, args...)
|
||||
|
||||
return systemdArgs
|
||||
}
|
||||
|
||||
// WrapCommand creates an exec.Cmd that runs with resource limits
|
||||
// This allows the caller to customize stdin/stdout/stderr before running
|
||||
func WrapCommand(ctx context.Context, log logger.Logger, limits *ResourceLimits, name string, args ...string) *exec.Cmd {
|
||||
if limits == nil {
|
||||
limits = DefaultResourceLimits()
|
||||
}
|
||||
|
||||
// If systemd-run not available, return direct command
|
||||
if !SystemdRunAvailable() {
|
||||
log.Debug("systemd-run not available, returning unwrapped command")
|
||||
return exec.CommandContext(ctx, name, args...)
|
||||
}
|
||||
|
||||
// Build systemd-run command
|
||||
systemdArgs := buildSystemdArgs(limits, name, args)
|
||||
|
||||
log.Debug("Wrapping command with systemd resource limits",
|
||||
"command", name,
|
||||
"memory_high", limits.MemoryHigh)
|
||||
|
||||
return exec.CommandContext(ctx, "systemd-run", systemdArgs...)
|
||||
}
|
||||
|
||||
// ResourceLimitsFromConfig creates resource limits from size estimates
|
||||
// Useful for dynamically setting limits based on backup/restore size
|
||||
func ResourceLimitsFromConfig(estimatedSizeBytes int64, isRestore bool) *ResourceLimits {
|
||||
limits := DefaultResourceLimits()
|
||||
|
||||
// Estimate memory needs based on data size
|
||||
// Restore needs more memory than backup
|
||||
var memoryMultiplier float64 = 0.1 // 10% of data size for backup
|
||||
if isRestore {
|
||||
memoryMultiplier = 0.2 // 20% of data size for restore
|
||||
}
|
||||
|
||||
estimatedMemMB := int64(float64(estimatedSizeBytes/1024/1024) * memoryMultiplier)
|
||||
|
||||
// Clamp to reasonable values
|
||||
if estimatedMemMB < 512 {
|
||||
estimatedMemMB = 512 // Minimum 512MB
|
||||
}
|
||||
if estimatedMemMB > 16384 {
|
||||
estimatedMemMB = 16384 // Maximum 16GB
|
||||
}
|
||||
|
||||
limits.MemoryHigh = fmt.Sprintf("%dM", estimatedMemMB)
|
||||
limits.MemoryMax = fmt.Sprintf("%dM", estimatedMemMB*2) // 2x high limit
|
||||
|
||||
return limits
|
||||
}
|
||||
|
||||
// GetActiveResourceUsage returns current resource usage if running in systemd scope
|
||||
func GetActiveResourceUsage() (string, error) {
|
||||
if !SystemdRunAvailable() {
|
||||
return "", fmt.Errorf("systemd not available")
|
||||
}
|
||||
|
||||
// Check if we're running in a scope
|
||||
cmd := exec.Command("systemctl", "--user", "status", "--no-pager")
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get systemd status: %w", err)
|
||||
}
|
||||
|
||||
// Extract dbbackup-related scopes
|
||||
lines := strings.Split(string(output), "\n")
|
||||
var dbbackupLines []string
|
||||
for _, line := range lines {
|
||||
if strings.Contains(line, "dbbackup") {
|
||||
dbbackupLines = append(dbbackupLines, strings.TrimSpace(line))
|
||||
}
|
||||
}
|
||||
|
||||
if len(dbbackupLines) == 0 {
|
||||
return "No active dbbackup scopes", nil
|
||||
}
|
||||
|
||||
return strings.Join(dbbackupLines, "\n"), nil
|
||||
}
|
||||
@ -41,9 +41,53 @@ type LocalConfig struct {
|
||||
MaxRetries int
|
||||
}
|
||||
|
||||
// LoadLocalConfig loads configuration from .dbbackup.conf in current directory
|
||||
// ConfigSearchPaths returns all paths where config files are searched, in order of priority
|
||||
func ConfigSearchPaths() []string {
|
||||
paths := []string{
|
||||
filepath.Join(".", ConfigFileName), // Current directory (highest priority)
|
||||
}
|
||||
|
||||
// User's home directory
|
||||
if home, err := os.UserHomeDir(); err == nil && home != "" {
|
||||
paths = append(paths, filepath.Join(home, ConfigFileName))
|
||||
}
|
||||
|
||||
// System-wide config locations
|
||||
paths = append(paths,
|
||||
"/etc/dbbackup.conf",
|
||||
"/etc/dbbackup/dbbackup.conf",
|
||||
)
|
||||
|
||||
return paths
|
||||
}
|
||||
|
||||
// LoadLocalConfig loads configuration from .dbbackup.conf
|
||||
// Search order: 1) current directory, 2) user's home directory, 3) /etc/dbbackup.conf, 4) /etc/dbbackup/dbbackup.conf
|
||||
func LoadLocalConfig() (*LocalConfig, error) {
|
||||
return LoadLocalConfigFromPath(filepath.Join(".", ConfigFileName))
|
||||
for _, path := range ConfigSearchPaths() {
|
||||
cfg, err := LoadLocalConfigFromPath(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cfg != nil {
|
||||
return cfg, nil
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// LoadLocalConfigWithPath loads configuration and returns the path it was loaded from
|
||||
func LoadLocalConfigWithPath() (*LocalConfig, string, error) {
|
||||
for _, path := range ConfigSearchPaths() {
|
||||
cfg, err := LoadLocalConfigFromPath(path)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if cfg != nil {
|
||||
return cfg, path, nil
|
||||
}
|
||||
}
|
||||
return nil, "", nil
|
||||
}
|
||||
|
||||
// LoadLocalConfigFromPath loads configuration from a specific path
|
||||
|
||||
@ -74,7 +74,7 @@ func (p *PostgreSQL) Connect(ctx context.Context) error {
|
||||
config.MinConns = 2 // Keep minimum connections ready
|
||||
config.MaxConnLifetime = 0 // No limit on connection lifetime
|
||||
config.MaxConnIdleTime = 0 // No idle timeout
|
||||
config.HealthCheckPeriod = 1 * time.Minute // Health check every minute
|
||||
config.HealthCheckPeriod = 5 * time.Second // Faster health check for quicker shutdown on Ctrl+C
|
||||
|
||||
// Optimize for large query results (BLOB data)
|
||||
config.ConnConfig.RuntimeParams["work_mem"] = "64MB"
|
||||
@ -97,6 +97,14 @@ func (p *PostgreSQL) Connect(ctx context.Context) error {
|
||||
|
||||
p.pool = pool
|
||||
p.db = db
|
||||
|
||||
// NOTE: We intentionally do NOT start a goroutine to close the pool on context cancellation.
|
||||
// The pool is closed via defer dbClient.Close() in the caller, which is the correct pattern.
|
||||
// Starting a goroutine here causes goroutine leaks and potential double-close issues when:
|
||||
// 1. The caller's defer runs first (normal case)
|
||||
// 2. Then context is cancelled and the goroutine tries to close an already-closed pool
|
||||
// This was causing deadlocks in the TUI when tea.Batch was waiting for commands to complete.
|
||||
|
||||
p.log.Info("Connected to PostgreSQL successfully", "driver", "pgx", "max_conns", config.MaxConns)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -28,6 +28,9 @@ type ParallelRestoreEngine struct {
|
||||
|
||||
// Configuration
|
||||
parallelWorkers int
|
||||
|
||||
// Internal cancel channel to stop the pool cleanup goroutine
|
||||
closeCh chan struct{}
|
||||
}
|
||||
|
||||
// ParallelRestoreOptions configures parallel restore behavior
|
||||
@ -71,7 +74,14 @@ const (
|
||||
)
|
||||
|
||||
// NewParallelRestoreEngine creates a new parallel restore engine
|
||||
// NOTE: Pass a cancellable context to ensure the pool is properly closed on Ctrl+C
|
||||
func NewParallelRestoreEngine(config *PostgreSQLNativeConfig, log logger.Logger, workers int) (*ParallelRestoreEngine, error) {
|
||||
return NewParallelRestoreEngineWithContext(context.Background(), config, log, workers)
|
||||
}
|
||||
|
||||
// NewParallelRestoreEngineWithContext creates a new parallel restore engine with context support
|
||||
// This ensures the connection pool is properly closed when the context is cancelled
|
||||
func NewParallelRestoreEngineWithContext(ctx context.Context, config *PostgreSQLNativeConfig, log logger.Logger, workers int) (*ParallelRestoreEngine, error) {
|
||||
if workers < 1 {
|
||||
workers = 4 // Default to 4 parallel workers
|
||||
}
|
||||
@ -94,17 +104,35 @@ func NewParallelRestoreEngine(config *PostgreSQLNativeConfig, log logger.Logger,
|
||||
poolConfig.MaxConns = int32(workers + 2)
|
||||
poolConfig.MinConns = int32(workers)
|
||||
|
||||
pool, err := pgxpool.NewWithConfig(context.Background(), poolConfig)
|
||||
// CRITICAL: Reduce health check period to allow faster shutdown
|
||||
// Default is 1 minute which causes hangs on Ctrl+C
|
||||
poolConfig.HealthCheckPeriod = 5 * time.Second
|
||||
|
||||
// Use the provided context so pool health checks stop when context is cancelled
|
||||
pool, err := pgxpool.NewWithConfig(ctx, poolConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create connection pool: %w", err)
|
||||
}
|
||||
|
||||
return &ParallelRestoreEngine{
|
||||
closeCh := make(chan struct{})
|
||||
|
||||
engine := &ParallelRestoreEngine{
|
||||
config: config,
|
||||
pool: pool,
|
||||
log: log,
|
||||
parallelWorkers: workers,
|
||||
}, nil
|
||||
closeCh: closeCh,
|
||||
}
|
||||
|
||||
// NOTE: We intentionally do NOT start a goroutine to close the pool on context cancellation.
|
||||
// The pool is closed via defer parallelEngine.Close() in the caller (restore/engine.go).
|
||||
// The Close() method properly signals closeCh and closes the pool.
|
||||
// Starting a goroutine here can cause:
|
||||
// 1. Race conditions with explicit Close() calls
|
||||
// 2. Goroutine leaks if neither ctx nor Close() fires
|
||||
// 3. Deadlocks with BubbleTea's event loop
|
||||
|
||||
return engine, nil
|
||||
}
|
||||
|
||||
// RestoreFile restores from a SQL file with parallel execution
|
||||
@ -215,17 +243,38 @@ func (e *ParallelRestoreEngine) RestoreFile(ctx context.Context, filePath string
|
||||
semaphore := make(chan struct{}, options.Workers)
|
||||
var completedCopies int64
|
||||
var totalRows int64
|
||||
var cancelled int32 // Atomic flag to signal cancellation
|
||||
|
||||
for _, stmt := range copyStmts {
|
||||
// Check for context cancellation before starting new work
|
||||
if ctx.Err() != nil {
|
||||
break
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
semaphore <- struct{}{} // Acquire worker slot
|
||||
select {
|
||||
case semaphore <- struct{}{}: // Acquire worker slot
|
||||
case <-ctx.Done():
|
||||
wg.Done()
|
||||
atomic.StoreInt32(&cancelled, 1)
|
||||
break
|
||||
}
|
||||
|
||||
go func(s *SQLStatement) {
|
||||
defer wg.Done()
|
||||
defer func() { <-semaphore }() // Release worker slot
|
||||
|
||||
// Check cancellation before executing
|
||||
if ctx.Err() != nil || atomic.LoadInt32(&cancelled) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := e.executeCopy(ctx, s)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
// Context cancelled, don't log as error
|
||||
return
|
||||
}
|
||||
if options.ContinueOnError {
|
||||
e.log.Warn("COPY failed", "table", s.TableName, "error", err)
|
||||
} else {
|
||||
@ -243,6 +292,12 @@ func (e *ParallelRestoreEngine) RestoreFile(ctx context.Context, filePath string
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Check if cancelled
|
||||
if ctx.Err() != nil {
|
||||
return result, ctx.Err()
|
||||
}
|
||||
|
||||
result.TablesRestored = completedCopies
|
||||
result.RowsRestored = totalRows
|
||||
|
||||
@ -264,15 +319,35 @@ func (e *ParallelRestoreEngine) RestoreFile(ctx context.Context, filePath string
|
||||
|
||||
// Execute post-data in parallel
|
||||
var completedPostData int64
|
||||
cancelled = 0 // Reset for phase 4
|
||||
for _, sql := range postDataStmts {
|
||||
// Check for context cancellation before starting new work
|
||||
if ctx.Err() != nil {
|
||||
break
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
semaphore <- struct{}{}
|
||||
select {
|
||||
case semaphore <- struct{}{}:
|
||||
case <-ctx.Done():
|
||||
wg.Done()
|
||||
atomic.StoreInt32(&cancelled, 1)
|
||||
break
|
||||
}
|
||||
|
||||
go func(stmt string) {
|
||||
defer wg.Done()
|
||||
defer func() { <-semaphore }()
|
||||
|
||||
// Check cancellation before executing
|
||||
if ctx.Err() != nil || atomic.LoadInt32(&cancelled) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
if err := e.executeStatement(ctx, stmt); err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return // Context cancelled
|
||||
}
|
||||
if options.ContinueOnError {
|
||||
e.log.Warn("Post-data statement failed", "error", err)
|
||||
}
|
||||
@ -289,6 +364,11 @@ func (e *ParallelRestoreEngine) RestoreFile(ctx context.Context, filePath string
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Check if cancelled
|
||||
if ctx.Err() != nil {
|
||||
return result, ctx.Err()
|
||||
}
|
||||
|
||||
result.Duration = time.Since(startTime)
|
||||
e.log.Info("Parallel restore completed",
|
||||
"duration", result.Duration,
|
||||
@ -450,8 +530,13 @@ func (e *ParallelRestoreEngine) executeCopy(ctx context.Context, stmt *SQLStatem
|
||||
return tag.RowsAffected(), nil
|
||||
}
|
||||
|
||||
// Close closes the connection pool
|
||||
// Close closes the connection pool and stops the cleanup goroutine
|
||||
func (e *ParallelRestoreEngine) Close() error {
|
||||
// Signal the cleanup goroutine to exit
|
||||
if e.closeCh != nil {
|
||||
close(e.closeCh)
|
||||
}
|
||||
// Close the pool
|
||||
if e.pool != nil {
|
||||
e.pool.Close()
|
||||
}
|
||||
|
||||
666
internal/restore/dryrun.go
Normal file
666
internal/restore/dryrun.go
Normal file
@ -0,0 +1,666 @@
|
||||
package restore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"dbbackup/internal/cleanup"
|
||||
"dbbackup/internal/config"
|
||||
"dbbackup/internal/logger"
|
||||
)
|
||||
|
||||
// DryRunCheck represents a single dry-run check result
|
||||
type DryRunCheck struct {
|
||||
Name string
|
||||
Status DryRunStatus
|
||||
Message string
|
||||
Details string
|
||||
Critical bool // If true, restore will definitely fail
|
||||
}
|
||||
|
||||
// DryRunStatus represents the status of a dry-run check
|
||||
type DryRunStatus int
|
||||
|
||||
const (
|
||||
DryRunPassed DryRunStatus = iota
|
||||
DryRunWarning
|
||||
DryRunFailed
|
||||
DryRunSkipped
|
||||
)
|
||||
|
||||
func (s DryRunStatus) String() string {
|
||||
switch s {
|
||||
case DryRunPassed:
|
||||
return "PASS"
|
||||
case DryRunWarning:
|
||||
return "WARN"
|
||||
case DryRunFailed:
|
||||
return "FAIL"
|
||||
case DryRunSkipped:
|
||||
return "SKIP"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
func (s DryRunStatus) Icon() string {
|
||||
switch s {
|
||||
case DryRunPassed:
|
||||
return "[+]"
|
||||
case DryRunWarning:
|
||||
return "[!]"
|
||||
case DryRunFailed:
|
||||
return "[-]"
|
||||
case DryRunSkipped:
|
||||
return "[ ]"
|
||||
default:
|
||||
return "[?]"
|
||||
}
|
||||
}
|
||||
|
||||
// DryRunResult contains all dry-run check results
|
||||
type DryRunResult struct {
|
||||
Checks []DryRunCheck
|
||||
CanProceed bool
|
||||
HasWarnings bool
|
||||
CriticalCount int
|
||||
WarningCount int
|
||||
EstimatedTime time.Duration
|
||||
RequiredDiskMB int64
|
||||
AvailableDiskMB int64
|
||||
}
|
||||
|
||||
// RestoreDryRun performs comprehensive pre-restore validation
|
||||
type RestoreDryRun struct {
|
||||
cfg *config.Config
|
||||
log logger.Logger
|
||||
safety *Safety
|
||||
archive string
|
||||
target string
|
||||
}
|
||||
|
||||
// NewRestoreDryRun creates a new restore dry-run validator
|
||||
func NewRestoreDryRun(cfg *config.Config, log logger.Logger, archivePath, targetDB string) *RestoreDryRun {
|
||||
return &RestoreDryRun{
|
||||
cfg: cfg,
|
||||
log: log,
|
||||
safety: NewSafety(cfg, log),
|
||||
archive: archivePath,
|
||||
target: targetDB,
|
||||
}
|
||||
}
|
||||
|
||||
// Run executes all dry-run checks
|
||||
func (r *RestoreDryRun) Run(ctx context.Context) (*DryRunResult, error) {
|
||||
result := &DryRunResult{
|
||||
Checks: make([]DryRunCheck, 0, 10),
|
||||
CanProceed: true,
|
||||
}
|
||||
|
||||
r.log.Info("Running restore dry-run checks",
|
||||
"archive", r.archive,
|
||||
"target", r.target)
|
||||
|
||||
// 1. Archive existence and accessibility
|
||||
result.Checks = append(result.Checks, r.checkArchiveAccess())
|
||||
|
||||
// 2. Archive format validation
|
||||
result.Checks = append(result.Checks, r.checkArchiveFormat())
|
||||
|
||||
// 3. Database connectivity
|
||||
result.Checks = append(result.Checks, r.checkDatabaseConnectivity(ctx))
|
||||
|
||||
// 4. User permissions (CREATE DATABASE, DROP, etc.)
|
||||
result.Checks = append(result.Checks, r.checkUserPermissions(ctx))
|
||||
|
||||
// 5. Target database conflicts
|
||||
result.Checks = append(result.Checks, r.checkTargetConflicts(ctx))
|
||||
|
||||
// 6. Disk space requirements
|
||||
diskCheck, requiredMB, availableMB := r.checkDiskSpace()
|
||||
result.Checks = append(result.Checks, diskCheck)
|
||||
result.RequiredDiskMB = requiredMB
|
||||
result.AvailableDiskMB = availableMB
|
||||
|
||||
// 7. Work directory permissions
|
||||
result.Checks = append(result.Checks, r.checkWorkDirectory())
|
||||
|
||||
// 8. Required tools availability
|
||||
result.Checks = append(result.Checks, r.checkRequiredTools())
|
||||
|
||||
// 9. PostgreSQL lock settings (for parallel restore)
|
||||
result.Checks = append(result.Checks, r.checkLockSettings(ctx))
|
||||
|
||||
// 10. Memory availability
|
||||
result.Checks = append(result.Checks, r.checkMemoryAvailability())
|
||||
|
||||
// Calculate summary
|
||||
for _, check := range result.Checks {
|
||||
switch check.Status {
|
||||
case DryRunFailed:
|
||||
if check.Critical {
|
||||
result.CriticalCount++
|
||||
result.CanProceed = false
|
||||
} else {
|
||||
result.WarningCount++
|
||||
result.HasWarnings = true
|
||||
}
|
||||
case DryRunWarning:
|
||||
result.WarningCount++
|
||||
result.HasWarnings = true
|
||||
}
|
||||
}
|
||||
|
||||
// Estimate restore time based on archive size
|
||||
result.EstimatedTime = r.estimateRestoreTime()
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// checkArchiveAccess verifies the archive file is accessible
|
||||
func (r *RestoreDryRun) checkArchiveAccess() DryRunCheck {
|
||||
check := DryRunCheck{
|
||||
Name: "Archive Access",
|
||||
Critical: true,
|
||||
}
|
||||
|
||||
info, err := os.Stat(r.archive)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
check.Status = DryRunFailed
|
||||
check.Message = "Archive file not found"
|
||||
check.Details = r.archive
|
||||
} else if os.IsPermission(err) {
|
||||
check.Status = DryRunFailed
|
||||
check.Message = "Permission denied reading archive"
|
||||
check.Details = err.Error()
|
||||
} else {
|
||||
check.Status = DryRunFailed
|
||||
check.Message = "Cannot access archive"
|
||||
check.Details = err.Error()
|
||||
}
|
||||
return check
|
||||
}
|
||||
|
||||
if info.Size() == 0 {
|
||||
check.Status = DryRunFailed
|
||||
check.Message = "Archive file is empty"
|
||||
return check
|
||||
}
|
||||
|
||||
check.Status = DryRunPassed
|
||||
check.Message = fmt.Sprintf("Archive accessible (%s)", formatBytesSize(info.Size()))
|
||||
return check
|
||||
}
|
||||
|
||||
// checkArchiveFormat validates the archive format
|
||||
func (r *RestoreDryRun) checkArchiveFormat() DryRunCheck {
|
||||
check := DryRunCheck{
|
||||
Name: "Archive Format",
|
||||
Critical: true,
|
||||
}
|
||||
|
||||
err := r.safety.ValidateArchive(r.archive)
|
||||
if err != nil {
|
||||
check.Status = DryRunFailed
|
||||
check.Message = "Invalid archive format"
|
||||
check.Details = err.Error()
|
||||
return check
|
||||
}
|
||||
|
||||
format := DetectArchiveFormat(r.archive)
|
||||
check.Status = DryRunPassed
|
||||
check.Message = fmt.Sprintf("Valid %s format", format.String())
|
||||
return check
|
||||
}
|
||||
|
||||
// checkDatabaseConnectivity tests database connection
|
||||
func (r *RestoreDryRun) checkDatabaseConnectivity(ctx context.Context) DryRunCheck {
|
||||
check := DryRunCheck{
|
||||
Name: "Database Connectivity",
|
||||
Critical: true,
|
||||
}
|
||||
|
||||
// Try to list databases as a connectivity check
|
||||
_, err := r.safety.ListUserDatabases(ctx)
|
||||
if err != nil {
|
||||
check.Status = DryRunFailed
|
||||
check.Message = "Cannot connect to database server"
|
||||
check.Details = err.Error()
|
||||
return check
|
||||
}
|
||||
|
||||
check.Status = DryRunPassed
|
||||
check.Message = fmt.Sprintf("Connected to %s:%d", r.cfg.Host, r.cfg.Port)
|
||||
return check
|
||||
}
|
||||
|
||||
// checkUserPermissions verifies required database permissions
|
||||
func (r *RestoreDryRun) checkUserPermissions(ctx context.Context) DryRunCheck {
|
||||
check := DryRunCheck{
|
||||
Name: "User Permissions",
|
||||
Critical: true,
|
||||
}
|
||||
|
||||
if r.cfg.DatabaseType != "postgres" {
|
||||
check.Status = DryRunSkipped
|
||||
check.Message = "Permission check only implemented for PostgreSQL"
|
||||
return check
|
||||
}
|
||||
|
||||
// Check if user has CREATEDB privilege
|
||||
query := `SELECT rolcreatedb, rolsuper FROM pg_roles WHERE rolname = current_user`
|
||||
|
||||
args := []string{
|
||||
"-h", r.cfg.Host,
|
||||
"-p", fmt.Sprintf("%d", r.cfg.Port),
|
||||
"-U", r.cfg.User,
|
||||
"-d", "postgres",
|
||||
"-tA",
|
||||
"-c", query,
|
||||
}
|
||||
|
||||
cmd := cleanup.SafeCommand(ctx, "psql", args...)
|
||||
if r.cfg.Password != "" {
|
||||
cmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", r.cfg.Password))
|
||||
}
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
check.Status = DryRunWarning
|
||||
check.Message = "Could not verify permissions"
|
||||
check.Details = err.Error()
|
||||
return check
|
||||
}
|
||||
|
||||
result := strings.TrimSpace(string(output))
|
||||
parts := strings.Split(result, "|")
|
||||
|
||||
if len(parts) >= 2 {
|
||||
canCreate := parts[0] == "t"
|
||||
isSuper := parts[1] == "t"
|
||||
|
||||
if isSuper {
|
||||
check.Status = DryRunPassed
|
||||
check.Message = "User is superuser (full permissions)"
|
||||
return check
|
||||
}
|
||||
|
||||
if canCreate {
|
||||
check.Status = DryRunPassed
|
||||
check.Message = "User has CREATEDB privilege"
|
||||
return check
|
||||
}
|
||||
}
|
||||
|
||||
check.Status = DryRunFailed
|
||||
check.Message = "User lacks CREATEDB privilege"
|
||||
check.Details = "Required for creating target database. Run: ALTER USER " + r.cfg.User + " CREATEDB;"
|
||||
return check
|
||||
}
|
||||
|
||||
// checkTargetConflicts checks if target database already exists
|
||||
func (r *RestoreDryRun) checkTargetConflicts(ctx context.Context) DryRunCheck {
|
||||
check := DryRunCheck{
|
||||
Name: "Target Database",
|
||||
Critical: false, // Not critical - can be overwritten with --clean
|
||||
}
|
||||
|
||||
if r.target == "" {
|
||||
check.Status = DryRunSkipped
|
||||
check.Message = "Cluster restore - checking multiple databases"
|
||||
return check
|
||||
}
|
||||
|
||||
databases, err := r.safety.ListUserDatabases(ctx)
|
||||
if err != nil {
|
||||
check.Status = DryRunWarning
|
||||
check.Message = "Could not check existing databases"
|
||||
check.Details = err.Error()
|
||||
return check
|
||||
}
|
||||
|
||||
for _, db := range databases {
|
||||
if db == r.target {
|
||||
check.Status = DryRunWarning
|
||||
check.Message = fmt.Sprintf("Database '%s' already exists", r.target)
|
||||
check.Details = "Use --clean to drop and recreate, or choose different target"
|
||||
return check
|
||||
}
|
||||
}
|
||||
|
||||
check.Status = DryRunPassed
|
||||
check.Message = fmt.Sprintf("Target '%s' is available", r.target)
|
||||
return check
|
||||
}
|
||||
|
||||
// checkDiskSpace verifies sufficient disk space
|
||||
func (r *RestoreDryRun) checkDiskSpace() (DryRunCheck, int64, int64) {
|
||||
check := DryRunCheck{
|
||||
Name: "Disk Space",
|
||||
Critical: true,
|
||||
}
|
||||
|
||||
// Get archive size
|
||||
info, err := os.Stat(r.archive)
|
||||
if err != nil {
|
||||
check.Status = DryRunSkipped
|
||||
check.Message = "Cannot determine archive size"
|
||||
return check, 0, 0
|
||||
}
|
||||
|
||||
// Estimate uncompressed size (assume 3x compression ratio)
|
||||
archiveSizeMB := info.Size() / 1024 / 1024
|
||||
estimatedUncompressedMB := archiveSizeMB * 3
|
||||
|
||||
// Need space for: work dir extraction + restored database
|
||||
// Work dir: full uncompressed size
|
||||
// Database: roughly same as uncompressed SQL
|
||||
requiredMB := estimatedUncompressedMB * 2
|
||||
|
||||
// Check available disk space in work directory
|
||||
workDir := r.cfg.GetEffectiveWorkDir()
|
||||
if workDir == "" {
|
||||
workDir = r.cfg.BackupDir
|
||||
}
|
||||
|
||||
var stat syscall.Statfs_t
|
||||
if err := syscall.Statfs(workDir, &stat); err != nil {
|
||||
check.Status = DryRunWarning
|
||||
check.Message = "Cannot check disk space"
|
||||
check.Details = err.Error()
|
||||
return check, requiredMB, 0
|
||||
}
|
||||
|
||||
availableMB := int64(stat.Bavail*uint64(stat.Bsize)) / 1024 / 1024
|
||||
|
||||
if availableMB < requiredMB {
|
||||
check.Status = DryRunFailed
|
||||
check.Message = fmt.Sprintf("Insufficient disk space: need %d MB, have %d MB", requiredMB, availableMB)
|
||||
check.Details = fmt.Sprintf("Work directory: %s", workDir)
|
||||
return check, requiredMB, availableMB
|
||||
}
|
||||
|
||||
// Warn if less than 20% buffer
|
||||
if availableMB < requiredMB*12/10 {
|
||||
check.Status = DryRunWarning
|
||||
check.Message = fmt.Sprintf("Low disk space margin: need %d MB, have %d MB", requiredMB, availableMB)
|
||||
return check, requiredMB, availableMB
|
||||
}
|
||||
|
||||
check.Status = DryRunPassed
|
||||
check.Message = fmt.Sprintf("Sufficient space: need ~%d MB, have %d MB", requiredMB, availableMB)
|
||||
return check, requiredMB, availableMB
|
||||
}
|
||||
|
||||
// checkWorkDirectory verifies work directory is writable
|
||||
func (r *RestoreDryRun) checkWorkDirectory() DryRunCheck {
|
||||
check := DryRunCheck{
|
||||
Name: "Work Directory",
|
||||
Critical: true,
|
||||
}
|
||||
|
||||
workDir := r.cfg.GetEffectiveWorkDir()
|
||||
if workDir == "" {
|
||||
workDir = r.cfg.BackupDir
|
||||
}
|
||||
|
||||
// Check if directory exists
|
||||
info, err := os.Stat(workDir)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
check.Status = DryRunFailed
|
||||
check.Message = "Work directory does not exist"
|
||||
check.Details = workDir
|
||||
} else {
|
||||
check.Status = DryRunFailed
|
||||
check.Message = "Cannot access work directory"
|
||||
check.Details = err.Error()
|
||||
}
|
||||
return check
|
||||
}
|
||||
|
||||
if !info.IsDir() {
|
||||
check.Status = DryRunFailed
|
||||
check.Message = "Work path is not a directory"
|
||||
check.Details = workDir
|
||||
return check
|
||||
}
|
||||
|
||||
// Try to create a test file
|
||||
testFile := filepath.Join(workDir, ".dbbackup-dryrun-test")
|
||||
f, err := os.Create(testFile)
|
||||
if err != nil {
|
||||
check.Status = DryRunFailed
|
||||
check.Message = "Work directory is not writable"
|
||||
check.Details = err.Error()
|
||||
return check
|
||||
}
|
||||
f.Close()
|
||||
os.Remove(testFile)
|
||||
|
||||
check.Status = DryRunPassed
|
||||
check.Message = fmt.Sprintf("Work directory writable: %s", workDir)
|
||||
return check
|
||||
}
|
||||
|
||||
// checkRequiredTools verifies required CLI tools are available
|
||||
func (r *RestoreDryRun) checkRequiredTools() DryRunCheck {
|
||||
check := DryRunCheck{
|
||||
Name: "Required Tools",
|
||||
Critical: true,
|
||||
}
|
||||
|
||||
var required []string
|
||||
switch r.cfg.DatabaseType {
|
||||
case "postgres":
|
||||
required = []string{"pg_restore", "psql", "createdb"}
|
||||
case "mysql", "mariadb":
|
||||
required = []string{"mysql", "mysqldump"}
|
||||
default:
|
||||
check.Status = DryRunSkipped
|
||||
check.Message = "Unknown database type"
|
||||
return check
|
||||
}
|
||||
|
||||
missing := []string{}
|
||||
for _, tool := range required {
|
||||
if _, err := LookPath(tool); err != nil {
|
||||
missing = append(missing, tool)
|
||||
}
|
||||
}
|
||||
|
||||
if len(missing) > 0 {
|
||||
check.Status = DryRunFailed
|
||||
check.Message = fmt.Sprintf("Missing tools: %s", strings.Join(missing, ", "))
|
||||
check.Details = "Install the database client tools package"
|
||||
return check
|
||||
}
|
||||
|
||||
check.Status = DryRunPassed
|
||||
check.Message = fmt.Sprintf("All tools available: %s", strings.Join(required, ", "))
|
||||
return check
|
||||
}
|
||||
|
||||
// checkLockSettings checks PostgreSQL lock settings for parallel restore
|
||||
func (r *RestoreDryRun) checkLockSettings(ctx context.Context) DryRunCheck {
|
||||
check := DryRunCheck{
|
||||
Name: "Lock Settings",
|
||||
Critical: false,
|
||||
}
|
||||
|
||||
if r.cfg.DatabaseType != "postgres" {
|
||||
check.Status = DryRunSkipped
|
||||
check.Message = "Lock check only for PostgreSQL"
|
||||
return check
|
||||
}
|
||||
|
||||
// Check max_locks_per_transaction
|
||||
query := `SHOW max_locks_per_transaction`
|
||||
args := []string{
|
||||
"-h", r.cfg.Host,
|
||||
"-p", fmt.Sprintf("%d", r.cfg.Port),
|
||||
"-U", r.cfg.User,
|
||||
"-d", "postgres",
|
||||
"-tA",
|
||||
"-c", query,
|
||||
}
|
||||
|
||||
cmd := cleanup.SafeCommand(ctx, "psql", args...)
|
||||
if r.cfg.Password != "" {
|
||||
cmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", r.cfg.Password))
|
||||
}
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
check.Status = DryRunWarning
|
||||
check.Message = "Could not check lock settings"
|
||||
return check
|
||||
}
|
||||
|
||||
locks := strings.TrimSpace(string(output))
|
||||
if locks == "" {
|
||||
check.Status = DryRunWarning
|
||||
check.Message = "Could not determine max_locks_per_transaction"
|
||||
return check
|
||||
}
|
||||
|
||||
// Default is 64, recommend at least 128 for parallel restores
|
||||
var lockCount int
|
||||
fmt.Sscanf(locks, "%d", &lockCount)
|
||||
|
||||
if lockCount < 128 {
|
||||
check.Status = DryRunWarning
|
||||
check.Message = fmt.Sprintf("max_locks_per_transaction=%d (recommend 128+ for parallel)", lockCount)
|
||||
check.Details = "Set: ALTER SYSTEM SET max_locks_per_transaction = 128; then restart PostgreSQL"
|
||||
return check
|
||||
}
|
||||
|
||||
check.Status = DryRunPassed
|
||||
check.Message = fmt.Sprintf("max_locks_per_transaction=%d (sufficient)", lockCount)
|
||||
return check
|
||||
}
|
||||
|
||||
// checkMemoryAvailability checks if enough memory is available
|
||||
func (r *RestoreDryRun) checkMemoryAvailability() DryRunCheck {
|
||||
check := DryRunCheck{
|
||||
Name: "Memory Availability",
|
||||
Critical: false,
|
||||
}
|
||||
|
||||
// Read /proc/meminfo on Linux
|
||||
data, err := os.ReadFile("/proc/meminfo")
|
||||
if err != nil {
|
||||
check.Status = DryRunSkipped
|
||||
check.Message = "Cannot check memory (non-Linux?)"
|
||||
return check
|
||||
}
|
||||
|
||||
var availableKB int64
|
||||
for _, line := range strings.Split(string(data), "\n") {
|
||||
if strings.HasPrefix(line, "MemAvailable:") {
|
||||
fmt.Sscanf(line, "MemAvailable: %d kB", &availableKB)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
availableMB := availableKB / 1024
|
||||
|
||||
// Recommend at least 1GB for restore operations
|
||||
if availableMB < 1024 {
|
||||
check.Status = DryRunWarning
|
||||
check.Message = fmt.Sprintf("Low available memory: %d MB", availableMB)
|
||||
check.Details = "Restore may be slow or fail. Consider closing other applications."
|
||||
return check
|
||||
}
|
||||
|
||||
check.Status = DryRunPassed
|
||||
check.Message = fmt.Sprintf("Available memory: %d MB", availableMB)
|
||||
return check
|
||||
}
|
||||
|
||||
// estimateRestoreTime estimates restore duration based on archive size
|
||||
func (r *RestoreDryRun) estimateRestoreTime() time.Duration {
|
||||
info, err := os.Stat(r.archive)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Rough estimate: 100 MB/minute for restore operations
|
||||
// This accounts for decompression, SQL parsing, and database writes
|
||||
sizeMB := info.Size() / 1024 / 1024
|
||||
minutes := sizeMB / 100
|
||||
if minutes < 1 {
|
||||
minutes = 1
|
||||
}
|
||||
|
||||
return time.Duration(minutes) * time.Minute
|
||||
}
|
||||
|
||||
// formatBytesSize formats bytes to human-readable string
|
||||
func formatBytesSize(bytes int64) string {
|
||||
const (
|
||||
KB = 1024
|
||||
MB = KB * 1024
|
||||
GB = MB * 1024
|
||||
)
|
||||
|
||||
switch {
|
||||
case bytes >= GB:
|
||||
return fmt.Sprintf("%.1f GB", float64(bytes)/GB)
|
||||
case bytes >= MB:
|
||||
return fmt.Sprintf("%.1f MB", float64(bytes)/MB)
|
||||
case bytes >= KB:
|
||||
return fmt.Sprintf("%.1f KB", float64(bytes)/KB)
|
||||
default:
|
||||
return fmt.Sprintf("%d B", bytes)
|
||||
}
|
||||
}
|
||||
|
||||
// LookPath is a wrapper around exec.LookPath for testing
|
||||
var LookPath = func(file string) (string, error) {
|
||||
return exec.LookPath(file)
|
||||
}
|
||||
|
||||
// PrintDryRunResult prints a formatted dry-run result
|
||||
func PrintDryRunResult(result *DryRunResult) {
|
||||
fmt.Println("\n" + strings.Repeat("=", 60))
|
||||
fmt.Println("RESTORE DRY-RUN RESULTS")
|
||||
fmt.Println(strings.Repeat("=", 60))
|
||||
|
||||
for _, check := range result.Checks {
|
||||
fmt.Printf("%s %-20s %s\n", check.Status.Icon(), check.Name+":", check.Message)
|
||||
if check.Details != "" {
|
||||
fmt.Printf(" └─ %s\n", check.Details)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println(strings.Repeat("-", 60))
|
||||
|
||||
if result.EstimatedTime > 0 {
|
||||
fmt.Printf("Estimated restore time: %s\n", result.EstimatedTime)
|
||||
}
|
||||
|
||||
if result.RequiredDiskMB > 0 {
|
||||
fmt.Printf("Disk space: %d MB required, %d MB available\n",
|
||||
result.RequiredDiskMB, result.AvailableDiskMB)
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
if result.CanProceed {
|
||||
if result.HasWarnings {
|
||||
fmt.Println("⚠️ DRY-RUN: PASSED with warnings - restore can proceed")
|
||||
} else {
|
||||
fmt.Println("✅ DRY-RUN: PASSED - restore can proceed")
|
||||
}
|
||||
} else {
|
||||
fmt.Printf("❌ DRY-RUN: FAILED - %d critical issue(s) must be resolved\n", result.CriticalCount)
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
@ -635,7 +635,8 @@ func (e *Engine) restoreWithNativeEngine(ctx context.Context, archivePath, targe
|
||||
"database", targetDB,
|
||||
"archive", archivePath)
|
||||
|
||||
parallelEngine, err := native.NewParallelRestoreEngine(nativeCfg, e.log, parallelWorkers)
|
||||
// Pass context to ensure pool is properly closed on Ctrl+C cancellation
|
||||
parallelEngine, err := native.NewParallelRestoreEngineWithContext(ctx, nativeCfg, e.log, parallelWorkers)
|
||||
if err != nil {
|
||||
e.log.Warn("Failed to create parallel restore engine, falling back to sequential", "error", err)
|
||||
// Fall back to sequential restore
|
||||
@ -1342,9 +1343,14 @@ func (e *Engine) RestoreCluster(ctx context.Context, archivePath string, preExtr
|
||||
}
|
||||
|
||||
format := DetectArchiveFormat(archivePath)
|
||||
if format != FormatClusterTarGz {
|
||||
if !format.CanBeClusterRestore() {
|
||||
operation.Fail("Invalid cluster archive format")
|
||||
return fmt.Errorf("not a cluster archive: %s (detected format: %s)", archivePath, format)
|
||||
return fmt.Errorf("not a valid cluster restore format: %s (detected format: %s). Supported: .tar.gz, .sql, .sql.gz", archivePath, format)
|
||||
}
|
||||
|
||||
// For SQL-based cluster restores, use a different restore path
|
||||
if format == FormatPostgreSQLSQL || format == FormatPostgreSQLSQLGz {
|
||||
return e.restoreClusterFromSQL(ctx, archivePath, operation)
|
||||
}
|
||||
|
||||
// Check if we have a pre-extracted directory (optimization to avoid double extraction)
|
||||
@ -2177,6 +2183,45 @@ func (e *Engine) RestoreCluster(ctx context.Context, archivePath string, preExtr
|
||||
return nil
|
||||
}
|
||||
|
||||
// restoreClusterFromSQL restores a pg_dumpall SQL file using the native engine
|
||||
// This handles .sql and .sql.gz files containing full cluster dumps
|
||||
func (e *Engine) restoreClusterFromSQL(ctx context.Context, archivePath string, operation logger.OperationLogger) error {
|
||||
e.log.Info("Restoring cluster from SQL file (pg_dumpall format)",
|
||||
"file", filepath.Base(archivePath),
|
||||
"native_engine", true)
|
||||
|
||||
clusterStartTime := time.Now()
|
||||
|
||||
// Determine if compressed
|
||||
compressed := strings.HasSuffix(strings.ToLower(archivePath), ".gz")
|
||||
|
||||
// Use native engine to restore directly to postgres database (globals + all databases)
|
||||
e.log.Info("Restoring SQL dump using native engine...",
|
||||
"compressed", compressed,
|
||||
"size", FormatBytes(getFileSize(archivePath)))
|
||||
|
||||
e.progress.Start("Restoring cluster from SQL dump...")
|
||||
|
||||
// For pg_dumpall, we restore to the 'postgres' database which then creates other databases
|
||||
targetDB := "postgres"
|
||||
|
||||
err := e.restoreWithNativeEngine(ctx, archivePath, targetDB, compressed)
|
||||
if err != nil {
|
||||
operation.Fail(fmt.Sprintf("SQL cluster restore failed: %v", err))
|
||||
e.recordClusterRestoreMetrics(clusterStartTime, archivePath, 0, 0, false, err.Error())
|
||||
return fmt.Errorf("SQL cluster restore failed: %w", err)
|
||||
}
|
||||
|
||||
duration := time.Since(clusterStartTime)
|
||||
e.progress.Complete(fmt.Sprintf("Cluster restored successfully from SQL in %s", duration.Round(time.Second)))
|
||||
operation.Complete("SQL cluster restore completed")
|
||||
|
||||
// Record metrics
|
||||
e.recordClusterRestoreMetrics(clusterStartTime, archivePath, 1, 1, true, "")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// recordClusterRestoreMetrics records metrics for cluster restore operations
|
||||
func (e *Engine) recordClusterRestoreMetrics(startTime time.Time, archivePath string, totalDBs, successCount int, success bool, errorMsg string) {
|
||||
duration := time.Since(startTime)
|
||||
@ -2924,6 +2969,15 @@ func (e *Engine) isIgnorableError(errorMsg string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// getFileSize returns the size of a file, or 0 if it can't be read
|
||||
func getFileSize(path string) int64 {
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return info.Size()
|
||||
}
|
||||
|
||||
// FormatBytes formats bytes to human readable format
|
||||
func FormatBytes(bytes int64) string {
|
||||
const unit = 1024
|
||||
|
||||
@ -168,11 +168,19 @@ func (f ArchiveFormat) IsCompressed() bool {
|
||||
f == FormatClusterTarGz
|
||||
}
|
||||
|
||||
// IsClusterBackup returns true if the archive is a cluster backup
|
||||
// IsClusterBackup returns true if the archive is a cluster backup (.tar.gz format created by dbbackup)
|
||||
func (f ArchiveFormat) IsClusterBackup() bool {
|
||||
return f == FormatClusterTarGz
|
||||
}
|
||||
|
||||
// CanBeClusterRestore returns true if the format can be used for cluster restore
|
||||
// This includes .tar.gz (dbbackup format) and .sql/.sql.gz (pg_dumpall format for native engine)
|
||||
func (f ArchiveFormat) CanBeClusterRestore() bool {
|
||||
return f == FormatClusterTarGz ||
|
||||
f == FormatPostgreSQLSQL ||
|
||||
f == FormatPostgreSQLSQLGz
|
||||
}
|
||||
|
||||
// IsPostgreSQL returns true if the archive is PostgreSQL format
|
||||
func (f ArchiveFormat) IsPostgreSQL() bool {
|
||||
return f == FormatPostgreSQLDump ||
|
||||
|
||||
@ -1,7 +1,15 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"dbbackup/internal/logger"
|
||||
@ -21,13 +29,36 @@ type AuditEvent struct {
|
||||
type AuditLogger struct {
|
||||
log logger.Logger
|
||||
enabled bool
|
||||
|
||||
// For signed audit log support
|
||||
mu sync.Mutex
|
||||
entries []SignedAuditEntry
|
||||
privateKey ed25519.PrivateKey
|
||||
publicKey ed25519.PublicKey
|
||||
prevHash string // Hash of previous entry for chaining
|
||||
}
|
||||
|
||||
// SignedAuditEntry represents an audit entry with cryptographic signature
|
||||
type SignedAuditEntry struct {
|
||||
Sequence int64 `json:"seq"`
|
||||
Timestamp string `json:"ts"`
|
||||
User string `json:"user"`
|
||||
Action string `json:"action"`
|
||||
Resource string `json:"resource"`
|
||||
Result string `json:"result"`
|
||||
Details string `json:"details,omitempty"`
|
||||
PrevHash string `json:"prev_hash"` // Hash chain for tamper detection
|
||||
Hash string `json:"hash"` // SHA-256 of this entry (without signature)
|
||||
Signature string `json:"sig"` // Ed25519 signature of Hash
|
||||
}
|
||||
|
||||
// NewAuditLogger creates a new audit logger
|
||||
func NewAuditLogger(log logger.Logger, enabled bool) *AuditLogger {
|
||||
return &AuditLogger{
|
||||
log: log,
|
||||
enabled: enabled,
|
||||
log: log,
|
||||
enabled: enabled,
|
||||
entries: make([]SignedAuditEntry, 0),
|
||||
prevHash: "genesis", // Initial hash for first entry
|
||||
}
|
||||
}
|
||||
|
||||
@ -232,3 +263,337 @@ func GetCurrentUser() string {
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Audit Log Signing and Verification
|
||||
// =============================================================================
|
||||
|
||||
// GenerateSigningKeys generates a new Ed25519 key pair for audit log signing
|
||||
func GenerateSigningKeys() (privateKey ed25519.PrivateKey, publicKey ed25519.PublicKey, err error) {
|
||||
publicKey, privateKey, err = ed25519.GenerateKey(rand.Reader)
|
||||
return
|
||||
}
|
||||
|
||||
// SavePrivateKey saves the private key to a file (PEM-like format)
|
||||
func SavePrivateKey(path string, key ed25519.PrivateKey) error {
|
||||
encoded := base64.StdEncoding.EncodeToString(key)
|
||||
content := fmt.Sprintf("-----BEGIN DBBACKUP AUDIT PRIVATE KEY-----\n%s\n-----END DBBACKUP AUDIT PRIVATE KEY-----\n", encoded)
|
||||
return os.WriteFile(path, []byte(content), 0600) // Restrictive permissions
|
||||
}
|
||||
|
||||
// SavePublicKey saves the public key to a file (PEM-like format)
|
||||
func SavePublicKey(path string, key ed25519.PublicKey) error {
|
||||
encoded := base64.StdEncoding.EncodeToString(key)
|
||||
content := fmt.Sprintf("-----BEGIN DBBACKUP AUDIT PUBLIC KEY-----\n%s\n-----END DBBACKUP AUDIT PUBLIC KEY-----\n", encoded)
|
||||
return os.WriteFile(path, []byte(content), 0644)
|
||||
}
|
||||
|
||||
// LoadPrivateKey loads a private key from file
|
||||
func LoadPrivateKey(path string) (ed25519.PrivateKey, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read private key: %w", err)
|
||||
}
|
||||
|
||||
// Extract base64 content between PEM markers
|
||||
content := extractPEMContent(string(data))
|
||||
if content == "" {
|
||||
return nil, fmt.Errorf("invalid private key format")
|
||||
}
|
||||
|
||||
decoded, err := base64.StdEncoding.DecodeString(content)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode private key: %w", err)
|
||||
}
|
||||
|
||||
if len(decoded) != ed25519.PrivateKeySize {
|
||||
return nil, fmt.Errorf("invalid private key size")
|
||||
}
|
||||
|
||||
return ed25519.PrivateKey(decoded), nil
|
||||
}
|
||||
|
||||
// LoadPublicKey loads a public key from file
|
||||
func LoadPublicKey(path string) (ed25519.PublicKey, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read public key: %w", err)
|
||||
}
|
||||
|
||||
content := extractPEMContent(string(data))
|
||||
if content == "" {
|
||||
return nil, fmt.Errorf("invalid public key format")
|
||||
}
|
||||
|
||||
decoded, err := base64.StdEncoding.DecodeString(content)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode public key: %w", err)
|
||||
}
|
||||
|
||||
if len(decoded) != ed25519.PublicKeySize {
|
||||
return nil, fmt.Errorf("invalid public key size")
|
||||
}
|
||||
|
||||
return ed25519.PublicKey(decoded), nil
|
||||
}
|
||||
|
||||
// extractPEMContent extracts base64 content from PEM-like format
|
||||
func extractPEMContent(data string) string {
|
||||
// Simple extraction - find content between markers
|
||||
start := 0
|
||||
for i := 0; i < len(data); i++ {
|
||||
if data[i] == '\n' && i > 0 && data[i-1] == '-' {
|
||||
start = i + 1
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
end := len(data)
|
||||
for i := len(data) - 1; i > start; i-- {
|
||||
if data[i] == '\n' && i+1 < len(data) && data[i+1] == '-' {
|
||||
end = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if start >= end {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Remove whitespace
|
||||
result := ""
|
||||
for _, c := range data[start:end] {
|
||||
if c != '\n' && c != '\r' && c != ' ' {
|
||||
result += string(c)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// EnableSigning enables cryptographic signing for audit entries
|
||||
func (a *AuditLogger) EnableSigning(privateKey ed25519.PrivateKey) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
a.privateKey = privateKey
|
||||
a.publicKey = privateKey.Public().(ed25519.PublicKey)
|
||||
}
|
||||
|
||||
// AddSignedEntry adds a signed entry to the audit log
|
||||
func (a *AuditLogger) AddSignedEntry(event AuditEvent) error {
|
||||
if !a.enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
// Serialize details
|
||||
detailsJSON := ""
|
||||
if len(event.Details) > 0 {
|
||||
if data, err := json.Marshal(event.Details); err == nil {
|
||||
detailsJSON = string(data)
|
||||
}
|
||||
}
|
||||
|
||||
entry := SignedAuditEntry{
|
||||
Sequence: int64(len(a.entries) + 1),
|
||||
Timestamp: event.Timestamp.Format(time.RFC3339Nano),
|
||||
User: event.User,
|
||||
Action: event.Action,
|
||||
Resource: event.Resource,
|
||||
Result: event.Result,
|
||||
Details: detailsJSON,
|
||||
PrevHash: a.prevHash,
|
||||
}
|
||||
|
||||
// Calculate hash of entry (without signature)
|
||||
entry.Hash = a.calculateEntryHash(entry)
|
||||
|
||||
// Sign if private key is available
|
||||
if a.privateKey != nil {
|
||||
hashBytes, _ := hex.DecodeString(entry.Hash)
|
||||
signature := ed25519.Sign(a.privateKey, hashBytes)
|
||||
entry.Signature = base64.StdEncoding.EncodeToString(signature)
|
||||
}
|
||||
|
||||
// Update chain
|
||||
a.prevHash = entry.Hash
|
||||
a.entries = append(a.entries, entry)
|
||||
|
||||
// Also log to standard logger
|
||||
a.logEvent(event)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// calculateEntryHash computes SHA-256 hash of an entry (without signature field)
|
||||
func (a *AuditLogger) calculateEntryHash(entry SignedAuditEntry) string {
|
||||
// Create canonical representation for hashing
|
||||
data := fmt.Sprintf("%d|%s|%s|%s|%s|%s|%s|%s",
|
||||
entry.Sequence,
|
||||
entry.Timestamp,
|
||||
entry.User,
|
||||
entry.Action,
|
||||
entry.Resource,
|
||||
entry.Result,
|
||||
entry.Details,
|
||||
entry.PrevHash,
|
||||
)
|
||||
|
||||
hash := sha256.Sum256([]byte(data))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// ExportSignedLog exports the signed audit log to a file
|
||||
func (a *AuditLogger) ExportSignedLog(path string) error {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
data, err := json.MarshalIndent(a.entries, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal audit log: %w", err)
|
||||
}
|
||||
|
||||
return os.WriteFile(path, data, 0644)
|
||||
}
|
||||
|
||||
// VerifyAuditLog verifies the integrity of an exported audit log
|
||||
func VerifyAuditLog(logPath string, publicKeyPath string) (*AuditVerificationResult, error) {
|
||||
// Load public key
|
||||
publicKey, err := LoadPublicKey(publicKeyPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load public key: %w", err)
|
||||
}
|
||||
|
||||
// Load audit log
|
||||
data, err := os.ReadFile(logPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read audit log: %w", err)
|
||||
}
|
||||
|
||||
var entries []SignedAuditEntry
|
||||
if err := json.Unmarshal(data, &entries); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse audit log: %w", err)
|
||||
}
|
||||
|
||||
result := &AuditVerificationResult{
|
||||
TotalEntries: len(entries),
|
||||
ValidEntries: 0,
|
||||
Errors: make([]string, 0),
|
||||
}
|
||||
|
||||
prevHash := "genesis"
|
||||
|
||||
for i, entry := range entries {
|
||||
// Verify hash chain
|
||||
if entry.PrevHash != prevHash {
|
||||
result.Errors = append(result.Errors,
|
||||
fmt.Sprintf("Entry %d: hash chain broken (expected %s, got %s)",
|
||||
i+1, prevHash[:16]+"...", entry.PrevHash[:min(16, len(entry.PrevHash))]+"..."))
|
||||
}
|
||||
|
||||
// Recalculate hash
|
||||
expectedHash := calculateVerifyHash(entry)
|
||||
if entry.Hash != expectedHash {
|
||||
result.Errors = append(result.Errors,
|
||||
fmt.Sprintf("Entry %d: hash mismatch (entry may be tampered)", i+1))
|
||||
}
|
||||
|
||||
// Verify signature
|
||||
if entry.Signature != "" {
|
||||
hashBytes, _ := hex.DecodeString(entry.Hash)
|
||||
sigBytes, err := base64.StdEncoding.DecodeString(entry.Signature)
|
||||
if err != nil {
|
||||
result.Errors = append(result.Errors,
|
||||
fmt.Sprintf("Entry %d: invalid signature encoding", i+1))
|
||||
} else if !ed25519.Verify(publicKey, hashBytes, sigBytes) {
|
||||
result.Errors = append(result.Errors,
|
||||
fmt.Sprintf("Entry %d: signature verification failed", i+1))
|
||||
} else {
|
||||
result.ValidEntries++
|
||||
}
|
||||
} else {
|
||||
result.Errors = append(result.Errors,
|
||||
fmt.Sprintf("Entry %d: missing signature", i+1))
|
||||
}
|
||||
|
||||
prevHash = entry.Hash
|
||||
}
|
||||
|
||||
result.ChainValid = len(result.Errors) == 0 ||
|
||||
!containsChainError(result.Errors)
|
||||
result.AllSignaturesValid = result.ValidEntries == result.TotalEntries
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// AuditVerificationResult contains the result of audit log verification
|
||||
type AuditVerificationResult struct {
|
||||
TotalEntries int
|
||||
ValidEntries int
|
||||
ChainValid bool
|
||||
AllSignaturesValid bool
|
||||
Errors []string
|
||||
}
|
||||
|
||||
// IsValid returns true if the audit log is completely valid
|
||||
func (r *AuditVerificationResult) IsValid() bool {
|
||||
return r.ChainValid && r.AllSignaturesValid && len(r.Errors) == 0
|
||||
}
|
||||
|
||||
// String returns a human-readable summary
|
||||
func (r *AuditVerificationResult) String() string {
|
||||
if r.IsValid() {
|
||||
return fmt.Sprintf("✅ Audit log verified: %d entries, chain intact, all signatures valid",
|
||||
r.TotalEntries)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("❌ Audit log verification failed: %d/%d valid entries, %d errors",
|
||||
r.ValidEntries, r.TotalEntries, len(r.Errors))
|
||||
}
|
||||
|
||||
// calculateVerifyHash recalculates hash for verification
|
||||
func calculateVerifyHash(entry SignedAuditEntry) string {
|
||||
data := fmt.Sprintf("%d|%s|%s|%s|%s|%s|%s|%s",
|
||||
entry.Sequence,
|
||||
entry.Timestamp,
|
||||
entry.User,
|
||||
entry.Action,
|
||||
entry.Resource,
|
||||
entry.Result,
|
||||
entry.Details,
|
||||
entry.PrevHash,
|
||||
)
|
||||
|
||||
hash := sha256.Sum256([]byte(data))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// containsChainError checks if errors include hash chain issues
|
||||
func containsChainError(errors []string) bool {
|
||||
for _, err := range errors {
|
||||
if len(err) > 0 && (err[0:min(20, len(err))] == "Entry" &&
|
||||
(contains(err, "hash chain") || contains(err, "hash mismatch"))) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// contains is a simple string contains helper
|
||||
func contains(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// min returns the minimum of two ints
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
@ -205,11 +205,20 @@ func (m ArchiveBrowserModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
return diagnoseView, diagnoseView.Init()
|
||||
}
|
||||
|
||||
// For restore-cluster mode:
|
||||
// - .tar.gz cluster archives → full cluster restore
|
||||
// - .sql/.sql.gz files → single database restore (Native Engine supports these)
|
||||
// - .dump/.dump.gz → single database restore (pg_restore)
|
||||
// ALL formats are now allowed for restore operations!
|
||||
// For restore-cluster mode: check if format can be used for cluster restore
|
||||
// - .tar.gz: dbbackup cluster format (works with pg_restore)
|
||||
// - .sql/.sql.gz: pg_dumpall format (works with native engine or psql)
|
||||
if m.mode == "restore-cluster" && !selected.Format.CanBeClusterRestore() {
|
||||
m.message = errorStyle.Render(fmt.Sprintf("⚠️ %s cannot be used for cluster restore.\n\n Supported formats: .tar.gz (dbbackup), .sql, .sql.gz (pg_dumpall)",
|
||||
selected.Name))
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// For SQL-based cluster restore, enable native engine automatically
|
||||
if m.mode == "restore-cluster" && !selected.Format.IsClusterBackup() {
|
||||
// This is a .sql or .sql.gz file - use native engine
|
||||
m.config.UseNativeEngine = true
|
||||
}
|
||||
|
||||
// For single restore mode with cluster backup selected - offer to select individual database
|
||||
if m.mode == "restore-single" && selected.Format.IsClusterBackup() {
|
||||
@ -217,7 +226,7 @@ func (m ArchiveBrowserModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
return clusterSelector, clusterSelector.Init()
|
||||
}
|
||||
|
||||
// Open restore preview for any valid format
|
||||
// Open restore preview for valid format
|
||||
preview := NewRestorePreview(m.config, m.logger, m.parent, m.ctx, selected, m.mode)
|
||||
return preview, preview.Init()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user