feat: Add enterprise DBA features for production reliability

New features implemented:

1. Backup Catalog (internal/catalog/)
   - SQLite-based backup tracking
   - Gap detection and RPO monitoring
   - Search and statistics
   - Filesystem sync

2. DR Drill Testing (internal/drill/)
   - Automated restore testing in Docker containers
   - Database validation with custom queries
   - Catalog integration for drill-tested status

3. Smart Notifications (internal/notify/)
   - Event batching with configurable intervals
   - Time-based escalation policies
   - HTML/text/Slack templates

4. Compliance Reports (internal/report/)
   - SOC2, GDPR, HIPAA, PCI-DSS, ISO27001 frameworks
   - Evidence collection from catalog
   - JSON, Markdown, HTML output formats

5. RTO/RPO Calculator (internal/rto/)
   - Recovery objective analysis
   - RTO breakdown by phase
   - Recommendations for improvement

6. Replica-Aware Backup (internal/replica/)
   - Topology detection for PostgreSQL/MySQL
   - Automatic replica selection
   - Configurable selection strategies

7. Parallel Table Backup (internal/parallel/)
   - Concurrent table dumps
   - Worker pool with progress tracking
   - Large table optimization

8. MySQL/MariaDB PITR (internal/pitr/)
   - Binary log parsing and replay
   - Point-in-time recovery support
   - Transaction filtering

CLI commands added: catalog, drill, report, rto

All changes support the goal: reliable 3 AM database recovery.
This commit is contained in:
2025-12-13 20:28:55 +01:00
parent d0d83b61ef
commit f69bfe7071
34 changed files with 13469 additions and 41 deletions

298
internal/drill/docker.go Normal file
View File

@@ -0,0 +1,298 @@
// Package drill - Docker container management for DR drills
package drill
import (
"context"
"fmt"
"os/exec"
"strings"
"time"
)
// DockerManager handles Docker container operations for DR drills
type DockerManager struct {
verbose bool
}
// NewDockerManager creates a new Docker manager
func NewDockerManager(verbose bool) *DockerManager {
return &DockerManager{verbose: verbose}
}
// ContainerConfig holds Docker container configuration
type ContainerConfig struct {
Image string // Docker image (e.g., "postgres:15")
Name string // Container name
Port int // Host port to map
ContainerPort int // Container port
Environment map[string]string // Environment variables
Volumes []string // Volume mounts
Network string // Docker network
Timeout int // Startup timeout in seconds
}
// ContainerInfo holds information about a running container
type ContainerInfo struct {
ID string
Name string
Image string
Port int
Status string
Started time.Time
Healthy bool
}
// CheckDockerAvailable verifies Docker is installed and running
func (dm *DockerManager) CheckDockerAvailable(ctx context.Context) error {
cmd := exec.CommandContext(ctx, "docker", "version")
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("docker not available: %w (output: %s)", err, string(output))
}
return nil
}
// PullImage pulls a Docker image if not present
func (dm *DockerManager) PullImage(ctx context.Context, image string) error {
// Check if image exists locally
checkCmd := exec.CommandContext(ctx, "docker", "image", "inspect", image)
if err := checkCmd.Run(); err == nil {
// Image exists
return nil
}
// Pull the image
pullCmd := exec.CommandContext(ctx, "docker", "pull", image)
output, err := pullCmd.CombinedOutput()
if err != nil {
return fmt.Errorf("failed to pull image %s: %w (output: %s)", image, err, string(output))
}
return nil
}
// CreateContainer creates and starts a database container
func (dm *DockerManager) CreateContainer(ctx context.Context, config *ContainerConfig) (*ContainerInfo, error) {
args := []string{
"run", "-d",
"--name", config.Name,
"-p", fmt.Sprintf("%d:%d", config.Port, config.ContainerPort),
}
// Add environment variables
for k, v := range config.Environment {
args = append(args, "-e", fmt.Sprintf("%s=%s", k, v))
}
// Add volumes
for _, v := range config.Volumes {
args = append(args, "-v", v)
}
// Add network if specified
if config.Network != "" {
args = append(args, "--network", config.Network)
}
// Add image
args = append(args, config.Image)
cmd := exec.CommandContext(ctx, "docker", args...)
output, err := cmd.CombinedOutput()
if err != nil {
return nil, fmt.Errorf("failed to create container: %w (output: %s)", err, string(output))
}
containerID := strings.TrimSpace(string(output))
return &ContainerInfo{
ID: containerID,
Name: config.Name,
Image: config.Image,
Port: config.Port,
Status: "created",
Started: time.Now(),
}, nil
}
// WaitForHealth waits for container to be healthy
func (dm *DockerManager) WaitForHealth(ctx context.Context, containerID string, dbType string, timeout int) error {
deadline := time.Now().Add(time.Duration(timeout) * time.Second)
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
if time.Now().After(deadline) {
return fmt.Errorf("timeout waiting for container to be healthy")
}
// Check container health
healthCmd := dm.healthCheckCommand(dbType)
args := append([]string{"exec", containerID}, healthCmd...)
cmd := exec.CommandContext(ctx, "docker", args...)
if err := cmd.Run(); err == nil {
return nil // Container is healthy
}
}
}
}
// healthCheckCommand returns the health check command for a database type
func (dm *DockerManager) healthCheckCommand(dbType string) []string {
switch dbType {
case "postgresql", "postgres":
return []string{"pg_isready", "-U", "postgres"}
case "mysql":
return []string{"mysqladmin", "ping", "-h", "localhost", "-u", "root", "--password=root"}
case "mariadb":
return []string{"mariadb-admin", "ping", "-h", "localhost", "-u", "root", "--password=root"}
default:
return []string{"echo", "ok"}
}
}
// ExecCommand executes a command inside the container
func (dm *DockerManager) ExecCommand(ctx context.Context, containerID string, command []string) (string, error) {
args := append([]string{"exec", containerID}, command...)
cmd := exec.CommandContext(ctx, "docker", args...)
output, err := cmd.CombinedOutput()
if err != nil {
return string(output), fmt.Errorf("exec failed: %w", err)
}
return string(output), nil
}
// CopyToContainer copies a file to the container
func (dm *DockerManager) CopyToContainer(ctx context.Context, containerID, src, dest string) error {
cmd := exec.CommandContext(ctx, "docker", "cp", src, fmt.Sprintf("%s:%s", containerID, dest))
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("copy failed: %w (output: %s)", err, string(output))
}
return nil
}
// StopContainer stops a running container
func (dm *DockerManager) StopContainer(ctx context.Context, containerID string) error {
cmd := exec.CommandContext(ctx, "docker", "stop", containerID)
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("failed to stop container: %w (output: %s)", err, string(output))
}
return nil
}
// RemoveContainer removes a container
func (dm *DockerManager) RemoveContainer(ctx context.Context, containerID string) error {
cmd := exec.CommandContext(ctx, "docker", "rm", "-f", containerID)
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("failed to remove container: %w (output: %s)", err, string(output))
}
return nil
}
// GetContainerLogs retrieves container logs
func (dm *DockerManager) GetContainerLogs(ctx context.Context, containerID string, tail int) (string, error) {
args := []string{"logs"}
if tail > 0 {
args = append(args, "--tail", fmt.Sprintf("%d", tail))
}
args = append(args, containerID)
cmd := exec.CommandContext(ctx, "docker", args...)
output, err := cmd.CombinedOutput()
if err != nil {
return "", fmt.Errorf("failed to get logs: %w", err)
}
return string(output), nil
}
// ListDrillContainers lists all containers created by drill operations
func (dm *DockerManager) ListDrillContainers(ctx context.Context) ([]*ContainerInfo, error) {
cmd := exec.CommandContext(ctx, "docker", "ps", "-a",
"--filter", "name=drill_",
"--format", "{{.ID}}\t{{.Names}}\t{{.Image}}\t{{.Status}}")
output, err := cmd.CombinedOutput()
if err != nil {
return nil, fmt.Errorf("failed to list containers: %w", err)
}
var containers []*ContainerInfo
lines := strings.Split(strings.TrimSpace(string(output)), "\n")
for _, line := range lines {
if line == "" {
continue
}
parts := strings.Split(line, "\t")
if len(parts) >= 4 {
containers = append(containers, &ContainerInfo{
ID: parts[0],
Name: parts[1],
Image: parts[2],
Status: parts[3],
})
}
}
return containers, nil
}
// GetDefaultImage returns the default Docker image for a database type
func GetDefaultImage(dbType, version string) string {
if version == "" {
version = "latest"
}
switch dbType {
case "postgresql", "postgres":
return fmt.Sprintf("postgres:%s", version)
case "mysql":
return fmt.Sprintf("mysql:%s", version)
case "mariadb":
return fmt.Sprintf("mariadb:%s", version)
default:
return ""
}
}
// GetDefaultPort returns the default port for a database type
func GetDefaultPort(dbType string) int {
switch dbType {
case "postgresql", "postgres":
return 5432
case "mysql", "mariadb":
return 3306
default:
return 0
}
}
// GetDefaultEnvironment returns default environment variables for a database container
func GetDefaultEnvironment(dbType string) map[string]string {
switch dbType {
case "postgresql", "postgres":
return map[string]string{
"POSTGRES_PASSWORD": "drill_test_password",
"POSTGRES_USER": "postgres",
"POSTGRES_DB": "postgres",
}
case "mysql":
return map[string]string{
"MYSQL_ROOT_PASSWORD": "root",
"MYSQL_DATABASE": "test",
}
case "mariadb":
return map[string]string{
"MARIADB_ROOT_PASSWORD": "root",
"MARIADB_DATABASE": "test",
}
default:
return map[string]string{}
}
}

247
internal/drill/drill.go Normal file
View File

@@ -0,0 +1,247 @@
// Package drill provides Disaster Recovery drill functionality
// for testing backup restorability in isolated environments
package drill
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"time"
)
// DrillConfig holds configuration for a DR drill
type DrillConfig struct {
// Backup configuration
BackupPath string `json:"backup_path"`
DatabaseName string `json:"database_name"`
DatabaseType string `json:"database_type"` // postgresql, mysql, mariadb
// Docker configuration
ContainerImage string `json:"container_image"` // e.g., "postgres:15"
ContainerName string `json:"container_name"` // Generated if empty
ContainerPort int `json:"container_port"` // Host port mapping
ContainerTimeout int `json:"container_timeout"` // Startup timeout in seconds
CleanupOnExit bool `json:"cleanup_on_exit"` // Remove container after drill
KeepOnFailure bool `json:"keep_on_failure"` // Keep container if drill fails
// Validation configuration
ValidationQueries []ValidationQuery `json:"validation_queries"`
MinRowCount int64 `json:"min_row_count"` // Minimum rows expected
ExpectedTables []string `json:"expected_tables"` // Tables that must exist
CustomChecks []CustomCheck `json:"custom_checks"`
// Encryption (if backup is encrypted)
EncryptionKeyFile string `json:"encryption_key_file,omitempty"`
EncryptionKeyEnv string `json:"encryption_key_env,omitempty"`
// Performance thresholds
MaxRestoreSeconds int `json:"max_restore_seconds"` // RTO threshold
MaxQuerySeconds int `json:"max_query_seconds"` // Query timeout
// Output
OutputDir string `json:"output_dir"` // Directory for drill reports
ReportFormat string `json:"report_format"` // json, markdown, html
Verbose bool `json:"verbose"`
}
// ValidationQuery represents a SQL query to validate restored data
type ValidationQuery struct {
Name string `json:"name"` // Human-readable name
Query string `json:"query"` // SQL query
ExpectedValue string `json:"expected_value"` // Expected result (optional)
MinValue int64 `json:"min_value"` // Minimum expected value
MaxValue int64 `json:"max_value"` // Maximum expected value
MustSucceed bool `json:"must_succeed"` // Fail drill if query fails
}
// CustomCheck represents a custom validation check
type CustomCheck struct {
Name string `json:"name"`
Type string `json:"type"` // row_count, table_exists, column_check
Table string `json:"table"`
Column string `json:"column,omitempty"`
Condition string `json:"condition,omitempty"` // SQL condition
MinValue int64 `json:"min_value,omitempty"`
MustSucceed bool `json:"must_succeed"`
}
// DrillResult contains the complete result of a DR drill
type DrillResult struct {
// Identification
DrillID string `json:"drill_id"`
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time"`
Duration float64 `json:"duration_seconds"`
// Configuration
BackupPath string `json:"backup_path"`
DatabaseName string `json:"database_name"`
DatabaseType string `json:"database_type"`
// Overall status
Success bool `json:"success"`
Status DrillStatus `json:"status"`
Message string `json:"message"`
// Phase timings
Phases []DrillPhase `json:"phases"`
// Validation results
ValidationResults []ValidationResult `json:"validation_results"`
CheckResults []CheckResult `json:"check_results"`
// Database metrics
TableCount int `json:"table_count"`
TotalRows int64 `json:"total_rows"`
DatabaseSize int64 `json:"database_size_bytes"`
// Performance metrics
RestoreTime float64 `json:"restore_time_seconds"`
ValidationTime float64 `json:"validation_time_seconds"`
QueryTimeAvg float64 `json:"query_time_avg_ms"`
// RTO/RPO metrics
ActualRTO float64 `json:"actual_rto_seconds"` // Total time to usable database
TargetRTO float64 `json:"target_rto_seconds"`
RTOMet bool `json:"rto_met"`
// Container info
ContainerID string `json:"container_id,omitempty"`
ContainerKept bool `json:"container_kept"`
// Errors and warnings
Errors []string `json:"errors,omitempty"`
Warnings []string `json:"warnings,omitempty"`
}
// DrillStatus represents the current status of a drill
type DrillStatus string
const (
StatusPending DrillStatus = "pending"
StatusRunning DrillStatus = "running"
StatusCompleted DrillStatus = "completed"
StatusFailed DrillStatus = "failed"
StatusAborted DrillStatus = "aborted"
StatusPartial DrillStatus = "partial" // Some validations failed
)
// DrillPhase represents a phase in the drill process
type DrillPhase struct {
Name string `json:"name"`
Status string `json:"status"` // pending, running, completed, failed, skipped
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time"`
Duration float64 `json:"duration_seconds"`
Message string `json:"message,omitempty"`
}
// ValidationResult holds the result of a validation query
type ValidationResult struct {
Name string `json:"name"`
Query string `json:"query"`
Success bool `json:"success"`
Result string `json:"result,omitempty"`
Expected string `json:"expected,omitempty"`
Duration float64 `json:"duration_ms"`
Error string `json:"error,omitempty"`
}
// CheckResult holds the result of a custom check
type CheckResult struct {
Name string `json:"name"`
Type string `json:"type"`
Success bool `json:"success"`
Actual int64 `json:"actual,omitempty"`
Expected int64 `json:"expected,omitempty"`
Message string `json:"message"`
}
// DefaultConfig returns a DrillConfig with sensible defaults
func DefaultConfig() *DrillConfig {
return &DrillConfig{
ContainerTimeout: 60,
CleanupOnExit: true,
KeepOnFailure: true,
MaxRestoreSeconds: 300, // 5 minutes
MaxQuerySeconds: 30,
ReportFormat: "json",
Verbose: false,
ValidationQueries: []ValidationQuery{},
ExpectedTables: []string{},
CustomChecks: []CustomCheck{},
}
}
// NewDrillID generates a unique drill ID
func NewDrillID() string {
return fmt.Sprintf("drill_%s", time.Now().Format("20060102_150405"))
}
// SaveResult saves the drill result to a file
func (r *DrillResult) SaveResult(outputDir string) error {
if err := os.MkdirAll(outputDir, 0755); err != nil {
return fmt.Errorf("failed to create output directory: %w", err)
}
filename := fmt.Sprintf("%s_report.json", r.DrillID)
filepath := filepath.Join(outputDir, filename)
data, err := json.MarshalIndent(r, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal result: %w", err)
}
if err := os.WriteFile(filepath, data, 0644); err != nil {
return fmt.Errorf("failed to write result file: %w", err)
}
return nil
}
// LoadResult loads a drill result from a file
func LoadResult(filepath string) (*DrillResult, error) {
data, err := os.ReadFile(filepath)
if err != nil {
return nil, fmt.Errorf("failed to read result file: %w", err)
}
var result DrillResult
if err := json.Unmarshal(data, &result); err != nil {
return nil, fmt.Errorf("failed to parse result: %w", err)
}
return &result, nil
}
// IsSuccess returns true if the drill was successful
func (r *DrillResult) IsSuccess() bool {
return r.Success && r.Status == StatusCompleted
}
// Summary returns a human-readable summary of the drill
func (r *DrillResult) Summary() string {
status := "✅ PASSED"
if !r.Success {
status = "❌ FAILED"
} else if r.Status == StatusPartial {
status = "⚠️ PARTIAL"
}
return fmt.Sprintf("%s - %s (%.2fs) - %d tables, %d rows",
status, r.DatabaseName, r.Duration, r.TableCount, r.TotalRows)
}
// Drill is the interface for DR drill operations
type Drill interface {
// Run executes the full DR drill
Run(ctx context.Context, config *DrillConfig) (*DrillResult, error)
// Validate runs validation queries against an existing database
Validate(ctx context.Context, config *DrillConfig) ([]ValidationResult, error)
// Cleanup removes drill resources (containers, temp files)
Cleanup(ctx context.Context, drillID string) error
}

532
internal/drill/engine.go Normal file
View File

@@ -0,0 +1,532 @@
// Package drill - Main drill execution engine
package drill
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"dbbackup/internal/logger"
)
// Engine executes DR drills
type Engine struct {
docker *DockerManager
log logger.Logger
verbose bool
}
// NewEngine creates a new drill engine
func NewEngine(log logger.Logger, verbose bool) *Engine {
return &Engine{
docker: NewDockerManager(verbose),
log: log,
verbose: verbose,
}
}
// Run executes a complete DR drill
func (e *Engine) Run(ctx context.Context, config *DrillConfig) (*DrillResult, error) {
result := &DrillResult{
DrillID: NewDrillID(),
StartTime: time.Now(),
BackupPath: config.BackupPath,
DatabaseName: config.DatabaseName,
DatabaseType: config.DatabaseType,
Status: StatusRunning,
Phases: make([]DrillPhase, 0),
TargetRTO: float64(config.MaxRestoreSeconds),
}
e.log.Info("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
e.log.Info(" 🧪 DR Drill: " + result.DrillID)
e.log.Info("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
e.log.Info("")
// Cleanup function for error cases
var containerID string
cleanup := func() {
if containerID != "" && config.CleanupOnExit && (result.Success || !config.KeepOnFailure) {
e.log.Info("🗑️ Cleaning up container...")
e.docker.RemoveContainer(context.Background(), containerID)
} else if containerID != "" {
result.ContainerKept = true
e.log.Info("📦 Container kept for debugging: " + containerID)
}
}
defer cleanup()
// Phase 1: Preflight checks
phase := e.startPhase("Preflight Checks")
if err := e.preflightChecks(ctx, config); err != nil {
e.failPhase(&phase, err.Error())
result.Phases = append(result.Phases, phase)
result.Status = StatusFailed
result.Message = "Preflight checks failed: " + err.Error()
result.Errors = append(result.Errors, err.Error())
e.finalize(result)
return result, nil
}
e.completePhase(&phase, "All checks passed")
result.Phases = append(result.Phases, phase)
// Phase 2: Start container
phase = e.startPhase("Start Container")
containerConfig := e.buildContainerConfig(config)
container, err := e.docker.CreateContainer(ctx, containerConfig)
if err != nil {
e.failPhase(&phase, err.Error())
result.Phases = append(result.Phases, phase)
result.Status = StatusFailed
result.Message = "Failed to start container: " + err.Error()
result.Errors = append(result.Errors, err.Error())
e.finalize(result)
return result, nil
}
containerID = container.ID
result.ContainerID = containerID
e.log.Info("📦 Container started: " + containerID[:12])
// Wait for container to be healthy
if err := e.docker.WaitForHealth(ctx, containerID, config.DatabaseType, config.ContainerTimeout); err != nil {
e.failPhase(&phase, "Container health check failed: "+err.Error())
result.Phases = append(result.Phases, phase)
result.Status = StatusFailed
result.Message = "Container failed to start"
result.Errors = append(result.Errors, err.Error())
e.finalize(result)
return result, nil
}
e.completePhase(&phase, "Container healthy")
result.Phases = append(result.Phases, phase)
// Phase 3: Restore backup
phase = e.startPhase("Restore Backup")
restoreStart := time.Now()
if err := e.restoreBackup(ctx, config, containerID, containerConfig); err != nil {
e.failPhase(&phase, err.Error())
result.Phases = append(result.Phases, phase)
result.Status = StatusFailed
result.Message = "Restore failed: " + err.Error()
result.Errors = append(result.Errors, err.Error())
e.finalize(result)
return result, nil
}
result.RestoreTime = time.Since(restoreStart).Seconds()
e.completePhase(&phase, fmt.Sprintf("Restored in %.2fs", result.RestoreTime))
result.Phases = append(result.Phases, phase)
e.log.Info(fmt.Sprintf("✅ Backup restored in %.2fs", result.RestoreTime))
// Phase 4: Validate
phase = e.startPhase("Validate Database")
validateStart := time.Now()
validationErrors := e.validateDatabase(ctx, config, result, containerConfig)
result.ValidationTime = time.Since(validateStart).Seconds()
if validationErrors > 0 {
e.completePhase(&phase, fmt.Sprintf("Completed with %d errors", validationErrors))
} else {
e.completePhase(&phase, "All validations passed")
}
result.Phases = append(result.Phases, phase)
// Determine overall status
result.ActualRTO = result.RestoreTime + result.ValidationTime
result.RTOMet = result.ActualRTO <= result.TargetRTO
criticalFailures := 0
for _, vr := range result.ValidationResults {
if !vr.Success {
criticalFailures++
}
}
for _, cr := range result.CheckResults {
if !cr.Success {
criticalFailures++
}
}
if criticalFailures == 0 {
result.Success = true
result.Status = StatusCompleted
result.Message = "DR drill completed successfully"
} else if criticalFailures < len(result.ValidationResults)+len(result.CheckResults) {
result.Success = false
result.Status = StatusPartial
result.Message = fmt.Sprintf("DR drill completed with %d validation failures", criticalFailures)
} else {
result.Success = false
result.Status = StatusFailed
result.Message = "All validations failed"
}
e.finalize(result)
// Save result if output dir specified
if config.OutputDir != "" {
if err := result.SaveResult(config.OutputDir); err != nil {
e.log.Warn("Failed to save drill result", "error", err)
} else {
e.log.Info("📄 Report saved to: " + filepath.Join(config.OutputDir, result.DrillID+"_report.json"))
}
}
return result, nil
}
// preflightChecks runs preflight checks before the drill
func (e *Engine) preflightChecks(ctx context.Context, config *DrillConfig) error {
// Check Docker is available
if err := e.docker.CheckDockerAvailable(ctx); err != nil {
return fmt.Errorf("docker not available: %w", err)
}
e.log.Info("✓ Docker is available")
// Check backup file exists
if _, err := os.Stat(config.BackupPath); err != nil {
return fmt.Errorf("backup file not found: %s", config.BackupPath)
}
e.log.Info("✓ Backup file exists: " + filepath.Base(config.BackupPath))
// Pull Docker image
image := config.ContainerImage
if image == "" {
image = GetDefaultImage(config.DatabaseType, "")
}
e.log.Info("⬇️ Pulling image: " + image)
if err := e.docker.PullImage(ctx, image); err != nil {
return fmt.Errorf("failed to pull image: %w", err)
}
e.log.Info("✓ Image ready: " + image)
return nil
}
// buildContainerConfig creates container configuration
func (e *Engine) buildContainerConfig(config *DrillConfig) *ContainerConfig {
containerName := config.ContainerName
if containerName == "" {
containerName = fmt.Sprintf("drill_%s_%s", config.DatabaseName, time.Now().Format("20060102_150405"))
}
image := config.ContainerImage
if image == "" {
image = GetDefaultImage(config.DatabaseType, "")
}
port := config.ContainerPort
if port == 0 {
port = 15432 // Default drill port (different from production)
if config.DatabaseType == "mysql" || config.DatabaseType == "mariadb" {
port = 13306
}
}
containerPort := GetDefaultPort(config.DatabaseType)
env := GetDefaultEnvironment(config.DatabaseType)
return &ContainerConfig{
Image: image,
Name: containerName,
Port: port,
ContainerPort: containerPort,
Environment: env,
Timeout: config.ContainerTimeout,
}
}
// restoreBackup restores the backup into the container
func (e *Engine) restoreBackup(ctx context.Context, config *DrillConfig, containerID string, containerConfig *ContainerConfig) error {
// Copy backup to container
backupName := filepath.Base(config.BackupPath)
containerBackupPath := "/tmp/" + backupName
e.log.Info("📁 Copying backup to container...")
if err := e.docker.CopyToContainer(ctx, containerID, config.BackupPath, containerBackupPath); err != nil {
return fmt.Errorf("failed to copy backup: %w", err)
}
// Handle encrypted backups
if config.EncryptionKeyFile != "" {
// For encrypted backups, we'd need to decrypt first
// This is a simplified implementation
e.log.Warn("Encrypted backup handling not fully implemented in drill mode")
}
// Restore based on database type and format
e.log.Info("🔄 Restoring backup...")
return e.executeRestore(ctx, config, containerID, containerBackupPath, containerConfig)
}
// executeRestore runs the actual restore command
func (e *Engine) executeRestore(ctx context.Context, config *DrillConfig, containerID, backupPath string, containerConfig *ContainerConfig) error {
var cmd []string
switch config.DatabaseType {
case "postgresql", "postgres":
// Decompress if needed
if strings.HasSuffix(backupPath, ".gz") {
decompressedPath := strings.TrimSuffix(backupPath, ".gz")
_, err := e.docker.ExecCommand(ctx, containerID, []string{
"sh", "-c", fmt.Sprintf("gunzip -c %s > %s", backupPath, decompressedPath),
})
if err != nil {
return fmt.Errorf("decompression failed: %w", err)
}
backupPath = decompressedPath
}
// Create database
_, err := e.docker.ExecCommand(ctx, containerID, []string{
"psql", "-U", "postgres", "-c", fmt.Sprintf("CREATE DATABASE %s", config.DatabaseName),
})
if err != nil {
// Database might already exist
e.log.Debug("Create database returned (may already exist)")
}
// Detect restore method based on file content
isCustomFormat := strings.Contains(backupPath, ".dump") || strings.Contains(backupPath, ".custom")
if isCustomFormat {
cmd = []string{"pg_restore", "-U", "postgres", "-d", config.DatabaseName, "-v", backupPath}
} else {
cmd = []string{"sh", "-c", fmt.Sprintf("psql -U postgres -d %s < %s", config.DatabaseName, backupPath)}
}
case "mysql":
// Decompress if needed
if strings.HasSuffix(backupPath, ".gz") {
decompressedPath := strings.TrimSuffix(backupPath, ".gz")
_, err := e.docker.ExecCommand(ctx, containerID, []string{
"sh", "-c", fmt.Sprintf("gunzip -c %s > %s", backupPath, decompressedPath),
})
if err != nil {
return fmt.Errorf("decompression failed: %w", err)
}
backupPath = decompressedPath
}
cmd = []string{"sh", "-c", fmt.Sprintf("mysql -u root --password=root %s < %s", config.DatabaseName, backupPath)}
case "mariadb":
if strings.HasSuffix(backupPath, ".gz") {
decompressedPath := strings.TrimSuffix(backupPath, ".gz")
_, err := e.docker.ExecCommand(ctx, containerID, []string{
"sh", "-c", fmt.Sprintf("gunzip -c %s > %s", backupPath, decompressedPath),
})
if err != nil {
return fmt.Errorf("decompression failed: %w", err)
}
backupPath = decompressedPath
}
cmd = []string{"sh", "-c", fmt.Sprintf("mariadb -u root --password=root %s < %s", config.DatabaseName, backupPath)}
default:
return fmt.Errorf("unsupported database type: %s", config.DatabaseType)
}
output, err := e.docker.ExecCommand(ctx, containerID, cmd)
if err != nil {
return fmt.Errorf("restore failed: %w (output: %s)", err, output)
}
return nil
}
// validateDatabase runs validation against the restored database
func (e *Engine) validateDatabase(ctx context.Context, config *DrillConfig, result *DrillResult, containerConfig *ContainerConfig) int {
errorCount := 0
// Connect to database
var user, password string
switch config.DatabaseType {
case "postgresql", "postgres":
user = "postgres"
password = containerConfig.Environment["POSTGRES_PASSWORD"]
case "mysql":
user = "root"
password = "root"
case "mariadb":
user = "root"
password = "root"
}
validator, err := NewValidator(config.DatabaseType, "localhost", containerConfig.Port, user, password, config.DatabaseName, e.verbose)
if err != nil {
e.log.Error("Failed to connect for validation", "error", err)
result.Errors = append(result.Errors, "Validation connection failed: "+err.Error())
return 1
}
defer validator.Close()
// Get database metrics
tables, err := validator.GetTableList(ctx)
if err == nil {
result.TableCount = len(tables)
e.log.Info(fmt.Sprintf("📊 Tables found: %d", result.TableCount))
}
totalRows, err := validator.GetTotalRowCount(ctx)
if err == nil {
result.TotalRows = totalRows
e.log.Info(fmt.Sprintf("📊 Total rows: %d", result.TotalRows))
}
dbSize, err := validator.GetDatabaseSize(ctx, config.DatabaseName)
if err == nil {
result.DatabaseSize = dbSize
}
// Run expected tables check
if len(config.ExpectedTables) > 0 {
tableResults := validator.ValidateExpectedTables(ctx, config.ExpectedTables)
for _, tr := range tableResults {
result.CheckResults = append(result.CheckResults, tr)
if !tr.Success {
errorCount++
e.log.Warn("❌ " + tr.Message)
} else {
e.log.Info("✓ " + tr.Message)
}
}
}
// Run validation queries
if len(config.ValidationQueries) > 0 {
queryResults := validator.RunValidationQueries(ctx, config.ValidationQueries)
result.ValidationResults = append(result.ValidationResults, queryResults...)
var totalQueryTime float64
for _, qr := range queryResults {
totalQueryTime += qr.Duration
if !qr.Success {
errorCount++
e.log.Warn(fmt.Sprintf("❌ %s: %s", qr.Name, qr.Error))
} else {
e.log.Info(fmt.Sprintf("✓ %s: %s (%.0fms)", qr.Name, qr.Result, qr.Duration))
}
}
if len(queryResults) > 0 {
result.QueryTimeAvg = totalQueryTime / float64(len(queryResults))
}
}
// Run custom checks
if len(config.CustomChecks) > 0 {
checkResults := validator.RunCustomChecks(ctx, config.CustomChecks)
for _, cr := range checkResults {
result.CheckResults = append(result.CheckResults, cr)
if !cr.Success {
errorCount++
e.log.Warn("❌ " + cr.Message)
} else {
e.log.Info("✓ " + cr.Message)
}
}
}
// Check minimum row count if specified
if config.MinRowCount > 0 && result.TotalRows < config.MinRowCount {
errorCount++
msg := fmt.Sprintf("Total rows (%d) below minimum (%d)", result.TotalRows, config.MinRowCount)
result.Warnings = append(result.Warnings, msg)
e.log.Warn("⚠️ " + msg)
}
return errorCount
}
// startPhase starts a new drill phase
func (e *Engine) startPhase(name string) DrillPhase {
e.log.Info("▶️ " + name)
return DrillPhase{
Name: name,
Status: "running",
StartTime: time.Now(),
}
}
// completePhase marks a phase as completed
func (e *Engine) completePhase(phase *DrillPhase, message string) {
phase.EndTime = time.Now()
phase.Duration = phase.EndTime.Sub(phase.StartTime).Seconds()
phase.Status = "completed"
phase.Message = message
}
// failPhase marks a phase as failed
func (e *Engine) failPhase(phase *DrillPhase, message string) {
phase.EndTime = time.Now()
phase.Duration = phase.EndTime.Sub(phase.StartTime).Seconds()
phase.Status = "failed"
phase.Message = message
e.log.Error("❌ Phase failed: " + message)
}
// finalize completes the drill result
func (e *Engine) finalize(result *DrillResult) {
result.EndTime = time.Now()
result.Duration = result.EndTime.Sub(result.StartTime).Seconds()
e.log.Info("")
e.log.Info("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
e.log.Info(" " + result.Summary())
e.log.Info("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
if result.Success {
e.log.Info(fmt.Sprintf(" RTO: %.2fs (target: %.0fs) %s",
result.ActualRTO, result.TargetRTO, boolIcon(result.RTOMet)))
}
}
func boolIcon(b bool) string {
if b {
return "✅"
}
return "❌"
}
// Cleanup removes drill resources
func (e *Engine) Cleanup(ctx context.Context, drillID string) error {
containers, err := e.docker.ListDrillContainers(ctx)
if err != nil {
return err
}
for _, c := range containers {
if strings.Contains(c.Name, drillID) || (drillID == "" && strings.HasPrefix(c.Name, "drill_")) {
e.log.Info("🗑️ Removing container: " + c.Name)
if err := e.docker.RemoveContainer(ctx, c.ID); err != nil {
e.log.Warn("Failed to remove container", "id", c.ID, "error", err)
}
}
}
return nil
}
// QuickTest runs a quick restore test without full validation
func (e *Engine) QuickTest(ctx context.Context, backupPath, dbType, dbName string) (*DrillResult, error) {
config := DefaultConfig()
config.BackupPath = backupPath
config.DatabaseType = dbType
config.DatabaseName = dbName
config.CleanupOnExit = true
config.MaxRestoreSeconds = 600
return e.Run(ctx, config)
}
// Validate runs validation queries against an existing database (non-Docker)
func (e *Engine) Validate(ctx context.Context, config *DrillConfig, host string, port int, user, password string) ([]ValidationResult, error) {
validator, err := NewValidator(config.DatabaseType, host, port, user, password, config.DatabaseName, e.verbose)
if err != nil {
return nil, err
}
defer validator.Close()
return validator.RunValidationQueries(ctx, config.ValidationQueries), nil
}

358
internal/drill/validate.go Normal file
View File

@@ -0,0 +1,358 @@
// Package drill - Validation logic for DR drills
package drill
import (
"context"
"database/sql"
"fmt"
"strings"
"time"
_ "github.com/go-sql-driver/mysql"
_ "github.com/jackc/pgx/v5/stdlib"
)
// Validator handles database validation during DR drills
type Validator struct {
db *sql.DB
dbType string
verbose bool
}
// NewValidator creates a new database validator
func NewValidator(dbType string, host string, port int, user, password, dbname string, verbose bool) (*Validator, error) {
var dsn string
var driver string
switch dbType {
case "postgresql", "postgres":
driver = "pgx"
dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
host, port, user, password, dbname)
case "mysql":
driver = "mysql"
dsn = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
user, password, host, port, dbname)
case "mariadb":
driver = "mysql"
dsn = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
user, password, host, port, dbname)
default:
return nil, fmt.Errorf("unsupported database type: %s", dbType)
}
db, err := sql.Open(driver, dsn)
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
// Test connection
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := db.PingContext(ctx); err != nil {
db.Close()
return nil, fmt.Errorf("failed to ping database: %w", err)
}
return &Validator{
db: db,
dbType: dbType,
verbose: verbose,
}, nil
}
// Close closes the database connection
func (v *Validator) Close() error {
return v.db.Close()
}
// RunValidationQueries executes validation queries and returns results
func (v *Validator) RunValidationQueries(ctx context.Context, queries []ValidationQuery) []ValidationResult {
var results []ValidationResult
for _, q := range queries {
result := v.runQuery(ctx, q)
results = append(results, result)
}
return results
}
// runQuery executes a single validation query
func (v *Validator) runQuery(ctx context.Context, query ValidationQuery) ValidationResult {
result := ValidationResult{
Name: query.Name,
Query: query.Query,
Expected: query.ExpectedValue,
}
start := time.Now()
rows, err := v.db.QueryContext(ctx, query.Query)
result.Duration = float64(time.Since(start).Milliseconds())
if err != nil {
result.Success = false
result.Error = err.Error()
return result
}
defer rows.Close()
// Get result
if rows.Next() {
var value interface{}
if err := rows.Scan(&value); err != nil {
result.Success = false
result.Error = fmt.Sprintf("scan error: %v", err)
return result
}
result.Result = fmt.Sprintf("%v", value)
}
// Validate result
result.Success = true
if query.ExpectedValue != "" && result.Result != query.ExpectedValue {
result.Success = false
result.Error = fmt.Sprintf("expected %s, got %s", query.ExpectedValue, result.Result)
}
// Check min/max if specified
if query.MinValue > 0 || query.MaxValue > 0 {
var numValue int64
fmt.Sscanf(result.Result, "%d", &numValue)
if query.MinValue > 0 && numValue < query.MinValue {
result.Success = false
result.Error = fmt.Sprintf("value %d below minimum %d", numValue, query.MinValue)
}
if query.MaxValue > 0 && numValue > query.MaxValue {
result.Success = false
result.Error = fmt.Sprintf("value %d above maximum %d", numValue, query.MaxValue)
}
}
return result
}
// RunCustomChecks executes custom validation checks
func (v *Validator) RunCustomChecks(ctx context.Context, checks []CustomCheck) []CheckResult {
var results []CheckResult
for _, check := range checks {
result := v.runCheck(ctx, check)
results = append(results, result)
}
return results
}
// runCheck executes a single custom check
func (v *Validator) runCheck(ctx context.Context, check CustomCheck) CheckResult {
result := CheckResult{
Name: check.Name,
Type: check.Type,
Expected: check.MinValue,
}
switch check.Type {
case "row_count":
count, err := v.getRowCount(ctx, check.Table, check.Condition)
if err != nil {
result.Success = false
result.Message = fmt.Sprintf("failed to get row count: %v", err)
return result
}
result.Actual = count
result.Success = count >= check.MinValue
if result.Success {
result.Message = fmt.Sprintf("Table %s has %d rows (min: %d)", check.Table, count, check.MinValue)
} else {
result.Message = fmt.Sprintf("Table %s has %d rows, expected at least %d", check.Table, count, check.MinValue)
}
case "table_exists":
exists, err := v.tableExists(ctx, check.Table)
if err != nil {
result.Success = false
result.Message = fmt.Sprintf("failed to check table: %v", err)
return result
}
result.Success = exists
if exists {
result.Actual = 1
result.Message = fmt.Sprintf("Table %s exists", check.Table)
} else {
result.Actual = 0
result.Message = fmt.Sprintf("Table %s does not exist", check.Table)
}
case "column_check":
exists, err := v.columnExists(ctx, check.Table, check.Column)
if err != nil {
result.Success = false
result.Message = fmt.Sprintf("failed to check column: %v", err)
return result
}
result.Success = exists
if exists {
result.Actual = 1
result.Message = fmt.Sprintf("Column %s.%s exists", check.Table, check.Column)
} else {
result.Actual = 0
result.Message = fmt.Sprintf("Column %s.%s does not exist", check.Table, check.Column)
}
default:
result.Success = false
result.Message = fmt.Sprintf("unknown check type: %s", check.Type)
}
return result
}
// getRowCount returns the row count for a table
func (v *Validator) getRowCount(ctx context.Context, table, condition string) (int64, error) {
query := fmt.Sprintf("SELECT COUNT(*) FROM %s", v.quoteIdentifier(table))
if condition != "" {
query += " WHERE " + condition
}
var count int64
err := v.db.QueryRowContext(ctx, query).Scan(&count)
return count, err
}
// tableExists checks if a table exists
func (v *Validator) tableExists(ctx context.Context, table string) (bool, error) {
var query string
switch v.dbType {
case "postgresql", "postgres":
query = `SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_name = $1
)`
case "mysql", "mariadb":
query = `SELECT COUNT(*) > 0 FROM information_schema.tables
WHERE table_name = ?`
}
var exists bool
err := v.db.QueryRowContext(ctx, query, table).Scan(&exists)
return exists, err
}
// columnExists checks if a column exists
func (v *Validator) columnExists(ctx context.Context, table, column string) (bool, error) {
var query string
switch v.dbType {
case "postgresql", "postgres":
query = `SELECT EXISTS (
SELECT FROM information_schema.columns
WHERE table_name = $1 AND column_name = $2
)`
case "mysql", "mariadb":
query = `SELECT COUNT(*) > 0 FROM information_schema.columns
WHERE table_name = ? AND column_name = ?`
}
var exists bool
err := v.db.QueryRowContext(ctx, query, table, column).Scan(&exists)
return exists, err
}
// GetTableList returns all tables in the database
func (v *Validator) GetTableList(ctx context.Context) ([]string, error) {
var query string
switch v.dbType {
case "postgresql", "postgres":
query = `SELECT table_name FROM information_schema.tables
WHERE table_schema = 'public' AND table_type = 'BASE TABLE'`
case "mysql", "mariadb":
query = `SELECT table_name FROM information_schema.tables
WHERE table_schema = DATABASE() AND table_type = 'BASE TABLE'`
}
rows, err := v.db.QueryContext(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
var tables []string
for rows.Next() {
var table string
if err := rows.Scan(&table); err != nil {
return nil, err
}
tables = append(tables, table)
}
return tables, rows.Err()
}
// GetTotalRowCount returns total row count across all tables
func (v *Validator) GetTotalRowCount(ctx context.Context) (int64, error) {
tables, err := v.GetTableList(ctx)
if err != nil {
return 0, err
}
var total int64
for _, table := range tables {
count, err := v.getRowCount(ctx, table, "")
if err != nil {
continue // Skip tables that can't be counted
}
total += count
}
return total, nil
}
// GetDatabaseSize returns the database size in bytes
func (v *Validator) GetDatabaseSize(ctx context.Context, dbname string) (int64, error) {
var query string
switch v.dbType {
case "postgresql", "postgres":
query = fmt.Sprintf("SELECT pg_database_size('%s')", dbname)
case "mysql", "mariadb":
query = fmt.Sprintf(`SELECT SUM(data_length + index_length)
FROM information_schema.tables WHERE table_schema = '%s'`, dbname)
}
var size sql.NullInt64
err := v.db.QueryRowContext(ctx, query).Scan(&size)
if err != nil {
return 0, err
}
return size.Int64, nil
}
// ValidateExpectedTables checks that all expected tables exist
func (v *Validator) ValidateExpectedTables(ctx context.Context, expectedTables []string) []CheckResult {
var results []CheckResult
for _, table := range expectedTables {
check := CustomCheck{
Name: fmt.Sprintf("Table '%s' exists", table),
Type: "table_exists",
Table: table,
}
results = append(results, v.runCheck(ctx, check))
}
return results
}
// quoteIdentifier quotes a database identifier
func (v *Validator) quoteIdentifier(id string) string {
switch v.dbType {
case "postgresql", "postgres":
return fmt.Sprintf(`"%s"`, strings.ReplaceAll(id, `"`, `""`))
case "mysql", "mariadb":
return fmt.Sprintf("`%s`", strings.ReplaceAll(id, "`", "``"))
default:
return id
}
}