feat: implement full restore functionality with TUI integration

- Add complete restore engine (internal/restore/)
  - RestoreSingle() for single database restore
  - RestoreCluster() for full cluster restore
  - Archive format detection (7 formats supported)
  - Safety validation (integrity, disk space, tools)
  - Streaming decompression with pigz support

- Add CLI restore commands (cmd/restore.go)
  - restore single: restore single database backup
  - restore cluster: restore full cluster backup
  - restore list: list available backup archives
  - Safety-first design: dry-run by default, --confirm required

- Add TUI restore integration (internal/tui/)
  - Archive browser: browse and select backups
  - Restore preview: safety checks and confirmation
  - Restore execution: real-time progress tracking
  - Backup manager: comprehensive archive management

- Features:
  - Format auto-detection (.dump, .dump.gz, .sql, .sql.gz, .tar.gz)
  - Archive validation before restore
  - Disk space verification
  - Tool availability checks
  - Target database configuration
  - Clean-first and create-if-missing options
  - Parallel decompression support
  - Progress tracking with phases

Phase 1 (Core Functionality) complete and tested
This commit is contained in:
2025-11-07 09:41:44 +00:00
parent 33d53612d2
commit 87e0ca3b39
12 changed files with 3222 additions and 19 deletions

445
internal/restore/engine.go Normal file
View File

@ -0,0 +1,445 @@
package restore
import (
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"time"
"dbbackup/internal/config"
"dbbackup/internal/database"
"dbbackup/internal/logger"
"dbbackup/internal/progress"
)
// Engine handles database restore operations
type Engine struct {
cfg *config.Config
log logger.Logger
db database.Database
progress progress.Indicator
detailedReporter *progress.DetailedReporter
dryRun bool
}
// New creates a new restore engine
func New(cfg *config.Config, log logger.Logger, db database.Database) *Engine {
progressIndicator := progress.NewIndicator(true, "line")
detailedReporter := progress.NewDetailedReporter(progressIndicator, &loggerAdapter{logger: log})
return &Engine{
cfg: cfg,
log: log,
db: db,
progress: progressIndicator,
detailedReporter: detailedReporter,
dryRun: false,
}
}
// NewWithProgress creates a restore engine with custom progress indicator
func NewWithProgress(cfg *config.Config, log logger.Logger, db database.Database, progressIndicator progress.Indicator, dryRun bool) *Engine {
if progressIndicator == nil {
progressIndicator = progress.NewNullIndicator()
}
detailedReporter := progress.NewDetailedReporter(progressIndicator, &loggerAdapter{logger: log})
return &Engine{
cfg: cfg,
log: log,
db: db,
progress: progressIndicator,
detailedReporter: detailedReporter,
dryRun: dryRun,
}
}
// loggerAdapter adapts our logger to the progress.Logger interface
type loggerAdapter struct {
logger logger.Logger
}
func (la *loggerAdapter) Info(msg string, args ...any) {
la.logger.Info(msg, args...)
}
func (la *loggerAdapter) Warn(msg string, args ...any) {
la.logger.Warn(msg, args...)
}
func (la *loggerAdapter) Error(msg string, args ...any) {
la.logger.Error(msg, args...)
}
func (la *loggerAdapter) Debug(msg string, args ...any) {
la.logger.Debug(msg, args...)
}
// RestoreSingle restores a single database from an archive
func (e *Engine) RestoreSingle(ctx context.Context, archivePath, targetDB string, cleanFirst, createIfMissing bool) error {
operation := e.log.StartOperation("Single Database Restore")
// Validate archive exists
if _, err := os.Stat(archivePath); os.IsNotExist(err) {
operation.Fail("Archive not found")
return fmt.Errorf("archive not found: %s", archivePath)
}
// Detect archive format
format := DetectArchiveFormat(archivePath)
e.log.Info("Detected archive format", "format", format, "path", archivePath)
if e.dryRun {
e.log.Info("DRY RUN: Would restore single database", "archive", archivePath, "target", targetDB)
return e.previewRestore(archivePath, targetDB, format)
}
// Start progress tracking
e.progress.Start(fmt.Sprintf("Restoring database '%s' from %s", targetDB, filepath.Base(archivePath)))
// Handle different archive formats
var err error
switch format {
case FormatPostgreSQLDump, FormatPostgreSQLDumpGz:
err = e.restorePostgreSQLDump(ctx, archivePath, targetDB, format == FormatPostgreSQLDumpGz, cleanFirst)
case FormatPostgreSQLSQL, FormatPostgreSQLSQLGz:
err = e.restorePostgreSQLSQL(ctx, archivePath, targetDB, format == FormatPostgreSQLSQLGz)
case FormatMySQLSQL, FormatMySQLSQLGz:
err = e.restoreMySQLSQL(ctx, archivePath, targetDB, format == FormatMySQLSQLGz)
default:
operation.Fail("Unsupported archive format")
return fmt.Errorf("unsupported archive format: %s", format)
}
if err != nil {
e.progress.Fail(fmt.Sprintf("Restore failed: %v", err))
operation.Fail(fmt.Sprintf("Restore failed: %v", err))
return err
}
e.progress.Complete(fmt.Sprintf("Database '%s' restored successfully", targetDB))
operation.Complete(fmt.Sprintf("Restored database '%s' from %s", targetDB, filepath.Base(archivePath)))
return nil
}
// restorePostgreSQLDump restores from PostgreSQL custom dump format
func (e *Engine) restorePostgreSQLDump(ctx context.Context, archivePath, targetDB string, compressed bool, cleanFirst bool) error {
// Build restore command
opts := database.RestoreOptions{
Parallel: 1,
Clean: cleanFirst,
NoOwner: true,
NoPrivileges: true,
SingleTransaction: true,
}
cmd := e.db.BuildRestoreCommand(targetDB, archivePath, opts)
if compressed {
// For compressed dumps, decompress first
return e.executeRestoreWithDecompression(ctx, archivePath, cmd)
}
return e.executeRestoreCommand(ctx, cmd)
}
// restorePostgreSQLSQL restores from PostgreSQL SQL script
func (e *Engine) restorePostgreSQLSQL(ctx context.Context, archivePath, targetDB string, compressed bool) error {
// Use psql for SQL scripts
var cmd []string
if compressed {
cmd = []string{
"bash", "-c",
fmt.Sprintf("gunzip -c %s | psql -h %s -p %d -U %s -d %s",
archivePath, e.cfg.Host, e.cfg.Port, e.cfg.User, targetDB),
}
} else {
cmd = []string{
"psql",
"-h", e.cfg.Host,
"-p", fmt.Sprintf("%d", e.cfg.Port),
"-U", e.cfg.User,
"-d", targetDB,
"-f", archivePath,
}
}
return e.executeRestoreCommand(ctx, cmd)
}
// restoreMySQLSQL restores from MySQL SQL script
func (e *Engine) restoreMySQLSQL(ctx context.Context, archivePath, targetDB string, compressed bool) error {
options := database.RestoreOptions{}
cmd := e.db.BuildRestoreCommand(targetDB, archivePath, options)
if compressed {
// For compressed SQL, decompress on the fly
cmd = []string{
"bash", "-c",
fmt.Sprintf("gunzip -c %s | %s", archivePath, strings.Join(cmd, " ")),
}
}
return e.executeRestoreCommand(ctx, cmd)
}
// executeRestoreCommand executes a restore command
func (e *Engine) executeRestoreCommand(ctx context.Context, cmdArgs []string) error {
e.log.Info("Executing restore command", "command", strings.Join(cmdArgs, " "))
cmd := exec.CommandContext(ctx, cmdArgs[0], cmdArgs[1:]...)
// Set environment variables
cmd.Env = append(os.Environ(),
fmt.Sprintf("PGPASSWORD=%s", e.cfg.Password),
fmt.Sprintf("MYSQL_PWD=%s", e.cfg.Password),
)
// Capture output
output, err := cmd.CombinedOutput()
if err != nil {
e.log.Error("Restore command failed", "error", err, "output", string(output))
return fmt.Errorf("restore failed: %w\nOutput: %s", err, string(output))
}
e.log.Info("Restore command completed successfully")
return nil
}
// executeRestoreWithDecompression handles decompression during restore
func (e *Engine) executeRestoreWithDecompression(ctx context.Context, archivePath string, restoreCmd []string) error {
// Check if pigz is available for faster decompression
decompressCmd := "gunzip"
if _, err := exec.LookPath("pigz"); err == nil {
decompressCmd = "pigz"
e.log.Info("Using pigz for parallel decompression")
}
// Build pipeline: decompress | restore
pipeline := fmt.Sprintf("%s -dc %s | %s", decompressCmd, archivePath, strings.Join(restoreCmd, " "))
cmd := exec.CommandContext(ctx, "bash", "-c", pipeline)
cmd.Env = append(os.Environ(),
fmt.Sprintf("PGPASSWORD=%s", e.cfg.Password),
fmt.Sprintf("MYSQL_PWD=%s", e.cfg.Password),
)
output, err := cmd.CombinedOutput()
if err != nil {
e.log.Error("Restore with decompression failed", "error", err, "output", string(output))
return fmt.Errorf("restore failed: %w\nOutput: %s", err, string(output))
}
return nil
}
// previewRestore shows what would be done without executing
func (e *Engine) previewRestore(archivePath, targetDB string, format ArchiveFormat) error {
fmt.Println("\n" + strings.Repeat("=", 60))
fmt.Println(" RESTORE PREVIEW (DRY RUN)")
fmt.Println(strings.Repeat("=", 60))
stat, _ := os.Stat(archivePath)
fmt.Printf("\nArchive: %s\n", filepath.Base(archivePath))
fmt.Printf("Format: %s\n", format)
if stat != nil {
fmt.Printf("Size: %s\n", FormatBytes(stat.Size()))
fmt.Printf("Modified: %s\n", stat.ModTime().Format("2006-01-02 15:04:05"))
}
fmt.Printf("Target Database: %s\n", targetDB)
fmt.Printf("Target Host: %s:%d\n", e.cfg.Host, e.cfg.Port)
fmt.Println("\nOperations that would be performed:")
switch format {
case FormatPostgreSQLDump:
fmt.Printf(" 1. Execute: pg_restore -d %s %s\n", targetDB, archivePath)
case FormatPostgreSQLDumpGz:
fmt.Printf(" 1. Decompress: %s\n", archivePath)
fmt.Printf(" 2. Execute: pg_restore -d %s\n", targetDB)
case FormatPostgreSQLSQL, FormatPostgreSQLSQLGz:
fmt.Printf(" 1. Execute: psql -d %s -f %s\n", targetDB, archivePath)
case FormatMySQLSQL, FormatMySQLSQLGz:
fmt.Printf(" 1. Execute: mysql %s < %s\n", targetDB, archivePath)
}
fmt.Println("\n⚠ WARNING: This will restore data to the target database.")
fmt.Println(" Existing data may be overwritten or merged.")
fmt.Println("\nTo execute this restore, add the --confirm flag.")
fmt.Println(strings.Repeat("=", 60) + "\n")
return nil
}
// RestoreCluster restores a full cluster from a tar.gz archive
func (e *Engine) RestoreCluster(ctx context.Context, archivePath string) error {
operation := e.log.StartOperation("Cluster Restore")
// Validate archive
if _, err := os.Stat(archivePath); os.IsNotExist(err) {
operation.Fail("Archive not found")
return fmt.Errorf("archive not found: %s", archivePath)
}
format := DetectArchiveFormat(archivePath)
if format != FormatClusterTarGz {
operation.Fail("Invalid cluster archive format")
return fmt.Errorf("not a cluster archive: %s (detected format: %s)", archivePath, format)
}
if e.dryRun {
e.log.Info("DRY RUN: Would restore cluster", "archive", archivePath)
return e.previewClusterRestore(archivePath)
}
e.progress.Start(fmt.Sprintf("Restoring cluster from %s", filepath.Base(archivePath)))
// Create temporary extraction directory
tempDir := filepath.Join(e.cfg.BackupDir, fmt.Sprintf(".restore_%d", time.Now().Unix()))
if err := os.MkdirAll(tempDir, 0755); err != nil {
operation.Fail("Failed to create temporary directory")
return fmt.Errorf("failed to create temp directory: %w", err)
}
defer os.RemoveAll(tempDir)
// Extract archive
e.log.Info("Extracting cluster archive", "archive", archivePath, "tempDir", tempDir)
if err := e.extractArchive(ctx, archivePath, tempDir); err != nil {
operation.Fail("Archive extraction failed")
return fmt.Errorf("failed to extract archive: %w", err)
}
// Restore global objects (roles, tablespaces)
globalsFile := filepath.Join(tempDir, "globals.sql")
if _, err := os.Stat(globalsFile); err == nil {
e.log.Info("Restoring global objects")
e.progress.Update("Restoring global objects (roles, tablespaces)...")
if err := e.restoreGlobals(ctx, globalsFile); err != nil {
e.log.Warn("Failed to restore global objects", "error", err)
// Continue anyway - global objects might already exist
}
}
// Restore individual databases
dumpsDir := filepath.Join(tempDir, "dumps")
if _, err := os.Stat(dumpsDir); err != nil {
operation.Fail("No database dumps found in archive")
return fmt.Errorf("no database dumps found in archive")
}
entries, err := os.ReadDir(dumpsDir)
if err != nil {
operation.Fail("Failed to read dumps directory")
return fmt.Errorf("failed to read dumps directory: %w", err)
}
successCount := 0
failCount := 0
for i, entry := range entries {
if entry.IsDir() {
continue
}
dumpFile := filepath.Join(dumpsDir, entry.Name())
dbName := strings.TrimSuffix(entry.Name(), ".dump")
e.progress.Update(fmt.Sprintf("[%d/%d] Restoring database: %s", i+1, len(entries), dbName))
e.log.Info("Restoring database", "name", dbName, "file", dumpFile)
if err := e.restorePostgreSQLDump(ctx, dumpFile, dbName, false, false); err != nil {
e.log.Error("Failed to restore database", "name", dbName, "error", err)
failCount++
continue
}
successCount++
}
if failCount > 0 {
e.progress.Fail(fmt.Sprintf("Cluster restore completed with errors: %d succeeded, %d failed", successCount, failCount))
operation.Complete(fmt.Sprintf("Partial restore: %d succeeded, %d failed", successCount, failCount))
return fmt.Errorf("cluster restore completed with %d failures", failCount)
}
e.progress.Complete(fmt.Sprintf("Cluster restored successfully: %d databases", successCount))
operation.Complete(fmt.Sprintf("Restored %d databases from cluster archive", successCount))
return nil
}
// extractArchive extracts a tar.gz archive
func (e *Engine) extractArchive(ctx context.Context, archivePath, destDir string) error {
cmd := exec.CommandContext(ctx, "tar", "-xzf", archivePath, "-C", destDir)
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("tar extraction failed: %w\nOutput: %s", err, string(output))
}
return nil
}
// restoreGlobals restores global objects (roles, tablespaces)
func (e *Engine) restoreGlobals(ctx context.Context, globalsFile string) error {
cmd := exec.CommandContext(ctx,
"psql",
"-h", e.cfg.Host,
"-p", fmt.Sprintf("%d", e.cfg.Port),
"-U", e.cfg.User,
"-d", "postgres",
"-f", globalsFile,
)
cmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", e.cfg.Password))
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("failed to restore globals: %w\nOutput: %s", err, string(output))
}
return nil
}
// previewClusterRestore shows cluster restore preview
func (e *Engine) previewClusterRestore(archivePath string) error {
fmt.Println("\n" + strings.Repeat("=", 60))
fmt.Println(" CLUSTER RESTORE PREVIEW (DRY RUN)")
fmt.Println(strings.Repeat("=", 60))
stat, _ := os.Stat(archivePath)
fmt.Printf("\nArchive: %s\n", filepath.Base(archivePath))
if stat != nil {
fmt.Printf("Size: %s\n", FormatBytes(stat.Size()))
fmt.Printf("Modified: %s\n", stat.ModTime().Format("2006-01-02 15:04:05"))
}
fmt.Printf("Target Host: %s:%d\n", e.cfg.Host, e.cfg.Port)
fmt.Println("\nOperations that would be performed:")
fmt.Println(" 1. Extract cluster archive to temporary directory")
fmt.Println(" 2. Restore global objects (roles, tablespaces)")
fmt.Println(" 3. Restore all databases found in archive")
fmt.Println(" 4. Cleanup temporary files")
fmt.Println("\n⚠ WARNING: This will restore multiple databases.")
fmt.Println(" Existing databases may be overwritten or merged.")
fmt.Println("\nTo execute this restore, add the --confirm flag.")
fmt.Println(strings.Repeat("=", 60) + "\n")
return nil
}
// FormatBytes formats bytes to human readable format
func FormatBytes(bytes int64) string {
const unit = 1024
if bytes < unit {
return fmt.Sprintf("%d B", bytes)
}
div, exp := int64(unit), 0
for n := bytes / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(div), "KMGTPE"[exp])
}

110
internal/restore/formats.go Normal file
View File

@ -0,0 +1,110 @@
package restore
import (
"strings"
)
// ArchiveFormat represents the type of backup archive
type ArchiveFormat string
const (
FormatPostgreSQLDump ArchiveFormat = "PostgreSQL Dump (.dump)"
FormatPostgreSQLDumpGz ArchiveFormat = "PostgreSQL Dump Compressed (.dump.gz)"
FormatPostgreSQLSQL ArchiveFormat = "PostgreSQL SQL (.sql)"
FormatPostgreSQLSQLGz ArchiveFormat = "PostgreSQL SQL Compressed (.sql.gz)"
FormatMySQLSQL ArchiveFormat = "MySQL SQL (.sql)"
FormatMySQLSQLGz ArchiveFormat = "MySQL SQL Compressed (.sql.gz)"
FormatClusterTarGz ArchiveFormat = "Cluster Archive (.tar.gz)"
FormatUnknown ArchiveFormat = "Unknown"
)
// DetectArchiveFormat detects the format of a backup archive from its filename
func DetectArchiveFormat(filename string) ArchiveFormat {
lower := strings.ToLower(filename)
// Check for cluster archives first (most specific)
if strings.Contains(lower, "cluster") && strings.HasSuffix(lower, ".tar.gz") {
return FormatClusterTarGz
}
// Check for compressed formats
if strings.HasSuffix(lower, ".dump.gz") {
return FormatPostgreSQLDumpGz
}
if strings.HasSuffix(lower, ".sql.gz") {
// Determine if MySQL or PostgreSQL based on naming convention
if strings.Contains(lower, "mysql") || strings.Contains(lower, "mariadb") {
return FormatMySQLSQLGz
}
return FormatPostgreSQLSQLGz
}
// Check for uncompressed formats
if strings.HasSuffix(lower, ".dump") {
return FormatPostgreSQLDump
}
if strings.HasSuffix(lower, ".sql") {
// Determine if MySQL or PostgreSQL based on naming convention
if strings.Contains(lower, "mysql") || strings.Contains(lower, "mariadb") {
return FormatMySQLSQL
}
return FormatPostgreSQLSQL
}
if strings.HasSuffix(lower, ".tar.gz") || strings.HasSuffix(lower, ".tgz") {
return FormatClusterTarGz
}
return FormatUnknown
}
// IsCompressed returns true if the archive format is compressed
func (f ArchiveFormat) IsCompressed() bool {
return f == FormatPostgreSQLDumpGz ||
f == FormatPostgreSQLSQLGz ||
f == FormatMySQLSQLGz ||
f == FormatClusterTarGz
}
// IsClusterBackup returns true if the archive is a cluster backup
func (f ArchiveFormat) IsClusterBackup() bool {
return f == FormatClusterTarGz
}
// IsPostgreSQL returns true if the archive is PostgreSQL format
func (f ArchiveFormat) IsPostgreSQL() bool {
return f == FormatPostgreSQLDump ||
f == FormatPostgreSQLDumpGz ||
f == FormatPostgreSQLSQL ||
f == FormatPostgreSQLSQLGz ||
f == FormatClusterTarGz
}
// IsMySQL returns true if format is MySQL
func (f ArchiveFormat) IsMySQL() bool {
return f == FormatMySQLSQL || f == FormatMySQLSQLGz
}
// String returns human-readable format name
func (f ArchiveFormat) String() string {
switch f {
case FormatPostgreSQLDump:
return "PostgreSQL Dump"
case FormatPostgreSQLDumpGz:
return "PostgreSQL Dump (gzip)"
case FormatPostgreSQLSQL:
return "PostgreSQL SQL"
case FormatPostgreSQLSQLGz:
return "PostgreSQL SQL (gzip)"
case FormatMySQLSQL:
return "MySQL SQL"
case FormatMySQLSQLGz:
return "MySQL SQL (gzip)"
case FormatClusterTarGz:
return "Cluster Archive (tar.gz)"
default:
return "Unknown"
}
}

342
internal/restore/safety.go Normal file
View File

@ -0,0 +1,342 @@
package restore
import (
"compress/gzip"
"context"
"fmt"
"io"
"os"
"os/exec"
"strings"
"syscall"
"dbbackup/internal/config"
"dbbackup/internal/logger"
)
// Safety provides pre-restore validation and safety checks
type Safety struct {
cfg *config.Config
log logger.Logger
}
// NewSafety creates a new safety checker
func NewSafety(cfg *config.Config, log logger.Logger) *Safety {
return &Safety{
cfg: cfg,
log: log,
}
}
// ValidateArchive performs integrity checks on the archive
func (s *Safety) ValidateArchive(archivePath string) error {
// Check if file exists
stat, err := os.Stat(archivePath)
if err != nil {
return fmt.Errorf("archive not accessible: %w", err)
}
// Check if file is not empty
if stat.Size() == 0 {
return fmt.Errorf("archive is empty")
}
// Check if file is too small (likely corrupted)
if stat.Size() < 100 {
return fmt.Errorf("archive is suspiciously small (%d bytes)", stat.Size())
}
// Detect format
format := DetectArchiveFormat(archivePath)
if format == FormatUnknown {
return fmt.Errorf("unknown archive format: %s", archivePath)
}
// Validate based on format
switch format {
case FormatPostgreSQLDump:
return s.validatePgDump(archivePath)
case FormatPostgreSQLDumpGz:
return s.validatePgDumpGz(archivePath)
case FormatPostgreSQLSQL, FormatMySQLSQL:
return s.validateSQLScript(archivePath)
case FormatPostgreSQLSQLGz, FormatMySQLSQLGz:
return s.validateSQLScriptGz(archivePath)
case FormatClusterTarGz:
return s.validateTarGz(archivePath)
}
return nil
}
// validatePgDump validates PostgreSQL dump file
func (s *Safety) validatePgDump(path string) error {
file, err := os.Open(path)
if err != nil {
return fmt.Errorf("cannot open file: %w", err)
}
defer file.Close()
// Read first 512 bytes for signature check
buffer := make([]byte, 512)
n, err := file.Read(buffer)
if err != nil && err != io.EOF {
return fmt.Errorf("cannot read file: %w", err)
}
if n < 5 {
return fmt.Errorf("file too small to validate")
}
// Check for PGDMP signature
if string(buffer[:5]) == "PGDMP" {
return nil
}
// Check for PostgreSQL dump indicators
content := strings.ToLower(string(buffer[:n]))
if strings.Contains(content, "postgresql") || strings.Contains(content, "pg_dump") {
return nil
}
return fmt.Errorf("does not appear to be a PostgreSQL dump file")
}
// validatePgDumpGz validates compressed PostgreSQL dump
func (s *Safety) validatePgDumpGz(path string) error {
file, err := os.Open(path)
if err != nil {
return fmt.Errorf("cannot open file: %w", err)
}
defer file.Close()
// Open gzip reader
gz, err := gzip.NewReader(file)
if err != nil {
return fmt.Errorf("not a valid gzip file: %w", err)
}
defer gz.Close()
// Read first 512 bytes
buffer := make([]byte, 512)
n, err := gz.Read(buffer)
if err != nil && err != io.EOF {
return fmt.Errorf("cannot read gzip contents: %w", err)
}
if n < 5 {
return fmt.Errorf("gzip archive too small")
}
// Check for PGDMP signature
if string(buffer[:5]) == "PGDMP" {
return nil
}
content := strings.ToLower(string(buffer[:n]))
if strings.Contains(content, "postgresql") || strings.Contains(content, "pg_dump") {
return nil
}
return fmt.Errorf("does not appear to be a PostgreSQL dump file")
}
// validateSQLScript validates SQL script
func (s *Safety) validateSQLScript(path string) error {
file, err := os.Open(path)
if err != nil {
return fmt.Errorf("cannot open file: %w", err)
}
defer file.Close()
buffer := make([]byte, 1024)
n, err := file.Read(buffer)
if err != nil && err != io.EOF {
return fmt.Errorf("cannot read file: %w", err)
}
content := strings.ToLower(string(buffer[:n]))
if containsSQLKeywords(content) {
return nil
}
return fmt.Errorf("does not appear to contain SQL content")
}
// validateSQLScriptGz validates compressed SQL script
func (s *Safety) validateSQLScriptGz(path string) error {
file, err := os.Open(path)
if err != nil {
return fmt.Errorf("cannot open file: %w", err)
}
defer file.Close()
gz, err := gzip.NewReader(file)
if err != nil {
return fmt.Errorf("not a valid gzip file: %w", err)
}
defer gz.Close()
buffer := make([]byte, 1024)
n, err := gz.Read(buffer)
if err != nil && err != io.EOF {
return fmt.Errorf("cannot read gzip contents: %w", err)
}
content := strings.ToLower(string(buffer[:n]))
if containsSQLKeywords(content) {
return nil
}
return fmt.Errorf("does not appear to contain SQL content")
}
// validateTarGz validates tar.gz archive
func (s *Safety) validateTarGz(path string) error {
file, err := os.Open(path)
if err != nil {
return fmt.Errorf("cannot open file: %w", err)
}
defer file.Close()
// Check gzip magic number
buffer := make([]byte, 3)
n, err := file.Read(buffer)
if err != nil || n < 3 {
return fmt.Errorf("cannot read file header")
}
if buffer[0] == 0x1f && buffer[1] == 0x8b {
return nil // Valid gzip header
}
return fmt.Errorf("not a valid gzip file")
}
// containsSQLKeywords checks if content contains SQL keywords
func containsSQLKeywords(content string) bool {
keywords := []string{
"select", "insert", "create", "drop", "alter",
"database", "table", "update", "delete", "from", "where",
}
for _, keyword := range keywords {
if strings.Contains(content, keyword) {
return true
}
}
return false
}
// CheckDiskSpace verifies sufficient disk space for restore
func (s *Safety) CheckDiskSpace(archivePath string, multiplier float64) error {
// Get archive size
stat, err := os.Stat(archivePath)
if err != nil {
return fmt.Errorf("cannot stat archive: %w", err)
}
archiveSize := stat.Size()
// Estimate required space (archive size * multiplier for decompression/extraction)
requiredSpace := int64(float64(archiveSize) * multiplier)
// Get available disk space
var statfs syscall.Statfs_t
if err := syscall.Statfs(s.cfg.BackupDir, &statfs); err != nil {
s.log.Warn("Cannot check disk space", "error", err)
return nil // Don't fail if we can't check
}
availableSpace := int64(statfs.Bavail) * statfs.Bsize
if availableSpace < requiredSpace {
return fmt.Errorf("insufficient disk space: need %s, have %s",
FormatBytes(requiredSpace), FormatBytes(availableSpace))
}
s.log.Info("Disk space check passed",
"required", FormatBytes(requiredSpace),
"available", FormatBytes(availableSpace))
return nil
}
// VerifyTools checks if required restore tools are available
func (s *Safety) VerifyTools(dbType string) error {
var tools []string
if dbType == "postgres" {
tools = []string{"pg_restore", "psql"}
} else if dbType == "mysql" || dbType == "mariadb" {
tools = []string{"mysql"}
}
missing := []string{}
for _, tool := range tools {
if _, err := exec.LookPath(tool); err != nil {
missing = append(missing, tool)
}
}
if len(missing) > 0 {
return fmt.Errorf("missing required tools: %s", strings.Join(missing, ", "))
}
return nil
}
// CheckDatabaseExists verifies if target database exists
func (s *Safety) CheckDatabaseExists(ctx context.Context, dbName string) (bool, error) {
if s.cfg.DatabaseType == "postgres" {
return s.checkPostgresDatabaseExists(ctx, dbName)
} else if s.cfg.DatabaseType == "mysql" || s.cfg.DatabaseType == "mariadb" {
return s.checkMySQLDatabaseExists(ctx, dbName)
}
return false, fmt.Errorf("unsupported database type: %s", s.cfg.DatabaseType)
}
// checkPostgresDatabaseExists checks if PostgreSQL database exists
func (s *Safety) checkPostgresDatabaseExists(ctx context.Context, dbName string) (bool, error) {
cmd := exec.CommandContext(ctx,
"psql",
"-h", s.cfg.Host,
"-p", fmt.Sprintf("%d", s.cfg.Port),
"-U", s.cfg.User,
"-d", "postgres",
"-tAc", fmt.Sprintf("SELECT 1 FROM pg_database WHERE datname='%s'", dbName),
)
cmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", s.cfg.Password))
output, err := cmd.Output()
if err != nil {
return false, fmt.Errorf("failed to check database existence: %w", err)
}
return strings.TrimSpace(string(output)) == "1", nil
}
// checkMySQLDatabaseExists checks if MySQL database exists
func (s *Safety) checkMySQLDatabaseExists(ctx context.Context, dbName string) (bool, error) {
cmd := exec.CommandContext(ctx,
"mysql",
"-h", s.cfg.Host,
"-P", fmt.Sprintf("%d", s.cfg.Port),
"-u", s.cfg.User,
"-e", fmt.Sprintf("SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME='%s'", dbName),
)
if s.cfg.Password != "" {
cmd.Env = append(os.Environ(), fmt.Sprintf("MYSQL_PWD=%s", s.cfg.Password))
}
output, err := cmd.Output()
if err != nil {
return false, fmt.Errorf("failed to check database existence: %w", err)
}
return strings.Contains(string(output), dbName), nil
}