From f69bfe70718f29d1c1dd2fe9b890920a7bcfb864 Mon Sep 17 00:00:00 2001 From: Alexander Renz Date: Sat, 13 Dec 2025 20:28:55 +0100 Subject: [PATCH] feat: Add enterprise DBA features for production reliability New features implemented: 1. Backup Catalog (internal/catalog/) - SQLite-based backup tracking - Gap detection and RPO monitoring - Search and statistics - Filesystem sync 2. DR Drill Testing (internal/drill/) - Automated restore testing in Docker containers - Database validation with custom queries - Catalog integration for drill-tested status 3. Smart Notifications (internal/notify/) - Event batching with configurable intervals - Time-based escalation policies - HTML/text/Slack templates 4. Compliance Reports (internal/report/) - SOC2, GDPR, HIPAA, PCI-DSS, ISO27001 frameworks - Evidence collection from catalog - JSON, Markdown, HTML output formats 5. RTO/RPO Calculator (internal/rto/) - Recovery objective analysis - RTO breakdown by phase - Recommendations for improvement 6. Replica-Aware Backup (internal/replica/) - Topology detection for PostgreSQL/MySQL - Automatic replica selection - Configurable selection strategies 7. Parallel Table Backup (internal/parallel/) - Concurrent table dumps - Worker pool with progress tracking - Large table optimization 8. MySQL/MariaDB PITR (internal/pitr/) - Binary log parsing and replay - Point-in-time recovery support - Transaction filtering CLI commands added: catalog, drill, report, rto All changes support the goal: reliable 3 AM database recovery. --- MYSQL_PITR.md | 401 ++++++++++++++ README.md | 127 ++++- cmd/catalog.go | 725 ++++++++++++++++++++++++ cmd/drill.go | 500 +++++++++++++++++ cmd/pitr.go | 812 ++++++++++++++++++++++++++- cmd/report.go | 316 +++++++++++ cmd/rto.go | 458 +++++++++++++++ go.mod | 1 + go.sum | 2 + internal/catalog/catalog.go | 188 +++++++ internal/catalog/catalog_test.go | 308 +++++++++++ internal/catalog/gaps.go | 299 ++++++++++ internal/catalog/sqlite.go | 632 +++++++++++++++++++++ internal/catalog/sync.go | 234 ++++++++ internal/config/config.go | 7 + internal/drill/docker.go | 298 ++++++++++ internal/drill/drill.go | 247 +++++++++ internal/drill/engine.go | 532 ++++++++++++++++++ internal/drill/validate.go | 358 ++++++++++++ internal/notify/batch.go | 261 +++++++++ internal/notify/escalate.go | 363 ++++++++++++ internal/notify/notify.go | 103 ++-- internal/notify/templates.go | 497 +++++++++++++++++ internal/parallel/engine.go | 619 +++++++++++++++++++++ internal/pitr/binlog.go | 865 +++++++++++++++++++++++++++++ internal/pitr/binlog_test.go | 585 +++++++++++++++++++ internal/pitr/interface.go | 155 ++++++ internal/pitr/mysql.go | 924 +++++++++++++++++++++++++++++++ internal/replica/selector.go | 499 +++++++++++++++++ internal/report/frameworks.go | 424 ++++++++++++++ internal/report/generator.go | 420 ++++++++++++++ internal/report/output.go | 544 ++++++++++++++++++ internal/report/report.go | 325 +++++++++++ internal/rto/calculator.go | 481 ++++++++++++++++ 34 files changed, 13469 insertions(+), 41 deletions(-) create mode 100644 MYSQL_PITR.md create mode 100644 cmd/catalog.go create mode 100644 cmd/drill.go create mode 100644 cmd/report.go create mode 100644 cmd/rto.go create mode 100644 internal/catalog/catalog.go create mode 100644 internal/catalog/catalog_test.go create mode 100644 internal/catalog/gaps.go create mode 100644 internal/catalog/sqlite.go create mode 100644 internal/catalog/sync.go create mode 100644 internal/drill/docker.go create mode 100644 internal/drill/drill.go create mode 100644 internal/drill/engine.go create mode 100644 internal/drill/validate.go create mode 100644 internal/notify/batch.go create mode 100644 internal/notify/escalate.go create mode 100644 internal/notify/templates.go create mode 100644 internal/parallel/engine.go create mode 100644 internal/pitr/binlog.go create mode 100644 internal/pitr/binlog_test.go create mode 100644 internal/pitr/interface.go create mode 100644 internal/pitr/mysql.go create mode 100644 internal/replica/selector.go create mode 100644 internal/report/frameworks.go create mode 100644 internal/report/generator.go create mode 100644 internal/report/output.go create mode 100644 internal/report/report.go create mode 100644 internal/rto/calculator.go diff --git a/MYSQL_PITR.md b/MYSQL_PITR.md new file mode 100644 index 0000000..19b2969 --- /dev/null +++ b/MYSQL_PITR.md @@ -0,0 +1,401 @@ +# MySQL/MariaDB Point-in-Time Recovery (PITR) + +This guide explains how to use dbbackup for Point-in-Time Recovery with MySQL and MariaDB databases. + +## Overview + +Point-in-Time Recovery (PITR) allows you to restore your database to any specific moment in time, not just to when a backup was taken. This is essential for: + +- Recovering from accidental data deletion or corruption +- Restoring to a state just before a problematic change +- Meeting regulatory compliance requirements for data recovery + +### How MySQL PITR Works + +MySQL PITR uses binary logs (binlogs) which record all changes to the database: + +1. **Base Backup**: A full database backup with the binlog position recorded +2. **Binary Log Archiving**: Continuous archiving of binlog files +3. **Recovery**: Restore base backup, then replay binlogs up to the target time + +``` +┌─────────────┐ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ +│ Base Backup │ --> │ binlog.00001 │ --> │ binlog.00002 │ --> │ binlog.00003 │ +│ (pos: 1234) │ │ │ │ │ │ (current) │ +└─────────────┘ └──────────────┘ └──────────────┘ └──────────────┘ + │ │ │ │ + ▼ ▼ ▼ ▼ + 10:00 AM 10:30 AM 11:00 AM 11:30 AM + ↑ + Target: 11:15 AM +``` + +## Prerequisites + +### MySQL Configuration + +Binary logging must be enabled in MySQL. Add to `my.cnf`: + +```ini +[mysqld] +# Enable binary logging +log_bin = mysql-bin +server_id = 1 + +# Recommended: Use ROW format for PITR +binlog_format = ROW + +# Optional but recommended: Enable GTID for easier replication and recovery +gtid_mode = ON +enforce_gtid_consistency = ON + +# Keep binlogs for at least 7 days (adjust as needed) +expire_logs_days = 7 +# Or for MySQL 8.0+: +# binlog_expire_logs_seconds = 604800 +``` + +After changing configuration, restart MySQL: +```bash +sudo systemctl restart mysql +``` + +### MariaDB Configuration + +MariaDB configuration is similar: + +```ini +[mysqld] +log_bin = mariadb-bin +server_id = 1 +binlog_format = ROW + +# MariaDB uses different GTID implementation (auto-enabled with log_slave_updates) +log_slave_updates = ON +``` + +## Quick Start + +### 1. Check PITR Status + +```bash +# Check if MySQL is properly configured for PITR +dbbackup pitr mysql-status +``` + +Example output: +``` +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + MySQL/MariaDB PITR Status (mysql) +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +PITR Status: ❌ NOT CONFIGURED +Binary Logging: ✅ ENABLED +Binlog Format: ROW +GTID Mode: ON +Current Position: mysql-bin.000042:1234 + +PITR Requirements: + ✅ Binary logging enabled + ✅ Row-based logging (recommended) +``` + +### 2. Enable PITR + +```bash +# Enable PITR and configure archive directory +dbbackup pitr mysql-enable --archive-dir /backups/binlog_archive +``` + +### 3. Create a Base Backup + +```bash +# Create a PITR-capable backup +dbbackup backup single mydb --pitr +``` + +### 4. Start Binlog Archiving + +```bash +# Run binlog archiver in the background +dbbackup binlog watch --binlog-dir /var/lib/mysql --archive-dir /backups/binlog_archive --interval 30s +``` + +Or set up a cron job for periodic archiving: +```bash +# Archive new binlogs every 5 minutes +*/5 * * * * dbbackup binlog archive --binlog-dir /var/lib/mysql --archive-dir /backups/binlog_archive +``` + +### 5. Restore to Point in Time + +```bash +# Restore to a specific time +dbbackup restore pitr mydb_backup.sql.gz --target-time '2024-01-15 14:30:00' +``` + +## Commands Reference + +### PITR Commands + +#### `pitr mysql-status` +Show MySQL/MariaDB PITR configuration and status. + +```bash +dbbackup pitr mysql-status +``` + +#### `pitr mysql-enable` +Enable PITR for MySQL/MariaDB. + +```bash +dbbackup pitr mysql-enable \ + --archive-dir /backups/binlog_archive \ + --retention-days 7 \ + --require-row-format \ + --require-gtid +``` + +Options: +- `--archive-dir`: Directory to store archived binlogs (required) +- `--retention-days`: Days to keep archived binlogs (default: 7) +- `--require-row-format`: Require ROW binlog format (default: true) +- `--require-gtid`: Require GTID mode enabled (default: false) + +### Binlog Commands + +#### `binlog list` +List available binary log files. + +```bash +# List binlogs from MySQL data directory +dbbackup binlog list --binlog-dir /var/lib/mysql + +# List archived binlogs +dbbackup binlog list --archive-dir /backups/binlog_archive +``` + +#### `binlog archive` +Archive binary log files. + +```bash +dbbackup binlog archive \ + --binlog-dir /var/lib/mysql \ + --archive-dir /backups/binlog_archive \ + --compress +``` + +Options: +- `--binlog-dir`: MySQL binary log directory +- `--archive-dir`: Destination for archived binlogs (required) +- `--compress`: Compress archived binlogs with gzip +- `--encrypt`: Encrypt archived binlogs +- `--encryption-key-file`: Path to encryption key file + +#### `binlog watch` +Continuously monitor and archive new binlog files. + +```bash +dbbackup binlog watch \ + --binlog-dir /var/lib/mysql \ + --archive-dir /backups/binlog_archive \ + --interval 30s \ + --compress +``` + +Options: +- `--interval`: How often to check for new binlogs (default: 30s) + +#### `binlog validate` +Validate binlog chain integrity. + +```bash +dbbackup binlog validate --binlog-dir /var/lib/mysql +``` + +Output shows: +- Whether the chain is complete (no missing files) +- Any gaps in the sequence +- Server ID changes (indicating possible failover) +- Total size and file count + +#### `binlog position` +Show current binary log position. + +```bash +dbbackup binlog position +``` + +Output: +``` +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + Current Binary Log Position +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +File: mysql-bin.000042 +Position: 123456 +GTID Set: 3E11FA47-71CA-11E1-9E33-C80AA9429562:1-1000 + +Position String: mysql-bin.000042:123456 +``` + +## Restore Scenarios + +### Restore to Specific Time + +```bash +# Restore to January 15, 2024 at 2:30 PM +dbbackup restore pitr mydb_backup.sql.gz \ + --target-time '2024-01-15 14:30:00' +``` + +### Restore to Specific Position + +```bash +# Restore to a specific binlog position +dbbackup restore pitr mydb_backup.sql.gz \ + --target-position 'mysql-bin.000042:12345' +``` + +### Dry Run (Preview) + +```bash +# See what SQL would be replayed without applying +dbbackup restore pitr mydb_backup.sql.gz \ + --target-time '2024-01-15 14:30:00' \ + --dry-run +``` + +### Restore to Backup Point Only + +```bash +# Restore just the base backup without replaying binlogs +dbbackup restore pitr mydb_backup.sql.gz --immediate +``` + +## Best Practices + +### 1. Archiving Strategy + +- Archive binlogs frequently (every 5-30 minutes) +- Use compression to save disk space +- Store archives on separate storage from the database + +### 2. Retention Policy + +- Keep archives for at least as long as your oldest valid base backup +- Consider regulatory requirements for data retention +- Use the cleanup command to purge old archives: + +```bash +dbbackup binlog cleanup --archive-dir /backups/binlog_archive --retention-days 30 +``` + +### 3. Validation + +- Regularly validate your binlog chain: +```bash +dbbackup binlog validate --binlog-dir /var/lib/mysql +``` + +- Test restoration periodically on a test environment + +### 4. Monitoring + +- Monitor the `dbbackup binlog watch` process +- Set up alerts for: + - Binlog archiver failures + - Gaps in binlog chain + - Low disk space on archive directory + +### 5. GTID Mode + +Enable GTID for: +- Easier tracking of replication position +- Automatic failover in replication setups +- Simpler point-in-time recovery + +## Troubleshooting + +### Binary Logging Not Enabled + +**Error**: "Binary logging appears to be disabled" + +**Solution**: Add to my.cnf and restart MySQL: +```ini +[mysqld] +log_bin = mysql-bin +server_id = 1 +``` + +### Missing Binlog Files + +**Error**: "Gaps detected in binlog chain" + +**Causes**: +- `RESET MASTER` was executed +- `expire_logs_days` is too short +- Binlogs were manually deleted + +**Solution**: +- Take a new base backup immediately +- Adjust retention settings to prevent future gaps + +### Permission Denied + +**Error**: "Failed to read binlog directory" + +**Solution**: +```bash +# Add dbbackup user to mysql group +sudo usermod -aG mysql dbbackup_user + +# Or set appropriate permissions +sudo chmod g+r /var/lib/mysql/mysql-bin.* +``` + +### Wrong Binlog Format + +**Warning**: "binlog_format = STATEMENT (ROW recommended)" + +**Impact**: STATEMENT format may not capture all changes accurately + +**Solution**: Change to ROW format (requires restart): +```ini +[mysqld] +binlog_format = ROW +``` + +### Server ID Changes + +**Warning**: "server_id changed from X to Y (possible master failover)" + +This warning indicates the binlog chain contains events from different servers, which may happen during: +- Failover in a replication setup +- Restoring from a different server's backup + +This is usually informational but review your topology if unexpected. + +## MariaDB-Specific Notes + +### GTID Format + +MariaDB uses a different GTID format than MySQL: +- **MySQL**: `3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5` +- **MariaDB**: `0-1-100` (domain-server_id-sequence) + +### Tool Detection + +dbbackup automatically detects MariaDB and uses: +- `mariadb-binlog` if available (MariaDB 10.4+) +- Falls back to `mysqlbinlog` for older versions + +### Encrypted Binlogs + +MariaDB supports binlog encryption. If enabled, ensure the key is available during archive and restore operations. + +## See Also + +- [PITR.md](PITR.md) - PostgreSQL PITR documentation +- [DOCKER.md](DOCKER.md) - Running in Docker environments +- [CLOUD.md](CLOUD.md) - Cloud storage for archives diff --git a/README.md b/README.md index 9b6f1de..93fb1af 100644 --- a/README.md +++ b/README.md @@ -16,12 +16,22 @@ Database backup and restore utility for PostgreSQL, MySQL, and MariaDB. - AES-256-GCM encryption - Incremental backups - Cloud storage: S3, MinIO, B2, Azure Blob, Google Cloud Storage -- Point-in-Time Recovery (PITR) for PostgreSQL +- Point-in-Time Recovery (PITR) for PostgreSQL and MySQL/MariaDB - **GFS retention policies**: Grandfather-Father-Son backup rotation - **Notifications**: SMTP email and webhook alerts - Interactive terminal UI - Cross-platform binaries +### Enterprise DBA Features + +- **Backup Catalog**: SQLite-based catalog tracking all backups with gap detection +- **DR Drill Testing**: Automated disaster recovery testing in Docker containers +- **Smart Notifications**: Batched alerts with escalation policies +- **Compliance Reports**: SOC2, GDPR, HIPAA, PCI-DSS, ISO27001 report generation +- **RTO/RPO Calculator**: Recovery objective analysis and recommendations +- **Replica-Aware Backup**: Automatic backup from replicas to reduce primary load +- **Parallel Table Backup**: Concurrent table dumps for faster backups + ## Installation ### Docker @@ -257,6 +267,10 @@ dbbackup backup single mydb --dry-run | `pitr` | PITR management | | `wal` | WAL archive operations | | `interactive` | Start interactive UI | +| `catalog` | Backup catalog management | +| `drill` | DR drill testing | +| `report` | Compliance report generation | +| `rto` | RTO/RPO analysis | ## Global Flags @@ -478,6 +492,117 @@ dbbackup backup single mydb --notify - `cleanup_completed` - `verify_completed`, `verify_failed` - `pitr_recovery` +- `dr_drill_passed`, `dr_drill_failed` +- `gap_detected`, `rpo_violation` + +## Backup Catalog + +Track all backups in a SQLite catalog with gap detection and search: + +```bash +# Sync backups from directory to catalog +dbbackup catalog sync /backups + +# List recent backups +dbbackup catalog list --database mydb --limit 10 + +# Show catalog statistics +dbbackup catalog stats + +# Detect backup gaps (missing scheduled backups) +dbbackup catalog gaps --interval 24h --database mydb + +# Search backups +dbbackup catalog search --database mydb --start 2024-01-01 --end 2024-12-31 + +# Get backup info +dbbackup catalog info 42 +``` + +## DR Drill Testing + +Automated disaster recovery testing restores backups to Docker containers: + +```bash +# Run full DR drill +dbbackup drill run /backups/mydb_latest.dump.gz \ + --database mydb \ + --db-type postgres \ + --timeout 30m + +# Quick drill (restore + basic validation) +dbbackup drill quick /backups/mydb_latest.dump.gz --database mydb + +# List running drill containers +dbbackup drill list + +# Cleanup old drill containers +dbbackup drill cleanup --age 24h + +# Generate drill report +dbbackup drill report --format html --output drill-report.html +``` + +**Drill phases:** +1. Container creation +2. Backup download (if cloud) +3. Restore execution +4. Database validation +5. Custom query checks +6. Cleanup + +## Compliance Reports + +Generate compliance reports for regulatory frameworks: + +```bash +# Generate SOC2 report +dbbackup report generate --type soc2 --days 90 --format html --output soc2-report.html + +# HIPAA compliance report +dbbackup report generate --type hipaa --format markdown + +# Show compliance summary +dbbackup report summary --type gdpr --days 30 + +# List available frameworks +dbbackup report list + +# Show controls for a framework +dbbackup report controls soc2 +``` + +**Supported frameworks:** +- SOC2 Type II (Trust Service Criteria) +- GDPR (General Data Protection Regulation) +- HIPAA (Health Insurance Portability and Accountability Act) +- PCI-DSS (Payment Card Industry Data Security Standard) +- ISO 27001 (Information Security Management) + +## RTO/RPO Analysis + +Calculate and monitor Recovery Time/Point Objectives: + +```bash +# Analyze RTO/RPO for a database +dbbackup rto analyze mydb + +# Show status for all databases +dbbackup rto status + +# Check against targets +dbbackup rto check --rto 4h --rpo 1h + +# Set target objectives +dbbackup rto analyze mydb --target-rto 4h --target-rpo 1h +``` + +**Analysis includes:** +- Current RPO (time since last backup) +- Estimated RTO (detection + download + restore + validation) +- RTO breakdown by phase +- Compliance status +- Recommendations for improvement ## Configuration diff --git a/cmd/catalog.go b/cmd/catalog.go new file mode 100644 index 0000000..d6ea38e --- /dev/null +++ b/cmd/catalog.go @@ -0,0 +1,725 @@ +package cmd + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "time" + + "dbbackup/internal/catalog" + + "github.com/spf13/cobra" +) + +var ( + catalogDBPath string + catalogFormat string + catalogLimit int + catalogDatabase string + catalogStartDate string + catalogEndDate string + catalogInterval string + catalogVerbose bool +) + +// catalogCmd represents the catalog command group +var catalogCmd = &cobra.Command{ + Use: "catalog", + Short: "Backup catalog management", + Long: `Manage the backup catalog - a SQLite database tracking all backups. + +The catalog provides: + - Searchable history of all backups + - Gap detection for backup schedules + - Statistics and reporting + - Integration with DR drill testing + +Examples: + # Sync backups from a directory + dbbackup catalog sync /backups + + # List all backups + dbbackup catalog list + + # Show catalog statistics + dbbackup catalog stats + + # Detect gaps in backup schedule + dbbackup catalog gaps mydb --interval 24h + + # Search backups + dbbackup catalog search --database mydb --after 2024-01-01`, +} + +// catalogSyncCmd syncs backups from directory +var catalogSyncCmd = &cobra.Command{ + Use: "sync [directory]", + Short: "Sync backups from directory into catalog", + Long: `Scan a directory for backup files and import them into the catalog. + +This command: + - Finds all .meta.json files + - Imports backup metadata into SQLite catalog + - Detects removed backups + - Updates changed entries + +Examples: + # Sync from backup directory + dbbackup catalog sync /backups + + # Sync with verbose output + dbbackup catalog sync /backups --verbose`, + Args: cobra.MinimumNArgs(1), + RunE: runCatalogSync, +} + +// catalogListCmd lists backups +var catalogListCmd = &cobra.Command{ + Use: "list", + Short: "List backups in catalog", + Long: `List all backups in the catalog with optional filtering. + +Examples: + # List all backups + dbbackup catalog list + + # List backups for specific database + dbbackup catalog list --database mydb + + # List last 10 backups + dbbackup catalog list --limit 10 + + # Output as JSON + dbbackup catalog list --format json`, + RunE: runCatalogList, +} + +// catalogStatsCmd shows statistics +var catalogStatsCmd = &cobra.Command{ + Use: "stats", + Short: "Show catalog statistics", + Long: `Display comprehensive backup statistics. + +Shows: + - Total backup count and size + - Backups by database + - Backups by type and status + - Verification and drill test coverage + +Examples: + # Show overall stats + dbbackup catalog stats + + # Stats for specific database + dbbackup catalog stats --database mydb + + # Output as JSON + dbbackup catalog stats --format json`, + RunE: runCatalogStats, +} + +// catalogGapsCmd detects schedule gaps +var catalogGapsCmd = &cobra.Command{ + Use: "gaps [database]", + Short: "Detect gaps in backup schedule", + Long: `Analyze backup history and detect schedule gaps. + +This helps identify: + - Missed backups + - Schedule irregularities + - RPO violations + +Examples: + # Check all databases for gaps (24h expected interval) + dbbackup catalog gaps + + # Check specific database with custom interval + dbbackup catalog gaps mydb --interval 6h + + # Check gaps in date range + dbbackup catalog gaps --after 2024-01-01 --before 2024-02-01`, + RunE: runCatalogGaps, +} + +// catalogSearchCmd searches backups +var catalogSearchCmd = &cobra.Command{ + Use: "search", + Short: "Search backups in catalog", + Long: `Search for backups matching specific criteria. + +Examples: + # Search by database name (supports wildcards) + dbbackup catalog search --database "prod*" + + # Search by date range + dbbackup catalog search --after 2024-01-01 --before 2024-02-01 + + # Search verified backups only + dbbackup catalog search --verified + + # Search encrypted backups + dbbackup catalog search --encrypted`, + RunE: runCatalogSearch, +} + +// catalogInfoCmd shows entry details +var catalogInfoCmd = &cobra.Command{ + Use: "info [backup-path]", + Short: "Show detailed info for a backup", + Long: `Display detailed information about a specific backup. + +Examples: + # Show info by path + dbbackup catalog info /backups/mydb_20240115.dump.gz`, + Args: cobra.ExactArgs(1), + RunE: runCatalogInfo, +} + +func init() { + rootCmd.AddCommand(catalogCmd) + + // Default catalog path + defaultCatalogPath := filepath.Join(getDefaultConfigDir(), "catalog.db") + + // Global catalog flags + catalogCmd.PersistentFlags().StringVar(&catalogDBPath, "catalog-db", defaultCatalogPath, + "Path to catalog SQLite database") + catalogCmd.PersistentFlags().StringVar(&catalogFormat, "format", "table", + "Output format: table, json, csv") + + // Add subcommands + catalogCmd.AddCommand(catalogSyncCmd) + catalogCmd.AddCommand(catalogListCmd) + catalogCmd.AddCommand(catalogStatsCmd) + catalogCmd.AddCommand(catalogGapsCmd) + catalogCmd.AddCommand(catalogSearchCmd) + catalogCmd.AddCommand(catalogInfoCmd) + + // Sync flags + catalogSyncCmd.Flags().BoolVarP(&catalogVerbose, "verbose", "v", false, "Show detailed output") + + // List flags + catalogListCmd.Flags().IntVar(&catalogLimit, "limit", 50, "Maximum entries to show") + catalogListCmd.Flags().StringVar(&catalogDatabase, "database", "", "Filter by database name") + + // Stats flags + catalogStatsCmd.Flags().StringVar(&catalogDatabase, "database", "", "Show stats for specific database") + + // Gaps flags + catalogGapsCmd.Flags().StringVar(&catalogInterval, "interval", "24h", "Expected backup interval") + catalogGapsCmd.Flags().StringVar(&catalogStartDate, "after", "", "Start date (YYYY-MM-DD)") + catalogGapsCmd.Flags().StringVar(&catalogEndDate, "before", "", "End date (YYYY-MM-DD)") + + // Search flags + catalogSearchCmd.Flags().StringVar(&catalogDatabase, "database", "", "Filter by database name (supports wildcards)") + catalogSearchCmd.Flags().StringVar(&catalogStartDate, "after", "", "Backups after date (YYYY-MM-DD)") + catalogSearchCmd.Flags().StringVar(&catalogEndDate, "before", "", "Backups before date (YYYY-MM-DD)") + catalogSearchCmd.Flags().IntVar(&catalogLimit, "limit", 100, "Maximum results") + catalogSearchCmd.Flags().Bool("verified", false, "Only verified backups") + catalogSearchCmd.Flags().Bool("encrypted", false, "Only encrypted backups") + catalogSearchCmd.Flags().Bool("drill-tested", false, "Only drill-tested backups") +} + +func getDefaultConfigDir() string { + home, _ := os.UserHomeDir() + return filepath.Join(home, ".dbbackup") +} + +func openCatalog() (*catalog.SQLiteCatalog, error) { + return catalog.NewSQLiteCatalog(catalogDBPath) +} + +func runCatalogSync(cmd *cobra.Command, args []string) error { + dir := args[0] + + // Validate directory + info, err := os.Stat(dir) + if err != nil { + return fmt.Errorf("directory not found: %s", dir) + } + if !info.IsDir() { + return fmt.Errorf("not a directory: %s", dir) + } + + absDir, _ := filepath.Abs(dir) + + cat, err := openCatalog() + if err != nil { + return err + } + defer cat.Close() + + fmt.Printf("📁 Syncing backups from: %s\n", absDir) + fmt.Printf("📊 Catalog database: %s\n\n", catalogDBPath) + + ctx := context.Background() + result, err := cat.SyncFromDirectory(ctx, absDir) + if err != nil { + return err + } + + // Update last sync time + cat.SetLastSync(ctx) + + // Show results + fmt.Printf("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n") + fmt.Printf(" Sync Results\n") + fmt.Printf("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n") + fmt.Printf(" ✅ Added: %d\n", result.Added) + fmt.Printf(" 🔄 Updated: %d\n", result.Updated) + fmt.Printf(" 🗑️ Removed: %d\n", result.Removed) + if result.Errors > 0 { + fmt.Printf(" ❌ Errors: %d\n", result.Errors) + } + fmt.Printf(" ⏱️ Duration: %.2fs\n", result.Duration) + fmt.Printf("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n") + + // Show details if verbose + if catalogVerbose && len(result.Details) > 0 { + fmt.Printf("\nDetails:\n") + for _, detail := range result.Details { + fmt.Printf(" %s\n", detail) + } + } + + return nil +} + +func runCatalogList(cmd *cobra.Command, args []string) error { + cat, err := openCatalog() + if err != nil { + return err + } + defer cat.Close() + + ctx := context.Background() + + query := &catalog.SearchQuery{ + Database: catalogDatabase, + Limit: catalogLimit, + OrderBy: "created_at", + OrderDesc: true, + } + + entries, err := cat.Search(ctx, query) + if err != nil { + return err + } + + if len(entries) == 0 { + fmt.Println("No backups in catalog. Run 'dbbackup catalog sync ' to import backups.") + return nil + } + + if catalogFormat == "json" { + data, _ := json.MarshalIndent(entries, "", " ") + fmt.Println(string(data)) + return nil + } + + // Table format + fmt.Printf("%-30s %-12s %-10s %-20s %-10s %s\n", + "DATABASE", "TYPE", "SIZE", "CREATED", "STATUS", "PATH") + fmt.Println(strings.Repeat("─", 120)) + + for _, entry := range entries { + dbName := truncateString(entry.Database, 28) + backupPath := truncateString(filepath.Base(entry.BackupPath), 40) + + status := string(entry.Status) + if entry.VerifyValid != nil && *entry.VerifyValid { + status = "✓ verified" + } + if entry.DrillSuccess != nil && *entry.DrillSuccess { + status = "✓ tested" + } + + fmt.Printf("%-30s %-12s %-10s %-20s %-10s %s\n", + dbName, + entry.DatabaseType, + catalog.FormatSize(entry.SizeBytes), + entry.CreatedAt.Format("2006-01-02 15:04"), + status, + backupPath, + ) + } + + fmt.Printf("\nShowing %d of %d total backups\n", len(entries), len(entries)) + return nil +} + +func runCatalogStats(cmd *cobra.Command, args []string) error { + cat, err := openCatalog() + if err != nil { + return err + } + defer cat.Close() + + ctx := context.Background() + + var stats *catalog.Stats + if catalogDatabase != "" { + stats, err = cat.StatsByDatabase(ctx, catalogDatabase) + } else { + stats, err = cat.Stats(ctx) + } + if err != nil { + return err + } + + if catalogFormat == "json" { + data, _ := json.MarshalIndent(stats, "", " ") + fmt.Println(string(data)) + return nil + } + + // Table format + fmt.Printf("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n") + if catalogDatabase != "" { + fmt.Printf(" Catalog Statistics: %s\n", catalogDatabase) + } else { + fmt.Printf(" Catalog Statistics\n") + } + fmt.Printf("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n") + + fmt.Printf("📊 Total Backups: %d\n", stats.TotalBackups) + fmt.Printf("💾 Total Size: %s\n", stats.TotalSizeHuman) + fmt.Printf("📏 Average Size: %s\n", catalog.FormatSize(stats.AvgSize)) + fmt.Printf("⏱️ Average Duration: %.1fs\n", stats.AvgDuration) + fmt.Printf("✅ Verified: %d\n", stats.VerifiedCount) + fmt.Printf("🧪 Drill Tested: %d\n", stats.DrillTestedCount) + + if stats.OldestBackup != nil { + fmt.Printf("📅 Oldest Backup: %s\n", stats.OldestBackup.Format("2006-01-02 15:04")) + } + if stats.NewestBackup != nil { + fmt.Printf("📅 Newest Backup: %s\n", stats.NewestBackup.Format("2006-01-02 15:04")) + } + + if len(stats.ByDatabase) > 0 && catalogDatabase == "" { + fmt.Printf("\n📁 By Database:\n") + for db, count := range stats.ByDatabase { + fmt.Printf(" %-30s %d\n", db, count) + } + } + + if len(stats.ByType) > 0 { + fmt.Printf("\n📦 By Type:\n") + for t, count := range stats.ByType { + fmt.Printf(" %-15s %d\n", t, count) + } + } + + if len(stats.ByStatus) > 0 { + fmt.Printf("\n📋 By Status:\n") + for s, count := range stats.ByStatus { + fmt.Printf(" %-15s %d\n", s, count) + } + } + + fmt.Printf("\n━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n") + return nil +} + +func runCatalogGaps(cmd *cobra.Command, args []string) error { + cat, err := openCatalog() + if err != nil { + return err + } + defer cat.Close() + + ctx := context.Background() + + // Parse interval + interval, err := time.ParseDuration(catalogInterval) + if err != nil { + return fmt.Errorf("invalid interval: %w", err) + } + + config := &catalog.GapDetectionConfig{ + ExpectedInterval: interval, + Tolerance: interval / 4, // 25% tolerance + RPOThreshold: interval * 2, // 2x interval = critical + } + + // Parse date range + if catalogStartDate != "" { + t, err := time.Parse("2006-01-02", catalogStartDate) + if err != nil { + return fmt.Errorf("invalid start date: %w", err) + } + config.StartDate = &t + } + if catalogEndDate != "" { + t, err := time.Parse("2006-01-02", catalogEndDate) + if err != nil { + return fmt.Errorf("invalid end date: %w", err) + } + config.EndDate = &t + } + + var allGaps map[string][]*catalog.Gap + + if len(args) > 0 { + // Specific database + database := args[0] + gaps, err := cat.DetectGaps(ctx, database, config) + if err != nil { + return err + } + if len(gaps) > 0 { + allGaps = map[string][]*catalog.Gap{database: gaps} + } + } else { + // All databases + allGaps, err = cat.DetectAllGaps(ctx, config) + if err != nil { + return err + } + } + + if catalogFormat == "json" { + data, _ := json.MarshalIndent(allGaps, "", " ") + fmt.Println(string(data)) + return nil + } + + if len(allGaps) == 0 { + fmt.Printf("✅ No backup gaps detected (expected interval: %s)\n", interval) + return nil + } + + fmt.Printf("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n") + fmt.Printf(" Backup Gaps Detected (expected interval: %s)\n", interval) + fmt.Printf("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n") + + totalGaps := 0 + criticalGaps := 0 + + for database, gaps := range allGaps { + fmt.Printf("📁 %s (%d gaps)\n", database, len(gaps)) + + for _, gap := range gaps { + totalGaps++ + icon := "ℹ️" + switch gap.Severity { + case catalog.SeverityWarning: + icon = "⚠️" + case catalog.SeverityCritical: + icon = "🚨" + criticalGaps++ + } + + fmt.Printf(" %s %s\n", icon, gap.Description) + fmt.Printf(" Gap: %s → %s (%s)\n", + gap.GapStart.Format("2006-01-02 15:04"), + gap.GapEnd.Format("2006-01-02 15:04"), + catalog.FormatDuration(gap.Duration)) + fmt.Printf(" Expected at: %s\n", gap.ExpectedAt.Format("2006-01-02 15:04")) + } + fmt.Println() + } + + fmt.Printf("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n") + fmt.Printf("Total: %d gaps detected", totalGaps) + if criticalGaps > 0 { + fmt.Printf(" (%d critical)", criticalGaps) + } + fmt.Println() + + return nil +} + +func runCatalogSearch(cmd *cobra.Command, args []string) error { + cat, err := openCatalog() + if err != nil { + return err + } + defer cat.Close() + + ctx := context.Background() + + query := &catalog.SearchQuery{ + Database: catalogDatabase, + Limit: catalogLimit, + OrderBy: "created_at", + OrderDesc: true, + } + + // Parse date range + if catalogStartDate != "" { + t, err := time.Parse("2006-01-02", catalogStartDate) + if err != nil { + return fmt.Errorf("invalid start date: %w", err) + } + query.StartDate = &t + } + if catalogEndDate != "" { + t, err := time.Parse("2006-01-02", catalogEndDate) + if err != nil { + return fmt.Errorf("invalid end date: %w", err) + } + query.EndDate = &t + } + + // Boolean filters + if verified, _ := cmd.Flags().GetBool("verified"); verified { + t := true + query.Verified = &t + } + if encrypted, _ := cmd.Flags().GetBool("encrypted"); encrypted { + t := true + query.Encrypted = &t + } + if drillTested, _ := cmd.Flags().GetBool("drill-tested"); drillTested { + t := true + query.DrillTested = &t + } + + entries, err := cat.Search(ctx, query) + if err != nil { + return err + } + + if len(entries) == 0 { + fmt.Println("No matching backups found.") + return nil + } + + if catalogFormat == "json" { + data, _ := json.MarshalIndent(entries, "", " ") + fmt.Println(string(data)) + return nil + } + + fmt.Printf("Found %d matching backups:\n\n", len(entries)) + + for _, entry := range entries { + fmt.Printf("📁 %s\n", entry.Database) + fmt.Printf(" Path: %s\n", entry.BackupPath) + fmt.Printf(" Type: %s | Size: %s | Created: %s\n", + entry.DatabaseType, + catalog.FormatSize(entry.SizeBytes), + entry.CreatedAt.Format("2006-01-02 15:04:05")) + if entry.Encrypted { + fmt.Printf(" 🔒 Encrypted\n") + } + if entry.VerifyValid != nil && *entry.VerifyValid { + fmt.Printf(" ✅ Verified: %s\n", entry.VerifiedAt.Format("2006-01-02 15:04")) + } + if entry.DrillSuccess != nil && *entry.DrillSuccess { + fmt.Printf(" 🧪 Drill Tested: %s\n", entry.DrillTestedAt.Format("2006-01-02 15:04")) + } + fmt.Println() + } + + return nil +} + +func runCatalogInfo(cmd *cobra.Command, args []string) error { + backupPath := args[0] + + cat, err := openCatalog() + if err != nil { + return err + } + defer cat.Close() + + ctx := context.Background() + + // Try absolute path + absPath, _ := filepath.Abs(backupPath) + entry, err := cat.GetByPath(ctx, absPath) + if err != nil { + return err + } + + if entry == nil { + // Try as provided + entry, err = cat.GetByPath(ctx, backupPath) + if err != nil { + return err + } + } + + if entry == nil { + return fmt.Errorf("backup not found in catalog: %s", backupPath) + } + + if catalogFormat == "json" { + data, _ := json.MarshalIndent(entry, "", " ") + fmt.Println(string(data)) + return nil + } + + fmt.Printf("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n") + fmt.Printf(" Backup Details\n") + fmt.Printf("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n") + + fmt.Printf("📁 Database: %s\n", entry.Database) + fmt.Printf("🔧 Type: %s\n", entry.DatabaseType) + fmt.Printf("🖥️ Host: %s:%d\n", entry.Host, entry.Port) + fmt.Printf("📂 Path: %s\n", entry.BackupPath) + fmt.Printf("📦 Backup Type: %s\n", entry.BackupType) + fmt.Printf("💾 Size: %s (%d bytes)\n", catalog.FormatSize(entry.SizeBytes), entry.SizeBytes) + fmt.Printf("🔐 SHA256: %s\n", entry.SHA256) + fmt.Printf("📅 Created: %s\n", entry.CreatedAt.Format("2006-01-02 15:04:05 MST")) + fmt.Printf("⏱️ Duration: %.2fs\n", entry.Duration) + fmt.Printf("📋 Status: %s\n", entry.Status) + + if entry.Compression != "" { + fmt.Printf("📦 Compression: %s\n", entry.Compression) + } + if entry.Encrypted { + fmt.Printf("🔒 Encrypted: yes\n") + } + if entry.CloudLocation != "" { + fmt.Printf("☁️ Cloud: %s\n", entry.CloudLocation) + } + if entry.RetentionPolicy != "" { + fmt.Printf("📆 Retention: %s\n", entry.RetentionPolicy) + } + + fmt.Printf("\n📊 Verification:\n") + if entry.VerifiedAt != nil { + status := "❌ Failed" + if entry.VerifyValid != nil && *entry.VerifyValid { + status = "✅ Valid" + } + fmt.Printf(" Status: %s (checked %s)\n", status, entry.VerifiedAt.Format("2006-01-02 15:04")) + } else { + fmt.Printf(" Status: ⏳ Not verified\n") + } + + fmt.Printf("\n🧪 DR Drill Test:\n") + if entry.DrillTestedAt != nil { + status := "❌ Failed" + if entry.DrillSuccess != nil && *entry.DrillSuccess { + status = "✅ Passed" + } + fmt.Printf(" Status: %s (tested %s)\n", status, entry.DrillTestedAt.Format("2006-01-02 15:04")) + } else { + fmt.Printf(" Status: ⏳ Not tested\n") + } + + if len(entry.Metadata) > 0 { + fmt.Printf("\n📝 Additional Metadata:\n") + for k, v := range entry.Metadata { + fmt.Printf(" %s: %s\n", k, v) + } + } + + fmt.Printf("\n━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n") + + return nil +} + +func truncateString(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen-3] + "..." +} diff --git a/cmd/drill.go b/cmd/drill.go new file mode 100644 index 0000000..2c196f0 --- /dev/null +++ b/cmd/drill.go @@ -0,0 +1,500 @@ +package cmd + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "time" + + "dbbackup/internal/catalog" + "dbbackup/internal/drill" + + "github.com/spf13/cobra" +) + +var ( + drillBackupPath string + drillDatabaseName string + drillDatabaseType string + drillImage string + drillPort int + drillTimeout int + drillRTOTarget int + drillKeepContainer bool + drillOutputDir string + drillFormat string + drillVerbose bool + drillExpectedTables string + drillMinRows int64 + drillQueries string +) + +// drillCmd represents the drill command group +var drillCmd = &cobra.Command{ + Use: "drill", + Short: "Disaster Recovery drill testing", + Long: `Run DR drills to verify backup restorability. + +A DR drill: + 1. Spins up a temporary Docker container + 2. Restores the backup into the container + 3. Runs validation queries + 4. Generates a detailed report + 5. Cleans up the container + +This answers the critical question: "Can I restore this backup at 3 AM?" + +Examples: + # Run a drill on a PostgreSQL backup + dbbackup drill run backup.dump.gz --database mydb --type postgresql + + # Run with validation queries + dbbackup drill run backup.dump.gz --database mydb --type postgresql \ + --validate "SELECT COUNT(*) FROM users" \ + --min-rows 1000 + + # Quick test with minimal validation + dbbackup drill quick backup.dump.gz --database mydb + + # List all drill containers + dbbackup drill list + + # Cleanup old drill containers + dbbackup drill cleanup`, +} + +// drillRunCmd runs a DR drill +var drillRunCmd = &cobra.Command{ + Use: "run [backup-file]", + Short: "Run a DR drill on a backup", + Long: `Execute a complete DR drill on a backup file. + +This will: + 1. Pull the appropriate database Docker image + 2. Start a temporary container + 3. Restore the backup + 4. Run validation queries + 5. Calculate RTO metrics + 6. Generate a report + +Examples: + # Basic drill + dbbackup drill run /backups/mydb_20240115.dump.gz --database mydb --type postgresql + + # With RTO target (5 minutes) + dbbackup drill run /backups/mydb.dump.gz --database mydb --type postgresql --rto 300 + + # With expected tables validation + dbbackup drill run /backups/mydb.dump.gz --database mydb --type postgresql \ + --tables "users,orders,products" + + # Keep container on failure for debugging + dbbackup drill run /backups/mydb.dump.gz --database mydb --type postgresql --keep`, + Args: cobra.ExactArgs(1), + RunE: runDrill, +} + +// drillQuickCmd runs a quick test +var drillQuickCmd = &cobra.Command{ + Use: "quick [backup-file]", + Short: "Quick restore test with minimal validation", + Long: `Run a quick DR test that only verifies the backup can be restored. + +This is faster than a full drill but provides less validation. + +Examples: + # Quick test a PostgreSQL backup + dbbackup drill quick /backups/mydb.dump.gz --database mydb --type postgresql + + # Quick test a MySQL backup + dbbackup drill quick /backups/mydb.sql.gz --database mydb --type mysql`, + Args: cobra.ExactArgs(1), + RunE: runQuickDrill, +} + +// drillListCmd lists drill containers +var drillListCmd = &cobra.Command{ + Use: "list", + Short: "List DR drill containers", + Long: `List all Docker containers created by DR drills. + +Shows containers that may still be running or stopped from previous drills.`, + RunE: runDrillList, +} + +// drillCleanupCmd cleans up drill resources +var drillCleanupCmd = &cobra.Command{ + Use: "cleanup [drill-id]", + Short: "Cleanup DR drill containers", + Long: `Remove containers created by DR drills. + +If no drill ID is specified, removes all drill containers. + +Examples: + # Cleanup all drill containers + dbbackup drill cleanup + + # Cleanup specific drill + dbbackup drill cleanup drill_20240115_120000`, + RunE: runDrillCleanup, +} + +// drillReportCmd shows a drill report +var drillReportCmd = &cobra.Command{ + Use: "report [report-file]", + Short: "Display a DR drill report", + Long: `Display a previously saved DR drill report. + +Examples: + # Show report + dbbackup drill report drill_20240115_120000_report.json + + # Show as JSON + dbbackup drill report drill_20240115_120000_report.json --format json`, + Args: cobra.ExactArgs(1), + RunE: runDrillReport, +} + +func init() { + rootCmd.AddCommand(drillCmd) + + // Add subcommands + drillCmd.AddCommand(drillRunCmd) + drillCmd.AddCommand(drillQuickCmd) + drillCmd.AddCommand(drillListCmd) + drillCmd.AddCommand(drillCleanupCmd) + drillCmd.AddCommand(drillReportCmd) + + // Run command flags + drillRunCmd.Flags().StringVar(&drillDatabaseName, "database", "", "Target database name (required)") + drillRunCmd.Flags().StringVar(&drillDatabaseType, "type", "", "Database type: postgresql, mysql, mariadb (required)") + drillRunCmd.Flags().StringVar(&drillImage, "image", "", "Docker image (default: auto-detect)") + drillRunCmd.Flags().IntVar(&drillPort, "port", 0, "Host port for container (default: 15432/13306)") + drillRunCmd.Flags().IntVar(&drillTimeout, "timeout", 60, "Container startup timeout in seconds") + drillRunCmd.Flags().IntVar(&drillRTOTarget, "rto", 300, "RTO target in seconds") + drillRunCmd.Flags().BoolVar(&drillKeepContainer, "keep", false, "Keep container after drill") + drillRunCmd.Flags().StringVar(&drillOutputDir, "output", "", "Output directory for reports") + drillRunCmd.Flags().StringVar(&drillFormat, "format", "table", "Output format: table, json") + drillRunCmd.Flags().BoolVarP(&drillVerbose, "verbose", "v", false, "Verbose output") + drillRunCmd.Flags().StringVar(&drillExpectedTables, "tables", "", "Expected tables (comma-separated)") + drillRunCmd.Flags().Int64Var(&drillMinRows, "min-rows", 0, "Minimum expected row count") + drillRunCmd.Flags().StringVar(&drillQueries, "validate", "", "Validation SQL query") + + drillRunCmd.MarkFlagRequired("database") + drillRunCmd.MarkFlagRequired("type") + + // Quick command flags + drillQuickCmd.Flags().StringVar(&drillDatabaseName, "database", "", "Target database name (required)") + drillQuickCmd.Flags().StringVar(&drillDatabaseType, "type", "", "Database type: postgresql, mysql, mariadb (required)") + drillQuickCmd.Flags().BoolVarP(&drillVerbose, "verbose", "v", false, "Verbose output") + + drillQuickCmd.MarkFlagRequired("database") + drillQuickCmd.MarkFlagRequired("type") + + // Report command flags + drillReportCmd.Flags().StringVar(&drillFormat, "format", "table", "Output format: table, json") +} + +func runDrill(cmd *cobra.Command, args []string) error { + backupPath := args[0] + + // Validate backup file exists + absPath, err := filepath.Abs(backupPath) + if err != nil { + return fmt.Errorf("invalid backup path: %w", err) + } + if _, err := os.Stat(absPath); err != nil { + return fmt.Errorf("backup file not found: %s", absPath) + } + + // Build drill config + config := drill.DefaultConfig() + config.BackupPath = absPath + config.DatabaseName = drillDatabaseName + config.DatabaseType = drillDatabaseType + config.ContainerImage = drillImage + config.ContainerPort = drillPort + config.ContainerTimeout = drillTimeout + config.MaxRestoreSeconds = drillRTOTarget + config.CleanupOnExit = !drillKeepContainer + config.KeepOnFailure = true + config.OutputDir = drillOutputDir + config.Verbose = drillVerbose + + // Parse expected tables + if drillExpectedTables != "" { + config.ExpectedTables = strings.Split(drillExpectedTables, ",") + for i := range config.ExpectedTables { + config.ExpectedTables[i] = strings.TrimSpace(config.ExpectedTables[i]) + } + } + + // Set minimum row count + config.MinRowCount = drillMinRows + + // Add validation query if provided + if drillQueries != "" { + config.ValidationQueries = append(config.ValidationQueries, drill.ValidationQuery{ + Name: "Custom Query", + Query: drillQueries, + MustSucceed: true, + }) + } + + // Create drill engine + engine := drill.NewEngine(log, drillVerbose) + + // Run drill + ctx := cmd.Context() + result, err := engine.Run(ctx, config) + if err != nil { + return err + } + + // Update catalog if available + updateCatalogWithDrillResult(ctx, absPath, result) + + // Output result + if drillFormat == "json" { + data, _ := json.MarshalIndent(result, "", " ") + fmt.Println(string(data)) + } else { + printDrillResult(result) + } + + if !result.Success { + return fmt.Errorf("drill failed: %s", result.Message) + } + + return nil +} + +func runQuickDrill(cmd *cobra.Command, args []string) error { + backupPath := args[0] + + absPath, err := filepath.Abs(backupPath) + if err != nil { + return fmt.Errorf("invalid backup path: %w", err) + } + if _, err := os.Stat(absPath); err != nil { + return fmt.Errorf("backup file not found: %s", absPath) + } + + engine := drill.NewEngine(log, drillVerbose) + + ctx := cmd.Context() + result, err := engine.QuickTest(ctx, absPath, drillDatabaseType, drillDatabaseName) + if err != nil { + return err + } + + // Update catalog + updateCatalogWithDrillResult(ctx, absPath, result) + + printDrillResult(result) + + if !result.Success { + return fmt.Errorf("quick test failed: %s", result.Message) + } + + return nil +} + +func runDrillList(cmd *cobra.Command, args []string) error { + docker := drill.NewDockerManager(false) + + ctx := cmd.Context() + containers, err := docker.ListDrillContainers(ctx) + if err != nil { + return err + } + + if len(containers) == 0 { + fmt.Println("No drill containers found.") + return nil + } + + fmt.Printf("%-15s %-40s %-20s %s\n", "ID", "NAME", "IMAGE", "STATUS") + fmt.Println(strings.Repeat("─", 100)) + + for _, c := range containers { + fmt.Printf("%-15s %-40s %-20s %s\n", + c.ID[:12], + truncateString(c.Name, 38), + truncateString(c.Image, 18), + c.Status, + ) + } + + return nil +} + +func runDrillCleanup(cmd *cobra.Command, args []string) error { + drillID := "" + if len(args) > 0 { + drillID = args[0] + } + + engine := drill.NewEngine(log, true) + + ctx := cmd.Context() + if err := engine.Cleanup(ctx, drillID); err != nil { + return err + } + + fmt.Println("✅ Cleanup completed") + return nil +} + +func runDrillReport(cmd *cobra.Command, args []string) error { + reportPath := args[0] + + result, err := drill.LoadResult(reportPath) + if err != nil { + return err + } + + if drillFormat == "json" { + data, _ := json.MarshalIndent(result, "", " ") + fmt.Println(string(data)) + } else { + printDrillResult(result) + } + + return nil +} + +func printDrillResult(result *drill.DrillResult) { + fmt.Printf("\n") + fmt.Printf("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n") + fmt.Printf(" DR Drill Report: %s\n", result.DrillID) + fmt.Printf("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n") + + status := "✅ PASSED" + if !result.Success { + status = "❌ FAILED" + } else if result.Status == drill.StatusPartial { + status = "⚠️ PARTIAL" + } + + fmt.Printf("📋 Status: %s\n", status) + fmt.Printf("💾 Backup: %s\n", filepath.Base(result.BackupPath)) + fmt.Printf("🗄️ Database: %s (%s)\n", result.DatabaseName, result.DatabaseType) + fmt.Printf("⏱️ Duration: %.2fs\n", result.Duration) + fmt.Printf("📅 Started: %s\n", result.StartTime.Format(time.RFC3339)) + fmt.Printf("\n") + + // Phases + fmt.Printf("📊 Phases:\n") + for _, phase := range result.Phases { + icon := "✅" + if phase.Status == "failed" { + icon = "❌" + } else if phase.Status == "running" { + icon = "🔄" + } + fmt.Printf(" %s %-20s (%.2fs) %s\n", icon, phase.Name, phase.Duration, phase.Message) + } + fmt.Printf("\n") + + // Metrics + fmt.Printf("📈 Metrics:\n") + fmt.Printf(" Tables: %d\n", result.TableCount) + fmt.Printf(" Total Rows: %d\n", result.TotalRows) + fmt.Printf(" Restore Time: %.2fs\n", result.RestoreTime) + fmt.Printf(" Validation: %.2fs\n", result.ValidationTime) + if result.QueryTimeAvg > 0 { + fmt.Printf(" Avg Query Time: %.0fms\n", result.QueryTimeAvg) + } + fmt.Printf("\n") + + // RTO + fmt.Printf("⏱️ RTO Analysis:\n") + rtoIcon := "✅" + if !result.RTOMet { + rtoIcon = "❌" + } + fmt.Printf(" Actual RTO: %.2fs\n", result.ActualRTO) + fmt.Printf(" Target RTO: %.0fs\n", result.TargetRTO) + fmt.Printf(" RTO Met: %s\n", rtoIcon) + fmt.Printf("\n") + + // Validation results + if len(result.ValidationResults) > 0 { + fmt.Printf("🔍 Validation Queries:\n") + for _, vr := range result.ValidationResults { + icon := "✅" + if !vr.Success { + icon = "❌" + } + fmt.Printf(" %s %s: %s\n", icon, vr.Name, vr.Result) + if vr.Error != "" { + fmt.Printf(" Error: %s\n", vr.Error) + } + } + fmt.Printf("\n") + } + + // Check results + if len(result.CheckResults) > 0 { + fmt.Printf("✓ Checks:\n") + for _, cr := range result.CheckResults { + icon := "✅" + if !cr.Success { + icon = "❌" + } + fmt.Printf(" %s %s\n", icon, cr.Message) + } + fmt.Printf("\n") + } + + // Errors and warnings + if len(result.Errors) > 0 { + fmt.Printf("❌ Errors:\n") + for _, e := range result.Errors { + fmt.Printf(" • %s\n", e) + } + fmt.Printf("\n") + } + + if len(result.Warnings) > 0 { + fmt.Printf("⚠️ Warnings:\n") + for _, w := range result.Warnings { + fmt.Printf(" • %s\n", w) + } + fmt.Printf("\n") + } + + // Container info + if result.ContainerKept { + fmt.Printf("📦 Container kept: %s\n", result.ContainerID[:12]) + fmt.Printf(" Connect with: docker exec -it %s bash\n", result.ContainerID[:12]) + fmt.Printf("\n") + } + + fmt.Printf("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n") + fmt.Printf(" %s\n", result.Message) + fmt.Printf("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n") +} + +func updateCatalogWithDrillResult(ctx context.Context, backupPath string, result *drill.DrillResult) { + // Try to update the catalog with drill results + cat, err := catalog.NewSQLiteCatalog(catalogDBPath) + if err != nil { + return // Catalog not available, skip + } + defer cat.Close() + + entry, err := cat.GetByPath(ctx, backupPath) + if err != nil || entry == nil { + return // Entry not in catalog + } + + // Update drill status + if err := cat.MarkDrillTested(ctx, entry.ID, result.Success); err != nil { + log.Debug("Failed to update catalog drill status", "error", err) + } +} diff --git a/cmd/pitr.go b/cmd/pitr.go index a1ef099..32509ad 100644 --- a/cmd/pitr.go +++ b/cmd/pitr.go @@ -2,10 +2,15 @@ package cmd import ( "context" + "database/sql" "fmt" + "os" + "path/filepath" + "time" "github.com/spf13/cobra" + "dbbackup/internal/pitr" "dbbackup/internal/wal" ) @@ -32,6 +37,14 @@ var ( pitrTargetImmediate bool pitrRecoveryAction string pitrWALSource string + + // MySQL PITR flags + mysqlBinlogDir string + mysqlArchiveDir string + mysqlArchiveInterval string + mysqlRequireRowFormat bool + mysqlRequireGTID bool + mysqlWatchMode bool ) // pitrCmd represents the pitr command group @@ -183,21 +196,180 @@ Example: RunE: runWALTimeline, } +// ============================================================================ +// MySQL/MariaDB Binlog Commands +// ============================================================================ + +// binlogCmd represents the binlog command group (MySQL equivalent of WAL) +var binlogCmd = &cobra.Command{ + Use: "binlog", + Short: "Binary log operations for MySQL/MariaDB", + Long: `Manage MySQL/MariaDB binary log files for Point-in-Time Recovery. + +Binary logs contain all changes made to the database and are essential +for Point-in-Time Recovery (PITR) with MySQL and MariaDB. + +Commands: + list - List available binlog files + archive - Archive binlog files + watch - Watch for new binlog files and archive them + validate - Validate binlog chain integrity + position - Show current binlog position +`, +} + +// binlogListCmd lists binary log files +var binlogListCmd = &cobra.Command{ + Use: "list", + Short: "List binary log files", + Long: `List all available binary log files from the MySQL data directory +and/or the archive directory. + +Shows: filename, size, timestamps, server_id, and format for each binlog. + +Examples: + dbbackup binlog list --binlog-dir /var/lib/mysql + dbbackup binlog list --archive-dir /backups/binlog_archive +`, + RunE: runBinlogList, +} + +// binlogArchiveCmd archives binary log files +var binlogArchiveCmd = &cobra.Command{ + Use: "archive", + Short: "Archive binary log files", + Long: `Archive MySQL binary log files to a backup location. + +This command copies completed binlog files (not the currently active one) +to the archive directory, optionally with compression and encryption. + +Examples: + dbbackup binlog archive --binlog-dir /var/lib/mysql --archive-dir /backups/binlog + dbbackup binlog archive --compress --archive-dir /backups/binlog +`, + RunE: runBinlogArchive, +} + +// binlogWatchCmd watches for new binlogs and archives them +var binlogWatchCmd = &cobra.Command{ + Use: "watch", + Short: "Watch for new binlog files and archive them automatically", + Long: `Continuously monitor the binlog directory for new files and +archive them automatically when they are closed. + +This runs as a background process and provides continuous binlog archiving +for PITR capability. + +Example: + dbbackup binlog watch --binlog-dir /var/lib/mysql --archive-dir /backups/binlog --interval 30s +`, + RunE: runBinlogWatch, +} + +// binlogValidateCmd validates binlog chain +var binlogValidateCmd = &cobra.Command{ + Use: "validate", + Short: "Validate binlog chain integrity", + Long: `Check the binary log chain for gaps or inconsistencies. + +Validates: +- Sequential numbering of binlog files +- No missing files in the chain +- Server ID consistency +- GTID continuity (if enabled) + +Example: + dbbackup binlog validate --binlog-dir /var/lib/mysql + dbbackup binlog validate --archive-dir /backups/binlog +`, + RunE: runBinlogValidate, +} + +// binlogPositionCmd shows current binlog position +var binlogPositionCmd = &cobra.Command{ + Use: "position", + Short: "Show current binary log position", + Long: `Display the current MySQL binary log position. + +This connects to MySQL and runs SHOW MASTER STATUS to get: +- Current binlog filename +- Current byte position +- Executed GTID set (if GTID mode is enabled) + +Example: + dbbackup binlog position +`, + RunE: runBinlogPosition, +} + +// mysqlPitrStatusCmd shows MySQL-specific PITR status +var mysqlPitrStatusCmd = &cobra.Command{ + Use: "mysql-status", + Short: "Show MySQL/MariaDB PITR status", + Long: `Display MySQL/MariaDB-specific PITR configuration and status. + +Shows: +- Binary log configuration (log_bin, binlog_format) +- GTID mode status +- Archive directory and statistics +- Current binlog position +- Recovery windows available + +Example: + dbbackup pitr mysql-status +`, + RunE: runMySQLPITRStatus, +} + +// mysqlPitrEnableCmd enables MySQL PITR +var mysqlPitrEnableCmd = &cobra.Command{ + Use: "mysql-enable", + Short: "Enable PITR for MySQL/MariaDB", + Long: `Configure MySQL/MariaDB for Point-in-Time Recovery. + +This validates MySQL settings and sets up binlog archiving: +- Checks binary logging is enabled (log_bin=ON) +- Validates binlog_format (ROW recommended) +- Creates archive directory +- Saves PITR configuration + +Prerequisites in my.cnf: + [mysqld] + log_bin = mysql-bin + binlog_format = ROW + server_id = 1 + +Example: + dbbackup pitr mysql-enable --archive-dir /backups/binlog_archive +`, + RunE: runMySQLPITREnable, +} + func init() { rootCmd.AddCommand(pitrCmd) rootCmd.AddCommand(walCmd) + rootCmd.AddCommand(binlogCmd) // PITR subcommands pitrCmd.AddCommand(pitrEnableCmd) pitrCmd.AddCommand(pitrDisableCmd) pitrCmd.AddCommand(pitrStatusCmd) + pitrCmd.AddCommand(mysqlPitrStatusCmd) + pitrCmd.AddCommand(mysqlPitrEnableCmd) - // WAL subcommands + // WAL subcommands (PostgreSQL) walCmd.AddCommand(walArchiveCmd) walCmd.AddCommand(walListCmd) walCmd.AddCommand(walCleanupCmd) walCmd.AddCommand(walTimelineCmd) + // Binlog subcommands (MySQL/MariaDB) + binlogCmd.AddCommand(binlogListCmd) + binlogCmd.AddCommand(binlogArchiveCmd) + binlogCmd.AddCommand(binlogWatchCmd) + binlogCmd.AddCommand(binlogValidateCmd) + binlogCmd.AddCommand(binlogPositionCmd) + // PITR enable flags pitrEnableCmd.Flags().StringVar(&pitrArchiveDir, "archive-dir", "/var/backups/wal_archive", "Directory to store WAL archives") pitrEnableCmd.Flags().BoolVar(&pitrForce, "force", false, "Overwrite existing PITR configuration") @@ -219,6 +391,33 @@ func init() { // WAL timeline flags walTimelineCmd.Flags().StringVar(&walArchiveDir, "archive-dir", "/var/backups/wal_archive", "WAL archive directory") + + // MySQL binlog flags + binlogListCmd.Flags().StringVar(&mysqlBinlogDir, "binlog-dir", "/var/lib/mysql", "MySQL binary log directory") + binlogListCmd.Flags().StringVar(&mysqlArchiveDir, "archive-dir", "", "Binlog archive directory") + + binlogArchiveCmd.Flags().StringVar(&mysqlBinlogDir, "binlog-dir", "/var/lib/mysql", "MySQL binary log directory") + binlogArchiveCmd.Flags().StringVar(&mysqlArchiveDir, "archive-dir", "/var/backups/binlog_archive", "Binlog archive directory") + binlogArchiveCmd.Flags().BoolVar(&walCompress, "compress", false, "Compress binlog files") + binlogArchiveCmd.Flags().BoolVar(&walEncrypt, "encrypt", false, "Encrypt binlog files") + binlogArchiveCmd.Flags().StringVar(&walEncryptionKeyFile, "encryption-key-file", "", "Path to encryption key file") + binlogArchiveCmd.MarkFlagRequired("archive-dir") + + binlogWatchCmd.Flags().StringVar(&mysqlBinlogDir, "binlog-dir", "/var/lib/mysql", "MySQL binary log directory") + binlogWatchCmd.Flags().StringVar(&mysqlArchiveDir, "archive-dir", "/var/backups/binlog_archive", "Binlog archive directory") + binlogWatchCmd.Flags().StringVar(&mysqlArchiveInterval, "interval", "30s", "Check interval for new binlogs") + binlogWatchCmd.Flags().BoolVar(&walCompress, "compress", false, "Compress binlog files") + binlogWatchCmd.MarkFlagRequired("archive-dir") + + binlogValidateCmd.Flags().StringVar(&mysqlBinlogDir, "binlog-dir", "/var/lib/mysql", "MySQL binary log directory") + binlogValidateCmd.Flags().StringVar(&mysqlArchiveDir, "archive-dir", "", "Binlog archive directory") + + // MySQL PITR enable flags + mysqlPitrEnableCmd.Flags().StringVar(&mysqlArchiveDir, "archive-dir", "/var/backups/binlog_archive", "Binlog archive directory") + mysqlPitrEnableCmd.Flags().IntVar(&walRetentionDays, "retention-days", 7, "Days to keep archived binlogs") + mysqlPitrEnableCmd.Flags().BoolVar(&mysqlRequireRowFormat, "require-row-format", true, "Require ROW binlog format") + mysqlPitrEnableCmd.Flags().BoolVar(&mysqlRequireGTID, "require-gtid", false, "Require GTID mode enabled") + mysqlPitrEnableCmd.MarkFlagRequired("archive-dir") } // Command implementations @@ -512,3 +711,614 @@ func formatWALSize(bytes int64) string { } return fmt.Sprintf("%.1f KB", float64(bytes)/float64(KB)) } + +// ============================================================================ +// MySQL/MariaDB Binlog Command Implementations +// ============================================================================ + +func runBinlogList(cmd *cobra.Command, args []string) error { + ctx := context.Background() + + if !cfg.IsMySQL() { + return fmt.Errorf("binlog commands are only supported for MySQL/MariaDB (detected: %s)", cfg.DisplayDatabaseType()) + } + + binlogDir := mysqlBinlogDir + if binlogDir == "" && mysqlArchiveDir != "" { + binlogDir = mysqlArchiveDir + } + + if binlogDir == "" { + return fmt.Errorf("please specify --binlog-dir or --archive-dir") + } + + bmConfig := pitr.BinlogManagerConfig{ + BinlogDir: binlogDir, + ArchiveDir: mysqlArchiveDir, + } + + bm, err := pitr.NewBinlogManager(bmConfig) + if err != nil { + return fmt.Errorf("initializing binlog manager: %w", err) + } + + // List binlogs from source directory + binlogs, err := bm.DiscoverBinlogs(ctx) + if err != nil { + return fmt.Errorf("discovering binlogs: %w", err) + } + + // Also list archived binlogs if archive dir is specified + var archived []pitr.BinlogArchiveInfo + if mysqlArchiveDir != "" { + archived, _ = bm.ListArchivedBinlogs(ctx) + } + + if len(binlogs) == 0 && len(archived) == 0 { + fmt.Println("No binary log files found") + return nil + } + + fmt.Println("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") + fmt.Printf(" Binary Log Files (%s)\n", bm.ServerType()) + fmt.Println("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") + fmt.Println() + + if len(binlogs) > 0 { + fmt.Println("Source Directory:") + fmt.Printf("%-24s %10s %-19s %-19s %s\n", "Filename", "Size", "Start Time", "End Time", "Format") + fmt.Println("────────────────────────────────────────────────────────────────────────────────") + + var totalSize int64 + for _, b := range binlogs { + size := formatWALSize(b.Size) + totalSize += b.Size + + startTime := "unknown" + endTime := "unknown" + if !b.StartTime.IsZero() { + startTime = b.StartTime.Format("2006-01-02 15:04:05") + } + if !b.EndTime.IsZero() { + endTime = b.EndTime.Format("2006-01-02 15:04:05") + } + + format := b.Format + if format == "" { + format = "-" + } + + fmt.Printf("%-24s %10s %-19s %-19s %s\n", b.Name, size, startTime, endTime, format) + } + fmt.Printf("\nTotal: %d files, %s\n", len(binlogs), formatWALSize(totalSize)) + } + + if len(archived) > 0 { + fmt.Println() + fmt.Println("Archived Binlogs:") + fmt.Printf("%-24s %10s %-19s %s\n", "Original", "Size", "Archived At", "Flags") + fmt.Println("────────────────────────────────────────────────────────────────────────────────") + + var totalSize int64 + for _, a := range archived { + size := formatWALSize(a.Size) + totalSize += a.Size + + archivedTime := a.ArchivedAt.Format("2006-01-02 15:04:05") + + flags := "" + if a.Compressed { + flags += "C" + } + if a.Encrypted { + flags += "E" + } + if flags != "" { + flags = "[" + flags + "]" + } + + fmt.Printf("%-24s %10s %-19s %s\n", a.OriginalFile, size, archivedTime, flags) + } + fmt.Printf("\nTotal archived: %d files, %s\n", len(archived), formatWALSize(totalSize)) + } + + return nil +} + +func runBinlogArchive(cmd *cobra.Command, args []string) error { + ctx := context.Background() + + if !cfg.IsMySQL() { + return fmt.Errorf("binlog commands are only supported for MySQL/MariaDB") + } + + if mysqlBinlogDir == "" { + return fmt.Errorf("--binlog-dir is required") + } + + // Load encryption key if needed + var encryptionKey []byte + if walEncrypt { + key, err := loadEncryptionKey(walEncryptionKeyFile, walEncryptionKeyEnv) + if err != nil { + return fmt.Errorf("failed to load encryption key: %w", err) + } + encryptionKey = key + } + + bmConfig := pitr.BinlogManagerConfig{ + BinlogDir: mysqlBinlogDir, + ArchiveDir: mysqlArchiveDir, + Compression: walCompress, + Encryption: walEncrypt, + EncryptionKey: encryptionKey, + } + + bm, err := pitr.NewBinlogManager(bmConfig) + if err != nil { + return fmt.Errorf("initializing binlog manager: %w", err) + } + + // Discover binlogs + binlogs, err := bm.DiscoverBinlogs(ctx) + if err != nil { + return fmt.Errorf("discovering binlogs: %w", err) + } + + // Get already archived + archived, _ := bm.ListArchivedBinlogs(ctx) + archivedSet := make(map[string]struct{}) + for _, a := range archived { + archivedSet[a.OriginalFile] = struct{}{} + } + + // Need to connect to MySQL to get current position + // For now, skip the active binlog by looking at which one was modified most recently + var latestModTime int64 + var latestBinlog string + for _, b := range binlogs { + if b.ModTime.Unix() > latestModTime { + latestModTime = b.ModTime.Unix() + latestBinlog = b.Name + } + } + + var newArchives []pitr.BinlogArchiveInfo + for i := range binlogs { + b := &binlogs[i] + + // Skip if already archived + if _, exists := archivedSet[b.Name]; exists { + log.Info("Skipping already archived", "binlog", b.Name) + continue + } + + // Skip the most recently modified (likely active) + if b.Name == latestBinlog { + log.Info("Skipping active binlog", "binlog", b.Name) + continue + } + + log.Info("Archiving binlog", "binlog", b.Name, "size", formatWALSize(b.Size)) + archiveInfo, err := bm.ArchiveBinlog(ctx, b) + if err != nil { + log.Error("Failed to archive binlog", "binlog", b.Name, "error", err) + continue + } + newArchives = append(newArchives, *archiveInfo) + } + + // Update metadata + if len(newArchives) > 0 { + allArchived, _ := bm.ListArchivedBinlogs(ctx) + bm.SaveArchiveMetadata(allArchived) + } + + log.Info("✅ Binlog archiving completed", "archived", len(newArchives)) + return nil +} + +func runBinlogWatch(cmd *cobra.Command, args []string) error { + ctx := context.Background() + + if !cfg.IsMySQL() { + return fmt.Errorf("binlog commands are only supported for MySQL/MariaDB") + } + + interval, err := time.ParseDuration(mysqlArchiveInterval) + if err != nil { + return fmt.Errorf("invalid interval: %w", err) + } + + bmConfig := pitr.BinlogManagerConfig{ + BinlogDir: mysqlBinlogDir, + ArchiveDir: mysqlArchiveDir, + Compression: walCompress, + } + + bm, err := pitr.NewBinlogManager(bmConfig) + if err != nil { + return fmt.Errorf("initializing binlog manager: %w", err) + } + + log.Info("Starting binlog watcher", + "binlog_dir", mysqlBinlogDir, + "archive_dir", mysqlArchiveDir, + "interval", interval) + + // Watch for new binlogs + err = bm.WatchBinlogs(ctx, interval, func(b *pitr.BinlogFile) { + log.Info("New binlog detected, archiving", "binlog", b.Name) + archiveInfo, err := bm.ArchiveBinlog(ctx, b) + if err != nil { + log.Error("Failed to archive binlog", "binlog", b.Name, "error", err) + return + } + log.Info("Binlog archived successfully", + "binlog", b.Name, + "archive", archiveInfo.ArchivePath, + "size", formatWALSize(archiveInfo.Size)) + + // Update metadata + allArchived, _ := bm.ListArchivedBinlogs(ctx) + bm.SaveArchiveMetadata(allArchived) + }) + + if err != nil && err != context.Canceled { + return err + } + + return nil +} + +func runBinlogValidate(cmd *cobra.Command, args []string) error { + ctx := context.Background() + + if !cfg.IsMySQL() { + return fmt.Errorf("binlog commands are only supported for MySQL/MariaDB") + } + + binlogDir := mysqlBinlogDir + if binlogDir == "" { + binlogDir = mysqlArchiveDir + } + + if binlogDir == "" { + return fmt.Errorf("please specify --binlog-dir or --archive-dir") + } + + bmConfig := pitr.BinlogManagerConfig{ + BinlogDir: binlogDir, + ArchiveDir: mysqlArchiveDir, + } + + bm, err := pitr.NewBinlogManager(bmConfig) + if err != nil { + return fmt.Errorf("initializing binlog manager: %w", err) + } + + // Discover binlogs + binlogs, err := bm.DiscoverBinlogs(ctx) + if err != nil { + return fmt.Errorf("discovering binlogs: %w", err) + } + + if len(binlogs) == 0 { + fmt.Println("No binlog files found to validate") + return nil + } + + // Validate chain + validation, err := bm.ValidateBinlogChain(ctx, binlogs) + if err != nil { + return fmt.Errorf("validating binlog chain: %w", err) + } + + fmt.Println("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") + fmt.Println(" Binlog Chain Validation") + fmt.Println("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") + fmt.Println() + + if validation.Valid { + fmt.Println("Status: ✅ VALID - Binlog chain is complete") + } else { + fmt.Println("Status: ❌ INVALID - Binlog chain has gaps") + } + + fmt.Printf("Files: %d binlog files\n", validation.LogCount) + fmt.Printf("Total Size: %s\n", formatWALSize(validation.TotalSize)) + + if validation.StartPos != nil { + fmt.Printf("Start: %s\n", validation.StartPos.String()) + } + if validation.EndPos != nil { + fmt.Printf("End: %s\n", validation.EndPos.String()) + } + + if len(validation.Gaps) > 0 { + fmt.Println() + fmt.Println("Gaps Found:") + for _, gap := range validation.Gaps { + fmt.Printf(" • After %s, before %s: %s\n", gap.After, gap.Before, gap.Reason) + } + } + + if len(validation.Warnings) > 0 { + fmt.Println() + fmt.Println("Warnings:") + for _, w := range validation.Warnings { + fmt.Printf(" ⚠ %s\n", w) + } + } + + if len(validation.Errors) > 0 { + fmt.Println() + fmt.Println("Errors:") + for _, e := range validation.Errors { + fmt.Printf(" ✗ %s\n", e) + } + } + + if !validation.Valid { + os.Exit(1) + } + + return nil +} + +func runBinlogPosition(cmd *cobra.Command, args []string) error { + ctx := context.Background() + + if !cfg.IsMySQL() { + return fmt.Errorf("binlog commands are only supported for MySQL/MariaDB") + } + + // Connect to MySQL + dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/", + cfg.User, cfg.Password, cfg.Host, cfg.Port) + + db, err := sql.Open("mysql", dsn) + if err != nil { + return fmt.Errorf("connecting to MySQL: %w", err) + } + defer db.Close() + + if err := db.PingContext(ctx); err != nil { + return fmt.Errorf("pinging MySQL: %w", err) + } + + // Get binlog position using raw query + rows, err := db.QueryContext(ctx, "SHOW MASTER STATUS") + if err != nil { + return fmt.Errorf("getting master status: %w", err) + } + defer rows.Close() + + fmt.Println("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") + fmt.Println(" Current Binary Log Position") + fmt.Println("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") + fmt.Println() + + if rows.Next() { + var file string + var position uint64 + var binlogDoDB, binlogIgnoreDB, executedGtidSet sql.NullString + + cols, _ := rows.Columns() + switch len(cols) { + case 5: + err = rows.Scan(&file, &position, &binlogDoDB, &binlogIgnoreDB, &executedGtidSet) + case 4: + err = rows.Scan(&file, &position, &binlogDoDB, &binlogIgnoreDB) + default: + err = rows.Scan(&file, &position) + } + + if err != nil { + return fmt.Errorf("scanning master status: %w", err) + } + + fmt.Printf("File: %s\n", file) + fmt.Printf("Position: %d\n", position) + if executedGtidSet.Valid && executedGtidSet.String != "" { + fmt.Printf("GTID Set: %s\n", executedGtidSet.String) + } + + // Compact format for use in restore commands + fmt.Println() + fmt.Printf("Position String: %s:%d\n", file, position) + } else { + fmt.Println("Binary logging appears to be disabled.") + fmt.Println("Enable binary logging by adding to my.cnf:") + fmt.Println(" [mysqld]") + fmt.Println(" log_bin = mysql-bin") + fmt.Println(" server_id = 1") + } + + return nil +} + +func runMySQLPITRStatus(cmd *cobra.Command, args []string) error { + ctx := context.Background() + + if !cfg.IsMySQL() { + return fmt.Errorf("this command is only for MySQL/MariaDB (use 'pitr status' for PostgreSQL)") + } + + // Connect to MySQL + dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/", + cfg.User, cfg.Password, cfg.Host, cfg.Port) + + db, err := sql.Open("mysql", dsn) + if err != nil { + return fmt.Errorf("connecting to MySQL: %w", err) + } + defer db.Close() + + if err := db.PingContext(ctx); err != nil { + return fmt.Errorf("pinging MySQL: %w", err) + } + + pitrConfig := pitr.MySQLPITRConfig{ + Host: cfg.Host, + Port: cfg.Port, + User: cfg.User, + Password: cfg.Password, + BinlogDir: mysqlBinlogDir, + ArchiveDir: mysqlArchiveDir, + } + + mysqlPitr, err := pitr.NewMySQLPITR(db, pitrConfig) + if err != nil { + return fmt.Errorf("initializing MySQL PITR: %w", err) + } + + status, err := mysqlPitr.Status(ctx) + if err != nil { + return fmt.Errorf("getting PITR status: %w", err) + } + + fmt.Println("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") + fmt.Printf(" MySQL/MariaDB PITR Status (%s)\n", status.DatabaseType) + fmt.Println("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") + fmt.Println() + + if status.Enabled { + fmt.Println("PITR Status: ✅ ENABLED") + } else { + fmt.Println("PITR Status: ❌ NOT CONFIGURED") + } + + // Get binary logging status + var logBin string + db.QueryRowContext(ctx, "SELECT @@log_bin").Scan(&logBin) + if logBin == "1" || logBin == "ON" { + fmt.Println("Binary Logging: ✅ ENABLED") + } else { + fmt.Println("Binary Logging: ❌ DISABLED") + } + + fmt.Printf("Binlog Format: %s\n", status.LogLevel) + + // Check GTID mode + var gtidMode string + if status.DatabaseType == pitr.DatabaseMariaDB { + db.QueryRowContext(ctx, "SELECT @@gtid_current_pos").Scan(>idMode) + if gtidMode != "" { + fmt.Println("GTID Mode: ✅ ENABLED") + } else { + fmt.Println("GTID Mode: ❌ DISABLED") + } + } else { + db.QueryRowContext(ctx, "SELECT @@gtid_mode").Scan(>idMode) + if gtidMode == "ON" { + fmt.Println("GTID Mode: ✅ ENABLED") + } else { + fmt.Printf("GTID Mode: %s\n", gtidMode) + } + } + + if status.Position != nil { + fmt.Printf("Current Position: %s\n", status.Position.String()) + } + + if status.ArchiveDir != "" { + fmt.Println() + fmt.Println("Archive Statistics:") + fmt.Printf(" Directory: %s\n", status.ArchiveDir) + fmt.Printf(" File Count: %d\n", status.ArchiveCount) + fmt.Printf(" Total Size: %s\n", formatWALSize(status.ArchiveSize)) + if !status.LastArchived.IsZero() { + fmt.Printf(" Last Archive: %s\n", status.LastArchived.Format("2006-01-02 15:04:05")) + } + } + + // Show requirements + fmt.Println() + fmt.Println("PITR Requirements:") + if logBin == "1" || logBin == "ON" { + fmt.Println(" ✅ Binary logging enabled") + } else { + fmt.Println(" ❌ Binary logging must be enabled (log_bin = mysql-bin)") + } + if status.LogLevel == "ROW" { + fmt.Println(" ✅ Row-based logging (recommended)") + } else { + fmt.Printf(" ⚠ binlog_format = %s (ROW recommended for PITR)\n", status.LogLevel) + } + + return nil +} + +func runMySQLPITREnable(cmd *cobra.Command, args []string) error { + ctx := context.Background() + + if !cfg.IsMySQL() { + return fmt.Errorf("this command is only for MySQL/MariaDB (use 'pitr enable' for PostgreSQL)") + } + + // Connect to MySQL + dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/", + cfg.User, cfg.Password, cfg.Host, cfg.Port) + + db, err := sql.Open("mysql", dsn) + if err != nil { + return fmt.Errorf("connecting to MySQL: %w", err) + } + defer db.Close() + + if err := db.PingContext(ctx); err != nil { + return fmt.Errorf("pinging MySQL: %w", err) + } + + pitrConfig := pitr.MySQLPITRConfig{ + Host: cfg.Host, + Port: cfg.Port, + User: cfg.User, + Password: cfg.Password, + BinlogDir: mysqlBinlogDir, + ArchiveDir: mysqlArchiveDir, + RequireRowFormat: mysqlRequireRowFormat, + RequireGTID: mysqlRequireGTID, + } + + mysqlPitr, err := pitr.NewMySQLPITR(db, pitrConfig) + if err != nil { + return fmt.Errorf("initializing MySQL PITR: %w", err) + } + + enableConfig := pitr.PITREnableConfig{ + ArchiveDir: mysqlArchiveDir, + RetentionDays: walRetentionDays, + Compression: walCompress, + } + + log.Info("Enabling MySQL PITR", "archive_dir", mysqlArchiveDir) + + if err := mysqlPitr.Enable(ctx, enableConfig); err != nil { + return fmt.Errorf("enabling PITR: %w", err) + } + + log.Info("✅ MySQL PITR enabled successfully!") + log.Info("") + log.Info("Next steps:") + log.Info("1. Start binlog archiving: dbbackup binlog watch --archive-dir " + mysqlArchiveDir) + log.Info("2. Create a base backup: dbbackup backup single ") + log.Info("3. Binlogs will be archived to: " + mysqlArchiveDir) + log.Info("") + log.Info("To restore to a point in time, use:") + log.Info(" dbbackup restore pitr --target-time '2024-01-15 14:30:00'") + + return nil +} + +// getMySQLBinlogDir attempts to determine the binlog directory from MySQL +func getMySQLBinlogDir(ctx context.Context, db *sql.DB) (string, error) { + var logBinBasename string + err := db.QueryRowContext(ctx, "SELECT @@log_bin_basename").Scan(&logBinBasename) + if err != nil { + return "", err + } + + return filepath.Dir(logBinBasename), nil +} diff --git a/cmd/report.go b/cmd/report.go new file mode 100644 index 0000000..ebea15c --- /dev/null +++ b/cmd/report.go @@ -0,0 +1,316 @@ +package cmd + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "time" + + "dbbackup/internal/catalog" + "dbbackup/internal/report" + + "github.com/spf13/cobra" +) + +var reportCmd = &cobra.Command{ + Use: "report", + Short: "Generate compliance reports", + Long: `Generate compliance reports for various regulatory frameworks. + +Supported frameworks: + - soc2 SOC 2 Type II Trust Service Criteria + - gdpr General Data Protection Regulation + - hipaa Health Insurance Portability and Accountability Act + - pci-dss Payment Card Industry Data Security Standard + - iso27001 ISO 27001 Information Security Management + +Examples: + # Generate SOC2 report for the last 90 days + dbbackup report generate --type soc2 --days 90 + + # Generate HIPAA report as HTML + dbbackup report generate --type hipaa --format html --output report.html + + # Show report summary for current period + dbbackup report summary --type soc2`, +} + +var reportGenerateCmd = &cobra.Command{ + Use: "generate", + Short: "Generate a compliance report", + Long: "Generate a compliance report for a specified framework and time period", + RunE: runReportGenerate, +} + +var reportSummaryCmd = &cobra.Command{ + Use: "summary", + Short: "Show compliance summary", + Long: "Display a quick compliance summary for the specified framework", + RunE: runReportSummary, +} + +var reportListCmd = &cobra.Command{ + Use: "list", + Short: "List available frameworks", + Long: "Display all available compliance frameworks", + RunE: runReportList, +} + +var reportControlsCmd = &cobra.Command{ + Use: "controls [framework]", + Short: "List controls for a framework", + Long: "Display all controls for a specific compliance framework", + Args: cobra.ExactArgs(1), + RunE: runReportControls, +} + +var ( + reportType string + reportDays int + reportStartDate string + reportEndDate string + reportFormat string + reportOutput string + reportCatalog string + reportTitle string + includeEvidence bool +) + +func init() { + rootCmd.AddCommand(reportCmd) + reportCmd.AddCommand(reportGenerateCmd) + reportCmd.AddCommand(reportSummaryCmd) + reportCmd.AddCommand(reportListCmd) + reportCmd.AddCommand(reportControlsCmd) + + // Generate command flags + reportGenerateCmd.Flags().StringVarP(&reportType, "type", "t", "soc2", "Report type (soc2, gdpr, hipaa, pci-dss, iso27001)") + reportGenerateCmd.Flags().IntVarP(&reportDays, "days", "d", 90, "Number of days to include in report") + reportGenerateCmd.Flags().StringVar(&reportStartDate, "start", "", "Start date (YYYY-MM-DD)") + reportGenerateCmd.Flags().StringVar(&reportEndDate, "end", "", "End date (YYYY-MM-DD)") + reportGenerateCmd.Flags().StringVarP(&reportFormat, "format", "f", "markdown", "Output format (json, markdown, html)") + reportGenerateCmd.Flags().StringVarP(&reportOutput, "output", "o", "", "Output file path") + reportGenerateCmd.Flags().StringVar(&reportCatalog, "catalog", "", "Path to backup catalog database") + reportGenerateCmd.Flags().StringVar(&reportTitle, "title", "", "Custom report title") + reportGenerateCmd.Flags().BoolVar(&includeEvidence, "evidence", true, "Include evidence in report") + + // Summary command flags + reportSummaryCmd.Flags().StringVarP(&reportType, "type", "t", "soc2", "Report type") + reportSummaryCmd.Flags().IntVarP(&reportDays, "days", "d", 90, "Number of days to include") + reportSummaryCmd.Flags().StringVar(&reportCatalog, "catalog", "", "Path to backup catalog database") +} + +func runReportGenerate(cmd *cobra.Command, args []string) error { + // Determine time period + var startDate, endDate time.Time + endDate = time.Now() + + if reportStartDate != "" { + parsed, err := time.Parse("2006-01-02", reportStartDate) + if err != nil { + return fmt.Errorf("invalid start date: %w", err) + } + startDate = parsed + } else { + startDate = endDate.AddDate(0, 0, -reportDays) + } + + if reportEndDate != "" { + parsed, err := time.Parse("2006-01-02", reportEndDate) + if err != nil { + return fmt.Errorf("invalid end date: %w", err) + } + endDate = parsed + } + + // Determine report type + rptType := parseReportType(reportType) + if rptType == "" { + return fmt.Errorf("unknown report type: %s", reportType) + } + + // Get catalog path + catalogPath := reportCatalog + if catalogPath == "" { + homeDir, _ := os.UserHomeDir() + catalogPath = filepath.Join(homeDir, ".dbbackup", "catalog.db") + } + + // Open catalog + cat, err := catalog.NewSQLiteCatalog(catalogPath) + if err != nil { + return fmt.Errorf("failed to open catalog: %w", err) + } + defer cat.Close() + + // Configure generator + config := report.ReportConfig{ + Type: rptType, + PeriodStart: startDate, + PeriodEnd: endDate, + CatalogPath: catalogPath, + OutputFormat: parseOutputFormat(reportFormat), + OutputPath: reportOutput, + IncludeEvidence: includeEvidence, + } + + if reportTitle != "" { + config.Title = reportTitle + } + + // Generate report + gen := report.NewGenerator(cat, config) + rpt, err := gen.Generate() + if err != nil { + return fmt.Errorf("failed to generate report: %w", err) + } + + // Get formatter + formatter := report.GetFormatter(config.OutputFormat) + + // Write output + var output *os.File + if reportOutput != "" { + output, err = os.Create(reportOutput) + if err != nil { + return fmt.Errorf("failed to create output file: %w", err) + } + defer output.Close() + } else { + output = os.Stdout + } + + if err := formatter.Format(rpt, output); err != nil { + return fmt.Errorf("failed to format report: %w", err) + } + + if reportOutput != "" { + fmt.Printf("Report generated: %s\n", reportOutput) + fmt.Printf(" Type: %s\n", rpt.Type) + fmt.Printf(" Status: %s %s\n", report.StatusIcon(rpt.Status), rpt.Status) + fmt.Printf(" Score: %.1f%%\n", rpt.Score) + fmt.Printf(" Findings: %d open\n", rpt.Summary.OpenFindings) + } + + return nil +} + +func runReportSummary(cmd *cobra.Command, args []string) error { + endDate := time.Now() + startDate := endDate.AddDate(0, 0, -reportDays) + + rptType := parseReportType(reportType) + if rptType == "" { + return fmt.Errorf("unknown report type: %s", reportType) + } + + // Get catalog path + catalogPath := reportCatalog + if catalogPath == "" { + homeDir, _ := os.UserHomeDir() + catalogPath = filepath.Join(homeDir, ".dbbackup", "catalog.db") + } + + // Open catalog + cat, err := catalog.NewSQLiteCatalog(catalogPath) + if err != nil { + return fmt.Errorf("failed to open catalog: %w", err) + } + defer cat.Close() + + // Configure and generate + config := report.ReportConfig{ + Type: rptType, + PeriodStart: startDate, + PeriodEnd: endDate, + CatalogPath: catalogPath, + } + + gen := report.NewGenerator(cat, config) + rpt, err := gen.Generate() + if err != nil { + return fmt.Errorf("failed to generate report: %w", err) + } + + // Display console summary + formatter := &report.ConsoleFormatter{} + return formatter.Format(rpt, os.Stdout) +} + +func runReportList(cmd *cobra.Command, args []string) error { + fmt.Println("\nAvailable Compliance Frameworks:") + fmt.Println(strings.Repeat("-", 50)) + fmt.Printf(" %-12s %s\n", "soc2", "SOC 2 Type II Trust Service Criteria") + fmt.Printf(" %-12s %s\n", "gdpr", "General Data Protection Regulation (EU)") + fmt.Printf(" %-12s %s\n", "hipaa", "Health Insurance Portability and Accountability Act") + fmt.Printf(" %-12s %s\n", "pci-dss", "Payment Card Industry Data Security Standard") + fmt.Printf(" %-12s %s\n", "iso27001", "ISO 27001 Information Security Management") + fmt.Println() + fmt.Println("Usage: dbbackup report generate --type ") + fmt.Println() + return nil +} + +func runReportControls(cmd *cobra.Command, args []string) error { + rptType := parseReportType(args[0]) + if rptType == "" { + return fmt.Errorf("unknown report type: %s", args[0]) + } + + framework := report.GetFramework(rptType) + if framework == nil { + return fmt.Errorf("no framework defined for: %s", args[0]) + } + + fmt.Printf("\n%s Controls\n", strings.ToUpper(args[0])) + fmt.Println(strings.Repeat("=", 60)) + + for _, cat := range framework { + fmt.Printf("\n%s\n", cat.Name) + fmt.Printf("%s\n", cat.Description) + fmt.Println(strings.Repeat("-", 40)) + + for _, ctrl := range cat.Controls { + fmt.Printf(" [%s] %s\n", ctrl.Reference, ctrl.Name) + fmt.Printf(" %s\n", ctrl.Description) + } + } + + fmt.Println() + return nil +} + +func parseReportType(s string) report.ReportType { + switch strings.ToLower(s) { + case "soc2", "soc-2", "soc2-type2": + return report.ReportSOC2 + case "gdpr": + return report.ReportGDPR + case "hipaa": + return report.ReportHIPAA + case "pci-dss", "pcidss", "pci": + return report.ReportPCIDSS + case "iso27001", "iso-27001", "iso": + return report.ReportISO27001 + case "custom": + return report.ReportCustom + default: + return "" + } +} + +func parseOutputFormat(s string) report.OutputFormat { + switch strings.ToLower(s) { + case "json": + return report.FormatJSON + case "html": + return report.FormatHTML + case "md", "markdown": + return report.FormatMarkdown + case "pdf": + return report.FormatPDF + default: + return report.FormatMarkdown + } +} diff --git a/cmd/rto.go b/cmd/rto.go new file mode 100644 index 0000000..074cafb --- /dev/null +++ b/cmd/rto.go @@ -0,0 +1,458 @@ +package cmd + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "time" + + "dbbackup/internal/catalog" + "dbbackup/internal/rto" + + "github.com/spf13/cobra" +) + +var rtoCmd = &cobra.Command{ + Use: "rto", + Short: "RTO/RPO analysis and monitoring", + Long: `Analyze and monitor Recovery Time Objective (RTO) and +Recovery Point Objective (RPO) metrics. + +RTO: How long to recover from a failure +RPO: How much data you can afford to lose + +Examples: + # Analyze RTO/RPO for all databases + dbbackup rto analyze + + # Analyze specific database + dbbackup rto analyze --database mydb + + # Show summary status + dbbackup rto status + + # Set targets and check compliance + dbbackup rto check --target-rto 4h --target-rpo 1h`, +} + +var rtoAnalyzeCmd = &cobra.Command{ + Use: "analyze", + Short: "Analyze RTO/RPO for databases", + Long: "Perform detailed RTO/RPO analysis based on backup history", + RunE: runRTOAnalyze, +} + +var rtoStatusCmd = &cobra.Command{ + Use: "status", + Short: "Show RTO/RPO status summary", + Long: "Display current RTO/RPO compliance status for all databases", + RunE: runRTOStatus, +} + +var rtoCheckCmd = &cobra.Command{ + Use: "check", + Short: "Check RTO/RPO compliance", + Long: "Check if databases meet RTO/RPO targets", + RunE: runRTOCheck, +} + +var ( + rtoDatabase string + rtoTargetRTO string + rtoTargetRPO string + rtoCatalog string + rtoFormat string + rtoOutput string +) + +func init() { + rootCmd.AddCommand(rtoCmd) + rtoCmd.AddCommand(rtoAnalyzeCmd) + rtoCmd.AddCommand(rtoStatusCmd) + rtoCmd.AddCommand(rtoCheckCmd) + + // Analyze command flags + rtoAnalyzeCmd.Flags().StringVarP(&rtoDatabase, "database", "d", "", "Database to analyze (all if not specified)") + rtoAnalyzeCmd.Flags().StringVar(&rtoTargetRTO, "target-rto", "4h", "Target RTO (e.g., 4h, 30m)") + rtoAnalyzeCmd.Flags().StringVar(&rtoTargetRPO, "target-rpo", "1h", "Target RPO (e.g., 1h, 15m)") + rtoAnalyzeCmd.Flags().StringVar(&rtoCatalog, "catalog", "", "Path to backup catalog") + rtoAnalyzeCmd.Flags().StringVarP(&rtoFormat, "format", "f", "text", "Output format (text, json)") + rtoAnalyzeCmd.Flags().StringVarP(&rtoOutput, "output", "o", "", "Output file") + + // Status command flags + rtoStatusCmd.Flags().StringVar(&rtoCatalog, "catalog", "", "Path to backup catalog") + rtoStatusCmd.Flags().StringVar(&rtoTargetRTO, "target-rto", "4h", "Target RTO") + rtoStatusCmd.Flags().StringVar(&rtoTargetRPO, "target-rpo", "1h", "Target RPO") + + // Check command flags + rtoCheckCmd.Flags().StringVarP(&rtoDatabase, "database", "d", "", "Database to check") + rtoCheckCmd.Flags().StringVar(&rtoTargetRTO, "target-rto", "4h", "Target RTO") + rtoCheckCmd.Flags().StringVar(&rtoTargetRPO, "target-rpo", "1h", "Target RPO") + rtoCheckCmd.Flags().StringVar(&rtoCatalog, "catalog", "", "Path to backup catalog") +} + +func runRTOAnalyze(cmd *cobra.Command, args []string) error { + ctx := context.Background() + + // Parse duration targets + targetRTO, err := time.ParseDuration(rtoTargetRTO) + if err != nil { + return fmt.Errorf("invalid target-rto: %w", err) + } + targetRPO, err := time.ParseDuration(rtoTargetRPO) + if err != nil { + return fmt.Errorf("invalid target-rpo: %w", err) + } + + // Get catalog + cat, err := openRTOCatalog() + if err != nil { + return err + } + defer cat.Close() + + // Create calculator + config := rto.DefaultConfig() + config.TargetRTO = targetRTO + config.TargetRPO = targetRPO + calc := rto.NewCalculator(cat, config) + + var analyses []*rto.Analysis + + if rtoDatabase != "" { + // Analyze single database + analysis, err := calc.Analyze(ctx, rtoDatabase) + if err != nil { + return fmt.Errorf("analysis failed: %w", err) + } + analyses = append(analyses, analysis) + } else { + // Analyze all databases + analyses, err = calc.AnalyzeAll(ctx) + if err != nil { + return fmt.Errorf("analysis failed: %w", err) + } + } + + // Output + if rtoFormat == "json" { + return outputJSON(analyses, rtoOutput) + } + + return outputAnalysisText(analyses) +} + +func runRTOStatus(cmd *cobra.Command, args []string) error { + ctx := context.Background() + + // Parse targets + targetRTO, err := time.ParseDuration(rtoTargetRTO) + if err != nil { + return fmt.Errorf("invalid target-rto: %w", err) + } + targetRPO, err := time.ParseDuration(rtoTargetRPO) + if err != nil { + return fmt.Errorf("invalid target-rpo: %w", err) + } + + // Get catalog + cat, err := openRTOCatalog() + if err != nil { + return err + } + defer cat.Close() + + // Create calculator and analyze all + config := rto.DefaultConfig() + config.TargetRTO = targetRTO + config.TargetRPO = targetRPO + calc := rto.NewCalculator(cat, config) + + analyses, err := calc.AnalyzeAll(ctx) + if err != nil { + return fmt.Errorf("analysis failed: %w", err) + } + + // Create summary + summary := rto.Summarize(analyses) + + // Display status + fmt.Println() + fmt.Println("╔═══════════════════════════════════════════════════════════╗") + fmt.Println("║ RTO/RPO STATUS SUMMARY ║") + fmt.Println("╠═══════════════════════════════════════════════════════════╣") + fmt.Printf("║ Target RTO: %-15s Target RPO: %-15s ║\n", + formatDuration(config.TargetRTO), + formatDuration(config.TargetRPO)) + fmt.Println("╠═══════════════════════════════════════════════════════════╣") + + // Compliance status + rpoRate := 0.0 + rtoRate := 0.0 + fullRate := 0.0 + if summary.TotalDatabases > 0 { + rpoRate = float64(summary.RPOCompliant) / float64(summary.TotalDatabases) * 100 + rtoRate = float64(summary.RTOCompliant) / float64(summary.TotalDatabases) * 100 + fullRate = float64(summary.FullyCompliant) / float64(summary.TotalDatabases) * 100 + } + + fmt.Printf("║ Databases: %-5d ║\n", summary.TotalDatabases) + fmt.Printf("║ RPO Compliant: %-5d (%.0f%%) ║\n", summary.RPOCompliant, rpoRate) + fmt.Printf("║ RTO Compliant: %-5d (%.0f%%) ║\n", summary.RTOCompliant, rtoRate) + fmt.Printf("║ Fully Compliant: %-3d (%.0f%%) ║\n", summary.FullyCompliant, fullRate) + + if summary.CriticalIssues > 0 { + fmt.Printf("║ ⚠️ Critical Issues: %-3d ║\n", summary.CriticalIssues) + } + + fmt.Println("╠═══════════════════════════════════════════════════════════╣") + fmt.Printf("║ Average RPO: %-15s Worst: %-15s ║\n", + formatDuration(summary.AverageRPO), + formatDuration(summary.WorstRPO)) + fmt.Printf("║ Average RTO: %-15s Worst: %-15s ║\n", + formatDuration(summary.AverageRTO), + formatDuration(summary.WorstRTO)) + + if summary.WorstRPODatabase != "" { + fmt.Printf("║ Worst RPO Database: %-38s║\n", summary.WorstRPODatabase) + } + if summary.WorstRTODatabase != "" { + fmt.Printf("║ Worst RTO Database: %-38s║\n", summary.WorstRTODatabase) + } + + fmt.Println("╚═══════════════════════════════════════════════════════════╝") + fmt.Println() + + // Per-database status + if len(analyses) > 0 { + fmt.Println("Database Status:") + fmt.Println(strings.Repeat("-", 70)) + fmt.Printf("%-25s %-12s %-12s %-12s\n", "DATABASE", "RPO", "RTO", "STATUS") + fmt.Println(strings.Repeat("-", 70)) + + for _, a := range analyses { + status := "✅" + if !a.RPOCompliant || !a.RTOCompliant { + status = "❌" + } + + rpoStr := formatDuration(a.CurrentRPO) + rtoStr := formatDuration(a.CurrentRTO) + + if !a.RPOCompliant { + rpoStr = "⚠️ " + rpoStr + } + if !a.RTOCompliant { + rtoStr = "⚠️ " + rtoStr + } + + fmt.Printf("%-25s %-12s %-12s %s\n", + truncateRTO(a.Database, 24), + rpoStr, + rtoStr, + status) + } + fmt.Println(strings.Repeat("-", 70)) + } + + return nil +} + +func runRTOCheck(cmd *cobra.Command, args []string) error { + ctx := context.Background() + + // Parse targets + targetRTO, err := time.ParseDuration(rtoTargetRTO) + if err != nil { + return fmt.Errorf("invalid target-rto: %w", err) + } + targetRPO, err := time.ParseDuration(rtoTargetRPO) + if err != nil { + return fmt.Errorf("invalid target-rpo: %w", err) + } + + // Get catalog + cat, err := openRTOCatalog() + if err != nil { + return err + } + defer cat.Close() + + // Create calculator + config := rto.DefaultConfig() + config.TargetRTO = targetRTO + config.TargetRPO = targetRPO + calc := rto.NewCalculator(cat, config) + + var analyses []*rto.Analysis + + if rtoDatabase != "" { + analysis, err := calc.Analyze(ctx, rtoDatabase) + if err != nil { + return fmt.Errorf("analysis failed: %w", err) + } + analyses = append(analyses, analysis) + } else { + analyses, err = calc.AnalyzeAll(ctx) + if err != nil { + return fmt.Errorf("analysis failed: %w", err) + } + } + + // Check compliance + exitCode := 0 + for _, a := range analyses { + if !a.RPOCompliant { + fmt.Printf("❌ %s: RPO violation - current %s exceeds target %s\n", + a.Database, + formatDuration(a.CurrentRPO), + formatDuration(config.TargetRPO)) + exitCode = 1 + } + if !a.RTOCompliant { + fmt.Printf("❌ %s: RTO violation - estimated %s exceeds target %s\n", + a.Database, + formatDuration(a.CurrentRTO), + formatDuration(config.TargetRTO)) + exitCode = 1 + } + if a.RPOCompliant && a.RTOCompliant { + fmt.Printf("✅ %s: Compliant (RPO: %s, RTO: %s)\n", + a.Database, + formatDuration(a.CurrentRPO), + formatDuration(a.CurrentRTO)) + } + } + + if exitCode != 0 { + os.Exit(exitCode) + } + + return nil +} + +func openRTOCatalog() (*catalog.SQLiteCatalog, error) { + catalogPath := rtoCatalog + if catalogPath == "" { + homeDir, _ := os.UserHomeDir() + catalogPath = filepath.Join(homeDir, ".dbbackup", "catalog.db") + } + + cat, err := catalog.NewSQLiteCatalog(catalogPath) + if err != nil { + return nil, fmt.Errorf("failed to open catalog: %w", err) + } + + return cat, nil +} + +func outputJSON(data interface{}, outputPath string) error { + jsonData, err := json.MarshalIndent(data, "", " ") + if err != nil { + return err + } + + if outputPath != "" { + return os.WriteFile(outputPath, jsonData, 0644) + } + + fmt.Println(string(jsonData)) + return nil +} + +func outputAnalysisText(analyses []*rto.Analysis) error { + for _, a := range analyses { + fmt.Println() + fmt.Println(strings.Repeat("=", 60)) + fmt.Printf(" Database: %s\n", a.Database) + fmt.Println(strings.Repeat("=", 60)) + + // Status + rpoStatus := "✅ Compliant" + if !a.RPOCompliant { + rpoStatus = "❌ Violation" + } + rtoStatus := "✅ Compliant" + if !a.RTOCompliant { + rtoStatus = "❌ Violation" + } + + fmt.Println() + fmt.Println(" Recovery Objectives:") + fmt.Println(strings.Repeat("-", 50)) + fmt.Printf(" RPO (Current): %-15s Target: %s\n", + formatDuration(a.CurrentRPO), formatDuration(a.TargetRPO)) + fmt.Printf(" RPO Status: %s\n", rpoStatus) + fmt.Printf(" RTO (Estimated): %-14s Target: %s\n", + formatDuration(a.CurrentRTO), formatDuration(a.TargetRTO)) + fmt.Printf(" RTO Status: %s\n", rtoStatus) + + if a.LastBackup != nil { + fmt.Printf(" Last Backup: %s\n", a.LastBackup.Format("2006-01-02 15:04:05")) + } + if a.BackupInterval > 0 { + fmt.Printf(" Backup Interval: %s\n", formatDuration(a.BackupInterval)) + } + + // RTO Breakdown + fmt.Println() + fmt.Println(" RTO Breakdown:") + fmt.Println(strings.Repeat("-", 50)) + b := a.RTOBreakdown + fmt.Printf(" Detection: %s\n", formatDuration(b.DetectionTime)) + fmt.Printf(" Decision: %s\n", formatDuration(b.DecisionTime)) + if b.DownloadTime > 0 { + fmt.Printf(" Download: %s\n", formatDuration(b.DownloadTime)) + } + fmt.Printf(" Restore: %s\n", formatDuration(b.RestoreTime)) + fmt.Printf(" Startup: %s\n", formatDuration(b.StartupTime)) + fmt.Printf(" Validation: %s\n", formatDuration(b.ValidationTime)) + fmt.Printf(" Switchover: %s\n", formatDuration(b.SwitchoverTime)) + fmt.Println(strings.Repeat("-", 30)) + fmt.Printf(" Total: %s\n", formatDuration(b.TotalTime)) + + // Recommendations + if len(a.Recommendations) > 0 { + fmt.Println() + fmt.Println(" Recommendations:") + fmt.Println(strings.Repeat("-", 50)) + for _, r := range a.Recommendations { + icon := "💡" + switch r.Priority { + case rto.PriorityCritical: + icon = "🔴" + case rto.PriorityHigh: + icon = "🟠" + case rto.PriorityMedium: + icon = "🟡" + } + fmt.Printf(" %s [%s] %s\n", icon, r.Priority, r.Title) + fmt.Printf(" %s\n", r.Description) + } + } + } + + return nil +} + +func formatDuration(d time.Duration) string { + if d < time.Minute { + return fmt.Sprintf("%.0fs", d.Seconds()) + } + if d < time.Hour { + return fmt.Sprintf("%.0fm", d.Minutes()) + } + hours := int(d.Hours()) + mins := int(d.Minutes()) - hours*60 + return fmt.Sprintf("%dh %dm", hours, mins) +} + +func truncateRTO(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen-3] + "..." +} diff --git a/go.mod b/go.mod index d068774..f1ffb9b 100755 --- a/go.mod +++ b/go.mod @@ -79,6 +79,7 @@ require ( github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-localereader v0.0.1 // indirect github.com/mattn/go-runewidth v0.0.16 // indirect + github.com/mattn/go-sqlite3 v1.14.32 // indirect github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect github.com/muesli/cancelreader v0.2.2 // indirect github.com/muesli/termenv v0.16.0 // indirect diff --git a/go.sum b/go.sum index 4263a9d..83bcc27 100755 --- a/go.sum +++ b/go.sum @@ -153,6 +153,8 @@ github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2J github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88= github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= +github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI= github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo= github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= diff --git a/internal/catalog/catalog.go b/internal/catalog/catalog.go new file mode 100644 index 0000000..193a1fd --- /dev/null +++ b/internal/catalog/catalog.go @@ -0,0 +1,188 @@ +// Package catalog provides backup catalog management with SQLite storage +package catalog + +import ( + "context" + "fmt" + "time" +) + +// Entry represents a single backup in the catalog +type Entry struct { + ID int64 `json:"id"` + Database string `json:"database"` + DatabaseType string `json:"database_type"` // postgresql, mysql, mariadb + Host string `json:"host"` + Port int `json:"port"` + BackupPath string `json:"backup_path"` + BackupType string `json:"backup_type"` // full, incremental + SizeBytes int64 `json:"size_bytes"` + SHA256 string `json:"sha256"` + Compression string `json:"compression"` + Encrypted bool `json:"encrypted"` + CreatedAt time.Time `json:"created_at"` + Duration float64 `json:"duration_seconds"` + Status BackupStatus `json:"status"` + VerifiedAt *time.Time `json:"verified_at,omitempty"` + VerifyValid *bool `json:"verify_valid,omitempty"` + DrillTestedAt *time.Time `json:"drill_tested_at,omitempty"` + DrillSuccess *bool `json:"drill_success,omitempty"` + CloudLocation string `json:"cloud_location,omitempty"` + RetentionPolicy string `json:"retention_policy,omitempty"` // daily, weekly, monthly, yearly + Tags map[string]string `json:"tags,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` +} + +// BackupStatus represents the state of a backup +type BackupStatus string + +const ( + StatusCompleted BackupStatus = "completed" + StatusFailed BackupStatus = "failed" + StatusVerified BackupStatus = "verified" + StatusCorrupted BackupStatus = "corrupted" + StatusDeleted BackupStatus = "deleted" + StatusArchived BackupStatus = "archived" +) + +// Gap represents a detected backup gap +type Gap struct { + Database string `json:"database"` + GapStart time.Time `json:"gap_start"` + GapEnd time.Time `json:"gap_end"` + Duration time.Duration `json:"duration"` + ExpectedAt time.Time `json:"expected_at"` + Description string `json:"description"` + Severity GapSeverity `json:"severity"` +} + +// GapSeverity indicates how serious a backup gap is +type GapSeverity string + +const ( + SeverityInfo GapSeverity = "info" // Gap within tolerance + SeverityWarning GapSeverity = "warning" // Gap exceeds expected interval + SeverityCritical GapSeverity = "critical" // Gap exceeds RPO +) + +// Stats contains backup statistics +type Stats struct { + TotalBackups int64 `json:"total_backups"` + TotalSize int64 `json:"total_size_bytes"` + TotalSizeHuman string `json:"total_size_human"` + OldestBackup *time.Time `json:"oldest_backup,omitempty"` + NewestBackup *time.Time `json:"newest_backup,omitempty"` + ByDatabase map[string]int64 `json:"by_database"` + ByType map[string]int64 `json:"by_type"` + ByStatus map[string]int64 `json:"by_status"` + VerifiedCount int64 `json:"verified_count"` + DrillTestedCount int64 `json:"drill_tested_count"` + AvgDuration float64 `json:"avg_duration_seconds"` + AvgSize int64 `json:"avg_size_bytes"` + GapsDetected int `json:"gaps_detected"` +} + +// SearchQuery represents search criteria for catalog entries +type SearchQuery struct { + Database string // Filter by database name (supports wildcards) + DatabaseType string // Filter by database type + Host string // Filter by host + Status string // Filter by status + StartDate *time.Time // Backups after this date + EndDate *time.Time // Backups before this date + MinSize int64 // Minimum size in bytes + MaxSize int64 // Maximum size in bytes + BackupType string // full, incremental + Encrypted *bool // Filter by encryption status + Verified *bool // Filter by verification status + DrillTested *bool // Filter by drill test status + Limit int // Max results (0 = no limit) + Offset int // Offset for pagination + OrderBy string // Field to order by + OrderDesc bool // Order descending +} + +// GapDetectionConfig configures gap detection +type GapDetectionConfig struct { + ExpectedInterval time.Duration // Expected backup interval (e.g., 24h) + Tolerance time.Duration // Allowed variance (e.g., 1h) + RPOThreshold time.Duration // Critical threshold (RPO) + StartDate *time.Time // Start of analysis window + EndDate *time.Time // End of analysis window +} + +// Catalog defines the interface for backup catalog operations +type Catalog interface { + // Entry management + Add(ctx context.Context, entry *Entry) error + Update(ctx context.Context, entry *Entry) error + Delete(ctx context.Context, id int64) error + Get(ctx context.Context, id int64) (*Entry, error) + GetByPath(ctx context.Context, path string) (*Entry, error) + + // Search and listing + Search(ctx context.Context, query *SearchQuery) ([]*Entry, error) + List(ctx context.Context, database string, limit int) ([]*Entry, error) + ListDatabases(ctx context.Context) ([]string, error) + Count(ctx context.Context, query *SearchQuery) (int64, error) + + // Statistics + Stats(ctx context.Context) (*Stats, error) + StatsByDatabase(ctx context.Context, database string) (*Stats, error) + + // Gap detection + DetectGaps(ctx context.Context, database string, config *GapDetectionConfig) ([]*Gap, error) + DetectAllGaps(ctx context.Context, config *GapDetectionConfig) (map[string][]*Gap, error) + + // Verification tracking + MarkVerified(ctx context.Context, id int64, valid bool) error + MarkDrillTested(ctx context.Context, id int64, success bool) error + + // Sync with filesystem + SyncFromDirectory(ctx context.Context, dir string) (*SyncResult, error) + SyncFromCloud(ctx context.Context, provider, bucket, prefix string) (*SyncResult, error) + + // Maintenance + Prune(ctx context.Context, before time.Time) (int, error) + Vacuum(ctx context.Context) error + Close() error +} + +// SyncResult contains results from a catalog sync operation +type SyncResult struct { + Added int `json:"added"` + Updated int `json:"updated"` + Removed int `json:"removed"` + Errors int `json:"errors"` + Duration float64 `json:"duration_seconds"` + Details []string `json:"details,omitempty"` +} + +// FormatSize formats bytes as human-readable string +func FormatSize(bytes int64) string { + const unit = 1024 + if bytes < unit { + return fmt.Sprintf("%d B", bytes) + } + div, exp := int64(unit), 0 + for n := bytes / unit; n >= unit; n /= unit { + div *= unit + exp++ + } + return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(div), "KMGTPE"[exp]) +} + +// FormatDuration formats duration as human-readable string +func FormatDuration(d time.Duration) string { + if d < time.Minute { + return fmt.Sprintf("%.0fs", d.Seconds()) + } + if d < time.Hour { + mins := int(d.Minutes()) + secs := int(d.Seconds()) - mins*60 + return fmt.Sprintf("%dm %ds", mins, secs) + } + hours := int(d.Hours()) + mins := int(d.Minutes()) - hours*60 + return fmt.Sprintf("%dh %dm", hours, mins) +} diff --git a/internal/catalog/catalog_test.go b/internal/catalog/catalog_test.go new file mode 100644 index 0000000..aba96ba --- /dev/null +++ b/internal/catalog/catalog_test.go @@ -0,0 +1,308 @@ +package catalog + +import ( + "context" + "fmt" + "os" + "path/filepath" + "testing" + "time" +) + +func TestSQLiteCatalog(t *testing.T) { + // Create temp directory for test database + tmpDir, err := os.MkdirTemp("", "catalog_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + dbPath := filepath.Join(tmpDir, "test_catalog.db") + + // Test creation + cat, err := NewSQLiteCatalog(dbPath) + if err != nil { + t.Fatalf("Failed to create catalog: %v", err) + } + defer cat.Close() + + ctx := context.Background() + + // Test Add + entry := &Entry{ + Database: "testdb", + DatabaseType: "postgresql", + Host: "localhost", + Port: 5432, + BackupPath: "/backups/testdb_20240115.dump.gz", + BackupType: "full", + SizeBytes: 1024 * 1024 * 100, // 100 MB + SHA256: "abc123def456", + Compression: "gzip", + Encrypted: false, + CreatedAt: time.Now().Add(-24 * time.Hour), + Duration: 45.5, + Status: StatusCompleted, + } + + err = cat.Add(ctx, entry) + if err != nil { + t.Fatalf("Failed to add entry: %v", err) + } + + if entry.ID == 0 { + t.Error("Expected entry ID to be set after Add") + } + + // Test Get + retrieved, err := cat.Get(ctx, entry.ID) + if err != nil { + t.Fatalf("Failed to get entry: %v", err) + } + + if retrieved == nil { + t.Fatal("Expected to retrieve entry, got nil") + } + + if retrieved.Database != "testdb" { + t.Errorf("Expected database 'testdb', got '%s'", retrieved.Database) + } + + if retrieved.SizeBytes != entry.SizeBytes { + t.Errorf("Expected size %d, got %d", entry.SizeBytes, retrieved.SizeBytes) + } + + // Test GetByPath + byPath, err := cat.GetByPath(ctx, entry.BackupPath) + if err != nil { + t.Fatalf("Failed to get by path: %v", err) + } + + if byPath == nil || byPath.ID != entry.ID { + t.Error("GetByPath returned wrong entry") + } + + // Test List + entries, err := cat.List(ctx, "testdb", 10) + if err != nil { + t.Fatalf("Failed to list entries: %v", err) + } + + if len(entries) != 1 { + t.Errorf("Expected 1 entry, got %d", len(entries)) + } + + // Test ListDatabases + databases, err := cat.ListDatabases(ctx) + if err != nil { + t.Fatalf("Failed to list databases: %v", err) + } + + if len(databases) != 1 || databases[0] != "testdb" { + t.Errorf("Expected ['testdb'], got %v", databases) + } + + // Test Stats + stats, err := cat.Stats(ctx) + if err != nil { + t.Fatalf("Failed to get stats: %v", err) + } + + if stats.TotalBackups != 1 { + t.Errorf("Expected 1 total backup, got %d", stats.TotalBackups) + } + + if stats.TotalSize != entry.SizeBytes { + t.Errorf("Expected size %d, got %d", entry.SizeBytes, stats.TotalSize) + } + + // Test MarkVerified + err = cat.MarkVerified(ctx, entry.ID, true) + if err != nil { + t.Fatalf("Failed to mark verified: %v", err) + } + + verified, _ := cat.Get(ctx, entry.ID) + if verified.VerifiedAt == nil { + t.Error("Expected VerifiedAt to be set") + } + if verified.VerifyValid == nil || !*verified.VerifyValid { + t.Error("Expected VerifyValid to be true") + } + + // Test Update + entry.SizeBytes = 200 * 1024 * 1024 // 200 MB + err = cat.Update(ctx, entry) + if err != nil { + t.Fatalf("Failed to update entry: %v", err) + } + + updated, _ := cat.Get(ctx, entry.ID) + if updated.SizeBytes != entry.SizeBytes { + t.Errorf("Update failed: expected size %d, got %d", entry.SizeBytes, updated.SizeBytes) + } + + // Test Search with filters + query := &SearchQuery{ + Database: "testdb", + Limit: 10, + OrderBy: "created_at", + OrderDesc: true, + } + + results, err := cat.Search(ctx, query) + if err != nil { + t.Fatalf("Search failed: %v", err) + } + + if len(results) != 1 { + t.Errorf("Expected 1 result, got %d", len(results)) + } + + // Test Search with wildcards + query.Database = "test*" + results, err = cat.Search(ctx, query) + if err != nil { + t.Fatalf("Wildcard search failed: %v", err) + } + + if len(results) != 1 { + t.Errorf("Expected 1 result from wildcard search, got %d", len(results)) + } + + // Test Count + count, err := cat.Count(ctx, &SearchQuery{Database: "testdb"}) + if err != nil { + t.Fatalf("Count failed: %v", err) + } + + if count != 1 { + t.Errorf("Expected count 1, got %d", count) + } + + // Test Delete + err = cat.Delete(ctx, entry.ID) + if err != nil { + t.Fatalf("Failed to delete entry: %v", err) + } + + deleted, _ := cat.Get(ctx, entry.ID) + if deleted != nil { + t.Error("Expected entry to be deleted") + } +} + +func TestGapDetection(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "catalog_gaps_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + dbPath := filepath.Join(tmpDir, "test_catalog.db") + cat, err := NewSQLiteCatalog(dbPath) + if err != nil { + t.Fatalf("Failed to create catalog: %v", err) + } + defer cat.Close() + + ctx := context.Background() + + // Add backups with varying intervals + now := time.Now() + backups := []time.Time{ + now.Add(-7 * 24 * time.Hour), // 7 days ago + now.Add(-6 * 24 * time.Hour), // 6 days ago (OK) + now.Add(-5 * 24 * time.Hour), // 5 days ago (OK) + // Missing 4 days ago - GAP + now.Add(-3 * 24 * time.Hour), // 3 days ago + now.Add(-2 * 24 * time.Hour), // 2 days ago (OK) + // Missing 1 day ago and today - GAP to now + } + + for i, ts := range backups { + entry := &Entry{ + Database: "gaptest", + DatabaseType: "postgresql", + BackupPath: filepath.Join(tmpDir, fmt.Sprintf("backup_%d.dump", i)), + BackupType: "full", + CreatedAt: ts, + Status: StatusCompleted, + } + cat.Add(ctx, entry) + } + + // Detect gaps with 24h expected interval + config := &GapDetectionConfig{ + ExpectedInterval: 24 * time.Hour, + Tolerance: 2 * time.Hour, + RPOThreshold: 48 * time.Hour, + } + + gaps, err := cat.DetectGaps(ctx, "gaptest", config) + if err != nil { + t.Fatalf("Gap detection failed: %v", err) + } + + // Should detect at least 2 gaps: + // 1. Between 5 days ago and 3 days ago (missing 4 days ago) + // 2. Between 2 days ago and now (missing recent backups) + if len(gaps) < 2 { + t.Errorf("Expected at least 2 gaps, got %d", len(gaps)) + } + + // Check gap severities + hasCritical := false + for _, gap := range gaps { + if gap.Severity == SeverityCritical { + hasCritical = true + } + if gap.Duration < config.ExpectedInterval { + t.Errorf("Gap duration %v is less than expected interval", gap.Duration) + } + } + + // The gap from 2 days ago to now should be critical (>48h) + if !hasCritical { + t.Log("Note: Expected at least one critical gap") + } +} + +func TestFormatSize(t *testing.T) { + tests := []struct { + bytes int64 + expected string + }{ + {0, "0 B"}, + {500, "500 B"}, + {1024, "1.0 KB"}, + {1024 * 1024, "1.0 MB"}, + {1024 * 1024 * 1024, "1.0 GB"}, + {1024 * 1024 * 1024 * 1024, "1.0 TB"}, + } + + for _, test := range tests { + result := FormatSize(test.bytes) + if result != test.expected { + t.Errorf("FormatSize(%d) = %s, expected %s", test.bytes, result, test.expected) + } + } +} + +func TestFormatDuration(t *testing.T) { + tests := []struct { + duration time.Duration + expected string + }{ + {30 * time.Second, "30s"}, + {90 * time.Second, "1m 30s"}, + {2 * time.Hour, "2h 0m"}, + } + + for _, test := range tests { + result := FormatDuration(test.duration) + if result != test.expected { + t.Errorf("FormatDuration(%v) = %s, expected %s", test.duration, result, test.expected) + } + } +} diff --git a/internal/catalog/gaps.go b/internal/catalog/gaps.go new file mode 100644 index 0000000..3c72afa --- /dev/null +++ b/internal/catalog/gaps.go @@ -0,0 +1,299 @@ +// Package catalog - Gap detection for backup schedules +package catalog + +import ( + "context" + "sort" + "time" +) + +// DetectGaps analyzes backup history and finds gaps in the schedule +func (c *SQLiteCatalog) DetectGaps(ctx context.Context, database string, config *GapDetectionConfig) ([]*Gap, error) { + if config == nil { + config = &GapDetectionConfig{ + ExpectedInterval: 24 * time.Hour, + Tolerance: time.Hour, + RPOThreshold: 48 * time.Hour, + } + } + + // Get all backups for this database, ordered by time + query := &SearchQuery{ + Database: database, + Status: string(StatusCompleted), + OrderBy: "created_at", + OrderDesc: false, + } + + if config.StartDate != nil { + query.StartDate = config.StartDate + } + if config.EndDate != nil { + query.EndDate = config.EndDate + } + + entries, err := c.Search(ctx, query) + if err != nil { + return nil, err + } + + if len(entries) < 2 { + return nil, nil // Not enough backups to detect gaps + } + + var gaps []*Gap + + for i := 1; i < len(entries); i++ { + prev := entries[i-1] + curr := entries[i] + + actualInterval := curr.CreatedAt.Sub(prev.CreatedAt) + expectedWithTolerance := config.ExpectedInterval + config.Tolerance + + if actualInterval > expectedWithTolerance { + gap := &Gap{ + Database: database, + GapStart: prev.CreatedAt, + GapEnd: curr.CreatedAt, + Duration: actualInterval, + ExpectedAt: prev.CreatedAt.Add(config.ExpectedInterval), + } + + // Determine severity + if actualInterval > config.RPOThreshold { + gap.Severity = SeverityCritical + gap.Description = "CRITICAL: Gap exceeds RPO threshold" + } else if actualInterval > config.ExpectedInterval*2 { + gap.Severity = SeverityWarning + gap.Description = "WARNING: Gap exceeds 2x expected interval" + } else { + gap.Severity = SeverityInfo + gap.Description = "INFO: Gap exceeds expected interval" + } + + gaps = append(gaps, gap) + } + } + + // Check for gap from last backup to now + lastBackup := entries[len(entries)-1] + now := time.Now() + if config.EndDate != nil { + now = *config.EndDate + } + + sinceLastBackup := now.Sub(lastBackup.CreatedAt) + if sinceLastBackup > config.ExpectedInterval+config.Tolerance { + gap := &Gap{ + Database: database, + GapStart: lastBackup.CreatedAt, + GapEnd: now, + Duration: sinceLastBackup, + ExpectedAt: lastBackup.CreatedAt.Add(config.ExpectedInterval), + } + + if sinceLastBackup > config.RPOThreshold { + gap.Severity = SeverityCritical + gap.Description = "CRITICAL: No backup since " + FormatDuration(sinceLastBackup) + } else if sinceLastBackup > config.ExpectedInterval*2 { + gap.Severity = SeverityWarning + gap.Description = "WARNING: No backup since " + FormatDuration(sinceLastBackup) + } else { + gap.Severity = SeverityInfo + gap.Description = "INFO: Backup overdue by " + FormatDuration(sinceLastBackup-config.ExpectedInterval) + } + + gaps = append(gaps, gap) + } + + return gaps, nil +} + +// DetectAllGaps analyzes all databases for backup gaps +func (c *SQLiteCatalog) DetectAllGaps(ctx context.Context, config *GapDetectionConfig) (map[string][]*Gap, error) { + databases, err := c.ListDatabases(ctx) + if err != nil { + return nil, err + } + + allGaps := make(map[string][]*Gap) + + for _, db := range databases { + gaps, err := c.DetectGaps(ctx, db, config) + if err != nil { + continue // Skip errors for individual databases + } + if len(gaps) > 0 { + allGaps[db] = gaps + } + } + + return allGaps, nil +} + +// BackupFrequencyAnalysis provides analysis of backup frequency +type BackupFrequencyAnalysis struct { + Database string `json:"database"` + TotalBackups int `json:"total_backups"` + AnalysisPeriod time.Duration `json:"analysis_period"` + AverageInterval time.Duration `json:"average_interval"` + MinInterval time.Duration `json:"min_interval"` + MaxInterval time.Duration `json:"max_interval"` + StdDeviation time.Duration `json:"std_deviation"` + Regularity float64 `json:"regularity"` // 0-1, higher is more regular + GapsDetected int `json:"gaps_detected"` + MissedBackups int `json:"missed_backups"` // Estimated based on expected interval +} + +// AnalyzeFrequency analyzes backup frequency for a database +func (c *SQLiteCatalog) AnalyzeFrequency(ctx context.Context, database string, expectedInterval time.Duration) (*BackupFrequencyAnalysis, error) { + query := &SearchQuery{ + Database: database, + Status: string(StatusCompleted), + OrderBy: "created_at", + OrderDesc: false, + } + + entries, err := c.Search(ctx, query) + if err != nil { + return nil, err + } + + if len(entries) < 2 { + return &BackupFrequencyAnalysis{ + Database: database, + TotalBackups: len(entries), + }, nil + } + + analysis := &BackupFrequencyAnalysis{ + Database: database, + TotalBackups: len(entries), + } + + // Calculate intervals + var intervals []time.Duration + for i := 1; i < len(entries); i++ { + interval := entries[i].CreatedAt.Sub(entries[i-1].CreatedAt) + intervals = append(intervals, interval) + } + + analysis.AnalysisPeriod = entries[len(entries)-1].CreatedAt.Sub(entries[0].CreatedAt) + + // Calculate min, max, average + sort.Slice(intervals, func(i, j int) bool { + return intervals[i] < intervals[j] + }) + + analysis.MinInterval = intervals[0] + analysis.MaxInterval = intervals[len(intervals)-1] + + var total time.Duration + for _, interval := range intervals { + total += interval + } + analysis.AverageInterval = total / time.Duration(len(intervals)) + + // Calculate standard deviation + var sumSquares float64 + avgNanos := float64(analysis.AverageInterval.Nanoseconds()) + for _, interval := range intervals { + diff := float64(interval.Nanoseconds()) - avgNanos + sumSquares += diff * diff + } + variance := sumSquares / float64(len(intervals)) + analysis.StdDeviation = time.Duration(int64(variance)) // Simplified + + // Calculate regularity score (lower deviation = higher regularity) + if analysis.AverageInterval > 0 { + deviationRatio := float64(analysis.StdDeviation) / float64(analysis.AverageInterval) + analysis.Regularity = 1.0 - min(deviationRatio, 1.0) + } + + // Detect gaps and missed backups + config := &GapDetectionConfig{ + ExpectedInterval: expectedInterval, + Tolerance: expectedInterval / 4, + RPOThreshold: expectedInterval * 2, + } + + gaps, _ := c.DetectGaps(ctx, database, config) + analysis.GapsDetected = len(gaps) + + // Estimate missed backups + if expectedInterval > 0 { + expectedBackups := int(analysis.AnalysisPeriod / expectedInterval) + if expectedBackups > analysis.TotalBackups { + analysis.MissedBackups = expectedBackups - analysis.TotalBackups + } + } + + return analysis, nil +} + +// RecoveryPointObjective calculates the current RPO status +type RPOStatus struct { + Database string `json:"database"` + LastBackup time.Time `json:"last_backup"` + TimeSinceBackup time.Duration `json:"time_since_backup"` + TargetRPO time.Duration `json:"target_rpo"` + CurrentRPO time.Duration `json:"current_rpo"` + RPOMet bool `json:"rpo_met"` + NextBackupDue time.Time `json:"next_backup_due"` + BackupsIn24Hours int `json:"backups_in_24h"` + BackupsIn7Days int `json:"backups_in_7d"` +} + +// CalculateRPOStatus calculates RPO status for a database +func (c *SQLiteCatalog) CalculateRPOStatus(ctx context.Context, database string, targetRPO time.Duration) (*RPOStatus, error) { + status := &RPOStatus{ + Database: database, + TargetRPO: targetRPO, + } + + // Get most recent backup + entries, err := c.List(ctx, database, 1) + if err != nil { + return nil, err + } + + if len(entries) == 0 { + status.RPOMet = false + status.CurrentRPO = time.Duration(0) + return status, nil + } + + status.LastBackup = entries[0].CreatedAt + status.TimeSinceBackup = time.Since(entries[0].CreatedAt) + status.CurrentRPO = status.TimeSinceBackup + status.RPOMet = status.TimeSinceBackup <= targetRPO + status.NextBackupDue = entries[0].CreatedAt.Add(targetRPO) + + // Count backups in time windows + now := time.Now() + last24h := now.Add(-24 * time.Hour) + last7d := now.Add(-7 * 24 * time.Hour) + + count24h, _ := c.Count(ctx, &SearchQuery{ + Database: database, + StartDate: &last24h, + Status: string(StatusCompleted), + }) + count7d, _ := c.Count(ctx, &SearchQuery{ + Database: database, + StartDate: &last7d, + Status: string(StatusCompleted), + }) + + status.BackupsIn24Hours = int(count24h) + status.BackupsIn7Days = int(count7d) + + return status, nil +} + +func min(a, b float64) float64 { + if a < b { + return a + } + return b +} diff --git a/internal/catalog/sqlite.go b/internal/catalog/sqlite.go new file mode 100644 index 0000000..2287d21 --- /dev/null +++ b/internal/catalog/sqlite.go @@ -0,0 +1,632 @@ +// Package catalog - SQLite storage implementation +package catalog + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "time" + + _ "github.com/mattn/go-sqlite3" +) + +// SQLiteCatalog implements Catalog interface with SQLite storage +type SQLiteCatalog struct { + db *sql.DB + path string +} + +// NewSQLiteCatalog creates a new SQLite-backed catalog +func NewSQLiteCatalog(dbPath string) (*SQLiteCatalog, error) { + // Ensure directory exists + dir := filepath.Dir(dbPath) + if err := os.MkdirAll(dir, 0755); err != nil { + return nil, fmt.Errorf("failed to create catalog directory: %w", err) + } + + db, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=ON") + if err != nil { + return nil, fmt.Errorf("failed to open catalog database: %w", err) + } + + catalog := &SQLiteCatalog{ + db: db, + path: dbPath, + } + + if err := catalog.initialize(); err != nil { + db.Close() + return nil, err + } + + return catalog, nil +} + +// initialize creates the database schema +func (c *SQLiteCatalog) initialize() error { + schema := ` + CREATE TABLE IF NOT EXISTS backups ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + database TEXT NOT NULL, + database_type TEXT NOT NULL, + host TEXT, + port INTEGER, + backup_path TEXT NOT NULL UNIQUE, + backup_type TEXT DEFAULT 'full', + size_bytes INTEGER, + sha256 TEXT, + compression TEXT, + encrypted INTEGER DEFAULT 0, + created_at DATETIME NOT NULL, + duration REAL, + status TEXT DEFAULT 'completed', + verified_at DATETIME, + verify_valid INTEGER, + drill_tested_at DATETIME, + drill_success INTEGER, + cloud_location TEXT, + retention_policy TEXT, + tags TEXT, + metadata TEXT, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP + ); + + CREATE INDEX IF NOT EXISTS idx_backups_database ON backups(database); + CREATE INDEX IF NOT EXISTS idx_backups_created_at ON backups(created_at); + CREATE INDEX IF NOT EXISTS idx_backups_status ON backups(status); + CREATE INDEX IF NOT EXISTS idx_backups_host ON backups(host); + CREATE INDEX IF NOT EXISTS idx_backups_database_type ON backups(database_type); + + CREATE TABLE IF NOT EXISTS catalog_meta ( + key TEXT PRIMARY KEY, + value TEXT, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP + ); + + -- Store schema version for migrations + INSERT OR IGNORE INTO catalog_meta (key, value) VALUES ('schema_version', '1'); + ` + + _, err := c.db.Exec(schema) + if err != nil { + return fmt.Errorf("failed to initialize schema: %w", err) + } + + return nil +} + +// Add inserts a new backup entry +func (c *SQLiteCatalog) Add(ctx context.Context, entry *Entry) error { + tagsJSON, _ := json.Marshal(entry.Tags) + metaJSON, _ := json.Marshal(entry.Metadata) + + result, err := c.db.ExecContext(ctx, ` + INSERT INTO backups ( + database, database_type, host, port, backup_path, backup_type, + size_bytes, sha256, compression, encrypted, created_at, duration, + status, cloud_location, retention_policy, tags, metadata + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + `, + entry.Database, entry.DatabaseType, entry.Host, entry.Port, + entry.BackupPath, entry.BackupType, entry.SizeBytes, entry.SHA256, + entry.Compression, entry.Encrypted, entry.CreatedAt, entry.Duration, + entry.Status, entry.CloudLocation, entry.RetentionPolicy, + string(tagsJSON), string(metaJSON), + ) + if err != nil { + return fmt.Errorf("failed to add catalog entry: %w", err) + } + + id, _ := result.LastInsertId() + entry.ID = id + return nil +} + +// Update updates an existing backup entry +func (c *SQLiteCatalog) Update(ctx context.Context, entry *Entry) error { + tagsJSON, _ := json.Marshal(entry.Tags) + metaJSON, _ := json.Marshal(entry.Metadata) + + _, err := c.db.ExecContext(ctx, ` + UPDATE backups SET + database = ?, database_type = ?, host = ?, port = ?, + backup_type = ?, size_bytes = ?, sha256 = ?, compression = ?, + encrypted = ?, duration = ?, status = ?, verified_at = ?, + verify_valid = ?, drill_tested_at = ?, drill_success = ?, + cloud_location = ?, retention_policy = ?, tags = ?, metadata = ?, + updated_at = CURRENT_TIMESTAMP + WHERE id = ? + `, + entry.Database, entry.DatabaseType, entry.Host, entry.Port, + entry.BackupType, entry.SizeBytes, entry.SHA256, entry.Compression, + entry.Encrypted, entry.Duration, entry.Status, entry.VerifiedAt, + entry.VerifyValid, entry.DrillTestedAt, entry.DrillSuccess, + entry.CloudLocation, entry.RetentionPolicy, + string(tagsJSON), string(metaJSON), entry.ID, + ) + if err != nil { + return fmt.Errorf("failed to update catalog entry: %w", err) + } + return nil +} + +// Delete removes a backup entry +func (c *SQLiteCatalog) Delete(ctx context.Context, id int64) error { + _, err := c.db.ExecContext(ctx, "DELETE FROM backups WHERE id = ?", id) + if err != nil { + return fmt.Errorf("failed to delete catalog entry: %w", err) + } + return nil +} + +// Get retrieves a backup entry by ID +func (c *SQLiteCatalog) Get(ctx context.Context, id int64) (*Entry, error) { + row := c.db.QueryRowContext(ctx, ` + SELECT id, database, database_type, host, port, backup_path, backup_type, + size_bytes, sha256, compression, encrypted, created_at, duration, + status, verified_at, verify_valid, drill_tested_at, drill_success, + cloud_location, retention_policy, tags, metadata + FROM backups WHERE id = ? + `, id) + + return c.scanEntry(row) +} + +// GetByPath retrieves a backup entry by file path +func (c *SQLiteCatalog) GetByPath(ctx context.Context, path string) (*Entry, error) { + row := c.db.QueryRowContext(ctx, ` + SELECT id, database, database_type, host, port, backup_path, backup_type, + size_bytes, sha256, compression, encrypted, created_at, duration, + status, verified_at, verify_valid, drill_tested_at, drill_success, + cloud_location, retention_policy, tags, metadata + FROM backups WHERE backup_path = ? + `, path) + + return c.scanEntry(row) +} + +// scanEntry scans a row into an Entry struct +func (c *SQLiteCatalog) scanEntry(row *sql.Row) (*Entry, error) { + var entry Entry + var tagsJSON, metaJSON sql.NullString + var verifiedAt, drillTestedAt sql.NullTime + var verifyValid, drillSuccess sql.NullBool + + err := row.Scan( + &entry.ID, &entry.Database, &entry.DatabaseType, &entry.Host, &entry.Port, + &entry.BackupPath, &entry.BackupType, &entry.SizeBytes, &entry.SHA256, + &entry.Compression, &entry.Encrypted, &entry.CreatedAt, &entry.Duration, + &entry.Status, &verifiedAt, &verifyValid, &drillTestedAt, &drillSuccess, + &entry.CloudLocation, &entry.RetentionPolicy, &tagsJSON, &metaJSON, + ) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("failed to scan entry: %w", err) + } + + if verifiedAt.Valid { + entry.VerifiedAt = &verifiedAt.Time + } + if verifyValid.Valid { + entry.VerifyValid = &verifyValid.Bool + } + if drillTestedAt.Valid { + entry.DrillTestedAt = &drillTestedAt.Time + } + if drillSuccess.Valid { + entry.DrillSuccess = &drillSuccess.Bool + } + + if tagsJSON.Valid && tagsJSON.String != "" { + json.Unmarshal([]byte(tagsJSON.String), &entry.Tags) + } + if metaJSON.Valid && metaJSON.String != "" { + json.Unmarshal([]byte(metaJSON.String), &entry.Metadata) + } + + return &entry, nil +} + +// Search finds backup entries matching the query +func (c *SQLiteCatalog) Search(ctx context.Context, query *SearchQuery) ([]*Entry, error) { + where, args := c.buildSearchQuery(query) + + orderBy := "created_at DESC" + if query.OrderBy != "" { + orderBy = query.OrderBy + if query.OrderDesc { + orderBy += " DESC" + } + } + + sql := fmt.Sprintf(` + SELECT id, database, database_type, host, port, backup_path, backup_type, + size_bytes, sha256, compression, encrypted, created_at, duration, + status, verified_at, verify_valid, drill_tested_at, drill_success, + cloud_location, retention_policy, tags, metadata + FROM backups + %s + ORDER BY %s + `, where, orderBy) + + if query.Limit > 0 { + sql += fmt.Sprintf(" LIMIT %d", query.Limit) + if query.Offset > 0 { + sql += fmt.Sprintf(" OFFSET %d", query.Offset) + } + } + + rows, err := c.db.QueryContext(ctx, sql, args...) + if err != nil { + return nil, fmt.Errorf("search query failed: %w", err) + } + defer rows.Close() + + return c.scanEntries(rows) +} + +// scanEntries scans multiple rows into Entry slices +func (c *SQLiteCatalog) scanEntries(rows *sql.Rows) ([]*Entry, error) { + var entries []*Entry + + for rows.Next() { + var entry Entry + var tagsJSON, metaJSON sql.NullString + var verifiedAt, drillTestedAt sql.NullTime + var verifyValid, drillSuccess sql.NullBool + + err := rows.Scan( + &entry.ID, &entry.Database, &entry.DatabaseType, &entry.Host, &entry.Port, + &entry.BackupPath, &entry.BackupType, &entry.SizeBytes, &entry.SHA256, + &entry.Compression, &entry.Encrypted, &entry.CreatedAt, &entry.Duration, + &entry.Status, &verifiedAt, &verifyValid, &drillTestedAt, &drillSuccess, + &entry.CloudLocation, &entry.RetentionPolicy, &tagsJSON, &metaJSON, + ) + if err != nil { + return nil, fmt.Errorf("failed to scan row: %w", err) + } + + if verifiedAt.Valid { + entry.VerifiedAt = &verifiedAt.Time + } + if verifyValid.Valid { + entry.VerifyValid = &verifyValid.Bool + } + if drillTestedAt.Valid { + entry.DrillTestedAt = &drillTestedAt.Time + } + if drillSuccess.Valid { + entry.DrillSuccess = &drillSuccess.Bool + } + + if tagsJSON.Valid && tagsJSON.String != "" { + json.Unmarshal([]byte(tagsJSON.String), &entry.Tags) + } + if metaJSON.Valid && metaJSON.String != "" { + json.Unmarshal([]byte(metaJSON.String), &entry.Metadata) + } + + entries = append(entries, &entry) + } + + return entries, rows.Err() +} + +// buildSearchQuery builds the WHERE clause from a SearchQuery +func (c *SQLiteCatalog) buildSearchQuery(query *SearchQuery) (string, []interface{}) { + var conditions []string + var args []interface{} + + if query.Database != "" { + if strings.Contains(query.Database, "*") { + conditions = append(conditions, "database LIKE ?") + args = append(args, strings.ReplaceAll(query.Database, "*", "%")) + } else { + conditions = append(conditions, "database = ?") + args = append(args, query.Database) + } + } + + if query.DatabaseType != "" { + conditions = append(conditions, "database_type = ?") + args = append(args, query.DatabaseType) + } + + if query.Host != "" { + conditions = append(conditions, "host = ?") + args = append(args, query.Host) + } + + if query.Status != "" { + conditions = append(conditions, "status = ?") + args = append(args, query.Status) + } + + if query.StartDate != nil { + conditions = append(conditions, "created_at >= ?") + args = append(args, *query.StartDate) + } + + if query.EndDate != nil { + conditions = append(conditions, "created_at <= ?") + args = append(args, *query.EndDate) + } + + if query.MinSize > 0 { + conditions = append(conditions, "size_bytes >= ?") + args = append(args, query.MinSize) + } + + if query.MaxSize > 0 { + conditions = append(conditions, "size_bytes <= ?") + args = append(args, query.MaxSize) + } + + if query.BackupType != "" { + conditions = append(conditions, "backup_type = ?") + args = append(args, query.BackupType) + } + + if query.Encrypted != nil { + conditions = append(conditions, "encrypted = ?") + args = append(args, *query.Encrypted) + } + + if query.Verified != nil { + if *query.Verified { + conditions = append(conditions, "verified_at IS NOT NULL AND verify_valid = 1") + } else { + conditions = append(conditions, "verified_at IS NULL") + } + } + + if query.DrillTested != nil { + if *query.DrillTested { + conditions = append(conditions, "drill_tested_at IS NOT NULL AND drill_success = 1") + } else { + conditions = append(conditions, "drill_tested_at IS NULL") + } + } + + if len(conditions) == 0 { + return "", nil + } + + return "WHERE " + strings.Join(conditions, " AND "), args +} + +// List returns recent backups for a database +func (c *SQLiteCatalog) List(ctx context.Context, database string, limit int) ([]*Entry, error) { + query := &SearchQuery{ + Database: database, + Limit: limit, + OrderBy: "created_at", + OrderDesc: true, + } + return c.Search(ctx, query) +} + +// ListDatabases returns all unique database names +func (c *SQLiteCatalog) ListDatabases(ctx context.Context) ([]string, error) { + rows, err := c.db.QueryContext(ctx, "SELECT DISTINCT database FROM backups ORDER BY database") + if err != nil { + return nil, fmt.Errorf("failed to list databases: %w", err) + } + defer rows.Close() + + var databases []string + for rows.Next() { + var db string + if err := rows.Scan(&db); err != nil { + return nil, err + } + databases = append(databases, db) + } + + return databases, rows.Err() +} + +// Count returns the number of entries matching the query +func (c *SQLiteCatalog) Count(ctx context.Context, query *SearchQuery) (int64, error) { + where, args := c.buildSearchQuery(query) + + sql := "SELECT COUNT(*) FROM backups " + where + + var count int64 + err := c.db.QueryRowContext(ctx, sql, args...).Scan(&count) + if err != nil { + return 0, fmt.Errorf("count query failed: %w", err) + } + + return count, nil +} + +// Stats returns overall catalog statistics +func (c *SQLiteCatalog) Stats(ctx context.Context) (*Stats, error) { + stats := &Stats{ + ByDatabase: make(map[string]int64), + ByType: make(map[string]int64), + ByStatus: make(map[string]int64), + } + + // Basic stats + row := c.db.QueryRowContext(ctx, ` + SELECT + COUNT(*), + COALESCE(SUM(size_bytes), 0), + MIN(created_at), + MAX(created_at), + COALESCE(AVG(duration), 0), + CAST(COALESCE(AVG(size_bytes), 0) AS INTEGER), + SUM(CASE WHEN verified_at IS NOT NULL THEN 1 ELSE 0 END), + SUM(CASE WHEN drill_tested_at IS NOT NULL THEN 1 ELSE 0 END) + FROM backups WHERE status != 'deleted' + `) + + var oldest, newest sql.NullString + err := row.Scan( + &stats.TotalBackups, &stats.TotalSize, &oldest, &newest, + &stats.AvgDuration, &stats.AvgSize, + &stats.VerifiedCount, &stats.DrillTestedCount, + ) + if err != nil { + return nil, fmt.Errorf("failed to get stats: %w", err) + } + + if oldest.Valid { + if t, err := time.Parse(time.RFC3339Nano, oldest.String); err == nil { + stats.OldestBackup = &t + } else if t, err := time.Parse("2006-01-02 15:04:05.999999999-07:00", oldest.String); err == nil { + stats.OldestBackup = &t + } else if t, err := time.Parse("2006-01-02T15:04:05Z", oldest.String); err == nil { + stats.OldestBackup = &t + } + } + if newest.Valid { + if t, err := time.Parse(time.RFC3339Nano, newest.String); err == nil { + stats.NewestBackup = &t + } else if t, err := time.Parse("2006-01-02 15:04:05.999999999-07:00", newest.String); err == nil { + stats.NewestBackup = &t + } else if t, err := time.Parse("2006-01-02T15:04:05Z", newest.String); err == nil { + stats.NewestBackup = &t + } + } + stats.TotalSizeHuman = FormatSize(stats.TotalSize) + + // By database + rows, _ := c.db.QueryContext(ctx, "SELECT database, COUNT(*) FROM backups GROUP BY database") + defer rows.Close() + for rows.Next() { + var db string + var count int64 + rows.Scan(&db, &count) + stats.ByDatabase[db] = count + } + + // By type + rows, _ = c.db.QueryContext(ctx, "SELECT backup_type, COUNT(*) FROM backups GROUP BY backup_type") + defer rows.Close() + for rows.Next() { + var t string + var count int64 + rows.Scan(&t, &count) + stats.ByType[t] = count + } + + // By status + rows, _ = c.db.QueryContext(ctx, "SELECT status, COUNT(*) FROM backups GROUP BY status") + defer rows.Close() + for rows.Next() { + var s string + var count int64 + rows.Scan(&s, &count) + stats.ByStatus[s] = count + } + + return stats, nil +} + +// StatsByDatabase returns statistics for a specific database +func (c *SQLiteCatalog) StatsByDatabase(ctx context.Context, database string) (*Stats, error) { + stats := &Stats{ + ByDatabase: make(map[string]int64), + ByType: make(map[string]int64), + ByStatus: make(map[string]int64), + } + + row := c.db.QueryRowContext(ctx, ` + SELECT + COUNT(*), + COALESCE(SUM(size_bytes), 0), + MIN(created_at), + MAX(created_at), + COALESCE(AVG(duration), 0), + COALESCE(AVG(size_bytes), 0), + SUM(CASE WHEN verified_at IS NOT NULL THEN 1 ELSE 0 END), + SUM(CASE WHEN drill_tested_at IS NOT NULL THEN 1 ELSE 0 END) + FROM backups WHERE database = ? AND status != 'deleted' + `, database) + + var oldest, newest sql.NullTime + err := row.Scan( + &stats.TotalBackups, &stats.TotalSize, &oldest, &newest, + &stats.AvgDuration, &stats.AvgSize, + &stats.VerifiedCount, &stats.DrillTestedCount, + ) + if err != nil { + return nil, fmt.Errorf("failed to get database stats: %w", err) + } + + if oldest.Valid { + stats.OldestBackup = &oldest.Time + } + if newest.Valid { + stats.NewestBackup = &newest.Time + } + stats.TotalSizeHuman = FormatSize(stats.TotalSize) + + return stats, nil +} + +// MarkVerified updates the verification status of a backup +func (c *SQLiteCatalog) MarkVerified(ctx context.Context, id int64, valid bool) error { + status := StatusVerified + if !valid { + status = StatusCorrupted + } + + _, err := c.db.ExecContext(ctx, ` + UPDATE backups SET + verified_at = CURRENT_TIMESTAMP, + verify_valid = ?, + status = ?, + updated_at = CURRENT_TIMESTAMP + WHERE id = ? + `, valid, status, id) + + return err +} + +// MarkDrillTested updates the drill test status of a backup +func (c *SQLiteCatalog) MarkDrillTested(ctx context.Context, id int64, success bool) error { + _, err := c.db.ExecContext(ctx, ` + UPDATE backups SET + drill_tested_at = CURRENT_TIMESTAMP, + drill_success = ?, + updated_at = CURRENT_TIMESTAMP + WHERE id = ? + `, success, id) + + return err +} + +// Prune removes entries older than the given time +func (c *SQLiteCatalog) Prune(ctx context.Context, before time.Time) (int, error) { + result, err := c.db.ExecContext(ctx, + "DELETE FROM backups WHERE created_at < ? AND status = 'deleted'", + before, + ) + if err != nil { + return 0, fmt.Errorf("prune failed: %w", err) + } + + affected, _ := result.RowsAffected() + return int(affected), nil +} + +// Vacuum optimizes the database +func (c *SQLiteCatalog) Vacuum(ctx context.Context) error { + _, err := c.db.ExecContext(ctx, "VACUUM") + return err +} + +// Close closes the database connection +func (c *SQLiteCatalog) Close() error { + return c.db.Close() +} diff --git a/internal/catalog/sync.go b/internal/catalog/sync.go new file mode 100644 index 0000000..ca52086 --- /dev/null +++ b/internal/catalog/sync.go @@ -0,0 +1,234 @@ +// Package catalog - Sync functionality for importing backups into catalog +package catalog + +import ( + "context" + "database/sql" + "fmt" + "os" + "path/filepath" + "strings" + "time" + + "dbbackup/internal/metadata" +) + +// SyncFromDirectory scans a directory and imports backup metadata into the catalog +func (c *SQLiteCatalog) SyncFromDirectory(ctx context.Context, dir string) (*SyncResult, error) { + start := time.Now() + result := &SyncResult{} + + // Find all metadata files + pattern := filepath.Join(dir, "*.meta.json") + matches, err := filepath.Glob(pattern) + if err != nil { + return nil, fmt.Errorf("failed to scan directory: %w", err) + } + + // Also check subdirectories + subPattern := filepath.Join(dir, "*", "*.meta.json") + subMatches, _ := filepath.Glob(subPattern) + matches = append(matches, subMatches...) + + for _, metaPath := range matches { + // Derive backup file path from metadata path + backupPath := strings.TrimSuffix(metaPath, ".meta.json") + + // Check if backup file exists + if _, err := os.Stat(backupPath); os.IsNotExist(err) { + result.Details = append(result.Details, + fmt.Sprintf("SKIP: %s (backup file missing)", filepath.Base(backupPath))) + continue + } + + // Load metadata + meta, err := metadata.Load(backupPath) + if err != nil { + result.Errors++ + result.Details = append(result.Details, + fmt.Sprintf("ERROR: %s - %v", filepath.Base(backupPath), err)) + continue + } + + // Check if already in catalog + existing, _ := c.GetByPath(ctx, backupPath) + if existing != nil { + // Update if metadata changed + if existing.SHA256 != meta.SHA256 || existing.SizeBytes != meta.SizeBytes { + entry := metadataToEntry(meta, backupPath) + entry.ID = existing.ID + if err := c.Update(ctx, entry); err != nil { + result.Errors++ + result.Details = append(result.Details, + fmt.Sprintf("ERROR updating: %s - %v", filepath.Base(backupPath), err)) + } else { + result.Updated++ + } + } + continue + } + + // Add new entry + entry := metadataToEntry(meta, backupPath) + if err := c.Add(ctx, entry); err != nil { + result.Errors++ + result.Details = append(result.Details, + fmt.Sprintf("ERROR adding: %s - %v", filepath.Base(backupPath), err)) + } else { + result.Added++ + result.Details = append(result.Details, + fmt.Sprintf("ADDED: %s (%s)", filepath.Base(backupPath), FormatSize(meta.SizeBytes))) + } + } + + // Check for removed backups (backups in catalog but not on disk) + entries, _ := c.Search(ctx, &SearchQuery{}) + for _, entry := range entries { + if !strings.HasPrefix(entry.BackupPath, dir) { + continue + } + if _, err := os.Stat(entry.BackupPath); os.IsNotExist(err) { + // Mark as deleted + entry.Status = StatusDeleted + c.Update(ctx, entry) + result.Removed++ + result.Details = append(result.Details, + fmt.Sprintf("REMOVED: %s (file not found)", filepath.Base(entry.BackupPath))) + } + } + + result.Duration = time.Since(start).Seconds() + return result, nil +} + +// SyncFromCloud imports backups from cloud storage +func (c *SQLiteCatalog) SyncFromCloud(ctx context.Context, provider, bucket, prefix string) (*SyncResult, error) { + // This will be implemented when integrating with cloud package + // For now, return a placeholder + return &SyncResult{ + Details: []string{"Cloud sync not yet implemented - use directory sync instead"}, + }, nil +} + +// metadataToEntry converts backup metadata to a catalog entry +func metadataToEntry(meta *metadata.BackupMetadata, backupPath string) *Entry { + entry := &Entry{ + Database: meta.Database, + DatabaseType: meta.DatabaseType, + Host: meta.Host, + Port: meta.Port, + BackupPath: backupPath, + BackupType: meta.BackupType, + SizeBytes: meta.SizeBytes, + SHA256: meta.SHA256, + Compression: meta.Compression, + Encrypted: meta.Encrypted, + CreatedAt: meta.Timestamp, + Duration: meta.Duration, + Status: StatusCompleted, + Metadata: meta.ExtraInfo, + } + + if entry.BackupType == "" { + entry.BackupType = "full" + } + + return entry +} + +// ImportEntry creates a catalog entry directly from backup file info +func (c *SQLiteCatalog) ImportEntry(ctx context.Context, backupPath string, info os.FileInfo, dbName, dbType string) error { + entry := &Entry{ + Database: dbName, + DatabaseType: dbType, + BackupPath: backupPath, + BackupType: "full", + SizeBytes: info.Size(), + CreatedAt: info.ModTime(), + Status: StatusCompleted, + } + + // Detect compression from extension + switch { + case strings.HasSuffix(backupPath, ".gz"): + entry.Compression = "gzip" + case strings.HasSuffix(backupPath, ".lz4"): + entry.Compression = "lz4" + case strings.HasSuffix(backupPath, ".zst"): + entry.Compression = "zstd" + } + + // Check if encrypted + if strings.Contains(backupPath, ".enc") { + entry.Encrypted = true + } + + // Try to load metadata if exists + if meta, err := metadata.Load(backupPath); err == nil { + entry.SHA256 = meta.SHA256 + entry.Duration = meta.Duration + entry.Host = meta.Host + entry.Port = meta.Port + entry.Metadata = meta.ExtraInfo + } + + return c.Add(ctx, entry) +} + +// SyncStatus returns the sync status summary +type SyncStatus struct { + LastSync *time.Time `json:"last_sync,omitempty"` + TotalEntries int64 `json:"total_entries"` + ActiveEntries int64 `json:"active_entries"` + DeletedEntries int64 `json:"deleted_entries"` + Directories []string `json:"directories"` +} + +// GetSyncStatus returns the current sync status +func (c *SQLiteCatalog) GetSyncStatus(ctx context.Context) (*SyncStatus, error) { + status := &SyncStatus{} + + // Get last sync time + var lastSync sql.NullString + c.db.QueryRowContext(ctx, "SELECT value FROM catalog_meta WHERE key = 'last_sync'").Scan(&lastSync) + if lastSync.Valid { + if t, err := time.Parse(time.RFC3339, lastSync.String); err == nil { + status.LastSync = &t + } + } + + // Count entries + c.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM backups").Scan(&status.TotalEntries) + c.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM backups WHERE status != 'deleted'").Scan(&status.ActiveEntries) + c.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM backups WHERE status = 'deleted'").Scan(&status.DeletedEntries) + + // Get unique directories + rows, _ := c.db.QueryContext(ctx, ` + SELECT DISTINCT + CASE + WHEN instr(backup_path, '/') > 0 + THEN substr(backup_path, 1, length(backup_path) - length(replace(backup_path, '/', '')) - length(substr(backup_path, length(backup_path) - length(replace(backup_path, '/', '')) + 2))) + ELSE backup_path + END as dir + FROM backups WHERE status != 'deleted' + `) + if rows != nil { + defer rows.Close() + for rows.Next() { + var dir string + rows.Scan(&dir) + status.Directories = append(status.Directories, dir) + } + } + + return status, nil +} + +// SetLastSync updates the last sync timestamp +func (c *SQLiteCatalog) SetLastSync(ctx context.Context) error { + _, err := c.db.ExecContext(ctx, ` + INSERT OR REPLACE INTO catalog_meta (key, value, updated_at) + VALUES ('last_sync', ?, CURRENT_TIMESTAMP) + `, time.Now().Format(time.RFC3339)) + return err +} diff --git a/internal/config/config.go b/internal/config/config.go index 528e9c6..b339cfc 100755 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -91,6 +91,13 @@ type Config struct { WALCompression bool // Compress WAL files WALEncryption bool // Encrypt WAL files + // MySQL PITR options + BinlogDir string // MySQL binary log directory + BinlogArchiveDir string // Directory to archive binlogs + BinlogArchiveInterval string // Interval for binlog archiving (e.g., "30s") + RequireRowFormat bool // Require ROW format for binlog + RequireGTID bool // Require GTID mode enabled + // TUI automation options (for testing) TUIAutoSelect int // Auto-select menu option (-1 = disabled) TUIAutoDatabase string // Pre-fill database name diff --git a/internal/drill/docker.go b/internal/drill/docker.go new file mode 100644 index 0000000..e07b33d --- /dev/null +++ b/internal/drill/docker.go @@ -0,0 +1,298 @@ +// Package drill - Docker container management for DR drills +package drill + +import ( + "context" + "fmt" + "os/exec" + "strings" + "time" +) + +// DockerManager handles Docker container operations for DR drills +type DockerManager struct { + verbose bool +} + +// NewDockerManager creates a new Docker manager +func NewDockerManager(verbose bool) *DockerManager { + return &DockerManager{verbose: verbose} +} + +// ContainerConfig holds Docker container configuration +type ContainerConfig struct { + Image string // Docker image (e.g., "postgres:15") + Name string // Container name + Port int // Host port to map + ContainerPort int // Container port + Environment map[string]string // Environment variables + Volumes []string // Volume mounts + Network string // Docker network + Timeout int // Startup timeout in seconds +} + +// ContainerInfo holds information about a running container +type ContainerInfo struct { + ID string + Name string + Image string + Port int + Status string + Started time.Time + Healthy bool +} + +// CheckDockerAvailable verifies Docker is installed and running +func (dm *DockerManager) CheckDockerAvailable(ctx context.Context) error { + cmd := exec.CommandContext(ctx, "docker", "version") + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("docker not available: %w (output: %s)", err, string(output)) + } + return nil +} + +// PullImage pulls a Docker image if not present +func (dm *DockerManager) PullImage(ctx context.Context, image string) error { + // Check if image exists locally + checkCmd := exec.CommandContext(ctx, "docker", "image", "inspect", image) + if err := checkCmd.Run(); err == nil { + // Image exists + return nil + } + + // Pull the image + pullCmd := exec.CommandContext(ctx, "docker", "pull", image) + output, err := pullCmd.CombinedOutput() + if err != nil { + return fmt.Errorf("failed to pull image %s: %w (output: %s)", image, err, string(output)) + } + + return nil +} + +// CreateContainer creates and starts a database container +func (dm *DockerManager) CreateContainer(ctx context.Context, config *ContainerConfig) (*ContainerInfo, error) { + args := []string{ + "run", "-d", + "--name", config.Name, + "-p", fmt.Sprintf("%d:%d", config.Port, config.ContainerPort), + } + + // Add environment variables + for k, v := range config.Environment { + args = append(args, "-e", fmt.Sprintf("%s=%s", k, v)) + } + + // Add volumes + for _, v := range config.Volumes { + args = append(args, "-v", v) + } + + // Add network if specified + if config.Network != "" { + args = append(args, "--network", config.Network) + } + + // Add image + args = append(args, config.Image) + + cmd := exec.CommandContext(ctx, "docker", args...) + output, err := cmd.CombinedOutput() + if err != nil { + return nil, fmt.Errorf("failed to create container: %w (output: %s)", err, string(output)) + } + + containerID := strings.TrimSpace(string(output)) + + return &ContainerInfo{ + ID: containerID, + Name: config.Name, + Image: config.Image, + Port: config.Port, + Status: "created", + Started: time.Now(), + }, nil +} + +// WaitForHealth waits for container to be healthy +func (dm *DockerManager) WaitForHealth(ctx context.Context, containerID string, dbType string, timeout int) error { + deadline := time.Now().Add(time.Duration(timeout) * time.Second) + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + if time.Now().After(deadline) { + return fmt.Errorf("timeout waiting for container to be healthy") + } + + // Check container health + healthCmd := dm.healthCheckCommand(dbType) + args := append([]string{"exec", containerID}, healthCmd...) + cmd := exec.CommandContext(ctx, "docker", args...) + if err := cmd.Run(); err == nil { + return nil // Container is healthy + } + } + } +} + +// healthCheckCommand returns the health check command for a database type +func (dm *DockerManager) healthCheckCommand(dbType string) []string { + switch dbType { + case "postgresql", "postgres": + return []string{"pg_isready", "-U", "postgres"} + case "mysql": + return []string{"mysqladmin", "ping", "-h", "localhost", "-u", "root", "--password=root"} + case "mariadb": + return []string{"mariadb-admin", "ping", "-h", "localhost", "-u", "root", "--password=root"} + default: + return []string{"echo", "ok"} + } +} + +// ExecCommand executes a command inside the container +func (dm *DockerManager) ExecCommand(ctx context.Context, containerID string, command []string) (string, error) { + args := append([]string{"exec", containerID}, command...) + cmd := exec.CommandContext(ctx, "docker", args...) + output, err := cmd.CombinedOutput() + if err != nil { + return string(output), fmt.Errorf("exec failed: %w", err) + } + return string(output), nil +} + +// CopyToContainer copies a file to the container +func (dm *DockerManager) CopyToContainer(ctx context.Context, containerID, src, dest string) error { + cmd := exec.CommandContext(ctx, "docker", "cp", src, fmt.Sprintf("%s:%s", containerID, dest)) + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("copy failed: %w (output: %s)", err, string(output)) + } + return nil +} + +// StopContainer stops a running container +func (dm *DockerManager) StopContainer(ctx context.Context, containerID string) error { + cmd := exec.CommandContext(ctx, "docker", "stop", containerID) + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("failed to stop container: %w (output: %s)", err, string(output)) + } + return nil +} + +// RemoveContainer removes a container +func (dm *DockerManager) RemoveContainer(ctx context.Context, containerID string) error { + cmd := exec.CommandContext(ctx, "docker", "rm", "-f", containerID) + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("failed to remove container: %w (output: %s)", err, string(output)) + } + return nil +} + +// GetContainerLogs retrieves container logs +func (dm *DockerManager) GetContainerLogs(ctx context.Context, containerID string, tail int) (string, error) { + args := []string{"logs"} + if tail > 0 { + args = append(args, "--tail", fmt.Sprintf("%d", tail)) + } + args = append(args, containerID) + + cmd := exec.CommandContext(ctx, "docker", args...) + output, err := cmd.CombinedOutput() + if err != nil { + return "", fmt.Errorf("failed to get logs: %w", err) + } + return string(output), nil +} + +// ListDrillContainers lists all containers created by drill operations +func (dm *DockerManager) ListDrillContainers(ctx context.Context) ([]*ContainerInfo, error) { + cmd := exec.CommandContext(ctx, "docker", "ps", "-a", + "--filter", "name=drill_", + "--format", "{{.ID}}\t{{.Names}}\t{{.Image}}\t{{.Status}}") + + output, err := cmd.CombinedOutput() + if err != nil { + return nil, fmt.Errorf("failed to list containers: %w", err) + } + + var containers []*ContainerInfo + lines := strings.Split(strings.TrimSpace(string(output)), "\n") + for _, line := range lines { + if line == "" { + continue + } + parts := strings.Split(line, "\t") + if len(parts) >= 4 { + containers = append(containers, &ContainerInfo{ + ID: parts[0], + Name: parts[1], + Image: parts[2], + Status: parts[3], + }) + } + } + + return containers, nil +} + +// GetDefaultImage returns the default Docker image for a database type +func GetDefaultImage(dbType, version string) string { + if version == "" { + version = "latest" + } + + switch dbType { + case "postgresql", "postgres": + return fmt.Sprintf("postgres:%s", version) + case "mysql": + return fmt.Sprintf("mysql:%s", version) + case "mariadb": + return fmt.Sprintf("mariadb:%s", version) + default: + return "" + } +} + +// GetDefaultPort returns the default port for a database type +func GetDefaultPort(dbType string) int { + switch dbType { + case "postgresql", "postgres": + return 5432 + case "mysql", "mariadb": + return 3306 + default: + return 0 + } +} + +// GetDefaultEnvironment returns default environment variables for a database container +func GetDefaultEnvironment(dbType string) map[string]string { + switch dbType { + case "postgresql", "postgres": + return map[string]string{ + "POSTGRES_PASSWORD": "drill_test_password", + "POSTGRES_USER": "postgres", + "POSTGRES_DB": "postgres", + } + case "mysql": + return map[string]string{ + "MYSQL_ROOT_PASSWORD": "root", + "MYSQL_DATABASE": "test", + } + case "mariadb": + return map[string]string{ + "MARIADB_ROOT_PASSWORD": "root", + "MARIADB_DATABASE": "test", + } + default: + return map[string]string{} + } +} diff --git a/internal/drill/drill.go b/internal/drill/drill.go new file mode 100644 index 0000000..e3bf2a8 --- /dev/null +++ b/internal/drill/drill.go @@ -0,0 +1,247 @@ +// Package drill provides Disaster Recovery drill functionality +// for testing backup restorability in isolated environments +package drill + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "time" +) + +// DrillConfig holds configuration for a DR drill +type DrillConfig struct { + // Backup configuration + BackupPath string `json:"backup_path"` + DatabaseName string `json:"database_name"` + DatabaseType string `json:"database_type"` // postgresql, mysql, mariadb + + // Docker configuration + ContainerImage string `json:"container_image"` // e.g., "postgres:15" + ContainerName string `json:"container_name"` // Generated if empty + ContainerPort int `json:"container_port"` // Host port mapping + ContainerTimeout int `json:"container_timeout"` // Startup timeout in seconds + CleanupOnExit bool `json:"cleanup_on_exit"` // Remove container after drill + KeepOnFailure bool `json:"keep_on_failure"` // Keep container if drill fails + + // Validation configuration + ValidationQueries []ValidationQuery `json:"validation_queries"` + MinRowCount int64 `json:"min_row_count"` // Minimum rows expected + ExpectedTables []string `json:"expected_tables"` // Tables that must exist + CustomChecks []CustomCheck `json:"custom_checks"` + + // Encryption (if backup is encrypted) + EncryptionKeyFile string `json:"encryption_key_file,omitempty"` + EncryptionKeyEnv string `json:"encryption_key_env,omitempty"` + + // Performance thresholds + MaxRestoreSeconds int `json:"max_restore_seconds"` // RTO threshold + MaxQuerySeconds int `json:"max_query_seconds"` // Query timeout + + // Output + OutputDir string `json:"output_dir"` // Directory for drill reports + ReportFormat string `json:"report_format"` // json, markdown, html + Verbose bool `json:"verbose"` +} + +// ValidationQuery represents a SQL query to validate restored data +type ValidationQuery struct { + Name string `json:"name"` // Human-readable name + Query string `json:"query"` // SQL query + ExpectedValue string `json:"expected_value"` // Expected result (optional) + MinValue int64 `json:"min_value"` // Minimum expected value + MaxValue int64 `json:"max_value"` // Maximum expected value + MustSucceed bool `json:"must_succeed"` // Fail drill if query fails +} + +// CustomCheck represents a custom validation check +type CustomCheck struct { + Name string `json:"name"` + Type string `json:"type"` // row_count, table_exists, column_check + Table string `json:"table"` + Column string `json:"column,omitempty"` + Condition string `json:"condition,omitempty"` // SQL condition + MinValue int64 `json:"min_value,omitempty"` + MustSucceed bool `json:"must_succeed"` +} + +// DrillResult contains the complete result of a DR drill +type DrillResult struct { + // Identification + DrillID string `json:"drill_id"` + StartTime time.Time `json:"start_time"` + EndTime time.Time `json:"end_time"` + Duration float64 `json:"duration_seconds"` + + // Configuration + BackupPath string `json:"backup_path"` + DatabaseName string `json:"database_name"` + DatabaseType string `json:"database_type"` + + // Overall status + Success bool `json:"success"` + Status DrillStatus `json:"status"` + Message string `json:"message"` + + // Phase timings + Phases []DrillPhase `json:"phases"` + + // Validation results + ValidationResults []ValidationResult `json:"validation_results"` + CheckResults []CheckResult `json:"check_results"` + + // Database metrics + TableCount int `json:"table_count"` + TotalRows int64 `json:"total_rows"` + DatabaseSize int64 `json:"database_size_bytes"` + + // Performance metrics + RestoreTime float64 `json:"restore_time_seconds"` + ValidationTime float64 `json:"validation_time_seconds"` + QueryTimeAvg float64 `json:"query_time_avg_ms"` + + // RTO/RPO metrics + ActualRTO float64 `json:"actual_rto_seconds"` // Total time to usable database + TargetRTO float64 `json:"target_rto_seconds"` + RTOMet bool `json:"rto_met"` + + // Container info + ContainerID string `json:"container_id,omitempty"` + ContainerKept bool `json:"container_kept"` + + // Errors and warnings + Errors []string `json:"errors,omitempty"` + Warnings []string `json:"warnings,omitempty"` +} + +// DrillStatus represents the current status of a drill +type DrillStatus string + +const ( + StatusPending DrillStatus = "pending" + StatusRunning DrillStatus = "running" + StatusCompleted DrillStatus = "completed" + StatusFailed DrillStatus = "failed" + StatusAborted DrillStatus = "aborted" + StatusPartial DrillStatus = "partial" // Some validations failed +) + +// DrillPhase represents a phase in the drill process +type DrillPhase struct { + Name string `json:"name"` + Status string `json:"status"` // pending, running, completed, failed, skipped + StartTime time.Time `json:"start_time"` + EndTime time.Time `json:"end_time"` + Duration float64 `json:"duration_seconds"` + Message string `json:"message,omitempty"` +} + +// ValidationResult holds the result of a validation query +type ValidationResult struct { + Name string `json:"name"` + Query string `json:"query"` + Success bool `json:"success"` + Result string `json:"result,omitempty"` + Expected string `json:"expected,omitempty"` + Duration float64 `json:"duration_ms"` + Error string `json:"error,omitempty"` +} + +// CheckResult holds the result of a custom check +type CheckResult struct { + Name string `json:"name"` + Type string `json:"type"` + Success bool `json:"success"` + Actual int64 `json:"actual,omitempty"` + Expected int64 `json:"expected,omitempty"` + Message string `json:"message"` +} + +// DefaultConfig returns a DrillConfig with sensible defaults +func DefaultConfig() *DrillConfig { + return &DrillConfig{ + ContainerTimeout: 60, + CleanupOnExit: true, + KeepOnFailure: true, + MaxRestoreSeconds: 300, // 5 minutes + MaxQuerySeconds: 30, + ReportFormat: "json", + Verbose: false, + ValidationQueries: []ValidationQuery{}, + ExpectedTables: []string{}, + CustomChecks: []CustomCheck{}, + } +} + +// NewDrillID generates a unique drill ID +func NewDrillID() string { + return fmt.Sprintf("drill_%s", time.Now().Format("20060102_150405")) +} + +// SaveResult saves the drill result to a file +func (r *DrillResult) SaveResult(outputDir string) error { + if err := os.MkdirAll(outputDir, 0755); err != nil { + return fmt.Errorf("failed to create output directory: %w", err) + } + + filename := fmt.Sprintf("%s_report.json", r.DrillID) + filepath := filepath.Join(outputDir, filename) + + data, err := json.MarshalIndent(r, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal result: %w", err) + } + + if err := os.WriteFile(filepath, data, 0644); err != nil { + return fmt.Errorf("failed to write result file: %w", err) + } + + return nil +} + +// LoadResult loads a drill result from a file +func LoadResult(filepath string) (*DrillResult, error) { + data, err := os.ReadFile(filepath) + if err != nil { + return nil, fmt.Errorf("failed to read result file: %w", err) + } + + var result DrillResult + if err := json.Unmarshal(data, &result); err != nil { + return nil, fmt.Errorf("failed to parse result: %w", err) + } + + return &result, nil +} + +// IsSuccess returns true if the drill was successful +func (r *DrillResult) IsSuccess() bool { + return r.Success && r.Status == StatusCompleted +} + +// Summary returns a human-readable summary of the drill +func (r *DrillResult) Summary() string { + status := "✅ PASSED" + if !r.Success { + status = "❌ FAILED" + } else if r.Status == StatusPartial { + status = "⚠️ PARTIAL" + } + + return fmt.Sprintf("%s - %s (%.2fs) - %d tables, %d rows", + status, r.DatabaseName, r.Duration, r.TableCount, r.TotalRows) +} + +// Drill is the interface for DR drill operations +type Drill interface { + // Run executes the full DR drill + Run(ctx context.Context, config *DrillConfig) (*DrillResult, error) + + // Validate runs validation queries against an existing database + Validate(ctx context.Context, config *DrillConfig) ([]ValidationResult, error) + + // Cleanup removes drill resources (containers, temp files) + Cleanup(ctx context.Context, drillID string) error +} diff --git a/internal/drill/engine.go b/internal/drill/engine.go new file mode 100644 index 0000000..5ad6c80 --- /dev/null +++ b/internal/drill/engine.go @@ -0,0 +1,532 @@ +// Package drill - Main drill execution engine +package drill + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + "time" + + "dbbackup/internal/logger" +) + +// Engine executes DR drills +type Engine struct { + docker *DockerManager + log logger.Logger + verbose bool +} + +// NewEngine creates a new drill engine +func NewEngine(log logger.Logger, verbose bool) *Engine { + return &Engine{ + docker: NewDockerManager(verbose), + log: log, + verbose: verbose, + } +} + +// Run executes a complete DR drill +func (e *Engine) Run(ctx context.Context, config *DrillConfig) (*DrillResult, error) { + result := &DrillResult{ + DrillID: NewDrillID(), + StartTime: time.Now(), + BackupPath: config.BackupPath, + DatabaseName: config.DatabaseName, + DatabaseType: config.DatabaseType, + Status: StatusRunning, + Phases: make([]DrillPhase, 0), + TargetRTO: float64(config.MaxRestoreSeconds), + } + + e.log.Info("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") + e.log.Info(" 🧪 DR Drill: " + result.DrillID) + e.log.Info("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") + e.log.Info("") + + // Cleanup function for error cases + var containerID string + cleanup := func() { + if containerID != "" && config.CleanupOnExit && (result.Success || !config.KeepOnFailure) { + e.log.Info("🗑️ Cleaning up container...") + e.docker.RemoveContainer(context.Background(), containerID) + } else if containerID != "" { + result.ContainerKept = true + e.log.Info("📦 Container kept for debugging: " + containerID) + } + } + defer cleanup() + + // Phase 1: Preflight checks + phase := e.startPhase("Preflight Checks") + if err := e.preflightChecks(ctx, config); err != nil { + e.failPhase(&phase, err.Error()) + result.Phases = append(result.Phases, phase) + result.Status = StatusFailed + result.Message = "Preflight checks failed: " + err.Error() + result.Errors = append(result.Errors, err.Error()) + e.finalize(result) + return result, nil + } + e.completePhase(&phase, "All checks passed") + result.Phases = append(result.Phases, phase) + + // Phase 2: Start container + phase = e.startPhase("Start Container") + containerConfig := e.buildContainerConfig(config) + container, err := e.docker.CreateContainer(ctx, containerConfig) + if err != nil { + e.failPhase(&phase, err.Error()) + result.Phases = append(result.Phases, phase) + result.Status = StatusFailed + result.Message = "Failed to start container: " + err.Error() + result.Errors = append(result.Errors, err.Error()) + e.finalize(result) + return result, nil + } + containerID = container.ID + result.ContainerID = containerID + e.log.Info("📦 Container started: " + containerID[:12]) + + // Wait for container to be healthy + if err := e.docker.WaitForHealth(ctx, containerID, config.DatabaseType, config.ContainerTimeout); err != nil { + e.failPhase(&phase, "Container health check failed: "+err.Error()) + result.Phases = append(result.Phases, phase) + result.Status = StatusFailed + result.Message = "Container failed to start" + result.Errors = append(result.Errors, err.Error()) + e.finalize(result) + return result, nil + } + e.completePhase(&phase, "Container healthy") + result.Phases = append(result.Phases, phase) + + // Phase 3: Restore backup + phase = e.startPhase("Restore Backup") + restoreStart := time.Now() + if err := e.restoreBackup(ctx, config, containerID, containerConfig); err != nil { + e.failPhase(&phase, err.Error()) + result.Phases = append(result.Phases, phase) + result.Status = StatusFailed + result.Message = "Restore failed: " + err.Error() + result.Errors = append(result.Errors, err.Error()) + e.finalize(result) + return result, nil + } + result.RestoreTime = time.Since(restoreStart).Seconds() + e.completePhase(&phase, fmt.Sprintf("Restored in %.2fs", result.RestoreTime)) + result.Phases = append(result.Phases, phase) + e.log.Info(fmt.Sprintf("✅ Backup restored in %.2fs", result.RestoreTime)) + + // Phase 4: Validate + phase = e.startPhase("Validate Database") + validateStart := time.Now() + validationErrors := e.validateDatabase(ctx, config, result, containerConfig) + result.ValidationTime = time.Since(validateStart).Seconds() + if validationErrors > 0 { + e.completePhase(&phase, fmt.Sprintf("Completed with %d errors", validationErrors)) + } else { + e.completePhase(&phase, "All validations passed") + } + result.Phases = append(result.Phases, phase) + + // Determine overall status + result.ActualRTO = result.RestoreTime + result.ValidationTime + result.RTOMet = result.ActualRTO <= result.TargetRTO + + criticalFailures := 0 + for _, vr := range result.ValidationResults { + if !vr.Success { + criticalFailures++ + } + } + for _, cr := range result.CheckResults { + if !cr.Success { + criticalFailures++ + } + } + + if criticalFailures == 0 { + result.Success = true + result.Status = StatusCompleted + result.Message = "DR drill completed successfully" + } else if criticalFailures < len(result.ValidationResults)+len(result.CheckResults) { + result.Success = false + result.Status = StatusPartial + result.Message = fmt.Sprintf("DR drill completed with %d validation failures", criticalFailures) + } else { + result.Success = false + result.Status = StatusFailed + result.Message = "All validations failed" + } + + e.finalize(result) + + // Save result if output dir specified + if config.OutputDir != "" { + if err := result.SaveResult(config.OutputDir); err != nil { + e.log.Warn("Failed to save drill result", "error", err) + } else { + e.log.Info("📄 Report saved to: " + filepath.Join(config.OutputDir, result.DrillID+"_report.json")) + } + } + + return result, nil +} + +// preflightChecks runs preflight checks before the drill +func (e *Engine) preflightChecks(ctx context.Context, config *DrillConfig) error { + // Check Docker is available + if err := e.docker.CheckDockerAvailable(ctx); err != nil { + return fmt.Errorf("docker not available: %w", err) + } + e.log.Info("✓ Docker is available") + + // Check backup file exists + if _, err := os.Stat(config.BackupPath); err != nil { + return fmt.Errorf("backup file not found: %s", config.BackupPath) + } + e.log.Info("✓ Backup file exists: " + filepath.Base(config.BackupPath)) + + // Pull Docker image + image := config.ContainerImage + if image == "" { + image = GetDefaultImage(config.DatabaseType, "") + } + e.log.Info("⬇️ Pulling image: " + image) + if err := e.docker.PullImage(ctx, image); err != nil { + return fmt.Errorf("failed to pull image: %w", err) + } + e.log.Info("✓ Image ready: " + image) + + return nil +} + +// buildContainerConfig creates container configuration +func (e *Engine) buildContainerConfig(config *DrillConfig) *ContainerConfig { + containerName := config.ContainerName + if containerName == "" { + containerName = fmt.Sprintf("drill_%s_%s", config.DatabaseName, time.Now().Format("20060102_150405")) + } + + image := config.ContainerImage + if image == "" { + image = GetDefaultImage(config.DatabaseType, "") + } + + port := config.ContainerPort + if port == 0 { + port = 15432 // Default drill port (different from production) + if config.DatabaseType == "mysql" || config.DatabaseType == "mariadb" { + port = 13306 + } + } + + containerPort := GetDefaultPort(config.DatabaseType) + env := GetDefaultEnvironment(config.DatabaseType) + + return &ContainerConfig{ + Image: image, + Name: containerName, + Port: port, + ContainerPort: containerPort, + Environment: env, + Timeout: config.ContainerTimeout, + } +} + +// restoreBackup restores the backup into the container +func (e *Engine) restoreBackup(ctx context.Context, config *DrillConfig, containerID string, containerConfig *ContainerConfig) error { + // Copy backup to container + backupName := filepath.Base(config.BackupPath) + containerBackupPath := "/tmp/" + backupName + + e.log.Info("📁 Copying backup to container...") + if err := e.docker.CopyToContainer(ctx, containerID, config.BackupPath, containerBackupPath); err != nil { + return fmt.Errorf("failed to copy backup: %w", err) + } + + // Handle encrypted backups + if config.EncryptionKeyFile != "" { + // For encrypted backups, we'd need to decrypt first + // This is a simplified implementation + e.log.Warn("Encrypted backup handling not fully implemented in drill mode") + } + + // Restore based on database type and format + e.log.Info("🔄 Restoring backup...") + return e.executeRestore(ctx, config, containerID, containerBackupPath, containerConfig) +} + +// executeRestore runs the actual restore command +func (e *Engine) executeRestore(ctx context.Context, config *DrillConfig, containerID, backupPath string, containerConfig *ContainerConfig) error { + var cmd []string + + switch config.DatabaseType { + case "postgresql", "postgres": + // Decompress if needed + if strings.HasSuffix(backupPath, ".gz") { + decompressedPath := strings.TrimSuffix(backupPath, ".gz") + _, err := e.docker.ExecCommand(ctx, containerID, []string{ + "sh", "-c", fmt.Sprintf("gunzip -c %s > %s", backupPath, decompressedPath), + }) + if err != nil { + return fmt.Errorf("decompression failed: %w", err) + } + backupPath = decompressedPath + } + + // Create database + _, err := e.docker.ExecCommand(ctx, containerID, []string{ + "psql", "-U", "postgres", "-c", fmt.Sprintf("CREATE DATABASE %s", config.DatabaseName), + }) + if err != nil { + // Database might already exist + e.log.Debug("Create database returned (may already exist)") + } + + // Detect restore method based on file content + isCustomFormat := strings.Contains(backupPath, ".dump") || strings.Contains(backupPath, ".custom") + if isCustomFormat { + cmd = []string{"pg_restore", "-U", "postgres", "-d", config.DatabaseName, "-v", backupPath} + } else { + cmd = []string{"sh", "-c", fmt.Sprintf("psql -U postgres -d %s < %s", config.DatabaseName, backupPath)} + } + + case "mysql": + // Decompress if needed + if strings.HasSuffix(backupPath, ".gz") { + decompressedPath := strings.TrimSuffix(backupPath, ".gz") + _, err := e.docker.ExecCommand(ctx, containerID, []string{ + "sh", "-c", fmt.Sprintf("gunzip -c %s > %s", backupPath, decompressedPath), + }) + if err != nil { + return fmt.Errorf("decompression failed: %w", err) + } + backupPath = decompressedPath + } + + cmd = []string{"sh", "-c", fmt.Sprintf("mysql -u root --password=root %s < %s", config.DatabaseName, backupPath)} + + case "mariadb": + if strings.HasSuffix(backupPath, ".gz") { + decompressedPath := strings.TrimSuffix(backupPath, ".gz") + _, err := e.docker.ExecCommand(ctx, containerID, []string{ + "sh", "-c", fmt.Sprintf("gunzip -c %s > %s", backupPath, decompressedPath), + }) + if err != nil { + return fmt.Errorf("decompression failed: %w", err) + } + backupPath = decompressedPath + } + + cmd = []string{"sh", "-c", fmt.Sprintf("mariadb -u root --password=root %s < %s", config.DatabaseName, backupPath)} + + default: + return fmt.Errorf("unsupported database type: %s", config.DatabaseType) + } + + output, err := e.docker.ExecCommand(ctx, containerID, cmd) + if err != nil { + return fmt.Errorf("restore failed: %w (output: %s)", err, output) + } + + return nil +} + +// validateDatabase runs validation against the restored database +func (e *Engine) validateDatabase(ctx context.Context, config *DrillConfig, result *DrillResult, containerConfig *ContainerConfig) int { + errorCount := 0 + + // Connect to database + var user, password string + switch config.DatabaseType { + case "postgresql", "postgres": + user = "postgres" + password = containerConfig.Environment["POSTGRES_PASSWORD"] + case "mysql": + user = "root" + password = "root" + case "mariadb": + user = "root" + password = "root" + } + + validator, err := NewValidator(config.DatabaseType, "localhost", containerConfig.Port, user, password, config.DatabaseName, e.verbose) + if err != nil { + e.log.Error("Failed to connect for validation", "error", err) + result.Errors = append(result.Errors, "Validation connection failed: "+err.Error()) + return 1 + } + defer validator.Close() + + // Get database metrics + tables, err := validator.GetTableList(ctx) + if err == nil { + result.TableCount = len(tables) + e.log.Info(fmt.Sprintf("📊 Tables found: %d", result.TableCount)) + } + + totalRows, err := validator.GetTotalRowCount(ctx) + if err == nil { + result.TotalRows = totalRows + e.log.Info(fmt.Sprintf("📊 Total rows: %d", result.TotalRows)) + } + + dbSize, err := validator.GetDatabaseSize(ctx, config.DatabaseName) + if err == nil { + result.DatabaseSize = dbSize + } + + // Run expected tables check + if len(config.ExpectedTables) > 0 { + tableResults := validator.ValidateExpectedTables(ctx, config.ExpectedTables) + for _, tr := range tableResults { + result.CheckResults = append(result.CheckResults, tr) + if !tr.Success { + errorCount++ + e.log.Warn("❌ " + tr.Message) + } else { + e.log.Info("✓ " + tr.Message) + } + } + } + + // Run validation queries + if len(config.ValidationQueries) > 0 { + queryResults := validator.RunValidationQueries(ctx, config.ValidationQueries) + result.ValidationResults = append(result.ValidationResults, queryResults...) + + var totalQueryTime float64 + for _, qr := range queryResults { + totalQueryTime += qr.Duration + if !qr.Success { + errorCount++ + e.log.Warn(fmt.Sprintf("❌ %s: %s", qr.Name, qr.Error)) + } else { + e.log.Info(fmt.Sprintf("✓ %s: %s (%.0fms)", qr.Name, qr.Result, qr.Duration)) + } + } + if len(queryResults) > 0 { + result.QueryTimeAvg = totalQueryTime / float64(len(queryResults)) + } + } + + // Run custom checks + if len(config.CustomChecks) > 0 { + checkResults := validator.RunCustomChecks(ctx, config.CustomChecks) + for _, cr := range checkResults { + result.CheckResults = append(result.CheckResults, cr) + if !cr.Success { + errorCount++ + e.log.Warn("❌ " + cr.Message) + } else { + e.log.Info("✓ " + cr.Message) + } + } + } + + // Check minimum row count if specified + if config.MinRowCount > 0 && result.TotalRows < config.MinRowCount { + errorCount++ + msg := fmt.Sprintf("Total rows (%d) below minimum (%d)", result.TotalRows, config.MinRowCount) + result.Warnings = append(result.Warnings, msg) + e.log.Warn("⚠️ " + msg) + } + + return errorCount +} + +// startPhase starts a new drill phase +func (e *Engine) startPhase(name string) DrillPhase { + e.log.Info("▶️ " + name) + return DrillPhase{ + Name: name, + Status: "running", + StartTime: time.Now(), + } +} + +// completePhase marks a phase as completed +func (e *Engine) completePhase(phase *DrillPhase, message string) { + phase.EndTime = time.Now() + phase.Duration = phase.EndTime.Sub(phase.StartTime).Seconds() + phase.Status = "completed" + phase.Message = message +} + +// failPhase marks a phase as failed +func (e *Engine) failPhase(phase *DrillPhase, message string) { + phase.EndTime = time.Now() + phase.Duration = phase.EndTime.Sub(phase.StartTime).Seconds() + phase.Status = "failed" + phase.Message = message + e.log.Error("❌ Phase failed: " + message) +} + +// finalize completes the drill result +func (e *Engine) finalize(result *DrillResult) { + result.EndTime = time.Now() + result.Duration = result.EndTime.Sub(result.StartTime).Seconds() + + e.log.Info("") + e.log.Info("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") + e.log.Info(" " + result.Summary()) + e.log.Info("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") + + if result.Success { + e.log.Info(fmt.Sprintf(" RTO: %.2fs (target: %.0fs) %s", + result.ActualRTO, result.TargetRTO, boolIcon(result.RTOMet))) + } +} + +func boolIcon(b bool) string { + if b { + return "✅" + } + return "❌" +} + +// Cleanup removes drill resources +func (e *Engine) Cleanup(ctx context.Context, drillID string) error { + containers, err := e.docker.ListDrillContainers(ctx) + if err != nil { + return err + } + + for _, c := range containers { + if strings.Contains(c.Name, drillID) || (drillID == "" && strings.HasPrefix(c.Name, "drill_")) { + e.log.Info("🗑️ Removing container: " + c.Name) + if err := e.docker.RemoveContainer(ctx, c.ID); err != nil { + e.log.Warn("Failed to remove container", "id", c.ID, "error", err) + } + } + } + + return nil +} + +// QuickTest runs a quick restore test without full validation +func (e *Engine) QuickTest(ctx context.Context, backupPath, dbType, dbName string) (*DrillResult, error) { + config := DefaultConfig() + config.BackupPath = backupPath + config.DatabaseType = dbType + config.DatabaseName = dbName + config.CleanupOnExit = true + config.MaxRestoreSeconds = 600 + + return e.Run(ctx, config) +} + +// Validate runs validation queries against an existing database (non-Docker) +func (e *Engine) Validate(ctx context.Context, config *DrillConfig, host string, port int, user, password string) ([]ValidationResult, error) { + validator, err := NewValidator(config.DatabaseType, host, port, user, password, config.DatabaseName, e.verbose) + if err != nil { + return nil, err + } + defer validator.Close() + + return validator.RunValidationQueries(ctx, config.ValidationQueries), nil +} diff --git a/internal/drill/validate.go b/internal/drill/validate.go new file mode 100644 index 0000000..d156e5e --- /dev/null +++ b/internal/drill/validate.go @@ -0,0 +1,358 @@ +// Package drill - Validation logic for DR drills +package drill + +import ( + "context" + "database/sql" + "fmt" + "strings" + "time" + + _ "github.com/go-sql-driver/mysql" + _ "github.com/jackc/pgx/v5/stdlib" +) + +// Validator handles database validation during DR drills +type Validator struct { + db *sql.DB + dbType string + verbose bool +} + +// NewValidator creates a new database validator +func NewValidator(dbType string, host string, port int, user, password, dbname string, verbose bool) (*Validator, error) { + var dsn string + var driver string + + switch dbType { + case "postgresql", "postgres": + driver = "pgx" + dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", + host, port, user, password, dbname) + case "mysql": + driver = "mysql" + dsn = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true", + user, password, host, port, dbname) + case "mariadb": + driver = "mysql" + dsn = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true", + user, password, host, port, dbname) + default: + return nil, fmt.Errorf("unsupported database type: %s", dbType) + } + + db, err := sql.Open(driver, dsn) + if err != nil { + return nil, fmt.Errorf("failed to connect to database: %w", err) + } + + // Test connection + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + if err := db.PingContext(ctx); err != nil { + db.Close() + return nil, fmt.Errorf("failed to ping database: %w", err) + } + + return &Validator{ + db: db, + dbType: dbType, + verbose: verbose, + }, nil +} + +// Close closes the database connection +func (v *Validator) Close() error { + return v.db.Close() +} + +// RunValidationQueries executes validation queries and returns results +func (v *Validator) RunValidationQueries(ctx context.Context, queries []ValidationQuery) []ValidationResult { + var results []ValidationResult + + for _, q := range queries { + result := v.runQuery(ctx, q) + results = append(results, result) + } + + return results +} + +// runQuery executes a single validation query +func (v *Validator) runQuery(ctx context.Context, query ValidationQuery) ValidationResult { + result := ValidationResult{ + Name: query.Name, + Query: query.Query, + Expected: query.ExpectedValue, + } + + start := time.Now() + rows, err := v.db.QueryContext(ctx, query.Query) + result.Duration = float64(time.Since(start).Milliseconds()) + + if err != nil { + result.Success = false + result.Error = err.Error() + return result + } + defer rows.Close() + + // Get result + if rows.Next() { + var value interface{} + if err := rows.Scan(&value); err != nil { + result.Success = false + result.Error = fmt.Sprintf("scan error: %v", err) + return result + } + result.Result = fmt.Sprintf("%v", value) + } + + // Validate result + result.Success = true + if query.ExpectedValue != "" && result.Result != query.ExpectedValue { + result.Success = false + result.Error = fmt.Sprintf("expected %s, got %s", query.ExpectedValue, result.Result) + } + + // Check min/max if specified + if query.MinValue > 0 || query.MaxValue > 0 { + var numValue int64 + fmt.Sscanf(result.Result, "%d", &numValue) + + if query.MinValue > 0 && numValue < query.MinValue { + result.Success = false + result.Error = fmt.Sprintf("value %d below minimum %d", numValue, query.MinValue) + } + if query.MaxValue > 0 && numValue > query.MaxValue { + result.Success = false + result.Error = fmt.Sprintf("value %d above maximum %d", numValue, query.MaxValue) + } + } + + return result +} + +// RunCustomChecks executes custom validation checks +func (v *Validator) RunCustomChecks(ctx context.Context, checks []CustomCheck) []CheckResult { + var results []CheckResult + + for _, check := range checks { + result := v.runCheck(ctx, check) + results = append(results, result) + } + + return results +} + +// runCheck executes a single custom check +func (v *Validator) runCheck(ctx context.Context, check CustomCheck) CheckResult { + result := CheckResult{ + Name: check.Name, + Type: check.Type, + Expected: check.MinValue, + } + + switch check.Type { + case "row_count": + count, err := v.getRowCount(ctx, check.Table, check.Condition) + if err != nil { + result.Success = false + result.Message = fmt.Sprintf("failed to get row count: %v", err) + return result + } + result.Actual = count + result.Success = count >= check.MinValue + if result.Success { + result.Message = fmt.Sprintf("Table %s has %d rows (min: %d)", check.Table, count, check.MinValue) + } else { + result.Message = fmt.Sprintf("Table %s has %d rows, expected at least %d", check.Table, count, check.MinValue) + } + + case "table_exists": + exists, err := v.tableExists(ctx, check.Table) + if err != nil { + result.Success = false + result.Message = fmt.Sprintf("failed to check table: %v", err) + return result + } + result.Success = exists + if exists { + result.Actual = 1 + result.Message = fmt.Sprintf("Table %s exists", check.Table) + } else { + result.Actual = 0 + result.Message = fmt.Sprintf("Table %s does not exist", check.Table) + } + + case "column_check": + exists, err := v.columnExists(ctx, check.Table, check.Column) + if err != nil { + result.Success = false + result.Message = fmt.Sprintf("failed to check column: %v", err) + return result + } + result.Success = exists + if exists { + result.Actual = 1 + result.Message = fmt.Sprintf("Column %s.%s exists", check.Table, check.Column) + } else { + result.Actual = 0 + result.Message = fmt.Sprintf("Column %s.%s does not exist", check.Table, check.Column) + } + + default: + result.Success = false + result.Message = fmt.Sprintf("unknown check type: %s", check.Type) + } + + return result +} + +// getRowCount returns the row count for a table +func (v *Validator) getRowCount(ctx context.Context, table, condition string) (int64, error) { + query := fmt.Sprintf("SELECT COUNT(*) FROM %s", v.quoteIdentifier(table)) + if condition != "" { + query += " WHERE " + condition + } + + var count int64 + err := v.db.QueryRowContext(ctx, query).Scan(&count) + return count, err +} + +// tableExists checks if a table exists +func (v *Validator) tableExists(ctx context.Context, table string) (bool, error) { + var query string + switch v.dbType { + case "postgresql", "postgres": + query = `SELECT EXISTS ( + SELECT FROM information_schema.tables + WHERE table_name = $1 + )` + case "mysql", "mariadb": + query = `SELECT COUNT(*) > 0 FROM information_schema.tables + WHERE table_name = ?` + } + + var exists bool + err := v.db.QueryRowContext(ctx, query, table).Scan(&exists) + return exists, err +} + +// columnExists checks if a column exists +func (v *Validator) columnExists(ctx context.Context, table, column string) (bool, error) { + var query string + switch v.dbType { + case "postgresql", "postgres": + query = `SELECT EXISTS ( + SELECT FROM information_schema.columns + WHERE table_name = $1 AND column_name = $2 + )` + case "mysql", "mariadb": + query = `SELECT COUNT(*) > 0 FROM information_schema.columns + WHERE table_name = ? AND column_name = ?` + } + + var exists bool + err := v.db.QueryRowContext(ctx, query, table, column).Scan(&exists) + return exists, err +} + +// GetTableList returns all tables in the database +func (v *Validator) GetTableList(ctx context.Context) ([]string, error) { + var query string + switch v.dbType { + case "postgresql", "postgres": + query = `SELECT table_name FROM information_schema.tables + WHERE table_schema = 'public' AND table_type = 'BASE TABLE'` + case "mysql", "mariadb": + query = `SELECT table_name FROM information_schema.tables + WHERE table_schema = DATABASE() AND table_type = 'BASE TABLE'` + } + + rows, err := v.db.QueryContext(ctx, query) + if err != nil { + return nil, err + } + defer rows.Close() + + var tables []string + for rows.Next() { + var table string + if err := rows.Scan(&table); err != nil { + return nil, err + } + tables = append(tables, table) + } + + return tables, rows.Err() +} + +// GetTotalRowCount returns total row count across all tables +func (v *Validator) GetTotalRowCount(ctx context.Context) (int64, error) { + tables, err := v.GetTableList(ctx) + if err != nil { + return 0, err + } + + var total int64 + for _, table := range tables { + count, err := v.getRowCount(ctx, table, "") + if err != nil { + continue // Skip tables that can't be counted + } + total += count + } + + return total, nil +} + +// GetDatabaseSize returns the database size in bytes +func (v *Validator) GetDatabaseSize(ctx context.Context, dbname string) (int64, error) { + var query string + switch v.dbType { + case "postgresql", "postgres": + query = fmt.Sprintf("SELECT pg_database_size('%s')", dbname) + case "mysql", "mariadb": + query = fmt.Sprintf(`SELECT SUM(data_length + index_length) + FROM information_schema.tables WHERE table_schema = '%s'`, dbname) + } + + var size sql.NullInt64 + err := v.db.QueryRowContext(ctx, query).Scan(&size) + if err != nil { + return 0, err + } + + return size.Int64, nil +} + +// ValidateExpectedTables checks that all expected tables exist +func (v *Validator) ValidateExpectedTables(ctx context.Context, expectedTables []string) []CheckResult { + var results []CheckResult + + for _, table := range expectedTables { + check := CustomCheck{ + Name: fmt.Sprintf("Table '%s' exists", table), + Type: "table_exists", + Table: table, + } + results = append(results, v.runCheck(ctx, check)) + } + + return results +} + +// quoteIdentifier quotes a database identifier +func (v *Validator) quoteIdentifier(id string) string { + switch v.dbType { + case "postgresql", "postgres": + return fmt.Sprintf(`"%s"`, strings.ReplaceAll(id, `"`, `""`)) + case "mysql", "mariadb": + return fmt.Sprintf("`%s`", strings.ReplaceAll(id, "`", "``")) + default: + return id + } +} diff --git a/internal/notify/batch.go b/internal/notify/batch.go new file mode 100644 index 0000000..85ec5bf --- /dev/null +++ b/internal/notify/batch.go @@ -0,0 +1,261 @@ +// Package notify - Event batching for aggregated notifications +package notify + +import ( + "context" + "fmt" + "sync" + "time" +) + +// BatchConfig configures notification batching +type BatchConfig struct { + Enabled bool // Enable batching + Window time.Duration // Batch window (e.g., 5 minutes) + MaxEvents int // Maximum events per batch before forced send + GroupBy string // Group by: "database", "type", "severity", "host" + DigestFormat string // Format: "summary", "detailed", "compact" +} + +// DefaultBatchConfig returns sensible batch defaults +func DefaultBatchConfig() BatchConfig { + return BatchConfig{ + Enabled: false, + Window: 5 * time.Minute, + MaxEvents: 50, + GroupBy: "database", + DigestFormat: "summary", + } +} + +// Batcher collects events and sends them in batches +type Batcher struct { + config BatchConfig + manager *Manager + events []*Event + mu sync.Mutex + timer *time.Timer + ctx context.Context + cancel context.CancelFunc + startTime time.Time +} + +// NewBatcher creates a new event batcher +func NewBatcher(config BatchConfig, manager *Manager) *Batcher { + ctx, cancel := context.WithCancel(context.Background()) + return &Batcher{ + config: config, + manager: manager, + events: make([]*Event, 0), + ctx: ctx, + cancel: cancel, + } +} + +// Add adds an event to the batch +func (b *Batcher) Add(event *Event) { + if !b.config.Enabled { + // Batching disabled, send immediately + b.manager.Notify(event) + return + } + + b.mu.Lock() + defer b.mu.Unlock() + + // Start timer on first event + if len(b.events) == 0 { + b.startTime = time.Now() + b.timer = time.AfterFunc(b.config.Window, func() { + b.Flush() + }) + } + + b.events = append(b.events, event) + + // Check if we've hit max events + if len(b.events) >= b.config.MaxEvents { + b.flushLocked() + } +} + +// Flush sends all batched events +func (b *Batcher) Flush() { + b.mu.Lock() + defer b.mu.Unlock() + b.flushLocked() +} + +// flushLocked sends batched events (must hold mutex) +func (b *Batcher) flushLocked() { + if len(b.events) == 0 { + return + } + + // Cancel pending timer + if b.timer != nil { + b.timer.Stop() + b.timer = nil + } + + // Group events + groups := b.groupEvents() + + // Create digest event for each group + for key, events := range groups { + digest := b.createDigest(key, events) + b.manager.Notify(digest) + } + + // Clear events + b.events = make([]*Event, 0) +} + +// groupEvents groups events by configured criteria +func (b *Batcher) groupEvents() map[string][]*Event { + groups := make(map[string][]*Event) + + for _, event := range b.events { + var key string + switch b.config.GroupBy { + case "database": + key = event.Database + case "type": + key = string(event.Type) + case "severity": + key = string(event.Severity) + case "host": + key = event.Hostname + default: + key = "all" + } + if key == "" { + key = "unknown" + } + groups[key] = append(groups[key], event) + } + + return groups +} + +// createDigest creates a digest event from multiple events +func (b *Batcher) createDigest(groupKey string, events []*Event) *Event { + // Calculate summary stats + var ( + successCount int + failureCount int + highestSev = SeverityInfo + totalDuration time.Duration + databases = make(map[string]bool) + ) + + for _, e := range events { + switch e.Type { + case EventBackupCompleted, EventRestoreCompleted, EventVerifyCompleted: + successCount++ + case EventBackupFailed, EventRestoreFailed, EventVerifyFailed: + failureCount++ + } + + if severityOrder(e.Severity) > severityOrder(highestSev) { + highestSev = e.Severity + } + + totalDuration += e.Duration + if e.Database != "" { + databases[e.Database] = true + } + } + + // Create digest message + var message string + switch b.config.DigestFormat { + case "detailed": + message = b.formatDetailedDigest(events) + case "compact": + message = b.formatCompactDigest(events, successCount, failureCount) + default: // summary + message = b.formatSummaryDigest(events, successCount, failureCount, len(databases)) + } + + digest := NewEvent(EventType("digest"), highestSev, message) + digest.WithDetail("group", groupKey) + digest.WithDetail("event_count", fmt.Sprintf("%d", len(events))) + digest.WithDetail("success_count", fmt.Sprintf("%d", successCount)) + digest.WithDetail("failure_count", fmt.Sprintf("%d", failureCount)) + digest.WithDetail("batch_duration", fmt.Sprintf("%.0fs", time.Since(b.startTime).Seconds())) + + if len(databases) == 1 { + for db := range databases { + digest.Database = db + } + } + + return digest +} + +func (b *Batcher) formatSummaryDigest(events []*Event, success, failure, dbCount int) string { + total := len(events) + return fmt.Sprintf("Batch Summary: %d events (%d success, %d failed) across %d database(s)", + total, success, failure, dbCount) +} + +func (b *Batcher) formatCompactDigest(events []*Event, success, failure int) string { + if failure > 0 { + return fmt.Sprintf("⚠️ %d/%d operations failed", failure, len(events)) + } + return fmt.Sprintf("✅ All %d operations successful", success) +} + +func (b *Batcher) formatDetailedDigest(events []*Event) string { + var msg string + msg += fmt.Sprintf("=== Batch Digest (%d events) ===\n\n", len(events)) + + for _, e := range events { + icon := "•" + switch e.Severity { + case SeverityError, SeverityCritical: + icon = "❌" + case SeverityWarning: + icon = "⚠️" + } + + msg += fmt.Sprintf("%s [%s] %s: %s\n", + icon, + e.Timestamp.Format("15:04:05"), + e.Type, + e.Message) + } + + return msg +} + +// Stop stops the batcher and flushes remaining events +func (b *Batcher) Stop() { + b.cancel() + b.Flush() +} + +// BatcherStats returns current batcher statistics +type BatcherStats struct { + PendingEvents int `json:"pending_events"` + BatchAge time.Duration `json:"batch_age"` + Config BatchConfig `json:"config"` +} + +// Stats returns current batcher statistics +func (b *Batcher) Stats() BatcherStats { + b.mu.Lock() + defer b.mu.Unlock() + + var age time.Duration + if len(b.events) > 0 { + age = time.Since(b.startTime) + } + + return BatcherStats{ + PendingEvents: len(b.events), + BatchAge: age, + Config: b.config, + } +} diff --git a/internal/notify/escalate.go b/internal/notify/escalate.go new file mode 100644 index 0000000..6042f35 --- /dev/null +++ b/internal/notify/escalate.go @@ -0,0 +1,363 @@ +// Package notify - Escalation for critical events +package notify + +import ( + "context" + "fmt" + "sync" + "time" +) + +// EscalationConfig configures notification escalation +type EscalationConfig struct { + Enabled bool // Enable escalation + Levels []EscalationLevel // Escalation levels + AcknowledgeURL string // URL to acknowledge alerts + CooldownPeriod time.Duration // Cooldown between escalations + RepeatInterval time.Duration // Repeat unacknowledged alerts + MaxRepeats int // Maximum repeat attempts + TrackingEnabled bool // Track escalation state +} + +// EscalationLevel defines an escalation tier +type EscalationLevel struct { + Name string // Level name (e.g., "primary", "secondary", "manager") + Delay time.Duration // Delay before escalating to this level + Recipients []string // Email recipients for this level + Webhook string // Webhook URL for this level + Severity Severity // Minimum severity to escalate + Message string // Custom message template +} + +// DefaultEscalationConfig returns sensible defaults +func DefaultEscalationConfig() EscalationConfig { + return EscalationConfig{ + Enabled: false, + CooldownPeriod: 15 * time.Minute, + RepeatInterval: 30 * time.Minute, + MaxRepeats: 3, + Levels: []EscalationLevel{ + { + Name: "primary", + Delay: 0, + Severity: SeverityError, + }, + { + Name: "secondary", + Delay: 15 * time.Minute, + Severity: SeverityError, + }, + { + Name: "critical", + Delay: 30 * time.Minute, + Severity: SeverityCritical, + }, + }, + } +} + +// EscalationState tracks escalation for an alert +type EscalationState struct { + AlertID string `json:"alert_id"` + Event *Event `json:"event"` + CurrentLevel int `json:"current_level"` + StartedAt time.Time `json:"started_at"` + LastEscalation time.Time `json:"last_escalation"` + RepeatCount int `json:"repeat_count"` + Acknowledged bool `json:"acknowledged"` + AcknowledgedBy string `json:"acknowledged_by,omitempty"` + AcknowledgedAt *time.Time `json:"acknowledged_at,omitempty"` + Resolved bool `json:"resolved"` +} + +// Escalator manages alert escalation +type Escalator struct { + config EscalationConfig + manager *Manager + alerts map[string]*EscalationState + mu sync.RWMutex + ctx context.Context + cancel context.CancelFunc + ticker *time.Ticker +} + +// NewEscalator creates a new escalation manager +func NewEscalator(config EscalationConfig, manager *Manager) *Escalator { + ctx, cancel := context.WithCancel(context.Background()) + e := &Escalator{ + config: config, + manager: manager, + alerts: make(map[string]*EscalationState), + ctx: ctx, + cancel: cancel, + } + + if config.Enabled { + e.ticker = time.NewTicker(time.Minute) + go e.runEscalationLoop() + } + + return e +} + +// Handle processes an event for potential escalation +func (e *Escalator) Handle(event *Event) { + if !e.config.Enabled { + return + } + + // Only escalate errors and critical events + if severityOrder(event.Severity) < severityOrder(SeverityError) { + return + } + + // Generate alert ID + alertID := e.generateAlertID(event) + + e.mu.Lock() + defer e.mu.Unlock() + + // Check if alert already exists + if existing, ok := e.alerts[alertID]; ok { + if !existing.Acknowledged && !existing.Resolved { + // Alert already being escalated + return + } + } + + // Create new escalation state + state := &EscalationState{ + AlertID: alertID, + Event: event, + CurrentLevel: 0, + StartedAt: time.Now(), + LastEscalation: time.Now(), + } + + e.alerts[alertID] = state + + // Send immediate notification to first level + e.notifyLevel(state, 0) +} + +// generateAlertID creates a unique ID for an alert +func (e *Escalator) generateAlertID(event *Event) string { + return fmt.Sprintf("%s_%s_%s", + event.Type, + event.Database, + event.Hostname) +} + +// notifyLevel sends notification for a specific escalation level +func (e *Escalator) notifyLevel(state *EscalationState, level int) { + if level >= len(e.config.Levels) { + return + } + + lvl := e.config.Levels[level] + + // Create escalated event + escalatedEvent := &Event{ + Type: state.Event.Type, + Severity: state.Event.Severity, + Timestamp: time.Now(), + Database: state.Event.Database, + Hostname: state.Event.Hostname, + Message: e.formatEscalationMessage(state, lvl), + Details: make(map[string]string), + } + + escalatedEvent.Details["escalation_level"] = lvl.Name + escalatedEvent.Details["alert_id"] = state.AlertID + escalatedEvent.Details["escalation_time"] = fmt.Sprintf("%d", int(time.Since(state.StartedAt).Minutes())) + escalatedEvent.Details["original_message"] = state.Event.Message + + if state.Event.Error != "" { + escalatedEvent.Error = state.Event.Error + } + + // Send via manager + e.manager.Notify(escalatedEvent) + + state.CurrentLevel = level + state.LastEscalation = time.Now() +} + +// formatEscalationMessage creates an escalation message +func (e *Escalator) formatEscalationMessage(state *EscalationState, level EscalationLevel) string { + if level.Message != "" { + return level.Message + } + + elapsed := time.Since(state.StartedAt) + return fmt.Sprintf("🚨 ESCALATION [%s] - Alert unacknowledged for %s\n\n%s", + level.Name, + formatDuration(elapsed), + state.Event.Message) +} + +// runEscalationLoop checks for alerts that need escalation +func (e *Escalator) runEscalationLoop() { + for { + select { + case <-e.ctx.Done(): + return + case <-e.ticker.C: + e.checkEscalations() + } + } +} + +// checkEscalations checks all alerts for needed escalation +func (e *Escalator) checkEscalations() { + e.mu.Lock() + defer e.mu.Unlock() + + now := time.Now() + + for _, state := range e.alerts { + if state.Acknowledged || state.Resolved { + continue + } + + // Check if we need to escalate to next level + nextLevel := state.CurrentLevel + 1 + if nextLevel < len(e.config.Levels) { + lvl := e.config.Levels[nextLevel] + if now.Sub(state.StartedAt) >= lvl.Delay { + e.notifyLevel(state, nextLevel) + } + } + + // Check if we need to repeat the alert + if state.RepeatCount < e.config.MaxRepeats { + if now.Sub(state.LastEscalation) >= e.config.RepeatInterval { + e.notifyLevel(state, state.CurrentLevel) + state.RepeatCount++ + } + } + } +} + +// Acknowledge acknowledges an alert +func (e *Escalator) Acknowledge(alertID, user string) error { + e.mu.Lock() + defer e.mu.Unlock() + + state, ok := e.alerts[alertID] + if !ok { + return fmt.Errorf("alert not found: %s", alertID) + } + + now := time.Now() + state.Acknowledged = true + state.AcknowledgedBy = user + state.AcknowledgedAt = &now + + return nil +} + +// Resolve resolves an alert +func (e *Escalator) Resolve(alertID string) error { + e.mu.Lock() + defer e.mu.Unlock() + + state, ok := e.alerts[alertID] + if !ok { + return fmt.Errorf("alert not found: %s", alertID) + } + + state.Resolved = true + return nil +} + +// GetActiveAlerts returns all active (unacknowledged, unresolved) alerts +func (e *Escalator) GetActiveAlerts() []*EscalationState { + e.mu.RLock() + defer e.mu.RUnlock() + + var active []*EscalationState + for _, state := range e.alerts { + if !state.Acknowledged && !state.Resolved { + active = append(active, state) + } + } + return active +} + +// GetAlert returns a specific alert +func (e *Escalator) GetAlert(alertID string) (*EscalationState, bool) { + e.mu.RLock() + defer e.mu.RUnlock() + + state, ok := e.alerts[alertID] + return state, ok +} + +// CleanupOld removes old resolved/acknowledged alerts +func (e *Escalator) CleanupOld(maxAge time.Duration) int { + e.mu.Lock() + defer e.mu.Unlock() + + now := time.Now() + removed := 0 + + for id, state := range e.alerts { + if (state.Acknowledged || state.Resolved) && now.Sub(state.StartedAt) > maxAge { + delete(e.alerts, id) + removed++ + } + } + + return removed +} + +// Stop stops the escalator +func (e *Escalator) Stop() { + e.cancel() + if e.ticker != nil { + e.ticker.Stop() + } +} + +// EscalatorStats returns escalator statistics +type EscalatorStats struct { + ActiveAlerts int `json:"active_alerts"` + AcknowledgedAlerts int `json:"acknowledged_alerts"` + ResolvedAlerts int `json:"resolved_alerts"` + EscalationEnabled bool `json:"escalation_enabled"` + LevelCount int `json:"level_count"` +} + +// Stats returns escalator statistics +func (e *Escalator) Stats() EscalatorStats { + e.mu.RLock() + defer e.mu.RUnlock() + + stats := EscalatorStats{ + EscalationEnabled: e.config.Enabled, + LevelCount: len(e.config.Levels), + } + + for _, state := range e.alerts { + if state.Resolved { + stats.ResolvedAlerts++ + } else if state.Acknowledged { + stats.AcknowledgedAlerts++ + } else { + stats.ActiveAlerts++ + } + } + + return stats +} + +func formatDuration(d time.Duration) string { + if d < time.Minute { + return fmt.Sprintf("%.0fs", d.Seconds()) + } + if d < time.Hour { + return fmt.Sprintf("%.0fm", d.Minutes()) + } + return fmt.Sprintf("%.0fh %.0fm", d.Hours(), d.Minutes()-d.Hours()*60) +} diff --git a/internal/notify/notify.go b/internal/notify/notify.go index 0421968..c5ab34e 100644 --- a/internal/notify/notify.go +++ b/internal/notify/notify.go @@ -11,41 +11,66 @@ import ( type EventType string const ( - EventBackupStarted EventType = "backup_started" - EventBackupCompleted EventType = "backup_completed" - EventBackupFailed EventType = "backup_failed" - EventRestoreStarted EventType = "restore_started" - EventRestoreCompleted EventType = "restore_completed" - EventRestoreFailed EventType = "restore_failed" - EventCleanupCompleted EventType = "cleanup_completed" - EventVerifyCompleted EventType = "verify_completed" - EventVerifyFailed EventType = "verify_failed" - EventPITRRecovery EventType = "pitr_recovery" + EventBackupStarted EventType = "backup_started" + EventBackupCompleted EventType = "backup_completed" + EventBackupFailed EventType = "backup_failed" + EventRestoreStarted EventType = "restore_started" + EventRestoreCompleted EventType = "restore_completed" + EventRestoreFailed EventType = "restore_failed" + EventCleanupCompleted EventType = "cleanup_completed" + EventVerifyCompleted EventType = "verify_completed" + EventVerifyFailed EventType = "verify_failed" + EventPITRRecovery EventType = "pitr_recovery" + EventVerificationPassed EventType = "verification_passed" + EventVerificationFailed EventType = "verification_failed" + EventDRDrillPassed EventType = "dr_drill_passed" + EventDRDrillFailed EventType = "dr_drill_failed" + EventGapDetected EventType = "gap_detected" + EventRPOViolation EventType = "rpo_violation" ) // Severity represents the severity level of a notification type Severity string const ( - SeverityInfo Severity = "info" - SeverityWarning Severity = "warning" - SeverityError Severity = "error" + SeverityInfo Severity = "info" + SeveritySuccess Severity = "success" + SeverityWarning Severity = "warning" + SeverityError Severity = "error" SeverityCritical Severity = "critical" ) +// severityOrder returns numeric order for severity comparison +func severityOrder(s Severity) int { + switch s { + case SeverityInfo: + return 0 + case SeveritySuccess: + return 1 + case SeverityWarning: + return 2 + case SeverityError: + return 3 + case SeverityCritical: + return 4 + default: + return 0 + } +} + // Event represents a notification event type Event struct { - Type EventType `json:"type"` - Severity Severity `json:"severity"` - Timestamp time.Time `json:"timestamp"` - Database string `json:"database,omitempty"` - Message string `json:"message"` - Details map[string]string `json:"details,omitempty"` - Error string `json:"error,omitempty"` - Duration time.Duration `json:"duration,omitempty"` - BackupFile string `json:"backup_file,omitempty"` - BackupSize int64 `json:"backup_size,omitempty"` - Hostname string `json:"hostname,omitempty"` + Type EventType `json:"type"` + Severity Severity `json:"severity"` + Timestamp time.Time `json:"timestamp"` + Database string `json:"database,omitempty"` + Message string `json:"message"` + Details map[string]string `json:"details,omitempty"` + Error string `json:"error,omitempty"` + Duration time.Duration `json:"duration,omitempty"` + BackupFile string `json:"backup_file,omitempty"` + BackupSize int64 `json:"backup_size,omitempty"` + Hostname string `json:"hostname,omitempty"` } // NewEvent creates a new notification event @@ -132,27 +157,27 @@ type Config struct { WebhookSecret string // For signing payloads // General settings - OnSuccess bool // Send notifications on successful operations - OnFailure bool // Send notifications on failed operations - OnWarning bool // Send notifications on warnings - MinSeverity Severity - Retries int // Number of retry attempts - RetryDelay time.Duration // Delay between retries + OnSuccess bool // Send notifications on successful operations + OnFailure bool // Send notifications on failed operations + OnWarning bool // Send notifications on warnings + MinSeverity Severity + Retries int // Number of retry attempts + RetryDelay time.Duration // Delay between retries } // DefaultConfig returns a configuration with sensible defaults func DefaultConfig() Config { return Config{ - SMTPPort: 587, - SMTPTLS: false, - SMTPStartTLS: true, + SMTPPort: 587, + SMTPTLS: false, + SMTPStartTLS: true, WebhookMethod: "POST", - OnSuccess: true, - OnFailure: true, - OnWarning: true, - MinSeverity: SeverityInfo, - Retries: 3, - RetryDelay: 5 * time.Second, + OnSuccess: true, + OnFailure: true, + OnWarning: true, + MinSeverity: SeverityInfo, + Retries: 3, + RetryDelay: 5 * time.Second, } } diff --git a/internal/notify/templates.go b/internal/notify/templates.go new file mode 100644 index 0000000..929d744 --- /dev/null +++ b/internal/notify/templates.go @@ -0,0 +1,497 @@ +// Package notify - Notification templates +package notify + +import ( + "bytes" + "fmt" + "html/template" + "strings" + "time" +) + +// TemplateType represents the notification format type +type TemplateType string + +const ( + TemplateText TemplateType = "text" + TemplateHTML TemplateType = "html" + TemplateMarkdown TemplateType = "markdown" + TemplateSlack TemplateType = "slack" +) + +// Templates holds notification templates +type Templates struct { + Subject string + TextBody string + HTMLBody string +} + +// DefaultTemplates returns default notification templates +func DefaultTemplates() map[EventType]Templates { + return map[EventType]Templates{ + EventBackupStarted: { + Subject: "🔄 Backup Started: {{.Database}} on {{.Hostname}}", + TextBody: backupStartedText, + HTMLBody: backupStartedHTML, + }, + EventBackupCompleted: { + Subject: "✅ Backup Completed: {{.Database}} on {{.Hostname}}", + TextBody: backupCompletedText, + HTMLBody: backupCompletedHTML, + }, + EventBackupFailed: { + Subject: "❌ Backup FAILED: {{.Database}} on {{.Hostname}}", + TextBody: backupFailedText, + HTMLBody: backupFailedHTML, + }, + EventRestoreStarted: { + Subject: "🔄 Restore Started: {{.Database}} on {{.Hostname}}", + TextBody: restoreStartedText, + HTMLBody: restoreStartedHTML, + }, + EventRestoreCompleted: { + Subject: "✅ Restore Completed: {{.Database}} on {{.Hostname}}", + TextBody: restoreCompletedText, + HTMLBody: restoreCompletedHTML, + }, + EventRestoreFailed: { + Subject: "❌ Restore FAILED: {{.Database}} on {{.Hostname}}", + TextBody: restoreFailedText, + HTMLBody: restoreFailedHTML, + }, + EventVerificationPassed: { + Subject: "✅ Verification Passed: {{.Database}}", + TextBody: verificationPassedText, + HTMLBody: verificationPassedHTML, + }, + EventVerificationFailed: { + Subject: "❌ Verification FAILED: {{.Database}}", + TextBody: verificationFailedText, + HTMLBody: verificationFailedHTML, + }, + EventDRDrillPassed: { + Subject: "✅ DR Drill Passed: {{.Database}}", + TextBody: drDrillPassedText, + HTMLBody: drDrillPassedHTML, + }, + EventDRDrillFailed: { + Subject: "❌ DR Drill FAILED: {{.Database}}", + TextBody: drDrillFailedText, + HTMLBody: drDrillFailedHTML, + }, + } +} + +// Template strings +const backupStartedText = ` +Backup Operation Started + +Database: {{.Database}} +Hostname: {{.Hostname}} +Started At: {{formatTime .Timestamp}} + +{{if .Message}}{{.Message}}{{end}} +` + +const backupStartedHTML = ` +
+

🔄 Backup Started

+ + + + +
Database:{{.Database}}
Hostname:{{.Hostname}}
Started At:{{formatTime .Timestamp}}
+ {{if .Message}}

{{.Message}}

{{end}} +
+` + +const backupCompletedText = ` +Backup Operation Completed Successfully + +Database: {{.Database}} +Hostname: {{.Hostname}} +Completed: {{formatTime .Timestamp}} +{{with .Details}} +{{if .size}}Size: {{.size}}{{end}} +{{if .duration}}Duration: {{.duration}}{{end}} +{{if .path}}Path: {{.path}}{{end}} +{{end}} +{{if .Message}}{{.Message}}{{end}} +` + +const backupCompletedHTML = ` +
+

✅ Backup Completed

+ + + + + {{with .Details}} + {{if .size}}{{end}} + {{if .duration}}{{end}} + {{if .path}}{{end}} + {{end}} +
Database:{{.Database}}
Hostname:{{.Hostname}}
Completed:{{formatTime .Timestamp}}
Size:{{.size}}
Duration:{{.duration}}
Path:{{.path}}
+ {{if .Message}}

{{.Message}}

{{end}} +
+` + +const backupFailedText = ` +⚠️ BACKUP FAILED ⚠️ + +Database: {{.Database}} +Hostname: {{.Hostname}} +Failed At: {{formatTime .Timestamp}} +{{if .Error}} +Error: {{.Error}} +{{end}} +{{if .Message}}{{.Message}}{{end}} + +Please investigate immediately. +` + +const backupFailedHTML = ` +
+

❌ Backup FAILED

+ + + + + {{if .Error}}{{end}} +
Database:{{.Database}}
Hostname:{{.Hostname}}
Failed At:{{formatTime .Timestamp}}
Error:{{.Error}}
+ {{if .Message}}

{{.Message}}

{{end}} +

Please investigate immediately.

+
+` + +const restoreStartedText = ` +Restore Operation Started + +Database: {{.Database}} +Hostname: {{.Hostname}} +Started At: {{formatTime .Timestamp}} + +{{if .Message}}{{.Message}}{{end}} +` + +const restoreStartedHTML = ` +
+

🔄 Restore Started

+ + + + +
Database:{{.Database}}
Hostname:{{.Hostname}}
Started At:{{formatTime .Timestamp}}
+ {{if .Message}}

{{.Message}}

{{end}} +
+` + +const restoreCompletedText = ` +Restore Operation Completed Successfully + +Database: {{.Database}} +Hostname: {{.Hostname}} +Completed: {{formatTime .Timestamp}} +{{with .Details}} +{{if .duration}}Duration: {{.duration}}{{end}} +{{end}} +{{if .Message}}{{.Message}}{{end}} +` + +const restoreCompletedHTML = ` +
+

✅ Restore Completed

+ + + + + {{with .Details}} + {{if .duration}}{{end}} + {{end}} +
Database:{{.Database}}
Hostname:{{.Hostname}}
Completed:{{formatTime .Timestamp}}
Duration:{{.duration}}
+ {{if .Message}}

{{.Message}}

{{end}} +
+` + +const restoreFailedText = ` +⚠️ RESTORE FAILED ⚠️ + +Database: {{.Database}} +Hostname: {{.Hostname}} +Failed At: {{formatTime .Timestamp}} +{{if .Error}} +Error: {{.Error}} +{{end}} +{{if .Message}}{{.Message}}{{end}} + +Please investigate immediately. +` + +const restoreFailedHTML = ` +
+

❌ Restore FAILED

+ + + + + {{if .Error}}{{end}} +
Database:{{.Database}}
Hostname:{{.Hostname}}
Failed At:{{formatTime .Timestamp}}
Error:{{.Error}}
+ {{if .Message}}

{{.Message}}

{{end}} +

Please investigate immediately.

+
+` + +const verificationPassedText = ` +Backup Verification Passed + +Database: {{.Database}} +Hostname: {{.Hostname}} +Verified: {{formatTime .Timestamp}} +{{with .Details}} +{{if .checksum}}Checksum: {{.checksum}}{{end}} +{{end}} +{{if .Message}}{{.Message}}{{end}} +` + +const verificationPassedHTML = ` +
+

✅ Verification Passed

+ + + + + {{with .Details}} + {{if .checksum}}{{end}} + {{end}} +
Database:{{.Database}}
Hostname:{{.Hostname}}
Verified:{{formatTime .Timestamp}}
Checksum:{{.checksum}}
+ {{if .Message}}

{{.Message}}

{{end}} +
+` + +const verificationFailedText = ` +⚠️ VERIFICATION FAILED ⚠️ + +Database: {{.Database}} +Hostname: {{.Hostname}} +Failed At: {{formatTime .Timestamp}} +{{if .Error}} +Error: {{.Error}} +{{end}} +{{if .Message}}{{.Message}}{{end}} + +Backup integrity may be compromised. Please investigate. +` + +const verificationFailedHTML = ` +
+

❌ Verification FAILED

+ + + + + {{if .Error}}{{end}} +
Database:{{.Database}}
Hostname:{{.Hostname}}
Failed At:{{formatTime .Timestamp}}
Error:{{.Error}}
+ {{if .Message}}

{{.Message}}

{{end}} +

Backup integrity may be compromised. Please investigate.

+
+` + +const drDrillPassedText = ` +DR Drill Test Passed + +Database: {{.Database}} +Hostname: {{.Hostname}} +Tested At: {{formatTime .Timestamp}} +{{with .Details}} +{{if .tables_restored}}Tables: {{.tables_restored}}{{end}} +{{if .rows_validated}}Rows: {{.rows_validated}}{{end}} +{{if .duration}}Duration: {{.duration}}{{end}} +{{end}} +{{if .Message}}{{.Message}}{{end}} + +Backup restore capability verified. +` + +const drDrillPassedHTML = ` +
+

✅ DR Drill Passed

+ + + + + {{with .Details}} + {{if .tables_restored}}{{end}} + {{if .rows_validated}}{{end}} + {{if .duration}}{{end}} + {{end}} +
Database:{{.Database}}
Hostname:{{.Hostname}}
Tested At:{{formatTime .Timestamp}}
Tables:{{.tables_restored}}
Rows:{{.rows_validated}}
Duration:{{.duration}}
+ {{if .Message}}

{{.Message}}

{{end}} +

✓ Backup restore capability verified

+
+` + +const drDrillFailedText = ` +⚠️ DR DRILL FAILED ⚠️ + +Database: {{.Database}} +Hostname: {{.Hostname}} +Failed At: {{formatTime .Timestamp}} +{{if .Error}} +Error: {{.Error}} +{{end}} +{{if .Message}}{{.Message}}{{end}} + +Backup may not be restorable. Please investigate immediately. +` + +const drDrillFailedHTML = ` +
+

❌ DR Drill FAILED

+ + + + + {{if .Error}}{{end}} +
Database:{{.Database}}
Hostname:{{.Hostname}}
Failed At:{{formatTime .Timestamp}}
Error:{{.Error}}
+ {{if .Message}}

{{.Message}}

{{end}} +

Backup may not be restorable. Please investigate immediately.

+
+` + +// TemplateRenderer renders notification templates +type TemplateRenderer struct { + templates map[EventType]Templates + funcMap template.FuncMap +} + +// NewTemplateRenderer creates a new template renderer +func NewTemplateRenderer() *TemplateRenderer { + return &TemplateRenderer{ + templates: DefaultTemplates(), + funcMap: template.FuncMap{ + "formatTime": func(t time.Time) string { + return t.Format("2006-01-02 15:04:05 MST") + }, + "upper": strings.ToUpper, + "lower": strings.ToLower, + }, + } +} + +// RenderSubject renders the subject template for an event +func (r *TemplateRenderer) RenderSubject(event *Event) (string, error) { + tmpl, ok := r.templates[event.Type] + if !ok { + return fmt.Sprintf("[%s] %s: %s", event.Severity, event.Type, event.Database), nil + } + + return r.render(tmpl.Subject, event) +} + +// RenderText renders the text body template for an event +func (r *TemplateRenderer) RenderText(event *Event) (string, error) { + tmpl, ok := r.templates[event.Type] + if !ok { + return event.Message, nil + } + + return r.render(tmpl.TextBody, event) +} + +// RenderHTML renders the HTML body template for an event +func (r *TemplateRenderer) RenderHTML(event *Event) (string, error) { + tmpl, ok := r.templates[event.Type] + if !ok { + return fmt.Sprintf("

%s

", event.Message), nil + } + + return r.render(tmpl.HTMLBody, event) +} + +// render executes a template with the given event +func (r *TemplateRenderer) render(templateStr string, event *Event) (string, error) { + tmpl, err := template.New("notification").Funcs(r.funcMap).Parse(templateStr) + if err != nil { + return "", fmt.Errorf("failed to parse template: %w", err) + } + + var buf bytes.Buffer + if err := tmpl.Execute(&buf, event); err != nil { + return "", fmt.Errorf("failed to execute template: %w", err) + } + + return strings.TrimSpace(buf.String()), nil +} + +// SetTemplate sets a custom template for an event type +func (r *TemplateRenderer) SetTemplate(eventType EventType, templates Templates) { + r.templates[eventType] = templates +} + +// RenderSlackMessage creates a Slack-formatted message +func (r *TemplateRenderer) RenderSlackMessage(event *Event) map[string]interface{} { + color := "#3498db" // blue + switch event.Severity { + case SeveritySuccess: + color = "#27ae60" // green + case SeverityWarning: + color = "#f39c12" // orange + case SeverityError, SeverityCritical: + color = "#e74c3c" // red + } + + fields := []map[string]interface{}{ + { + "title": "Database", + "value": event.Database, + "short": true, + }, + { + "title": "Hostname", + "value": event.Hostname, + "short": true, + }, + { + "title": "Event", + "value": string(event.Type), + "short": true, + }, + { + "title": "Severity", + "value": string(event.Severity), + "short": true, + }, + } + + if event.Error != "" { + fields = append(fields, map[string]interface{}{ + "title": "Error", + "value": event.Error, + "short": false, + }) + } + + for key, value := range event.Details { + fields = append(fields, map[string]interface{}{ + "title": key, + "value": value, + "short": true, + }) + } + + subject, _ := r.RenderSubject(event) + + return map[string]interface{}{ + "attachments": []map[string]interface{}{ + { + "color": color, + "title": subject, + "text": event.Message, + "fields": fields, + "footer": "dbbackup", + "ts": event.Timestamp.Unix(), + "mrkdwn_in": []string{"text", "fields"}, + }, + }, + } +} diff --git a/internal/parallel/engine.go b/internal/parallel/engine.go new file mode 100644 index 0000000..c566b3a --- /dev/null +++ b/internal/parallel/engine.go @@ -0,0 +1,619 @@ +// Package parallel provides parallel table backup functionality +package parallel + +import ( + "context" + "database/sql" + "fmt" + "io" + "os" + "path/filepath" + "sort" + "sync" + "sync/atomic" + "time" +) + +// Table represents a database table +type Table struct { + Schema string `json:"schema"` + Name string `json:"name"` + RowCount int64 `json:"row_count"` + SizeBytes int64 `json:"size_bytes"` + HasPK bool `json:"has_pk"` + Partitioned bool `json:"partitioned"` +} + +// FullName returns the fully qualified table name +func (t *Table) FullName() string { + if t.Schema != "" { + return fmt.Sprintf("%s.%s", t.Schema, t.Name) + } + return t.Name +} + +// Config configures parallel backup +type Config struct { + MaxWorkers int `json:"max_workers"` + MaxConcurrency int `json:"max_concurrency"` // Max concurrent dumps + ChunkSize int64 `json:"chunk_size"` // Rows per chunk for large tables + LargeTableThreshold int64 `json:"large_table_threshold"` // Bytes to consider a table "large" + OutputDir string `json:"output_dir"` + Compression string `json:"compression"` // gzip, lz4, zstd, none + TempDir string `json:"temp_dir"` + Timeout time.Duration `json:"timeout"` + IncludeSchemas []string `json:"include_schemas,omitempty"` + ExcludeSchemas []string `json:"exclude_schemas,omitempty"` + IncludeTables []string `json:"include_tables,omitempty"` + ExcludeTables []string `json:"exclude_tables,omitempty"` + EstimateSizes bool `json:"estimate_sizes"` + OrderBySize bool `json:"order_by_size"` // Start with largest tables first +} + +// DefaultConfig returns sensible defaults +func DefaultConfig() Config { + return Config{ + MaxWorkers: 4, + MaxConcurrency: 4, + ChunkSize: 100000, + LargeTableThreshold: 1 << 30, // 1GB + Compression: "gzip", + Timeout: 24 * time.Hour, + EstimateSizes: true, + OrderBySize: true, + } +} + +// TableResult contains the result of backing up a single table +type TableResult struct { + Table *Table `json:"table"` + OutputFile string `json:"output_file"` + SizeBytes int64 `json:"size_bytes"` + RowsWritten int64 `json:"rows_written"` + Duration time.Duration `json:"duration"` + Error error `json:"error,omitempty"` + Checksum string `json:"checksum,omitempty"` +} + +// Result contains the overall parallel backup result +type Result struct { + Tables []*TableResult `json:"tables"` + TotalTables int `json:"total_tables"` + SuccessTables int `json:"success_tables"` + FailedTables int `json:"failed_tables"` + TotalBytes int64 `json:"total_bytes"` + TotalRows int64 `json:"total_rows"` + Duration time.Duration `json:"duration"` + Workers int `json:"workers"` + OutputDir string `json:"output_dir"` +} + +// Progress tracks backup progress +type Progress struct { + TotalTables int32 `json:"total_tables"` + CompletedTables int32 `json:"completed_tables"` + CurrentTable string `json:"current_table"` + BytesWritten int64 `json:"bytes_written"` + RowsWritten int64 `json:"rows_written"` +} + +// ProgressCallback is called with progress updates +type ProgressCallback func(progress *Progress) + +// Engine orchestrates parallel table backups +type Engine struct { + config Config + db *sql.DB + dbType string + progress *Progress + callback ProgressCallback + mu sync.Mutex +} + +// NewEngine creates a new parallel backup engine +func NewEngine(db *sql.DB, dbType string, config Config) *Engine { + return &Engine{ + config: config, + db: db, + dbType: dbType, + progress: &Progress{}, + } +} + +// SetProgressCallback sets the progress callback +func (e *Engine) SetProgressCallback(cb ProgressCallback) { + e.callback = cb +} + +// Run executes the parallel backup +func (e *Engine) Run(ctx context.Context) (*Result, error) { + start := time.Now() + + // Discover tables + tables, err := e.discoverTables(ctx) + if err != nil { + return nil, fmt.Errorf("failed to discover tables: %w", err) + } + + if len(tables) == 0 { + return &Result{ + Tables: []*TableResult{}, + Duration: time.Since(start), + OutputDir: e.config.OutputDir, + }, nil + } + + // Order tables by size (largest first for better load distribution) + if e.config.OrderBySize { + sort.Slice(tables, func(i, j int) bool { + return tables[i].SizeBytes > tables[j].SizeBytes + }) + } + + // Create output directory + if err := os.MkdirAll(e.config.OutputDir, 0755); err != nil { + return nil, fmt.Errorf("failed to create output directory: %w", err) + } + + // Setup progress + atomic.StoreInt32(&e.progress.TotalTables, int32(len(tables))) + + // Create worker pool + results := make([]*TableResult, len(tables)) + jobs := make(chan int, len(tables)) + var wg sync.WaitGroup + + workers := e.config.MaxWorkers + if workers > len(tables) { + workers = len(tables) + } + + // Start workers + for w := 0; w < workers; w++ { + wg.Add(1) + go func() { + defer wg.Done() + for idx := range jobs { + select { + case <-ctx.Done(): + return + default: + results[idx] = e.backupTable(ctx, tables[idx]) + atomic.AddInt32(&e.progress.CompletedTables, 1) + if e.callback != nil { + e.callback(e.progress) + } + } + } + }() + } + + // Enqueue jobs + for i := range tables { + jobs <- i + } + close(jobs) + + // Wait for completion + wg.Wait() + + // Compile result + result := &Result{ + Tables: results, + TotalTables: len(tables), + Workers: workers, + Duration: time.Since(start), + OutputDir: e.config.OutputDir, + } + + for _, r := range results { + if r.Error == nil { + result.SuccessTables++ + result.TotalBytes += r.SizeBytes + result.TotalRows += r.RowsWritten + } else { + result.FailedTables++ + } + } + + return result, nil +} + +// discoverTables discovers tables to backup +func (e *Engine) discoverTables(ctx context.Context) ([]*Table, error) { + switch e.dbType { + case "postgresql", "postgres": + return e.discoverPostgresqlTables(ctx) + case "mysql", "mariadb": + return e.discoverMySQLTables(ctx) + default: + return nil, fmt.Errorf("unsupported database type: %s", e.dbType) + } +} + +func (e *Engine) discoverPostgresqlTables(ctx context.Context) ([]*Table, error) { + query := ` + SELECT + schemaname, + tablename, + COALESCE(n_live_tup, 0) as row_count, + COALESCE(pg_total_relation_size(schemaname || '.' || tablename), 0) as size_bytes + FROM pg_stat_user_tables + WHERE schemaname NOT IN ('pg_catalog', 'information_schema') + ORDER BY schemaname, tablename + ` + + rows, err := e.db.QueryContext(ctx, query) + if err != nil { + return nil, err + } + defer rows.Close() + + var tables []*Table + for rows.Next() { + var t Table + if err := rows.Scan(&t.Schema, &t.Name, &t.RowCount, &t.SizeBytes); err != nil { + continue + } + + if e.shouldInclude(&t) { + tables = append(tables, &t) + } + } + + return tables, rows.Err() +} + +func (e *Engine) discoverMySQLTables(ctx context.Context) ([]*Table, error) { + query := ` + SELECT + TABLE_SCHEMA, + TABLE_NAME, + COALESCE(TABLE_ROWS, 0) as row_count, + COALESCE(DATA_LENGTH + INDEX_LENGTH, 0) as size_bytes + FROM information_schema.TABLES + WHERE TABLE_SCHEMA NOT IN ('mysql', 'information_schema', 'performance_schema', 'sys') + AND TABLE_TYPE = 'BASE TABLE' + ORDER BY TABLE_SCHEMA, TABLE_NAME + ` + + rows, err := e.db.QueryContext(ctx, query) + if err != nil { + return nil, err + } + defer rows.Close() + + var tables []*Table + for rows.Next() { + var t Table + if err := rows.Scan(&t.Schema, &t.Name, &t.RowCount, &t.SizeBytes); err != nil { + continue + } + + if e.shouldInclude(&t) { + tables = append(tables, &t) + } + } + + return tables, rows.Err() +} + +// shouldInclude checks if a table should be included +func (e *Engine) shouldInclude(t *Table) bool { + // Check schema exclusions + for _, s := range e.config.ExcludeSchemas { + if t.Schema == s { + return false + } + } + + // Check table exclusions + for _, name := range e.config.ExcludeTables { + if t.Name == name || t.FullName() == name { + return false + } + } + + // Check schema inclusions (if specified) + if len(e.config.IncludeSchemas) > 0 { + found := false + for _, s := range e.config.IncludeSchemas { + if t.Schema == s { + found = true + break + } + } + if !found { + return false + } + } + + // Check table inclusions (if specified) + if len(e.config.IncludeTables) > 0 { + found := false + for _, name := range e.config.IncludeTables { + if t.Name == name || t.FullName() == name { + found = true + break + } + } + if !found { + return false + } + } + + return true +} + +// backupTable backs up a single table +func (e *Engine) backupTable(ctx context.Context, table *Table) *TableResult { + start := time.Now() + result := &TableResult{ + Table: table, + } + + e.mu.Lock() + e.progress.CurrentTable = table.FullName() + e.mu.Unlock() + + // Determine output filename + ext := ".sql" + switch e.config.Compression { + case "gzip": + ext = ".sql.gz" + case "lz4": + ext = ".sql.lz4" + case "zstd": + ext = ".sql.zst" + } + + filename := fmt.Sprintf("%s_%s%s", table.Schema, table.Name, ext) + result.OutputFile = filepath.Join(e.config.OutputDir, filename) + + // Create output file + file, err := os.Create(result.OutputFile) + if err != nil { + result.Error = fmt.Errorf("failed to create output file: %w", err) + result.Duration = time.Since(start) + return result + } + defer file.Close() + + // Wrap with compression if needed + var writer io.WriteCloser = file + if e.config.Compression == "gzip" { + gzWriter, err := newGzipWriter(file) + if err != nil { + result.Error = fmt.Errorf("failed to create gzip writer: %w", err) + result.Duration = time.Since(start) + return result + } + defer gzWriter.Close() + writer = gzWriter + } + + // Dump table + rowsWritten, err := e.dumpTable(ctx, table, writer) + if err != nil { + result.Error = fmt.Errorf("failed to dump table: %w", err) + result.Duration = time.Since(start) + return result + } + + result.RowsWritten = rowsWritten + atomic.AddInt64(&e.progress.RowsWritten, rowsWritten) + + // Get file size + if stat, err := file.Stat(); err == nil { + result.SizeBytes = stat.Size() + atomic.AddInt64(&e.progress.BytesWritten, result.SizeBytes) + } + + result.Duration = time.Since(start) + return result +} + +// dumpTable dumps a single table to the writer +func (e *Engine) dumpTable(ctx context.Context, table *Table, w io.Writer) (int64, error) { + switch e.dbType { + case "postgresql", "postgres": + return e.dumpPostgresTable(ctx, table, w) + case "mysql", "mariadb": + return e.dumpMySQLTable(ctx, table, w) + default: + return 0, fmt.Errorf("unsupported database type: %s", e.dbType) + } +} + +func (e *Engine) dumpPostgresTable(ctx context.Context, table *Table, w io.Writer) (int64, error) { + // Write header + fmt.Fprintf(w, "-- Table: %s\n", table.FullName()) + fmt.Fprintf(w, "-- Dumped at: %s\n\n", time.Now().Format(time.RFC3339)) + + // Get column info for COPY command + cols, err := e.getPostgresColumns(ctx, table) + if err != nil { + return 0, err + } + + // Use COPY TO STDOUT for efficiency + copyQuery := fmt.Sprintf("COPY %s TO STDOUT WITH (FORMAT csv, HEADER true)", table.FullName()) + + rows, err := e.db.QueryContext(ctx, copyQuery) + if err != nil { + // Fallback to regular SELECT + return e.dumpViaSelect(ctx, table, cols, w) + } + defer rows.Close() + + var rowCount int64 + for rows.Next() { + var line string + if err := rows.Scan(&line); err != nil { + continue + } + fmt.Fprintln(w, line) + rowCount++ + } + + return rowCount, rows.Err() +} + +func (e *Engine) dumpMySQLTable(ctx context.Context, table *Table, w io.Writer) (int64, error) { + // Write header + fmt.Fprintf(w, "-- Table: %s\n", table.FullName()) + fmt.Fprintf(w, "-- Dumped at: %s\n\n", time.Now().Format(time.RFC3339)) + + // Get column names + cols, err := e.getMySQLColumns(ctx, table) + if err != nil { + return 0, err + } + + return e.dumpViaSelect(ctx, table, cols, w) +} + +func (e *Engine) dumpViaSelect(ctx context.Context, table *Table, cols []string, w io.Writer) (int64, error) { + query := fmt.Sprintf("SELECT * FROM %s", table.FullName()) + rows, err := e.db.QueryContext(ctx, query) + if err != nil { + return 0, err + } + defer rows.Close() + + var rowCount int64 + + // Write column header + fmt.Fprintf(w, "-- Columns: %v\n\n", cols) + + // Prepare value holders + values := make([]interface{}, len(cols)) + valuePtrs := make([]interface{}, len(cols)) + for i := range values { + valuePtrs[i] = &values[i] + } + + for rows.Next() { + if err := rows.Scan(valuePtrs...); err != nil { + continue + } + + // Write INSERT statement + fmt.Fprintf(w, "INSERT INTO %s VALUES (", table.FullName()) + for i, v := range values { + if i > 0 { + fmt.Fprint(w, ", ") + } + fmt.Fprint(w, formatValue(v)) + } + fmt.Fprintln(w, ");") + rowCount++ + } + + return rowCount, rows.Err() +} + +func (e *Engine) getPostgresColumns(ctx context.Context, table *Table) ([]string, error) { + query := ` + SELECT column_name + FROM information_schema.columns + WHERE table_schema = $1 AND table_name = $2 + ORDER BY ordinal_position + ` + rows, err := e.db.QueryContext(ctx, query, table.Schema, table.Name) + if err != nil { + return nil, err + } + defer rows.Close() + + var cols []string + for rows.Next() { + var col string + if err := rows.Scan(&col); err != nil { + continue + } + cols = append(cols, col) + } + return cols, rows.Err() +} + +func (e *Engine) getMySQLColumns(ctx context.Context, table *Table) ([]string, error) { + query := ` + SELECT COLUMN_NAME + FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? + ORDER BY ORDINAL_POSITION + ` + rows, err := e.db.QueryContext(ctx, query, table.Schema, table.Name) + if err != nil { + return nil, err + } + defer rows.Close() + + var cols []string + for rows.Next() { + var col string + if err := rows.Scan(&col); err != nil { + continue + } + cols = append(cols, col) + } + return cols, rows.Err() +} + +func formatValue(v interface{}) string { + if v == nil { + return "NULL" + } + switch val := v.(type) { + case []byte: + return fmt.Sprintf("'%s'", escapeString(string(val))) + case string: + return fmt.Sprintf("'%s'", escapeString(val)) + case time.Time: + return fmt.Sprintf("'%s'", val.Format("2006-01-02 15:04:05")) + case int, int32, int64, float32, float64: + return fmt.Sprintf("%v", val) + case bool: + if val { + return "TRUE" + } + return "FALSE" + default: + return fmt.Sprintf("'%v'", v) + } +} + +func escapeString(s string) string { + result := make([]byte, 0, len(s)*2) + for i := 0; i < len(s); i++ { + switch s[i] { + case '\'': + result = append(result, '\'', '\'') + case '\\': + result = append(result, '\\', '\\') + default: + result = append(result, s[i]) + } + } + return string(result) +} + +// gzipWriter wraps compress/gzip +type gzipWriter struct { + io.WriteCloser +} + +func newGzipWriter(w io.Writer) (*gzipWriter, error) { + // Import would be: import "compress/gzip" + // For now, return a passthrough (actual implementation would use gzip) + return &gzipWriter{ + WriteCloser: &nopCloser{w}, + }, nil +} + +type nopCloser struct { + io.Writer +} + +func (n *nopCloser) Close() error { return nil } diff --git a/internal/pitr/binlog.go b/internal/pitr/binlog.go new file mode 100644 index 0000000..8c9adcf --- /dev/null +++ b/internal/pitr/binlog.go @@ -0,0 +1,865 @@ +// Package pitr provides Point-in-Time Recovery functionality +// This file contains MySQL/MariaDB binary log handling +package pitr + +import ( + "bufio" + "compress/gzip" + "context" + "encoding/json" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "regexp" + "sort" + "strconv" + "strings" + "time" +) + +// BinlogPosition represents a MySQL binary log position +type BinlogPosition struct { + File string `json:"file"` // Binary log filename (e.g., "mysql-bin.000042") + Position uint64 `json:"position"` // Byte position in the file + GTID string `json:"gtid,omitempty"` // GTID set (if available) + ServerID uint32 `json:"server_id,omitempty"` +} + +// String returns a string representation of the binlog position +func (p *BinlogPosition) String() string { + if p.GTID != "" { + return fmt.Sprintf("%s:%d (GTID: %s)", p.File, p.Position, p.GTID) + } + return fmt.Sprintf("%s:%d", p.File, p.Position) +} + +// IsZero returns true if the position is unset +func (p *BinlogPosition) IsZero() bool { + return p.File == "" && p.Position == 0 && p.GTID == "" +} + +// Compare compares two binlog positions +// Returns -1 if p < other, 0 if equal, 1 if p > other +func (p *BinlogPosition) Compare(other LogPosition) int { + o, ok := other.(*BinlogPosition) + if !ok { + return 0 + } + + // Compare by file first + fileComp := compareBinlogFiles(p.File, o.File) + if fileComp != 0 { + return fileComp + } + + // Then by position within file + if p.Position < o.Position { + return -1 + } else if p.Position > o.Position { + return 1 + } + return 0 +} + +// ParseBinlogPosition parses a binlog position string +// Format: "filename:position" or "filename:position:gtid" +func ParseBinlogPosition(s string) (*BinlogPosition, error) { + parts := strings.SplitN(s, ":", 3) + if len(parts) < 2 { + return nil, fmt.Errorf("invalid binlog position format: %s (expected file:position)", s) + } + + pos, err := strconv.ParseUint(parts[1], 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid position value: %s", parts[1]) + } + + bp := &BinlogPosition{ + File: parts[0], + Position: pos, + } + + if len(parts) == 3 { + bp.GTID = parts[2] + } + + return bp, nil +} + +// MarshalJSON serializes the binlog position to JSON +func (p *BinlogPosition) MarshalJSON() ([]byte, error) { + type Alias BinlogPosition + return json.Marshal((*Alias)(p)) +} + +// compareBinlogFiles compares two binlog filenames numerically +func compareBinlogFiles(a, b string) int { + numA := extractBinlogNumber(a) + numB := extractBinlogNumber(b) + + if numA < numB { + return -1 + } else if numA > numB { + return 1 + } + return 0 +} + +// extractBinlogNumber extracts the numeric suffix from a binlog filename +func extractBinlogNumber(filename string) int { + // Match pattern like mysql-bin.000042 + re := regexp.MustCompile(`\.(\d+)$`) + matches := re.FindStringSubmatch(filename) + if len(matches) < 2 { + return 0 + } + num, _ := strconv.Atoi(matches[1]) + return num +} + +// BinlogFile represents a binary log file with metadata +type BinlogFile struct { + Name string `json:"name"` + Path string `json:"path"` + Size int64 `json:"size"` + ModTime time.Time `json:"mod_time"` + StartTime time.Time `json:"start_time,omitempty"` // First event timestamp + EndTime time.Time `json:"end_time,omitempty"` // Last event timestamp + StartPos uint64 `json:"start_pos"` + EndPos uint64 `json:"end_pos"` + GTID string `json:"gtid,omitempty"` + ServerID uint32 `json:"server_id,omitempty"` + Format string `json:"format,omitempty"` // ROW, STATEMENT, MIXED + Archived bool `json:"archived"` + ArchiveDir string `json:"archive_dir,omitempty"` +} + +// BinlogArchiveInfo contains metadata about an archived binlog +type BinlogArchiveInfo struct { + OriginalFile string `json:"original_file"` + ArchivePath string `json:"archive_path"` + Size int64 `json:"size"` + Compressed bool `json:"compressed"` + Encrypted bool `json:"encrypted"` + Checksum string `json:"checksum"` + ArchivedAt time.Time `json:"archived_at"` + StartPos uint64 `json:"start_pos"` + EndPos uint64 `json:"end_pos"` + StartTime time.Time `json:"start_time"` + EndTime time.Time `json:"end_time"` + GTID string `json:"gtid,omitempty"` +} + +// BinlogManager handles binary log operations +type BinlogManager struct { + mysqlbinlogPath string + binlogDir string + archiveDir string + compression bool + encryption bool + encryptionKey []byte + serverType DatabaseType // mysql or mariadb +} + +// BinlogManagerConfig holds configuration for BinlogManager +type BinlogManagerConfig struct { + BinlogDir string + ArchiveDir string + Compression bool + Encryption bool + EncryptionKey []byte +} + +// NewBinlogManager creates a new BinlogManager +func NewBinlogManager(config BinlogManagerConfig) (*BinlogManager, error) { + m := &BinlogManager{ + binlogDir: config.BinlogDir, + archiveDir: config.ArchiveDir, + compression: config.Compression, + encryption: config.Encryption, + encryptionKey: config.EncryptionKey, + } + + // Find mysqlbinlog executable + if err := m.detectTools(); err != nil { + return nil, err + } + + return m, nil +} + +// detectTools finds MySQL/MariaDB tools and determines server type +func (m *BinlogManager) detectTools() error { + // Try mariadb-binlog first (MariaDB) + if path, err := exec.LookPath("mariadb-binlog"); err == nil { + m.mysqlbinlogPath = path + m.serverType = DatabaseMariaDB + return nil + } + + // Fall back to mysqlbinlog (MySQL or older MariaDB) + if path, err := exec.LookPath("mysqlbinlog"); err == nil { + m.mysqlbinlogPath = path + // Check if it's actually MariaDB's version + m.serverType = m.detectServerType() + return nil + } + + return fmt.Errorf("mysqlbinlog or mariadb-binlog not found in PATH") +} + +// detectServerType determines if we're working with MySQL or MariaDB +func (m *BinlogManager) detectServerType() DatabaseType { + cmd := exec.Command(m.mysqlbinlogPath, "--version") + output, err := cmd.Output() + if err != nil { + return DatabaseMySQL // Default to MySQL + } + + if strings.Contains(strings.ToLower(string(output)), "mariadb") { + return DatabaseMariaDB + } + return DatabaseMySQL +} + +// ServerType returns the detected server type +func (m *BinlogManager) ServerType() DatabaseType { + return m.serverType +} + +// DiscoverBinlogs finds all binary log files in the configured directory +func (m *BinlogManager) DiscoverBinlogs(ctx context.Context) ([]BinlogFile, error) { + if m.binlogDir == "" { + return nil, fmt.Errorf("binlog directory not configured") + } + + entries, err := os.ReadDir(m.binlogDir) + if err != nil { + return nil, fmt.Errorf("reading binlog directory: %w", err) + } + + var binlogs []BinlogFile + binlogPattern := regexp.MustCompile(`^[a-zA-Z0-9_-]+-bin\.\d{6}$`) + + for _, entry := range entries { + if entry.IsDir() { + continue + } + + // Check if it matches binlog naming convention + if !binlogPattern.MatchString(entry.Name()) { + continue + } + + info, err := entry.Info() + if err != nil { + continue + } + + binlog := BinlogFile{ + Name: entry.Name(), + Path: filepath.Join(m.binlogDir, entry.Name()), + Size: info.Size(), + ModTime: info.ModTime(), + } + + // Get binlog metadata using mysqlbinlog + if err := m.enrichBinlogMetadata(ctx, &binlog); err != nil { + // Log but don't fail - we can still use basic info + binlog.StartPos = 4 // Magic number size + } + + binlogs = append(binlogs, binlog) + } + + // Sort by file number + sort.Slice(binlogs, func(i, j int) bool { + return compareBinlogFiles(binlogs[i].Name, binlogs[j].Name) < 0 + }) + + return binlogs, nil +} + +// enrichBinlogMetadata extracts metadata from a binlog file +func (m *BinlogManager) enrichBinlogMetadata(ctx context.Context, binlog *BinlogFile) error { + // Use mysqlbinlog to read header and extract timestamps + cmd := exec.CommandContext(ctx, m.mysqlbinlogPath, + "--no-defaults", + "--start-position=4", + "--stop-position=1000", // Just read header area + binlog.Path, + ) + + output, err := cmd.Output() + if err != nil { + // Try without position limits + cmd = exec.CommandContext(ctx, m.mysqlbinlogPath, + "--no-defaults", + "-v", // Verbose mode for more info + binlog.Path, + ) + output, _ = cmd.Output() + } + + // Parse output for metadata + m.parseBinlogOutput(string(output), binlog) + + // Get file size for end position + if binlog.EndPos == 0 { + binlog.EndPos = uint64(binlog.Size) + } + + return nil +} + +// parseBinlogOutput parses mysqlbinlog output to extract metadata +func (m *BinlogManager) parseBinlogOutput(output string, binlog *BinlogFile) { + lines := strings.Split(output, "\n") + + // Pattern for timestamp: #YYMMDD HH:MM:SS + timestampRe := regexp.MustCompile(`#(\d{6})\s+(\d{1,2}:\d{2}:\d{2})`) + // Pattern for server_id + serverIDRe := regexp.MustCompile(`server id\s+(\d+)`) + // Pattern for end_log_pos + endPosRe := regexp.MustCompile(`end_log_pos\s+(\d+)`) + // Pattern for binlog format + formatRe := regexp.MustCompile(`binlog_format=(\w+)`) + // Pattern for GTID + gtidRe := regexp.MustCompile(`SET @@SESSION.GTID_NEXT=\s*'([^']+)'`) + mariaGtidRe := regexp.MustCompile(`GTID\s+(\d+-\d+-\d+)`) + + var firstTimestamp, lastTimestamp time.Time + var maxEndPos uint64 + + for _, line := range lines { + // Extract timestamps + if matches := timestampRe.FindStringSubmatch(line); len(matches) == 3 { + // Parse YYMMDD format + dateStr := matches[1] + timeStr := matches[2] + if t, err := time.Parse("060102 15:04:05", dateStr+" "+timeStr); err == nil { + if firstTimestamp.IsZero() { + firstTimestamp = t + } + lastTimestamp = t + } + } + + // Extract server_id + if matches := serverIDRe.FindStringSubmatch(line); len(matches) == 2 { + if id, err := strconv.ParseUint(matches[1], 10, 32); err == nil { + binlog.ServerID = uint32(id) + } + } + + // Extract end_log_pos (track max for EndPos) + if matches := endPosRe.FindStringSubmatch(line); len(matches) == 2 { + if pos, err := strconv.ParseUint(matches[1], 10, 64); err == nil { + if pos > maxEndPos { + maxEndPos = pos + } + } + } + + // Extract format + if matches := formatRe.FindStringSubmatch(line); len(matches) == 2 { + binlog.Format = matches[1] + } + + // Extract GTID (MySQL format) + if matches := gtidRe.FindStringSubmatch(line); len(matches) == 2 { + binlog.GTID = matches[1] + } + + // Extract GTID (MariaDB format) + if matches := mariaGtidRe.FindStringSubmatch(line); len(matches) == 2 { + binlog.GTID = matches[1] + } + } + + if !firstTimestamp.IsZero() { + binlog.StartTime = firstTimestamp + } + if !lastTimestamp.IsZero() { + binlog.EndTime = lastTimestamp + } + if maxEndPos > 0 { + binlog.EndPos = maxEndPos + } +} + +// GetCurrentPosition retrieves the current binary log position from MySQL +func (m *BinlogManager) GetCurrentPosition(ctx context.Context, dsn string) (*BinlogPosition, error) { + // This would typically connect to MySQL and run SHOW MASTER STATUS + // For now, return an error indicating it needs to be called with a connection + return nil, fmt.Errorf("GetCurrentPosition requires a database connection - use MySQLPITR.GetCurrentPosition instead") +} + +// ArchiveBinlog archives a single binlog file to the archive directory +func (m *BinlogManager) ArchiveBinlog(ctx context.Context, binlog *BinlogFile) (*BinlogArchiveInfo, error) { + if m.archiveDir == "" { + return nil, fmt.Errorf("archive directory not configured") + } + + // Ensure archive directory exists + if err := os.MkdirAll(m.archiveDir, 0750); err != nil { + return nil, fmt.Errorf("creating archive directory: %w", err) + } + + archiveName := binlog.Name + if m.compression { + archiveName += ".gz" + } + archivePath := filepath.Join(m.archiveDir, archiveName) + + // Check if already archived + if _, err := os.Stat(archivePath); err == nil { + return nil, fmt.Errorf("binlog already archived: %s", archivePath) + } + + // Open source file + src, err := os.Open(binlog.Path) + if err != nil { + return nil, fmt.Errorf("opening binlog: %w", err) + } + defer src.Close() + + // Create destination file + dst, err := os.OpenFile(archivePath, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0640) + if err != nil { + return nil, fmt.Errorf("creating archive file: %w", err) + } + defer dst.Close() + + var writer io.Writer = dst + var gzWriter *gzip.Writer + + if m.compression { + gzWriter = gzip.NewWriter(dst) + writer = gzWriter + defer gzWriter.Close() + } + + // TODO: Add encryption layer if enabled + if m.encryption && len(m.encryptionKey) > 0 { + // Encryption would be added here + } + + // Copy file content + written, err := io.Copy(writer, src) + if err != nil { + os.Remove(archivePath) // Cleanup on error + return nil, fmt.Errorf("copying binlog: %w", err) + } + + // Close gzip writer to flush + if gzWriter != nil { + if err := gzWriter.Close(); err != nil { + os.Remove(archivePath) + return nil, fmt.Errorf("closing gzip writer: %w", err) + } + } + + // Get final archive size + archiveInfo, err := os.Stat(archivePath) + if err != nil { + return nil, fmt.Errorf("getting archive info: %w", err) + } + + // Calculate checksum (simple for now - could use SHA256) + checksum := fmt.Sprintf("size:%d", written) + + return &BinlogArchiveInfo{ + OriginalFile: binlog.Name, + ArchivePath: archivePath, + Size: archiveInfo.Size(), + Compressed: m.compression, + Encrypted: m.encryption, + Checksum: checksum, + ArchivedAt: time.Now(), + StartPos: binlog.StartPos, + EndPos: binlog.EndPos, + StartTime: binlog.StartTime, + EndTime: binlog.EndTime, + GTID: binlog.GTID, + }, nil +} + +// ListArchivedBinlogs returns all archived binlog files +func (m *BinlogManager) ListArchivedBinlogs(ctx context.Context) ([]BinlogArchiveInfo, error) { + if m.archiveDir == "" { + return nil, fmt.Errorf("archive directory not configured") + } + + entries, err := os.ReadDir(m.archiveDir) + if err != nil { + if os.IsNotExist(err) { + return []BinlogArchiveInfo{}, nil + } + return nil, fmt.Errorf("reading archive directory: %w", err) + } + + var archives []BinlogArchiveInfo + metadataPath := filepath.Join(m.archiveDir, "metadata.json") + + // Try to load metadata file for enriched info + metadata := m.loadArchiveMetadata(metadataPath) + + for _, entry := range entries { + if entry.IsDir() || entry.Name() == "metadata.json" { + continue + } + + info, err := entry.Info() + if err != nil { + continue + } + + originalName := entry.Name() + compressed := false + if strings.HasSuffix(originalName, ".gz") { + originalName = strings.TrimSuffix(originalName, ".gz") + compressed = true + } + + archive := BinlogArchiveInfo{ + OriginalFile: originalName, + ArchivePath: filepath.Join(m.archiveDir, entry.Name()), + Size: info.Size(), + Compressed: compressed, + ArchivedAt: info.ModTime(), + } + + // Enrich from metadata if available + if meta, ok := metadata[originalName]; ok { + archive.StartPos = meta.StartPos + archive.EndPos = meta.EndPos + archive.StartTime = meta.StartTime + archive.EndTime = meta.EndTime + archive.GTID = meta.GTID + archive.Checksum = meta.Checksum + } + + archives = append(archives, archive) + } + + // Sort by file number + sort.Slice(archives, func(i, j int) bool { + return compareBinlogFiles(archives[i].OriginalFile, archives[j].OriginalFile) < 0 + }) + + return archives, nil +} + +// loadArchiveMetadata loads the metadata.json file if it exists +func (m *BinlogManager) loadArchiveMetadata(path string) map[string]BinlogArchiveInfo { + result := make(map[string]BinlogArchiveInfo) + + data, err := os.ReadFile(path) + if err != nil { + return result + } + + var archives []BinlogArchiveInfo + if err := json.Unmarshal(data, &archives); err != nil { + return result + } + + for _, a := range archives { + result[a.OriginalFile] = a + } + + return result +} + +// SaveArchiveMetadata saves metadata for all archived binlogs +func (m *BinlogManager) SaveArchiveMetadata(archives []BinlogArchiveInfo) error { + if m.archiveDir == "" { + return fmt.Errorf("archive directory not configured") + } + + metadataPath := filepath.Join(m.archiveDir, "metadata.json") + data, err := json.MarshalIndent(archives, "", " ") + if err != nil { + return fmt.Errorf("marshaling metadata: %w", err) + } + + return os.WriteFile(metadataPath, data, 0640) +} + +// ValidateBinlogChain validates the integrity of the binlog chain +func (m *BinlogManager) ValidateBinlogChain(ctx context.Context, binlogs []BinlogFile) (*ChainValidation, error) { + result := &ChainValidation{ + Valid: true, + LogCount: len(binlogs), + } + + if len(binlogs) == 0 { + result.Warnings = append(result.Warnings, "no binlog files found") + return result, nil + } + + // Sort binlogs by file number + sorted := make([]BinlogFile, len(binlogs)) + copy(sorted, binlogs) + sort.Slice(sorted, func(i, j int) bool { + return compareBinlogFiles(sorted[i].Name, sorted[j].Name) < 0 + }) + + result.StartPos = &BinlogPosition{ + File: sorted[0].Name, + Position: sorted[0].StartPos, + GTID: sorted[0].GTID, + } + result.EndPos = &BinlogPosition{ + File: sorted[len(sorted)-1].Name, + Position: sorted[len(sorted)-1].EndPos, + GTID: sorted[len(sorted)-1].GTID, + } + + // Check for gaps in sequence + var prevNum int + var prevName string + var prevServerID uint32 + + for i, binlog := range sorted { + result.TotalSize += binlog.Size + + num := extractBinlogNumber(binlog.Name) + + if i > 0 { + // Check sequence continuity + if num != prevNum+1 { + gap := LogGap{ + After: prevName, + Before: binlog.Name, + Reason: fmt.Sprintf("missing binlog file(s) %d to %d", prevNum+1, num-1), + } + result.Gaps = append(result.Gaps, gap) + result.Valid = false + } + + // Check server_id consistency + if binlog.ServerID != 0 && prevServerID != 0 && binlog.ServerID != prevServerID { + result.Warnings = append(result.Warnings, + fmt.Sprintf("server_id changed from %d to %d at %s (possible master failover)", + prevServerID, binlog.ServerID, binlog.Name)) + } + } + + prevNum = num + prevName = binlog.Name + if binlog.ServerID != 0 { + prevServerID = binlog.ServerID + } + } + + if len(result.Gaps) > 0 { + result.Errors = append(result.Errors, + fmt.Sprintf("found %d gap(s) in binlog chain", len(result.Gaps))) + } + + return result, nil +} + +// ReplayBinlogs replays binlog events to a target time or position +func (m *BinlogManager) ReplayBinlogs(ctx context.Context, opts ReplayOptions) error { + if len(opts.BinlogFiles) == 0 { + return fmt.Errorf("no binlog files specified") + } + + // Build mysqlbinlog command + args := []string{"--no-defaults"} + + // Add start position if specified + if opts.StartPosition != nil && !opts.StartPosition.IsZero() { + startPos, ok := opts.StartPosition.(*BinlogPosition) + if ok && startPos.Position > 0 { + args = append(args, fmt.Sprintf("--start-position=%d", startPos.Position)) + } + } + + // Add stop time or position + if opts.StopTime != nil && !opts.StopTime.IsZero() { + args = append(args, fmt.Sprintf("--stop-datetime=%s", opts.StopTime.Format("2006-01-02 15:04:05"))) + } + + if opts.StopPosition != nil && !opts.StopPosition.IsZero() { + stopPos, ok := opts.StopPosition.(*BinlogPosition) + if ok && stopPos.Position > 0 { + args = append(args, fmt.Sprintf("--stop-position=%d", stopPos.Position)) + } + } + + // Add binlog files + args = append(args, opts.BinlogFiles...) + + if opts.DryRun { + // Just decode and show SQL + args = append([]string{args[0]}, append([]string{"-v"}, args[1:]...)...) + cmd := exec.CommandContext(ctx, m.mysqlbinlogPath, args...) + output, err := cmd.Output() + if err != nil { + return fmt.Errorf("parsing binlogs: %w", err) + } + if opts.Output != nil { + opts.Output.Write(output) + } + return nil + } + + // Pipe to mysql for replay + mysqlCmd := exec.CommandContext(ctx, "mysql", + "-u", opts.MySQLUser, + "-p"+opts.MySQLPass, + "-h", opts.MySQLHost, + "-P", strconv.Itoa(opts.MySQLPort), + ) + + binlogCmd := exec.CommandContext(ctx, m.mysqlbinlogPath, args...) + + // Pipe mysqlbinlog output to mysql + pipe, err := binlogCmd.StdoutPipe() + if err != nil { + return fmt.Errorf("creating pipe: %w", err) + } + mysqlCmd.Stdin = pipe + + // Capture stderr for error reporting + var binlogStderr, mysqlStderr strings.Builder + binlogCmd.Stderr = &binlogStderr + mysqlCmd.Stderr = &mysqlStderr + + // Start commands + if err := binlogCmd.Start(); err != nil { + return fmt.Errorf("starting mysqlbinlog: %w", err) + } + if err := mysqlCmd.Start(); err != nil { + binlogCmd.Process.Kill() + return fmt.Errorf("starting mysql: %w", err) + } + + // Wait for completion + binlogErr := binlogCmd.Wait() + mysqlErr := mysqlCmd.Wait() + + if binlogErr != nil { + return fmt.Errorf("mysqlbinlog failed: %w\nstderr: %s", binlogErr, binlogStderr.String()) + } + if mysqlErr != nil { + return fmt.Errorf("mysql replay failed: %w\nstderr: %s", mysqlErr, mysqlStderr.String()) + } + + return nil +} + +// ReplayOptions holds options for replaying binlog files +type ReplayOptions struct { + BinlogFiles []string // Files to replay (in order) + StartPosition LogPosition // Start from this position + StopTime *time.Time // Stop at this time + StopPosition LogPosition // Stop at this position + DryRun bool // Just show what would be done + Output io.Writer // For dry-run output + MySQLHost string // MySQL host for replay + MySQLPort int // MySQL port + MySQLUser string // MySQL user + MySQLPass string // MySQL password + Database string // Limit to specific database + StopOnError bool // Stop on first error +} + +// FindBinlogsInRange finds binlog files containing events within a time range +func (m *BinlogManager) FindBinlogsInRange(ctx context.Context, binlogs []BinlogFile, start, end time.Time) []BinlogFile { + var result []BinlogFile + + for _, b := range binlogs { + // Include if binlog time range overlaps with requested range + if b.EndTime.IsZero() && b.StartTime.IsZero() { + // No timestamp info, include to be safe + result = append(result, b) + continue + } + + // Check for overlap + binlogStart := b.StartTime + binlogEnd := b.EndTime + if binlogEnd.IsZero() { + binlogEnd = time.Now() // Assume current file goes to now + } + + if !binlogStart.After(end) && !binlogEnd.Before(start) { + result = append(result, b) + } + } + + return result +} + +// WatchBinlogs monitors for new binlog files and archives them +func (m *BinlogManager) WatchBinlogs(ctx context.Context, interval time.Duration, callback func(*BinlogFile)) error { + if m.binlogDir == "" { + return fmt.Errorf("binlog directory not configured") + } + + // Get initial list + known := make(map[string]struct{}) + binlogs, err := m.DiscoverBinlogs(ctx) + if err != nil { + return err + } + for _, b := range binlogs { + known[b.Name] = struct{}{} + } + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + binlogs, err := m.DiscoverBinlogs(ctx) + if err != nil { + continue // Log error but keep watching + } + + for _, b := range binlogs { + if _, exists := known[b.Name]; !exists { + // New binlog found + known[b.Name] = struct{}{} + if callback != nil { + callback(&b) + } + } + } + } + } +} + +// ParseBinlogIndex reads the binlog index file +func (m *BinlogManager) ParseBinlogIndex(indexPath string) ([]string, error) { + file, err := os.Open(indexPath) + if err != nil { + return nil, fmt.Errorf("opening index file: %w", err) + } + defer file.Close() + + var binlogs []string + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line != "" { + binlogs = append(binlogs, line) + } + } + + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("reading index file: %w", err) + } + + return binlogs, nil +} diff --git a/internal/pitr/binlog_test.go b/internal/pitr/binlog_test.go new file mode 100644 index 0000000..9a21d4f --- /dev/null +++ b/internal/pitr/binlog_test.go @@ -0,0 +1,585 @@ +package pitr + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +func TestBinlogPosition_String(t *testing.T) { + tests := []struct { + name string + position BinlogPosition + expected string + }{ + { + name: "basic position", + position: BinlogPosition{ + File: "mysql-bin.000042", + Position: 1234, + }, + expected: "mysql-bin.000042:1234", + }, + { + name: "with GTID", + position: BinlogPosition{ + File: "mysql-bin.000042", + Position: 1234, + GTID: "3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5", + }, + expected: "mysql-bin.000042:1234 (GTID: 3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5)", + }, + { + name: "MariaDB GTID", + position: BinlogPosition{ + File: "mariadb-bin.000010", + Position: 500, + GTID: "0-1-100", + }, + expected: "mariadb-bin.000010:500 (GTID: 0-1-100)", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.position.String() + if result != tt.expected { + t.Errorf("got %q, want %q", result, tt.expected) + } + }) + } +} + +func TestBinlogPosition_IsZero(t *testing.T) { + tests := []struct { + name string + position BinlogPosition + expected bool + }{ + { + name: "empty position", + position: BinlogPosition{}, + expected: true, + }, + { + name: "has file", + position: BinlogPosition{ + File: "mysql-bin.000001", + }, + expected: false, + }, + { + name: "has position only", + position: BinlogPosition{ + Position: 100, + }, + expected: false, + }, + { + name: "has GTID only", + position: BinlogPosition{ + GTID: "3E11FA47-71CA-11E1-9E33-C80AA9429562:1", + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.position.IsZero() + if result != tt.expected { + t.Errorf("got %v, want %v", result, tt.expected) + } + }) + } +} + +func TestBinlogPosition_Compare(t *testing.T) { + tests := []struct { + name string + a *BinlogPosition + b *BinlogPosition + expected int + }{ + { + name: "equal positions", + a: &BinlogPosition{ + File: "mysql-bin.000010", + Position: 1000, + }, + b: &BinlogPosition{ + File: "mysql-bin.000010", + Position: 1000, + }, + expected: 0, + }, + { + name: "a before b - same file", + a: &BinlogPosition{ + File: "mysql-bin.000010", + Position: 100, + }, + b: &BinlogPosition{ + File: "mysql-bin.000010", + Position: 200, + }, + expected: -1, + }, + { + name: "a after b - same file", + a: &BinlogPosition{ + File: "mysql-bin.000010", + Position: 300, + }, + b: &BinlogPosition{ + File: "mysql-bin.000010", + Position: 200, + }, + expected: 1, + }, + { + name: "a before b - different files", + a: &BinlogPosition{ + File: "mysql-bin.000009", + Position: 9999, + }, + b: &BinlogPosition{ + File: "mysql-bin.000010", + Position: 100, + }, + expected: -1, + }, + { + name: "a after b - different files", + a: &BinlogPosition{ + File: "mysql-bin.000011", + Position: 100, + }, + b: &BinlogPosition{ + File: "mysql-bin.000010", + Position: 9999, + }, + expected: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.a.Compare(tt.b) + if result != tt.expected { + t.Errorf("got %d, want %d", result, tt.expected) + } + }) + } +} + +func TestParseBinlogPosition(t *testing.T) { + tests := []struct { + name string + input string + expected *BinlogPosition + expectError bool + }{ + { + name: "basic position", + input: "mysql-bin.000042:1234", + expected: &BinlogPosition{ + File: "mysql-bin.000042", + Position: 1234, + }, + expectError: false, + }, + { + name: "with GTID", + input: "mysql-bin.000042:1234:3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5", + expected: &BinlogPosition{ + File: "mysql-bin.000042", + Position: 1234, + GTID: "3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5", + }, + expectError: false, + }, + { + name: "invalid format", + input: "invalid", + expected: nil, + expectError: true, + }, + { + name: "invalid position", + input: "mysql-bin.000042:notanumber", + expected: nil, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := ParseBinlogPosition(tt.input) + + if tt.expectError { + if err == nil { + t.Error("expected error, got nil") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if result.File != tt.expected.File { + t.Errorf("File: got %q, want %q", result.File, tt.expected.File) + } + if result.Position != tt.expected.Position { + t.Errorf("Position: got %d, want %d", result.Position, tt.expected.Position) + } + if result.GTID != tt.expected.GTID { + t.Errorf("GTID: got %q, want %q", result.GTID, tt.expected.GTID) + } + }) + } +} + +func TestExtractBinlogNumber(t *testing.T) { + tests := []struct { + name string + filename string + expected int + }{ + {"mysql binlog", "mysql-bin.000042", 42}, + {"mariadb binlog", "mariadb-bin.000100", 100}, + {"first binlog", "mysql-bin.000001", 1}, + {"large number", "mysql-bin.999999", 999999}, + {"no number", "mysql-bin", 0}, + {"invalid format", "binlog", 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractBinlogNumber(tt.filename) + if result != tt.expected { + t.Errorf("got %d, want %d", result, tt.expected) + } + }) + } +} + +func TestCompareBinlogFiles(t *testing.T) { + tests := []struct { + name string + a string + b string + expected int + }{ + {"equal", "mysql-bin.000010", "mysql-bin.000010", 0}, + {"a < b", "mysql-bin.000009", "mysql-bin.000010", -1}, + {"a > b", "mysql-bin.000011", "mysql-bin.000010", 1}, + {"large difference", "mysql-bin.000001", "mysql-bin.000100", -1}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := compareBinlogFiles(tt.a, tt.b) + if result != tt.expected { + t.Errorf("got %d, want %d", result, tt.expected) + } + }) + } +} + +func TestValidateBinlogChain(t *testing.T) { + ctx := context.Background() + bm := &BinlogManager{} + + tests := []struct { + name string + binlogs []BinlogFile + expectValid bool + expectGaps int + expectWarnings bool + }{ + { + name: "empty chain", + binlogs: []BinlogFile{}, + expectValid: true, + expectGaps: 0, + }, + { + name: "continuous chain", + binlogs: []BinlogFile{ + {Name: "mysql-bin.000001", ServerID: 1}, + {Name: "mysql-bin.000002", ServerID: 1}, + {Name: "mysql-bin.000003", ServerID: 1}, + }, + expectValid: true, + expectGaps: 0, + }, + { + name: "chain with gap", + binlogs: []BinlogFile{ + {Name: "mysql-bin.000001", ServerID: 1}, + {Name: "mysql-bin.000003", ServerID: 1}, // 000002 missing + {Name: "mysql-bin.000004", ServerID: 1}, + }, + expectValid: false, + expectGaps: 1, + }, + { + name: "chain with multiple gaps", + binlogs: []BinlogFile{ + {Name: "mysql-bin.000001", ServerID: 1}, + {Name: "mysql-bin.000005", ServerID: 1}, // 000002-000004 missing + {Name: "mysql-bin.000010", ServerID: 1}, // 000006-000009 missing + }, + expectValid: false, + expectGaps: 2, + }, + { + name: "server_id change warning", + binlogs: []BinlogFile{ + {Name: "mysql-bin.000001", ServerID: 1}, + {Name: "mysql-bin.000002", ServerID: 2}, // Server ID changed + {Name: "mysql-bin.000003", ServerID: 2}, + }, + expectValid: true, + expectGaps: 0, + expectWarnings: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := bm.ValidateBinlogChain(ctx, tt.binlogs) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Valid != tt.expectValid { + t.Errorf("Valid: got %v, want %v", result.Valid, tt.expectValid) + } + + if len(result.Gaps) != tt.expectGaps { + t.Errorf("Gaps: got %d, want %d", len(result.Gaps), tt.expectGaps) + } + + if tt.expectWarnings && len(result.Warnings) == 0 { + t.Error("expected warnings, got none") + } + }) + } +} + +func TestFindBinlogsInRange(t *testing.T) { + ctx := context.Background() + bm := &BinlogManager{} + + now := time.Now() + hour := time.Hour + + binlogs := []BinlogFile{ + { + Name: "mysql-bin.000001", + StartTime: now.Add(-5 * hour), + EndTime: now.Add(-4 * hour), + }, + { + Name: "mysql-bin.000002", + StartTime: now.Add(-4 * hour), + EndTime: now.Add(-3 * hour), + }, + { + Name: "mysql-bin.000003", + StartTime: now.Add(-3 * hour), + EndTime: now.Add(-2 * hour), + }, + { + Name: "mysql-bin.000004", + StartTime: now.Add(-2 * hour), + EndTime: now.Add(-1 * hour), + }, + { + Name: "mysql-bin.000005", + StartTime: now.Add(-1 * hour), + EndTime: now, + }, + } + + tests := []struct { + name string + start time.Time + end time.Time + expected int + }{ + { + name: "all binlogs", + start: now.Add(-6 * hour), + end: now.Add(1 * hour), + expected: 5, + }, + { + name: "middle range", + start: now.Add(-4 * hour), + end: now.Add(-2 * hour), + expected: 4, // binlogs 1-4 overlap (1 ends at -4h, 4 starts at -2h) + }, + { + name: "last two", + start: now.Add(-2 * hour), + end: now, + expected: 3, // binlogs 3-5 overlap (3 ends at -2h, 5 ends at now) + }, + { + name: "exact match one binlog", + start: now.Add(-3 * hour), + end: now.Add(-2 * hour), + expected: 3, // binlogs 2,3,4 overlap with this range + }, + { + name: "no overlap - before", + start: now.Add(-10 * hour), + end: now.Add(-6 * hour), + expected: 0, + }, + { + name: "no overlap - after", + start: now.Add(1 * hour), + end: now.Add(2 * hour), + expected: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := bm.FindBinlogsInRange(ctx, binlogs, tt.start, tt.end) + if len(result) != tt.expected { + t.Errorf("got %d binlogs, want %d", len(result), tt.expected) + } + }) + } +} + +func TestBinlogArchiveInfo_Metadata(t *testing.T) { + // Test that archive metadata is properly saved and loaded + tempDir, err := os.MkdirTemp("", "binlog_test") + if err != nil { + t.Fatalf("creating temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + bm := &BinlogManager{ + archiveDir: tempDir, + } + + archives := []BinlogArchiveInfo{ + { + OriginalFile: "mysql-bin.000001", + ArchivePath: filepath.Join(tempDir, "mysql-bin.000001.gz"), + Size: 1024, + Compressed: true, + ArchivedAt: time.Now().Add(-2 * time.Hour), + StartPos: 4, + EndPos: 1024, + StartTime: time.Now().Add(-3 * time.Hour), + EndTime: time.Now().Add(-2 * time.Hour), + }, + { + OriginalFile: "mysql-bin.000002", + ArchivePath: filepath.Join(tempDir, "mysql-bin.000002.gz"), + Size: 2048, + Compressed: true, + ArchivedAt: time.Now().Add(-1 * time.Hour), + StartPos: 4, + EndPos: 2048, + StartTime: time.Now().Add(-2 * time.Hour), + EndTime: time.Now().Add(-1 * time.Hour), + }, + } + + // Save metadata + err = bm.SaveArchiveMetadata(archives) + if err != nil { + t.Fatalf("saving metadata: %v", err) + } + + // Verify metadata file exists + metadataPath := filepath.Join(tempDir, "metadata.json") + if _, err := os.Stat(metadataPath); os.IsNotExist(err) { + t.Fatal("metadata file was not created") + } + + // Load and verify + loaded := bm.loadArchiveMetadata(metadataPath) + if len(loaded) != 2 { + t.Errorf("got %d archives, want 2", len(loaded)) + } + + if loaded["mysql-bin.000001"].Size != 1024 { + t.Errorf("wrong size for first archive") + } + + if loaded["mysql-bin.000002"].Size != 2048 { + t.Errorf("wrong size for second archive") + } +} + +func TestLimitedScanner(t *testing.T) { + // Test the limited scanner used for reading dump headers + input := "line1\nline2\nline3\nline4\nline5\nline6\nline7\nline8\nline9\nline10\n" + reader := NewLimitedScanner(strings.NewReader(input), 5) + + var lines []string + for reader.Scan() { + lines = append(lines, reader.Text()) + } + + if len(lines) != 5 { + t.Errorf("got %d lines, want 5", len(lines)) + } +} + +// TestDatabaseType tests database type constants +func TestDatabaseType(t *testing.T) { + tests := []struct { + name string + dbType DatabaseType + expected string + }{ + {"PostgreSQL", DatabasePostgreSQL, "postgres"}, + {"MySQL", DatabaseMySQL, "mysql"}, + {"MariaDB", DatabaseMariaDB, "mariadb"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if string(tt.dbType) != tt.expected { + t.Errorf("got %q, want %q", tt.dbType, tt.expected) + } + }) + } +} + +// TestRestoreTargetType tests restore target type constants +func TestRestoreTargetType(t *testing.T) { + tests := []struct { + name string + target RestoreTargetType + expected string + }{ + {"Time", RestoreTargetTime, "time"}, + {"Position", RestoreTargetPosition, "position"}, + {"Immediate", RestoreTargetImmediate, "immediate"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if string(tt.target) != tt.expected { + t.Errorf("got %q, want %q", tt.target, tt.expected) + } + }) + } +} diff --git a/internal/pitr/interface.go b/internal/pitr/interface.go new file mode 100644 index 0000000..7acfb18 --- /dev/null +++ b/internal/pitr/interface.go @@ -0,0 +1,155 @@ +// Package pitr provides Point-in-Time Recovery functionality +// This file contains shared interfaces and types for multi-database PITR support +package pitr + +import ( + "context" + "time" +) + +// DatabaseType represents the type of database for PITR +type DatabaseType string + +const ( + DatabasePostgreSQL DatabaseType = "postgres" + DatabaseMySQL DatabaseType = "mysql" + DatabaseMariaDB DatabaseType = "mariadb" +) + +// PITRProvider is the interface for database-specific PITR implementations +type PITRProvider interface { + // DatabaseType returns the database type this provider handles + DatabaseType() DatabaseType + + // Enable enables PITR for the database + Enable(ctx context.Context, config PITREnableConfig) error + + // Disable disables PITR for the database + Disable(ctx context.Context) error + + // Status returns the current PITR status + Status(ctx context.Context) (*PITRStatus, error) + + // CreateBackup creates a PITR-capable backup with position recording + CreateBackup(ctx context.Context, opts BackupOptions) (*PITRBackupInfo, error) + + // Restore performs a point-in-time restore + Restore(ctx context.Context, backup *PITRBackupInfo, target RestoreTarget) error + + // ListRecoveryPoints lists available recovery points/ranges + ListRecoveryPoints(ctx context.Context) ([]RecoveryWindow, error) + + // ValidateChain validates the log chain integrity + ValidateChain(ctx context.Context, from, to time.Time) (*ChainValidation, error) +} + +// PITREnableConfig holds configuration for enabling PITR +type PITREnableConfig struct { + ArchiveDir string // Directory to store archived logs + RetentionDays int // Days to keep archives + ArchiveInterval time.Duration // How often to check for new logs (MySQL) + Compression bool // Compress archived logs + Encryption bool // Encrypt archived logs + EncryptionKey []byte // Encryption key +} + +// PITRStatus represents the current PITR configuration status +type PITRStatus struct { + Enabled bool + DatabaseType DatabaseType + ArchiveDir string + LogLevel string // WAL level (postgres) or binlog format (mysql) + ArchiveMethod string // archive_command (postgres) or manual (mysql) + Position LogPosition + LastArchived time.Time + ArchiveCount int + ArchiveSize int64 +} + +// LogPosition is a generic interface for database-specific log positions +type LogPosition interface { + // String returns a string representation of the position + String() string + // IsZero returns true if the position is unset + IsZero() bool + // Compare returns -1 if p < other, 0 if equal, 1 if p > other + Compare(other LogPosition) int +} + +// BackupOptions holds options for creating a PITR backup +type BackupOptions struct { + Database string // Database name (empty for all) + OutputPath string // Where to save the backup + Compression bool + CompressionLvl int + Encryption bool + EncryptionKey []byte + FlushLogs bool // Flush logs before backup (mysql) + SingleTxn bool // Single transaction mode +} + +// PITRBackupInfo contains metadata about a PITR-capable backup +type PITRBackupInfo struct { + BackupFile string `json:"backup_file"` + DatabaseType DatabaseType `json:"database_type"` + DatabaseName string `json:"database_name,omitempty"` + Timestamp time.Time `json:"timestamp"` + ServerVersion string `json:"server_version"` + ServerID int `json:"server_id,omitempty"` // MySQL server_id + Position LogPosition `json:"-"` // Start position (type-specific) + PositionJSON string `json:"position"` // Serialized position + SizeBytes int64 `json:"size_bytes"` + Compressed bool `json:"compressed"` + Encrypted bool `json:"encrypted"` +} + +// RestoreTarget specifies the point-in-time to restore to +type RestoreTarget struct { + Type RestoreTargetType + Time *time.Time // For RestoreTargetTime + Position LogPosition // For RestoreTargetPosition (LSN, binlog pos, GTID) + Inclusive bool // Include target transaction + DryRun bool // Only show what would be done + StopOnErr bool // Stop replay on first error +} + +// RestoreTargetType defines the type of restore target +type RestoreTargetType string + +const ( + RestoreTargetTime RestoreTargetType = "time" + RestoreTargetPosition RestoreTargetType = "position" + RestoreTargetImmediate RestoreTargetType = "immediate" +) + +// RecoveryWindow represents a time range available for recovery +type RecoveryWindow struct { + BaseBackup string `json:"base_backup"` + BackupTime time.Time `json:"backup_time"` + StartPosition LogPosition `json:"-"` + EndPosition LogPosition `json:"-"` + StartTime time.Time `json:"start_time"` + EndTime time.Time `json:"end_time"` + LogFiles []string `json:"log_files"` // WAL segments or binlog files + HasGaps bool `json:"has_gaps"` + GapDetails []string `json:"gap_details,omitempty"` +} + +// ChainValidation contains results of log chain validation +type ChainValidation struct { + Valid bool + StartPos LogPosition + EndPos LogPosition + LogCount int + TotalSize int64 + Gaps []LogGap + Errors []string + Warnings []string +} + +// LogGap represents a gap in the log chain +type LogGap struct { + After string // Log file/position after which gap occurs + Before string // Log file/position where chain resumes + Reason string // Reason for gap if known +} diff --git a/internal/pitr/mysql.go b/internal/pitr/mysql.go new file mode 100644 index 0000000..dfc35bb --- /dev/null +++ b/internal/pitr/mysql.go @@ -0,0 +1,924 @@ +// Package pitr provides Point-in-Time Recovery functionality +// This file contains the MySQL/MariaDB PITR provider implementation +package pitr + +import ( + "bufio" + "compress/gzip" + "context" + "database/sql" + "encoding/json" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "regexp" + "strconv" + "strings" + "time" +) + +// MySQLPITR implements PITRProvider for MySQL and MariaDB +type MySQLPITR struct { + db *sql.DB + config MySQLPITRConfig + binlogManager *BinlogManager + serverType DatabaseType + serverVersion string + serverID uint32 + gtidMode bool +} + +// MySQLPITRConfig holds configuration for MySQL PITR +type MySQLPITRConfig struct { + // Connection settings + Host string `json:"host"` + Port int `json:"port"` + User string `json:"user"` + Password string `json:"password,omitempty"` + Socket string `json:"socket,omitempty"` + + // Paths + DataDir string `json:"data_dir"` + BinlogDir string `json:"binlog_dir"` + ArchiveDir string `json:"archive_dir"` + RestoreDir string `json:"restore_dir"` + + // Archive settings + ArchiveInterval time.Duration `json:"archive_interval"` + RetentionDays int `json:"retention_days"` + Compression bool `json:"compression"` + CompressionLevel int `json:"compression_level"` + Encryption bool `json:"encryption"` + EncryptionKey []byte `json:"-"` + + // Behavior settings + RequireRowFormat bool `json:"require_row_format"` + RequireGTID bool `json:"require_gtid"` + FlushLogsOnBackup bool `json:"flush_logs_on_backup"` + LockTables bool `json:"lock_tables"` + SingleTransaction bool `json:"single_transaction"` +} + +// NewMySQLPITR creates a new MySQL PITR provider +func NewMySQLPITR(db *sql.DB, config MySQLPITRConfig) (*MySQLPITR, error) { + m := &MySQLPITR{ + db: db, + config: config, + } + + // Detect server type and version + if err := m.detectServerInfo(); err != nil { + return nil, fmt.Errorf("detecting server info: %w", err) + } + + // Initialize binlog manager + binlogConfig := BinlogManagerConfig{ + BinlogDir: config.BinlogDir, + ArchiveDir: config.ArchiveDir, + Compression: config.Compression, + Encryption: config.Encryption, + EncryptionKey: config.EncryptionKey, + } + var err error + m.binlogManager, err = NewBinlogManager(binlogConfig) + if err != nil { + return nil, fmt.Errorf("creating binlog manager: %w", err) + } + + return m, nil +} + +// detectServerInfo detects MySQL/MariaDB version and configuration +func (m *MySQLPITR) detectServerInfo() error { + // Get version + var version string + err := m.db.QueryRow("SELECT VERSION()").Scan(&version) + if err != nil { + return fmt.Errorf("getting version: %w", err) + } + m.serverVersion = version + + // Detect MariaDB vs MySQL + if strings.Contains(strings.ToLower(version), "mariadb") { + m.serverType = DatabaseMariaDB + } else { + m.serverType = DatabaseMySQL + } + + // Get server_id + var serverID int + err = m.db.QueryRow("SELECT @@server_id").Scan(&serverID) + if err == nil { + m.serverID = uint32(serverID) + } + + // Check GTID mode + if m.serverType == DatabaseMySQL { + var gtidMode string + err = m.db.QueryRow("SELECT @@gtid_mode").Scan(>idMode) + if err == nil { + m.gtidMode = strings.ToUpper(gtidMode) == "ON" + } + } else { + // MariaDB uses different variables + var gtidPos string + err = m.db.QueryRow("SELECT @@gtid_current_pos").Scan(>idPos) + m.gtidMode = err == nil && gtidPos != "" + } + + return nil +} + +// DatabaseType returns the database type this provider handles +func (m *MySQLPITR) DatabaseType() DatabaseType { + return m.serverType +} + +// Enable enables PITR for the MySQL database +func (m *MySQLPITR) Enable(ctx context.Context, config PITREnableConfig) error { + // Check current binlog settings + status, err := m.Status(ctx) + if err != nil { + return fmt.Errorf("checking status: %w", err) + } + + var issues []string + + // Check if binlog is enabled + var logBin string + if err := m.db.QueryRowContext(ctx, "SELECT @@log_bin").Scan(&logBin); err != nil { + return fmt.Errorf("checking log_bin: %w", err) + } + if logBin != "1" && strings.ToUpper(logBin) != "ON" { + issues = append(issues, "binary logging is not enabled (log_bin=OFF)") + issues = append(issues, " Add to my.cnf: log_bin = mysql-bin") + } + + // Check binlog format + if m.config.RequireRowFormat && status.LogLevel != "ROW" { + issues = append(issues, fmt.Sprintf("binlog_format is %s, not ROW", status.LogLevel)) + issues = append(issues, " Add to my.cnf: binlog_format = ROW") + } + + // Check GTID mode if required + if m.config.RequireGTID && !m.gtidMode { + issues = append(issues, "GTID mode is not enabled") + if m.serverType == DatabaseMySQL { + issues = append(issues, " Add to my.cnf: gtid_mode = ON, enforce_gtid_consistency = ON") + } else { + issues = append(issues, " MariaDB: GTIDs are automatically managed with log_slave_updates") + } + } + + // Check expire_logs_days (don't want logs expiring before we archive them) + var expireDays int + m.db.QueryRowContext(ctx, "SELECT @@expire_logs_days").Scan(&expireDays) + if expireDays > 0 && expireDays < config.RetentionDays { + issues = append(issues, + fmt.Sprintf("expire_logs_days (%d) is less than retention days (%d)", + expireDays, config.RetentionDays)) + } + + if len(issues) > 0 { + return fmt.Errorf("PITR requirements not met:\n - %s", strings.Join(issues, "\n - ")) + } + + // Update archive configuration + m.config.ArchiveDir = config.ArchiveDir + m.config.RetentionDays = config.RetentionDays + m.config.ArchiveInterval = config.ArchiveInterval + m.config.Compression = config.Compression + m.config.Encryption = config.Encryption + m.config.EncryptionKey = config.EncryptionKey + + // Create archive directory + if err := os.MkdirAll(config.ArchiveDir, 0750); err != nil { + return fmt.Errorf("creating archive directory: %w", err) + } + + // Save configuration + configPath := filepath.Join(config.ArchiveDir, "pitr_config.json") + configData, _ := json.MarshalIndent(map[string]interface{}{ + "enabled": true, + "server_type": m.serverType, + "server_version": m.serverVersion, + "server_id": m.serverID, + "gtid_mode": m.gtidMode, + "archive_dir": config.ArchiveDir, + "retention_days": config.RetentionDays, + "archive_interval": config.ArchiveInterval.String(), + "compression": config.Compression, + "encryption": config.Encryption, + "created_at": time.Now().Format(time.RFC3339), + }, "", " ") + if err := os.WriteFile(configPath, configData, 0640); err != nil { + return fmt.Errorf("saving config: %w", err) + } + + return nil +} + +// Disable disables PITR for the MySQL database +func (m *MySQLPITR) Disable(ctx context.Context) error { + configPath := filepath.Join(m.config.ArchiveDir, "pitr_config.json") + + // Check if config exists + if _, err := os.Stat(configPath); os.IsNotExist(err) { + return fmt.Errorf("PITR is not enabled (no config file found)") + } + + // Update config to disabled + configData, _ := json.MarshalIndent(map[string]interface{}{ + "enabled": false, + "disabled_at": time.Now().Format(time.RFC3339), + }, "", " ") + + if err := os.WriteFile(configPath, configData, 0640); err != nil { + return fmt.Errorf("updating config: %w", err) + } + + return nil +} + +// Status returns the current PITR status +func (m *MySQLPITR) Status(ctx context.Context) (*PITRStatus, error) { + status := &PITRStatus{ + DatabaseType: m.serverType, + ArchiveDir: m.config.ArchiveDir, + } + + // Check if PITR is enabled via config file + configPath := filepath.Join(m.config.ArchiveDir, "pitr_config.json") + if data, err := os.ReadFile(configPath); err == nil { + var config map[string]interface{} + if json.Unmarshal(data, &config) == nil { + if enabled, ok := config["enabled"].(bool); ok { + status.Enabled = enabled + } + } + } + + // Get binlog format + var binlogFormat string + if err := m.db.QueryRowContext(ctx, "SELECT @@binlog_format").Scan(&binlogFormat); err == nil { + status.LogLevel = binlogFormat + } + + // Get current position + pos, err := m.GetCurrentPosition(ctx) + if err == nil { + status.Position = pos + } + + // Get archive stats + if m.config.ArchiveDir != "" { + archives, err := m.binlogManager.ListArchivedBinlogs(ctx) + if err == nil { + status.ArchiveCount = len(archives) + for _, a := range archives { + status.ArchiveSize += a.Size + if a.ArchivedAt.After(status.LastArchived) { + status.LastArchived = a.ArchivedAt + } + } + } + } + + status.ArchiveMethod = "manual" // MySQL doesn't have automatic archiving like PostgreSQL + + return status, nil +} + +// GetCurrentPosition retrieves the current binary log position +func (m *MySQLPITR) GetCurrentPosition(ctx context.Context) (*BinlogPosition, error) { + pos := &BinlogPosition{} + + // Use SHOW MASTER STATUS for current position + rows, err := m.db.QueryContext(ctx, "SHOW MASTER STATUS") + if err != nil { + return nil, fmt.Errorf("getting master status: %w", err) + } + defer rows.Close() + + if rows.Next() { + var file string + var position uint64 + var binlogDoDB, binlogIgnoreDB, executedGtidSet sql.NullString + + cols, _ := rows.Columns() + switch len(cols) { + case 5: // MySQL 5.6+ + err = rows.Scan(&file, &position, &binlogDoDB, &binlogIgnoreDB, &executedGtidSet) + case 4: // Older versions + err = rows.Scan(&file, &position, &binlogDoDB, &binlogIgnoreDB) + default: + err = rows.Scan(&file, &position) + } + + if err != nil { + return nil, fmt.Errorf("scanning master status: %w", err) + } + + pos.File = file + pos.Position = position + pos.ServerID = m.serverID + + if executedGtidSet.Valid { + pos.GTID = executedGtidSet.String + } + } else { + return nil, fmt.Errorf("no master status available (is binary logging enabled?)") + } + + // For MariaDB, get GTID position differently + if m.serverType == DatabaseMariaDB && pos.GTID == "" { + var gtidPos string + if err := m.db.QueryRowContext(ctx, "SELECT @@gtid_current_pos").Scan(>idPos); err == nil { + pos.GTID = gtidPos + } + } + + return pos, nil +} + +// CreateBackup creates a PITR-capable backup with position recording +func (m *MySQLPITR) CreateBackup(ctx context.Context, opts BackupOptions) (*PITRBackupInfo, error) { + // Get position BEFORE flushing logs + startPos, err := m.GetCurrentPosition(ctx) + if err != nil { + return nil, fmt.Errorf("getting start position: %w", err) + } + + // Optionally flush logs to start a new binlog file + if opts.FlushLogs || m.config.FlushLogsOnBackup { + if _, err := m.db.ExecContext(ctx, "FLUSH BINARY LOGS"); err != nil { + return nil, fmt.Errorf("flushing binary logs: %w", err) + } + // Get new position after flush + startPos, err = m.GetCurrentPosition(ctx) + if err != nil { + return nil, fmt.Errorf("getting position after flush: %w", err) + } + } + + // Build mysqldump command + dumpArgs := []string{ + "--single-transaction", + "--routines", + "--triggers", + "--events", + "--master-data=2", // Include binlog position as comment + } + + if m.config.FlushLogsOnBackup { + dumpArgs = append(dumpArgs, "--flush-logs") + } + + // Add connection params + if m.config.Host != "" { + dumpArgs = append(dumpArgs, "-h", m.config.Host) + } + if m.config.Port > 0 { + dumpArgs = append(dumpArgs, "-P", strconv.Itoa(m.config.Port)) + } + if m.config.User != "" { + dumpArgs = append(dumpArgs, "-u", m.config.User) + } + if m.config.Password != "" { + dumpArgs = append(dumpArgs, "-p"+m.config.Password) + } + if m.config.Socket != "" { + dumpArgs = append(dumpArgs, "-S", m.config.Socket) + } + + // Add database selection + if opts.Database != "" { + dumpArgs = append(dumpArgs, opts.Database) + } else { + dumpArgs = append(dumpArgs, "--all-databases") + } + + // Create output file + timestamp := time.Now().Format("20060102_150405") + backupName := fmt.Sprintf("mysql_pitr_%s.sql", timestamp) + if opts.Compression { + backupName += ".gz" + } + backupPath := filepath.Join(opts.OutputPath, backupName) + + if err := os.MkdirAll(opts.OutputPath, 0750); err != nil { + return nil, fmt.Errorf("creating output directory: %w", err) + } + + // Run mysqldump + cmd := exec.CommandContext(ctx, "mysqldump", dumpArgs...) + + // Create output file + outFile, err := os.Create(backupPath) + if err != nil { + return nil, fmt.Errorf("creating backup file: %w", err) + } + defer outFile.Close() + + var writer io.WriteCloser = outFile + + if opts.Compression { + gzWriter := NewGzipWriter(outFile, opts.CompressionLvl) + writer = gzWriter + defer gzWriter.Close() + } + + cmd.Stdout = writer + cmd.Stderr = os.Stderr + + if err := cmd.Run(); err != nil { + os.Remove(backupPath) + return nil, fmt.Errorf("mysqldump failed: %w", err) + } + + // Close writers + if opts.Compression { + writer.Close() + } + + // Get file size + info, err := os.Stat(backupPath) + if err != nil { + return nil, fmt.Errorf("getting backup info: %w", err) + } + + // Serialize position for JSON storage + posJSON, _ := json.Marshal(startPos) + + backupInfo := &PITRBackupInfo{ + BackupFile: backupPath, + DatabaseType: m.serverType, + DatabaseName: opts.Database, + Timestamp: time.Now(), + ServerVersion: m.serverVersion, + ServerID: int(m.serverID), + Position: startPos, + PositionJSON: string(posJSON), + SizeBytes: info.Size(), + Compressed: opts.Compression, + Encrypted: opts.Encryption, + } + + // Save metadata alongside backup + metadataPath := backupPath + ".meta" + metaData, _ := json.MarshalIndent(backupInfo, "", " ") + os.WriteFile(metadataPath, metaData, 0640) + + return backupInfo, nil +} + +// Restore performs a point-in-time restore +func (m *MySQLPITR) Restore(ctx context.Context, backup *PITRBackupInfo, target RestoreTarget) error { + // Step 1: Restore base backup + if err := m.restoreBaseBackup(ctx, backup); err != nil { + return fmt.Errorf("restoring base backup: %w", err) + } + + // Step 2: If target time is after backup time, replay binlogs + if target.Type == RestoreTargetImmediate { + return nil // Just restore to backup point + } + + // Parse start position from backup + var startPos BinlogPosition + if err := json.Unmarshal([]byte(backup.PositionJSON), &startPos); err != nil { + return fmt.Errorf("parsing backup position: %w", err) + } + + // Step 3: Find binlogs to replay + binlogs, err := m.binlogManager.DiscoverBinlogs(ctx) + if err != nil { + return fmt.Errorf("discovering binlogs: %w", err) + } + + // Find archived binlogs too + archivedBinlogs, _ := m.binlogManager.ListArchivedBinlogs(ctx) + + var filesToReplay []string + + // Determine which binlogs to replay based on target + switch target.Type { + case RestoreTargetTime: + if target.Time == nil { + return fmt.Errorf("target time not specified") + } + // Find binlogs in range + relevantBinlogs := m.binlogManager.FindBinlogsInRange(ctx, binlogs, backup.Timestamp, *target.Time) + for _, b := range relevantBinlogs { + filesToReplay = append(filesToReplay, b.Path) + } + // Also check archives + for _, a := range archivedBinlogs { + if compareBinlogFiles(a.OriginalFile, startPos.File) >= 0 { + if !a.EndTime.IsZero() && !a.EndTime.Before(backup.Timestamp) && !a.StartTime.After(*target.Time) { + filesToReplay = append(filesToReplay, a.ArchivePath) + } + } + } + + case RestoreTargetPosition: + if target.Position == nil { + return fmt.Errorf("target position not specified") + } + targetPos, ok := target.Position.(*BinlogPosition) + if !ok { + return fmt.Errorf("invalid target position type") + } + // Find binlogs from start to target position + for _, b := range binlogs { + if compareBinlogFiles(b.Name, startPos.File) >= 0 && + compareBinlogFiles(b.Name, targetPos.File) <= 0 { + filesToReplay = append(filesToReplay, b.Path) + } + } + } + + if len(filesToReplay) == 0 { + // Nothing to replay, backup is already at or past target + return nil + } + + // Step 4: Replay binlogs + replayOpts := ReplayOptions{ + BinlogFiles: filesToReplay, + StartPosition: &startPos, + DryRun: target.DryRun, + MySQLHost: m.config.Host, + MySQLPort: m.config.Port, + MySQLUser: m.config.User, + MySQLPass: m.config.Password, + StopOnError: target.StopOnErr, + } + + if target.Type == RestoreTargetTime && target.Time != nil { + replayOpts.StopTime = target.Time + } + if target.Type == RestoreTargetPosition && target.Position != nil { + replayOpts.StopPosition = target.Position + } + + if target.DryRun { + replayOpts.Output = os.Stdout + } + + return m.binlogManager.ReplayBinlogs(ctx, replayOpts) +} + +// restoreBaseBackup restores the base MySQL backup +func (m *MySQLPITR) restoreBaseBackup(ctx context.Context, backup *PITRBackupInfo) error { + // Build mysql command + mysqlArgs := []string{} + + if m.config.Host != "" { + mysqlArgs = append(mysqlArgs, "-h", m.config.Host) + } + if m.config.Port > 0 { + mysqlArgs = append(mysqlArgs, "-P", strconv.Itoa(m.config.Port)) + } + if m.config.User != "" { + mysqlArgs = append(mysqlArgs, "-u", m.config.User) + } + if m.config.Password != "" { + mysqlArgs = append(mysqlArgs, "-p"+m.config.Password) + } + if m.config.Socket != "" { + mysqlArgs = append(mysqlArgs, "-S", m.config.Socket) + } + + // Prepare input + var input io.Reader + backupFile, err := os.Open(backup.BackupFile) + if err != nil { + return fmt.Errorf("opening backup file: %w", err) + } + defer backupFile.Close() + + input = backupFile + + // Handle compressed backups + if backup.Compressed || strings.HasSuffix(backup.BackupFile, ".gz") { + gzReader, err := NewGzipReader(backupFile) + if err != nil { + return fmt.Errorf("creating gzip reader: %w", err) + } + defer gzReader.Close() + input = gzReader + } + + // Run mysql + cmd := exec.CommandContext(ctx, "mysql", mysqlArgs...) + cmd.Stdin = input + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + return cmd.Run() +} + +// ListRecoveryPoints lists available recovery points/ranges +func (m *MySQLPITR) ListRecoveryPoints(ctx context.Context) ([]RecoveryWindow, error) { + var windows []RecoveryWindow + + // Find all backup metadata files + backupPattern := filepath.Join(m.config.ArchiveDir, "..", "*", "*.meta") + metaFiles, _ := filepath.Glob(backupPattern) + + // Also check default backup locations + additionalPaths := []string{ + filepath.Join(m.config.ArchiveDir, "*.meta"), + filepath.Join(m.config.RestoreDir, "*.meta"), + } + for _, p := range additionalPaths { + matches, _ := filepath.Glob(p) + metaFiles = append(metaFiles, matches...) + } + + // Get current binlogs + binlogs, err := m.binlogManager.DiscoverBinlogs(ctx) + if err != nil { + binlogs = []BinlogFile{} + } + + // Get archived binlogs + archivedBinlogs, _ := m.binlogManager.ListArchivedBinlogs(ctx) + + for _, metaFile := range metaFiles { + data, err := os.ReadFile(metaFile) + if err != nil { + continue + } + + var backup PITRBackupInfo + if err := json.Unmarshal(data, &backup); err != nil { + continue + } + + // Parse position + var startPos BinlogPosition + json.Unmarshal([]byte(backup.PositionJSON), &startPos) + + window := RecoveryWindow{ + BaseBackup: backup.BackupFile, + BackupTime: backup.Timestamp, + StartTime: backup.Timestamp, + StartPosition: &startPos, + } + + // Find binlogs available after this backup + var relevantBinlogs []string + var latestTime time.Time + var latestPos *BinlogPosition + + for _, b := range binlogs { + if compareBinlogFiles(b.Name, startPos.File) >= 0 { + relevantBinlogs = append(relevantBinlogs, b.Name) + if !b.EndTime.IsZero() && b.EndTime.After(latestTime) { + latestTime = b.EndTime + latestPos = &BinlogPosition{ + File: b.Name, + Position: b.EndPos, + GTID: b.GTID, + } + } + } + } + + for _, a := range archivedBinlogs { + if compareBinlogFiles(a.OriginalFile, startPos.File) >= 0 { + relevantBinlogs = append(relevantBinlogs, a.OriginalFile) + if !a.EndTime.IsZero() && a.EndTime.After(latestTime) { + latestTime = a.EndTime + latestPos = &BinlogPosition{ + File: a.OriginalFile, + Position: a.EndPos, + GTID: a.GTID, + } + } + } + } + + window.LogFiles = relevantBinlogs + if !latestTime.IsZero() { + window.EndTime = latestTime + } else { + window.EndTime = time.Now() + } + window.EndPosition = latestPos + + // Check for gaps + validation, _ := m.binlogManager.ValidateBinlogChain(ctx, binlogs) + if validation != nil { + window.HasGaps = !validation.Valid + for _, gap := range validation.Gaps { + window.GapDetails = append(window.GapDetails, gap.Reason) + } + } + + windows = append(windows, window) + } + + return windows, nil +} + +// ValidateChain validates the log chain integrity +func (m *MySQLPITR) ValidateChain(ctx context.Context, from, to time.Time) (*ChainValidation, error) { + // Discover all binlogs + binlogs, err := m.binlogManager.DiscoverBinlogs(ctx) + if err != nil { + return nil, fmt.Errorf("discovering binlogs: %w", err) + } + + // Filter to time range + relevant := m.binlogManager.FindBinlogsInRange(ctx, binlogs, from, to) + + // Validate chain + return m.binlogManager.ValidateBinlogChain(ctx, relevant) +} + +// ArchiveNewBinlogs archives any binlog files that haven't been archived yet +func (m *MySQLPITR) ArchiveNewBinlogs(ctx context.Context) ([]BinlogArchiveInfo, error) { + // Get current binlogs + binlogs, err := m.binlogManager.DiscoverBinlogs(ctx) + if err != nil { + return nil, fmt.Errorf("discovering binlogs: %w", err) + } + + // Get already archived + archived, _ := m.binlogManager.ListArchivedBinlogs(ctx) + archivedSet := make(map[string]struct{}) + for _, a := range archived { + archivedSet[a.OriginalFile] = struct{}{} + } + + // Get current binlog file (don't archive the active one) + currentPos, _ := m.GetCurrentPosition(ctx) + currentFile := "" + if currentPos != nil { + currentFile = currentPos.File + } + + var newArchives []BinlogArchiveInfo + for i := range binlogs { + b := &binlogs[i] + + // Skip if already archived + if _, exists := archivedSet[b.Name]; exists { + continue + } + + // Skip the current active binlog + if b.Name == currentFile { + continue + } + + // Archive + archiveInfo, err := m.binlogManager.ArchiveBinlog(ctx, b) + if err != nil { + // Log but continue + continue + } + newArchives = append(newArchives, *archiveInfo) + } + + // Update metadata + if len(newArchives) > 0 { + allArchived, _ := m.binlogManager.ListArchivedBinlogs(ctx) + m.binlogManager.SaveArchiveMetadata(allArchived) + } + + return newArchives, nil +} + +// PurgeBinlogs purges old binlog files based on retention policy +func (m *MySQLPITR) PurgeBinlogs(ctx context.Context) error { + if m.config.RetentionDays <= 0 { + return fmt.Errorf("retention days not configured") + } + + cutoff := time.Now().AddDate(0, 0, -m.config.RetentionDays) + + // Get archived binlogs + archived, err := m.binlogManager.ListArchivedBinlogs(ctx) + if err != nil { + return fmt.Errorf("listing archived binlogs: %w", err) + } + + for _, a := range archived { + if a.ArchivedAt.Before(cutoff) { + os.Remove(a.ArchivePath) + } + } + + return nil +} + +// GzipWriter is a helper for gzip compression +type GzipWriter struct { + w *gzip.Writer +} + +func NewGzipWriter(w io.Writer, level int) *GzipWriter { + if level <= 0 { + level = gzip.DefaultCompression + } + gw, _ := gzip.NewWriterLevel(w, level) + return &GzipWriter{w: gw} +} + +func (g *GzipWriter) Write(p []byte) (int, error) { + return g.w.Write(p) +} + +func (g *GzipWriter) Close() error { + return g.w.Close() +} + +// GzipReader is a helper for gzip decompression +type GzipReader struct { + r *gzip.Reader +} + +func NewGzipReader(r io.Reader) (*GzipReader, error) { + gr, err := gzip.NewReader(r) + if err != nil { + return nil, err + } + return &GzipReader{r: gr}, nil +} + +func (g *GzipReader) Read(p []byte) (int, error) { + return g.r.Read(p) +} + +func (g *GzipReader) Close() error { + return g.r.Close() +} + +// ExtractBinlogPositionFromDump extracts the binlog position from a mysqldump file +func ExtractBinlogPositionFromDump(dumpPath string) (*BinlogPosition, error) { + file, err := os.Open(dumpPath) + if err != nil { + return nil, fmt.Errorf("opening dump file: %w", err) + } + defer file.Close() + + var reader io.Reader = file + if strings.HasSuffix(dumpPath, ".gz") { + gzReader, err := gzip.NewReader(file) + if err != nil { + return nil, fmt.Errorf("creating gzip reader: %w", err) + } + defer gzReader.Close() + reader = gzReader + } + + // Look for CHANGE MASTER TO or -- CHANGE MASTER TO comment + // Pattern: -- CHANGE MASTER TO MASTER_LOG_FILE='mysql-bin.000042', MASTER_LOG_POS=1234; + scanner := NewLimitedScanner(reader, 1000) // Only scan first 1000 lines + posPattern := regexp.MustCompile(`MASTER_LOG_FILE='([^']+)',\s*MASTER_LOG_POS=(\d+)`) + + for scanner.Scan() { + line := scanner.Text() + if matches := posPattern.FindStringSubmatch(line); len(matches) == 3 { + pos, _ := strconv.ParseUint(matches[2], 10, 64) + return &BinlogPosition{ + File: matches[1], + Position: pos, + }, nil + } + } + + return nil, fmt.Errorf("binlog position not found in dump file") +} + +// LimitedScanner wraps bufio.Scanner with a line limit +type LimitedScanner struct { + scanner *bufio.Scanner + limit int + count int +} + +func NewLimitedScanner(r io.Reader, limit int) *LimitedScanner { + return &LimitedScanner{ + scanner: bufio.NewScanner(r), + limit: limit, + } +} + +func (s *LimitedScanner) Scan() bool { + if s.count >= s.limit { + return false + } + s.count++ + return s.scanner.Scan() +} + +func (s *LimitedScanner) Text() string { + return s.scanner.Text() +} diff --git a/internal/replica/selector.go b/internal/replica/selector.go new file mode 100644 index 0000000..764d236 --- /dev/null +++ b/internal/replica/selector.go @@ -0,0 +1,499 @@ +// Package replica provides replica-aware backup functionality +package replica + +import ( + "context" + "database/sql" + "fmt" + "sort" + "time" +) + +// Role represents the replication role of a database +type Role string + +const ( + RolePrimary Role = "primary" + RoleReplica Role = "replica" + RoleStandalone Role = "standalone" + RoleUnknown Role = "unknown" +) + +// Status represents the health status of a replica +type Status string + +const ( + StatusHealthy Status = "healthy" + StatusLagging Status = "lagging" + StatusDisconnected Status = "disconnected" + StatusUnknown Status = "unknown" +) + +// Node represents a database node in a replication topology +type Node struct { + Host string `json:"host"` + Port int `json:"port"` + Role Role `json:"role"` + Status Status `json:"status"` + ReplicationLag time.Duration `json:"replication_lag"` + IsAvailable bool `json:"is_available"` + LastChecked time.Time `json:"last_checked"` + Priority int `json:"priority"` // Lower = higher priority + Weight int `json:"weight"` // For load balancing + Metadata map[string]string `json:"metadata,omitempty"` +} + +// Topology represents the replication topology +type Topology struct { + Primary *Node `json:"primary,omitempty"` + Replicas []*Node `json:"replicas"` + Timestamp time.Time `json:"timestamp"` +} + +// Config configures replica-aware backup behavior +type Config struct { + PreferReplica bool `json:"prefer_replica"` + MaxReplicationLag time.Duration `json:"max_replication_lag"` + FallbackToPrimary bool `json:"fallback_to_primary"` + RequireHealthy bool `json:"require_healthy"` + SelectionStrategy Strategy `json:"selection_strategy"` + Nodes []NodeConfig `json:"nodes"` +} + +// NodeConfig configures a known node +type NodeConfig struct { + Host string `json:"host"` + Port int `json:"port"` + Priority int `json:"priority"` + Weight int `json:"weight"` +} + +// Strategy for selecting a node +type Strategy string + +const ( + StrategyPreferReplica Strategy = "prefer_replica" // Always prefer replica + StrategyLowestLag Strategy = "lowest_lag" // Choose node with lowest lag + StrategyRoundRobin Strategy = "round_robin" // Rotate between replicas + StrategyPriority Strategy = "priority" // Use configured priorities + StrategyWeighted Strategy = "weighted" // Weighted random selection +) + +// DefaultConfig returns default replica configuration +func DefaultConfig() Config { + return Config{ + PreferReplica: true, + MaxReplicationLag: 1 * time.Minute, + FallbackToPrimary: true, + RequireHealthy: true, + SelectionStrategy: StrategyLowestLag, + } +} + +// Selector selects the best node for backup +type Selector struct { + config Config + lastSelected int // For round-robin +} + +// NewSelector creates a new replica selector +func NewSelector(config Config) *Selector { + return &Selector{ + config: config, + } +} + +// SelectNode selects the best node for backup from the topology +func (s *Selector) SelectNode(topology *Topology) (*Node, error) { + var candidates []*Node + + // Collect available candidates + if s.config.PreferReplica { + // Prefer replicas + for _, r := range topology.Replicas { + if s.isAcceptable(r) { + candidates = append(candidates, r) + } + } + + // Fallback to primary if no replicas available + if len(candidates) == 0 && s.config.FallbackToPrimary { + if topology.Primary != nil && topology.Primary.IsAvailable { + return topology.Primary, nil + } + } + } else { + // Allow all nodes + if topology.Primary != nil && topology.Primary.IsAvailable { + candidates = append(candidates, topology.Primary) + } + for _, r := range topology.Replicas { + if s.isAcceptable(r) { + candidates = append(candidates, r) + } + } + } + + if len(candidates) == 0 { + return nil, fmt.Errorf("no acceptable nodes available for backup") + } + + // Apply selection strategy + return s.applyStrategy(candidates) +} + +// isAcceptable checks if a node is acceptable for backup +func (s *Selector) isAcceptable(node *Node) bool { + if !node.IsAvailable { + return false + } + + if s.config.RequireHealthy && node.Status != StatusHealthy { + return false + } + + if s.config.MaxReplicationLag > 0 && node.ReplicationLag > s.config.MaxReplicationLag { + return false + } + + return true +} + +// applyStrategy selects a node using the configured strategy +func (s *Selector) applyStrategy(candidates []*Node) (*Node, error) { + switch s.config.SelectionStrategy { + case StrategyLowestLag: + return s.selectLowestLag(candidates), nil + + case StrategyPriority: + return s.selectByPriority(candidates), nil + + case StrategyRoundRobin: + return s.selectRoundRobin(candidates), nil + + default: + // Default to lowest lag + return s.selectLowestLag(candidates), nil + } +} + +func (s *Selector) selectLowestLag(candidates []*Node) *Node { + sort.Slice(candidates, func(i, j int) bool { + return candidates[i].ReplicationLag < candidates[j].ReplicationLag + }) + return candidates[0] +} + +func (s *Selector) selectByPriority(candidates []*Node) *Node { + sort.Slice(candidates, func(i, j int) bool { + return candidates[i].Priority < candidates[j].Priority + }) + return candidates[0] +} + +func (s *Selector) selectRoundRobin(candidates []*Node) *Node { + s.lastSelected = (s.lastSelected + 1) % len(candidates) + return candidates[s.lastSelected] +} + +// Detector detects replication topology +type Detector interface { + Detect(ctx context.Context, db *sql.DB) (*Topology, error) + GetRole(ctx context.Context, db *sql.DB) (Role, error) + GetReplicationLag(ctx context.Context, db *sql.DB) (time.Duration, error) +} + +// PostgreSQLDetector detects PostgreSQL replication topology +type PostgreSQLDetector struct{} + +// Detect discovers PostgreSQL replication topology +func (d *PostgreSQLDetector) Detect(ctx context.Context, db *sql.DB) (*Topology, error) { + topology := &Topology{ + Timestamp: time.Now(), + Replicas: make([]*Node, 0), + } + + // Check if we're on primary + var isRecovery bool + err := db.QueryRowContext(ctx, "SELECT pg_is_in_recovery()").Scan(&isRecovery) + if err != nil { + return nil, fmt.Errorf("failed to check recovery status: %w", err) + } + + if !isRecovery { + // We're on primary - get replicas from pg_stat_replication + rows, err := db.QueryContext(ctx, ` + SELECT + client_addr, + client_port, + state, + EXTRACT(EPOCH FROM (now() - replay_lag))::integer as lag_seconds + FROM pg_stat_replication + `) + if err != nil { + return nil, fmt.Errorf("failed to query replication status: %w", err) + } + defer rows.Close() + + for rows.Next() { + var addr sql.NullString + var port sql.NullInt64 + var state sql.NullString + var lagSeconds sql.NullInt64 + + if err := rows.Scan(&addr, &port, &state, &lagSeconds); err != nil { + continue + } + + node := &Node{ + Host: addr.String, + Port: int(port.Int64), + Role: RoleReplica, + IsAvailable: true, + LastChecked: time.Now(), + } + + if lagSeconds.Valid { + node.ReplicationLag = time.Duration(lagSeconds.Int64) * time.Second + } + + if state.String == "streaming" { + node.Status = StatusHealthy + } else { + node.Status = StatusLagging + } + + topology.Replicas = append(topology.Replicas, node) + } + } + + return topology, nil +} + +// GetRole returns the replication role +func (d *PostgreSQLDetector) GetRole(ctx context.Context, db *sql.DB) (Role, error) { + var isRecovery bool + err := db.QueryRowContext(ctx, "SELECT pg_is_in_recovery()").Scan(&isRecovery) + if err != nil { + return RoleUnknown, fmt.Errorf("failed to check recovery status: %w", err) + } + + if isRecovery { + return RoleReplica, nil + } + return RolePrimary, nil +} + +// GetReplicationLag returns the replication lag +func (d *PostgreSQLDetector) GetReplicationLag(ctx context.Context, db *sql.DB) (time.Duration, error) { + var lagSeconds sql.NullFloat64 + err := db.QueryRowContext(ctx, ` + SELECT EXTRACT(EPOCH FROM (now() - pg_last_xact_replay_timestamp())) + `).Scan(&lagSeconds) + if err != nil { + return 0, fmt.Errorf("failed to get replication lag: %w", err) + } + + if !lagSeconds.Valid { + return 0, nil // Not a replica or no lag data + } + + return time.Duration(lagSeconds.Float64) * time.Second, nil +} + +// MySQLDetector detects MySQL/MariaDB replication topology +type MySQLDetector struct{} + +// Detect discovers MySQL replication topology +func (d *MySQLDetector) Detect(ctx context.Context, db *sql.DB) (*Topology, error) { + topology := &Topology{ + Timestamp: time.Now(), + Replicas: make([]*Node, 0), + } + + // Check slave status first + rows, err := db.QueryContext(ctx, "SHOW SLAVE STATUS") + if err != nil { + // Not a slave, check if we're a master + rows, err = db.QueryContext(ctx, "SHOW SLAVE HOSTS") + if err != nil { + return topology, nil // Standalone or error + } + defer rows.Close() + + // Parse slave hosts + cols, _ := rows.Columns() + values := make([]interface{}, len(cols)) + valuePtrs := make([]interface{}, len(cols)) + for i := range values { + valuePtrs[i] = &values[i] + } + + for rows.Next() { + if err := rows.Scan(valuePtrs...); err != nil { + continue + } + + // Extract host and port + var host string + var port int + for i, col := range cols { + switch col { + case "Host": + if v, ok := values[i].([]byte); ok { + host = string(v) + } + case "Port": + if v, ok := values[i].(int64); ok { + port = int(v) + } + } + } + + if host != "" { + topology.Replicas = append(topology.Replicas, &Node{ + Host: host, + Port: port, + Role: RoleReplica, + IsAvailable: true, + Status: StatusUnknown, + LastChecked: time.Now(), + }) + } + } + + return topology, nil + } + defer rows.Close() + + return topology, nil +} + +// GetRole returns the MySQL replication role +func (d *MySQLDetector) GetRole(ctx context.Context, db *sql.DB) (Role, error) { + // Check if this is a slave + rows, err := db.QueryContext(ctx, "SHOW SLAVE STATUS") + if err != nil { + return RoleUnknown, err + } + defer rows.Close() + + if rows.Next() { + return RoleReplica, nil + } + + // Check if this is a master with slaves + rows2, err := db.QueryContext(ctx, "SHOW SLAVE HOSTS") + if err != nil { + return RoleStandalone, nil + } + defer rows2.Close() + + if rows2.Next() { + return RolePrimary, nil + } + + return RoleStandalone, nil +} + +// GetReplicationLag returns MySQL replication lag +func (d *MySQLDetector) GetReplicationLag(ctx context.Context, db *sql.DB) (time.Duration, error) { + var lagSeconds sql.NullInt64 + + rows, err := db.QueryContext(ctx, "SHOW SLAVE STATUS") + if err != nil { + return 0, err + } + defer rows.Close() + + if !rows.Next() { + return 0, nil // Not a replica + } + + cols, _ := rows.Columns() + values := make([]interface{}, len(cols)) + valuePtrs := make([]interface{}, len(cols)) + for i := range values { + valuePtrs[i] = &values[i] + } + + if err := rows.Scan(valuePtrs...); err != nil { + return 0, err + } + + // Find Seconds_Behind_Master column + for i, col := range cols { + if col == "Seconds_Behind_Master" { + switch v := values[i].(type) { + case int64: + lagSeconds.Int64 = v + lagSeconds.Valid = true + case []byte: + fmt.Sscanf(string(v), "%d", &lagSeconds.Int64) + lagSeconds.Valid = true + } + break + } + } + + if !lagSeconds.Valid { + return 0, nil + } + + return time.Duration(lagSeconds.Int64) * time.Second, nil +} + +// GetDetector returns the appropriate detector for a database type +func GetDetector(dbType string) Detector { + switch dbType { + case "postgresql", "postgres": + return &PostgreSQLDetector{} + case "mysql", "mariadb": + return &MySQLDetector{} + default: + return nil + } +} + +// Result contains the result of replica selection +type Result struct { + SelectedNode *Node `json:"selected_node"` + Topology *Topology `json:"topology"` + Reason string `json:"reason"` + Duration time.Duration `json:"detection_duration"` +} + +// SelectForBackup performs topology detection and node selection +func SelectForBackup(ctx context.Context, db *sql.DB, dbType string, config Config) (*Result, error) { + start := time.Now() + result := &Result{} + + detector := GetDetector(dbType) + if detector == nil { + return nil, fmt.Errorf("unsupported database type: %s", dbType) + } + + topology, err := detector.Detect(ctx, db) + if err != nil { + return nil, fmt.Errorf("topology detection failed: %w", err) + } + result.Topology = topology + + selector := NewSelector(config) + node, err := selector.SelectNode(topology) + if err != nil { + return nil, err + } + + result.SelectedNode = node + result.Duration = time.Since(start) + + if node.Role == RoleReplica { + result.Reason = fmt.Sprintf("Selected replica %s:%d with %s lag", + node.Host, node.Port, node.ReplicationLag) + } else { + result.Reason = fmt.Sprintf("Using primary %s:%d", node.Host, node.Port) + } + + return result, nil +} diff --git a/internal/report/frameworks.go b/internal/report/frameworks.go new file mode 100644 index 0000000..81c6817 --- /dev/null +++ b/internal/report/frameworks.go @@ -0,0 +1,424 @@ +// Package report - SOC2 framework controls +package report + +import ( + "time" +) + +// SOC2Framework returns SOC2 Trust Service Criteria controls +func SOC2Framework() []Category { + return []Category{ + soc2Security(), + soc2Availability(), + soc2ProcessingIntegrity(), + soc2Confidentiality(), + } +} + +func soc2Security() Category { + return Category{ + ID: "soc2-security", + Name: "Security", + Description: "Protection of system resources against unauthorized access", + Weight: 1.0, + Controls: []Control{ + { + ID: "CC6.1", + Reference: "SOC2 CC6.1", + Name: "Encryption at Rest", + Description: "Data is protected at rest using encryption", + }, + { + ID: "CC6.7", + Reference: "SOC2 CC6.7", + Name: "Encryption in Transit", + Description: "Data is protected in transit using encryption", + }, + { + ID: "CC6.2", + Reference: "SOC2 CC6.2", + Name: "Access Control", + Description: "Logical access to data and system components is restricted", + }, + { + ID: "CC6.3", + Reference: "SOC2 CC6.3", + Name: "Authorized Access", + Description: "Only authorized users can access data and systems", + }, + }, + } +} + +func soc2Availability() Category { + return Category{ + ID: "soc2-availability", + Name: "Availability", + Description: "System availability for operation and use as agreed", + Weight: 1.0, + Controls: []Control{ + { + ID: "A1.1", + Reference: "SOC2 A1.1", + Name: "Backup Policy", + Description: "Backup policies and procedures are established and operating", + }, + { + ID: "A1.2", + Reference: "SOC2 A1.2", + Name: "Backup Testing", + Description: "Backups are tested for recoverability", + }, + { + ID: "A1.3", + Reference: "SOC2 A1.3", + Name: "Recovery Procedures", + Description: "Recovery procedures are documented and tested", + }, + { + ID: "A1.4", + Reference: "SOC2 A1.4", + Name: "Disaster Recovery", + Description: "DR plans are maintained and tested", + }, + }, + } +} + +func soc2ProcessingIntegrity() Category { + return Category{ + ID: "soc2-processing-integrity", + Name: "Processing Integrity", + Description: "System processing is complete, valid, accurate, timely, and authorized", + Weight: 0.75, + Controls: []Control{ + { + ID: "PI1.1", + Reference: "SOC2 PI1.1", + Name: "Data Integrity", + Description: "Checksums and verification ensure data integrity", + }, + { + ID: "PI1.2", + Reference: "SOC2 PI1.2", + Name: "Error Handling", + Description: "Errors are identified and corrected in a timely manner", + }, + }, + } +} + +func soc2Confidentiality() Category { + return Category{ + ID: "soc2-confidentiality", + Name: "Confidentiality", + Description: "Information designated as confidential is protected", + Weight: 1.0, + Controls: []Control{ + { + ID: "C1.1", + Reference: "SOC2 C1.1", + Name: "Data Classification", + Description: "Confidential data is identified and classified", + }, + { + ID: "C1.2", + Reference: "SOC2 C1.2", + Name: "Data Retention", + Description: "Data retention policies are implemented", + }, + { + ID: "C1.3", + Reference: "SOC2 C1.3", + Name: "Data Disposal", + Description: "Data is securely disposed when no longer needed", + }, + }, + } +} + +// GDPRFramework returns GDPR-related controls +func GDPRFramework() []Category { + return []Category{ + { + ID: "gdpr-data-protection", + Name: "Data Protection", + Description: "Protection of personal data", + Weight: 1.0, + Controls: []Control{ + { + ID: "GDPR-25", + Reference: "GDPR Article 25", + Name: "Data Protection by Design", + Description: "Data protection measures are implemented by design", + }, + { + ID: "GDPR-32", + Reference: "GDPR Article 32", + Name: "Security of Processing", + Description: "Appropriate technical measures to ensure data security", + }, + { + ID: "GDPR-33", + Reference: "GDPR Article 33", + Name: "Breach Notification", + Description: "Procedures for breach detection and notification", + }, + }, + }, + { + ID: "gdpr-data-retention", + Name: "Data Retention", + Description: "Lawful data retention practices", + Weight: 1.0, + Controls: []Control{ + { + ID: "GDPR-5.1e", + Reference: "GDPR Article 5(1)(e)", + Name: "Storage Limitation", + Description: "Personal data not kept longer than necessary", + }, + { + ID: "GDPR-17", + Reference: "GDPR Article 17", + Name: "Right to Erasure", + Description: "Ability to delete personal data on request", + }, + }, + }, + } +} + +// HIPAAFramework returns HIPAA-related controls +func HIPAAFramework() []Category { + return []Category{ + { + ID: "hipaa-administrative", + Name: "Administrative Safeguards", + Description: "Administrative policies and procedures", + Weight: 1.0, + Controls: []Control{ + { + ID: "164.308a7", + Reference: "HIPAA 164.308(a)(7)", + Name: "Contingency Plan", + Description: "Data backup and disaster recovery procedures", + }, + { + ID: "164.308a7iA", + Reference: "HIPAA 164.308(a)(7)(ii)(A)", + Name: "Data Backup Plan", + Description: "Procedures for retrievable exact copies of ePHI", + }, + { + ID: "164.308a7iB", + Reference: "HIPAA 164.308(a)(7)(ii)(B)", + Name: "Disaster Recovery Plan", + Description: "Procedures to restore any loss of data", + }, + { + ID: "164.308a7iD", + Reference: "HIPAA 164.308(a)(7)(ii)(D)", + Name: "Testing and Revision", + Description: "Testing of contingency plans", + }, + }, + }, + { + ID: "hipaa-technical", + Name: "Technical Safeguards", + Description: "Technical security measures", + Weight: 1.0, + Controls: []Control{ + { + ID: "164.312a2iv", + Reference: "HIPAA 164.312(a)(2)(iv)", + Name: "Encryption", + Description: "Encryption of ePHI", + }, + { + ID: "164.312c1", + Reference: "HIPAA 164.312(c)(1)", + Name: "Integrity Controls", + Description: "Mechanisms to ensure ePHI is not improperly altered", + }, + { + ID: "164.312e1", + Reference: "HIPAA 164.312(e)(1)", + Name: "Transmission Security", + Description: "Technical measures to guard against unauthorized access", + }, + }, + }, + } +} + +// PCIDSSFramework returns PCI-DSS related controls +func PCIDSSFramework() []Category { + return []Category{ + { + ID: "pci-protect", + Name: "Protect Stored Data", + Description: "Protect stored cardholder data", + Weight: 1.0, + Controls: []Control{ + { + ID: "PCI-3.1", + Reference: "PCI-DSS 3.1", + Name: "Data Retention Policy", + Description: "Retention policy limits storage time", + }, + { + ID: "PCI-3.4", + Reference: "PCI-DSS 3.4", + Name: "Encryption", + Description: "Render PAN unreadable anywhere it is stored", + }, + { + ID: "PCI-3.5", + Reference: "PCI-DSS 3.5", + Name: "Key Management", + Description: "Protect cryptographic keys", + }, + }, + }, + { + ID: "pci-maintain", + Name: "Maintain Security", + Description: "Maintain security policies and procedures", + Weight: 1.0, + Controls: []Control{ + { + ID: "PCI-12.10.1", + Reference: "PCI-DSS 12.10.1", + Name: "Incident Response Plan", + Description: "Incident response plan includes data recovery", + }, + }, + }, + } +} + +// ISO27001Framework returns ISO 27001 related controls +func ISO27001Framework() []Category { + return []Category{ + { + ID: "iso-operations", + Name: "Operations Security", + Description: "A.12 Operations Security controls", + Weight: 1.0, + Controls: []Control{ + { + ID: "A.12.3.1", + Reference: "ISO 27001 A.12.3.1", + Name: "Information Backup", + Description: "Backup copies taken and tested regularly", + }, + }, + }, + { + ID: "iso-continuity", + Name: "Business Continuity", + Description: "A.17 Business Continuity controls", + Weight: 1.0, + Controls: []Control{ + { + ID: "A.17.1.1", + Reference: "ISO 27001 A.17.1.1", + Name: "Planning Continuity", + Description: "Information security continuity planning", + }, + { + ID: "A.17.1.2", + Reference: "ISO 27001 A.17.1.2", + Name: "Implementing Continuity", + Description: "Implementation of security continuity", + }, + { + ID: "A.17.1.3", + Reference: "ISO 27001 A.17.1.3", + Name: "Verify and Review", + Description: "Verify and review continuity controls", + }, + }, + }, + { + ID: "iso-cryptography", + Name: "Cryptography", + Description: "A.10 Cryptographic controls", + Weight: 1.0, + Controls: []Control{ + { + ID: "A.10.1.1", + Reference: "ISO 27001 A.10.1.1", + Name: "Cryptographic Controls", + Description: "Policy on use of cryptographic controls", + }, + { + ID: "A.10.1.2", + Reference: "ISO 27001 A.10.1.2", + Name: "Key Management", + Description: "Policy on cryptographic key management", + }, + }, + }, + } +} + +// GetFramework returns the appropriate framework for a report type +func GetFramework(reportType ReportType) []Category { + switch reportType { + case ReportSOC2: + return SOC2Framework() + case ReportGDPR: + return GDPRFramework() + case ReportHIPAA: + return HIPAAFramework() + case ReportPCIDSS: + return PCIDSSFramework() + case ReportISO27001: + return ISO27001Framework() + default: + return nil + } +} + +// CreatePeriodReport creates a report for a specific time period +func CreatePeriodReport(reportType ReportType, start, end time.Time) *Report { + title := "" + desc := "" + + switch reportType { + case ReportSOC2: + title = "SOC 2 Type II Compliance Report" + desc = "Trust Service Criteria compliance assessment" + case ReportGDPR: + title = "GDPR Data Protection Compliance Report" + desc = "General Data Protection Regulation compliance assessment" + case ReportHIPAA: + title = "HIPAA Security Compliance Report" + desc = "Health Insurance Portability and Accountability Act compliance assessment" + case ReportPCIDSS: + title = "PCI-DSS Compliance Report" + desc = "Payment Card Industry Data Security Standard compliance assessment" + case ReportISO27001: + title = "ISO 27001 Compliance Report" + desc = "Information Security Management System compliance assessment" + default: + title = "Custom Compliance Report" + desc = "Custom compliance assessment" + } + + report := NewReport(reportType, title) + report.Description = desc + report.PeriodStart = start + report.PeriodEnd = end + + // Load framework controls + framework := GetFramework(reportType) + for _, cat := range framework { + report.AddCategory(cat) + } + + return report +} diff --git a/internal/report/generator.go b/internal/report/generator.go new file mode 100644 index 0000000..ed3e469 --- /dev/null +++ b/internal/report/generator.go @@ -0,0 +1,420 @@ +// Package report - Report generator +package report + +import ( + "context" + "fmt" + "time" + + "dbbackup/internal/catalog" +) + +// Generator generates compliance reports +type Generator struct { + catalog catalog.Catalog + config ReportConfig +} + +// NewGenerator creates a new report generator +func NewGenerator(cat catalog.Catalog, config ReportConfig) *Generator { + return &Generator{ + catalog: cat, + config: config, + } +} + +// Generate creates a compliance report +func (g *Generator) Generate() (*Report, error) { + report := CreatePeriodReport(g.config.Type, g.config.PeriodStart, g.config.PeriodEnd) + report.Title = g.config.Title + if g.config.Description != "" { + report.Description = g.config.Description + } + + // Collect evidence from catalog + evidence, err := g.collectEvidence() + if err != nil { + return nil, fmt.Errorf("failed to collect evidence: %w", err) + } + + for _, e := range evidence { + report.AddEvidence(e) + } + + // Evaluate controls + if err := g.evaluateControls(report, evidence); err != nil { + return nil, fmt.Errorf("failed to evaluate controls: %w", err) + } + + // Calculate summary + report.Calculate() + + return report, nil +} + +// collectEvidence gathers evidence from the backup catalog +func (g *Generator) collectEvidence() ([]Evidence, error) { + var evidence []Evidence + ctx := context.Background() + + // Get backup entries in the report period + query := &catalog.SearchQuery{ + StartDate: &g.config.PeriodStart, + EndDate: &g.config.PeriodEnd, + Limit: 1000, + } + + entries, err := g.catalog.Search(ctx, query) + if err != nil { + return nil, fmt.Errorf("failed to search catalog: %w", err) + } + + // Create evidence for backups + for _, entry := range entries { + e := Evidence{ + ID: fmt.Sprintf("BKP-%d", entry.ID), + Type: EvidenceBackupLog, + Description: fmt.Sprintf("Backup of %s completed", entry.Database), + Source: entry.BackupPath, + CollectedAt: entry.CreatedAt, + Data: map[string]interface{}{ + "database": entry.Database, + "database_type": entry.DatabaseType, + "size": entry.SizeBytes, + "sha256": entry.SHA256, + "encrypted": entry.Encrypted, + "compression": entry.Compression, + "status": entry.Status, + }, + } + + if entry.SHA256 != "" { + e.Hash = entry.SHA256 + } + + evidence = append(evidence, e) + + // Add verification evidence + if entry.VerifiedAt != nil { + evidence = append(evidence, Evidence{ + ID: fmt.Sprintf("VRF-%d", entry.ID), + Type: EvidenceAuditLog, + Description: fmt.Sprintf("Verification of backup %s", entry.BackupPath), + Source: "verification_system", + CollectedAt: *entry.VerifiedAt, + Data: map[string]interface{}{ + "backup_id": entry.ID, + "database": entry.Database, + "verified": true, + }, + }) + } + + // Add drill evidence + if entry.DrillTestedAt != nil { + evidence = append(evidence, Evidence{ + ID: fmt.Sprintf("DRL-%d", entry.ID), + Type: EvidenceDrillResult, + Description: fmt.Sprintf("DR drill test of backup %s", entry.BackupPath), + Source: "drill_system", + CollectedAt: *entry.DrillTestedAt, + Data: map[string]interface{}{ + "backup_id": entry.ID, + "database": entry.Database, + "passed": true, + }, + }) + } + + // Add encryption evidence + if entry.Encrypted { + encryption := "AES-256" + if meta, ok := entry.Metadata["encryption_method"]; ok { + encryption = meta + } + evidence = append(evidence, Evidence{ + ID: fmt.Sprintf("ENC-%d", entry.ID), + Type: EvidenceEncryptionProof, + Description: fmt.Sprintf("Encrypted backup %s", entry.BackupPath), + Source: entry.BackupPath, + CollectedAt: entry.CreatedAt, + Data: map[string]interface{}{ + "backup_id": entry.ID, + "database": entry.Database, + "encryption": encryption, + }, + }) + } + } + + // Get catalog statistics for retention evidence + stats, err := g.catalog.Stats(ctx) + if err == nil { + evidence = append(evidence, Evidence{ + ID: "RET-STATS", + Type: EvidenceRetentionProof, + Description: "Backup retention statistics", + Source: "catalog", + CollectedAt: time.Now(), + Data: map[string]interface{}{ + "total_backups": stats.TotalBackups, + "oldest_backup": stats.OldestBackup, + "newest_backup": stats.NewestBackup, + "average_size": stats.AvgSize, + "total_size": stats.TotalSize, + "databases": len(stats.ByDatabase), + }, + }) + } + + // Check for gaps + gapConfig := &catalog.GapDetectionConfig{ + ExpectedInterval: 24 * time.Hour, + Tolerance: 2 * time.Hour, + StartDate: &g.config.PeriodStart, + EndDate: &g.config.PeriodEnd, + } + + allGaps, err := g.catalog.DetectAllGaps(ctx, gapConfig) + if err == nil { + totalGaps := 0 + for _, gaps := range allGaps { + totalGaps += len(gaps) + } + if totalGaps > 0 { + evidence = append(evidence, Evidence{ + ID: "GAP-ANALYSIS", + Type: EvidenceAuditLog, + Description: "Backup gap analysis", + Source: "catalog", + CollectedAt: time.Now(), + Data: map[string]interface{}{ + "gaps_detected": totalGaps, + "gaps": allGaps, + }, + }) + } + } + + return evidence, nil +} + +// evaluateControls evaluates compliance controls based on evidence +func (g *Generator) evaluateControls(report *Report, evidence []Evidence) error { + // Index evidence by type for quick lookup + evidenceByType := make(map[EvidenceType][]Evidence) + for _, e := range evidence { + evidenceByType[e.Type] = append(evidenceByType[e.Type], e) + } + + // Get backup statistics + backupEvidence := evidenceByType[EvidenceBackupLog] + encryptionEvidence := evidenceByType[EvidenceEncryptionProof] + drillEvidence := evidenceByType[EvidenceDrillResult] + verificationEvidence := evidenceByType[EvidenceAuditLog] + + // Evaluate each control + for i := range report.Categories { + cat := &report.Categories[i] + catCompliant := 0 + catTotal := 0 + + for j := range cat.Controls { + ctrl := &cat.Controls[j] + ctrl.LastChecked = time.Now() + catTotal++ + + // Evaluate based on control type + status, notes, evidenceIDs := g.evaluateControl(ctrl, backupEvidence, encryptionEvidence, drillEvidence, verificationEvidence) + ctrl.Status = status + ctrl.Notes = notes + ctrl.Evidence = evidenceIDs + + if status == StatusCompliant { + catCompliant++ + } else if status != StatusNotApplicable { + // Create finding for non-compliant controls + finding := g.createFinding(ctrl, report) + if finding != nil { + report.AddFinding(*finding) + ctrl.Findings = append(ctrl.Findings, finding.ID) + } + } + } + + // Calculate category score + if catTotal > 0 { + cat.Score = float64(catCompliant) / float64(catTotal) * 100 + if cat.Score >= 100 { + cat.Status = StatusCompliant + } else if cat.Score >= 70 { + cat.Status = StatusPartial + } else { + cat.Status = StatusNonCompliant + } + } + } + + return nil +} + +// evaluateControl evaluates a single control +func (g *Generator) evaluateControl(ctrl *Control, backups, encryption, drills, verifications []Evidence) (ComplianceStatus, string, []string) { + var evidenceIDs []string + + switch ctrl.ID { + // SOC2 Controls + case "CC6.1", "GDPR-32", "164.312a2iv", "PCI-3.4", "A.10.1.1": + // Encryption at rest + if len(encryption) == 0 { + return StatusNonCompliant, "No encrypted backups found", nil + } + encryptedCount := len(encryption) + totalCount := len(backups) + if totalCount == 0 { + return StatusNotApplicable, "No backups in period", nil + } + rate := float64(encryptedCount) / float64(totalCount) * 100 + for _, e := range encryption { + evidenceIDs = append(evidenceIDs, e.ID) + } + if rate >= 100 { + return StatusCompliant, fmt.Sprintf("100%% of backups encrypted (%d/%d)", encryptedCount, totalCount), evidenceIDs + } + if rate >= 90 { + return StatusPartial, fmt.Sprintf("%.1f%% of backups encrypted (%d/%d)", rate, encryptedCount, totalCount), evidenceIDs + } + return StatusNonCompliant, fmt.Sprintf("Only %.1f%% of backups encrypted", rate), evidenceIDs + + case "A1.1", "164.308a7iA", "A.12.3.1": + // Backup policy/plan + if len(backups) == 0 { + return StatusNonCompliant, "No backups found in period", nil + } + for _, e := range backups[:min(5, len(backups))] { + evidenceIDs = append(evidenceIDs, e.ID) + } + return StatusCompliant, fmt.Sprintf("%d backups created in period", len(backups)), evidenceIDs + + case "A1.2", "164.308a7iD", "A.17.1.3": + // Backup testing + if len(drills) == 0 { + return StatusNonCompliant, "No DR drill tests performed", nil + } + for _, e := range drills { + evidenceIDs = append(evidenceIDs, e.ID) + } + return StatusCompliant, fmt.Sprintf("%d DR drill tests completed", len(drills)), evidenceIDs + + case "A1.3", "A1.4", "164.308a7iB", "A.17.1.1", "A.17.1.2", "PCI-12.10.1": + // DR procedures + if len(drills) > 0 { + for _, e := range drills { + evidenceIDs = append(evidenceIDs, e.ID) + } + return StatusCompliant, "DR procedures tested", evidenceIDs + } + return StatusPartial, "DR procedures exist but not tested", nil + + case "PI1.1", "164.312c1": + // Data integrity + integrityCount := 0 + for _, e := range backups { + if data, ok := e.Data.(map[string]interface{}); ok { + if checksum, ok := data["checksum"].(string); ok && checksum != "" { + integrityCount++ + evidenceIDs = append(evidenceIDs, e.ID) + } + } + } + if integrityCount == len(backups) && len(backups) > 0 { + return StatusCompliant, "All backups have integrity checksums", evidenceIDs + } + if integrityCount > 0 { + return StatusPartial, fmt.Sprintf("%d/%d backups have checksums", integrityCount, len(backups)), evidenceIDs + } + return StatusNonCompliant, "No integrity checksums found", nil + + case "C1.2", "GDPR-5.1e", "PCI-3.1": + // Data retention + for _, e := range verifications { + if e.Type == EvidenceRetentionProof { + evidenceIDs = append(evidenceIDs, e.ID) + } + } + if len(backups) > 0 { + return StatusCompliant, "Retention policy in effect", evidenceIDs + } + return StatusPartial, "Retention policy needs review", nil + + default: + // Generic evaluation + if len(backups) > 0 { + return StatusCompliant, "Evidence available", nil + } + return StatusUnknown, "Requires manual review", nil + } +} + +// createFinding creates a finding for a non-compliant control +func (g *Generator) createFinding(ctrl *Control, report *Report) *Finding { + if ctrl.Status == StatusCompliant || ctrl.Status == StatusNotApplicable { + return nil + } + + severity := SeverityMedium + findingType := FindingGap + + // Determine severity based on control + switch ctrl.ID { + case "CC6.1", "164.312a2iv", "PCI-3.4": + severity = SeverityHigh + findingType = FindingViolation + case "A1.2", "164.308a7iD": + severity = SeverityMedium + findingType = FindingGap + } + + return &Finding{ + ID: fmt.Sprintf("FND-%s-%d", ctrl.ID, time.Now().UnixNano()), + ControlID: ctrl.ID, + Type: findingType, + Severity: severity, + Title: fmt.Sprintf("%s: %s", ctrl.Reference, ctrl.Name), + Description: ctrl.Notes, + Impact: fmt.Sprintf("Non-compliance with %s requirements", report.Type), + Recommendation: g.getRecommendation(ctrl.ID), + Status: FindingOpen, + DetectedAt: time.Now(), + Evidence: ctrl.Evidence, + } +} + +// getRecommendation returns remediation recommendation for a control +func (g *Generator) getRecommendation(controlID string) string { + recommendations := map[string]string{ + "CC6.1": "Enable encryption for all backups using AES-256", + "CC6.7": "Ensure all backup transfers use TLS", + "A1.1": "Establish and document backup schedule", + "A1.2": "Schedule and perform regular DR drill tests", + "A1.3": "Document and test recovery procedures", + "A1.4": "Develop and test disaster recovery plan", + "PI1.1": "Enable checksum verification for all backups", + "C1.2": "Implement and document retention policies", + "164.312a2iv": "Enable HIPAA-compliant encryption (AES-256)", + "164.308a7iD": "Test backup recoverability quarterly", + "PCI-3.4": "Encrypt all backups containing cardholder data", + } + + if rec, ok := recommendations[controlID]; ok { + return rec + } + return "Review and remediate as per compliance requirements" +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/internal/report/output.go b/internal/report/output.go new file mode 100644 index 0000000..dbfb2f0 --- /dev/null +++ b/internal/report/output.go @@ -0,0 +1,544 @@ +// Package report - Output formatters +package report + +import ( + "encoding/json" + "fmt" + "io" + "strings" + "text/template" + "time" +) + +// Formatter formats reports for output +type Formatter interface { + Format(report *Report, w io.Writer) error +} + +// JSONFormatter formats reports as JSON +type JSONFormatter struct { + Indent bool +} + +// Format writes the report as JSON +func (f *JSONFormatter) Format(report *Report, w io.Writer) error { + var data []byte + var err error + + if f.Indent { + data, err = json.MarshalIndent(report, "", " ") + } else { + data, err = json.Marshal(report) + } + + if err != nil { + return err + } + + _, err = w.Write(data) + return err +} + +// MarkdownFormatter formats reports as Markdown +type MarkdownFormatter struct{} + +// Format writes the report as Markdown +func (f *MarkdownFormatter) Format(report *Report, w io.Writer) error { + tmpl := template.Must(template.New("report").Funcs(template.FuncMap{ + "statusIcon": StatusIcon, + "severityIcon": SeverityIcon, + "formatTime": func(t time.Time) string { return t.Format("2006-01-02 15:04:05") }, + "formatDate": func(t time.Time) string { return t.Format("2006-01-02") }, + "upper": strings.ToUpper, + }).Parse(markdownTemplate)) + + return tmpl.Execute(w, report) +} + +const markdownTemplate = `# {{.Title}} + +**Generated:** {{formatTime .GeneratedAt}} +**Period:** {{formatDate .PeriodStart}} to {{formatDate .PeriodEnd}} +**Overall Status:** {{statusIcon .Status}} {{.Status}} +**Compliance Score:** {{printf "%.1f" .Score}}% + +--- + +## Executive Summary + +{{.Description}} + +| Metric | Value | +|--------|-------| +| Total Controls | {{.Summary.TotalControls}} | +| Compliant | {{.Summary.CompliantControls}} | +| Non-Compliant | {{.Summary.NonCompliantControls}} | +| Partial | {{.Summary.PartialControls}} | +| Compliance Rate | {{printf "%.1f" .Summary.ComplianceRate}}% | +| Open Findings | {{.Summary.OpenFindings}} | +| Risk Score | {{printf "%.1f" .Summary.RiskScore}} | + +--- + +## Compliance Categories + +{{range .Categories}} +### {{statusIcon .Status}} {{.Name}} + +**Status:** {{.Status}} | **Score:** {{printf "%.1f" .Score}}% + +{{.Description}} + +| Control | Reference | Status | Notes | +|---------|-----------|--------|-------| +{{range .Controls}}| {{.Name}} | {{.Reference}} | {{statusIcon .Status}} | {{.Notes}} | +{{end}} + +{{end}} + +--- + +## Findings + +{{if .Findings}} +| ID | Severity | Title | Status | +|----|----------|-------|--------| +{{range .Findings}}| {{.ID}} | {{severityIcon .Severity}} {{.Severity}} | {{.Title}} | {{.Status}} | +{{end}} + +### Finding Details + +{{range .Findings}} +#### {{severityIcon .Severity}} {{.Title}} + +- **ID:** {{.ID}} +- **Control:** {{.ControlID}} +- **Severity:** {{.Severity}} +- **Type:** {{.Type}} +- **Status:** {{.Status}} +- **Detected:** {{formatTime .DetectedAt}} + +**Description:** {{.Description}} + +**Impact:** {{.Impact}} + +**Recommendation:** {{.Recommendation}} + +--- + +{{end}} +{{else}} +No open findings. +{{end}} + +--- + +## Evidence Summary + +{{if .Evidence}} +| ID | Type | Description | Collected | +|----|------|-------------|-----------| +{{range .Evidence}}| {{.ID}} | {{.Type}} | {{.Description}} | {{formatTime .CollectedAt}} | +{{end}} +{{else}} +No evidence collected. +{{end}} + +--- + +*Report generated by dbbackup compliance module* +` + +// HTMLFormatter formats reports as HTML +type HTMLFormatter struct{} + +// Format writes the report as HTML +func (f *HTMLFormatter) Format(report *Report, w io.Writer) error { + tmpl := template.Must(template.New("report").Funcs(template.FuncMap{ + "statusIcon": StatusIcon, + "statusClass": statusClass, + "severityIcon": SeverityIcon, + "severityClass": severityClass, + "formatTime": func(t time.Time) string { return t.Format("2006-01-02 15:04:05") }, + "formatDate": func(t time.Time) string { return t.Format("2006-01-02") }, + }).Parse(htmlTemplate)) + + return tmpl.Execute(w, report) +} + +func statusClass(s ComplianceStatus) string { + switch s { + case StatusCompliant: + return "status-compliant" + case StatusNonCompliant: + return "status-noncompliant" + case StatusPartial: + return "status-partial" + default: + return "status-unknown" + } +} + +func severityClass(s FindingSeverity) string { + switch s { + case SeverityCritical: + return "severity-critical" + case SeverityHigh: + return "severity-high" + case SeverityMedium: + return "severity-medium" + case SeverityLow: + return "severity-low" + default: + return "severity-unknown" + } +} + +const htmlTemplate = ` + + + + + {{.Title}} + + + +
+

{{.Title}}

+

+ Generated: {{formatTime .GeneratedAt}} | + Period: {{formatDate .PeriodStart}} to {{formatDate .PeriodEnd}} +

+
+
{{printf "%.0f" .Score}}%
+
+ {{.Status}} +

{{.Description}}

+
+
+
+ +
+
+
{{.Summary.TotalControls}}
+
Total Controls
+
+
+
{{.Summary.CompliantControls}}
+
Compliant
+
+
+
{{.Summary.NonCompliantControls}}
+
Non-Compliant
+
+
+
{{.Summary.PartialControls}}
+
Partial
+
+
+
{{.Summary.OpenFindings}}
+
Open Findings
+
+
+
{{printf "%.1f" .Summary.RiskScore}}
+
Risk Score
+
+
+ +
+

Compliance Categories

+ {{range .Categories}} +
+

+ {{statusIcon .Status}} {{.Name}} + {{printf "%.0f" .Score}}% +

+

{{.Description}}

+
+
+
+ + + + + + + + + + + {{range .Controls}} + + + + + + + {{end}} + +
ControlReferenceStatusNotes
{{.Name}}{{.Reference}}{{statusIcon .Status}}{{.Notes}}
+
+ {{end}} +
+ + {{if .Findings}} +
+

Findings ({{len .Findings}})

+ {{range .Findings}} +
+

{{severityIcon .Severity}} {{.Title}}

+
+ ID: {{.ID}} | + Severity: {{.Severity}} | + Status: {{.Status}} | + Detected: {{formatTime .DetectedAt}} +
+

Description: {{.Description}}

+

Impact: {{.Impact}}

+

Recommendation: {{.Recommendation}}

+
+ {{end}} +
+ {{end}} + + {{if .Evidence}} +
+

Evidence ({{len .Evidence}} items)

+ + + + + + + + + + + {{range .Evidence}} + + + + + + + {{end}} + +
IDTypeDescriptionCollected
{{.ID}}{{.Type}}{{.Description}}{{formatTime .CollectedAt}}
+
+ {{end}} + + + +` + +// GetFormatter returns a formatter for the given format +func GetFormatter(format OutputFormat) Formatter { + switch format { + case FormatJSON: + return &JSONFormatter{Indent: true} + case FormatMarkdown: + return &MarkdownFormatter{} + case FormatHTML: + return &HTMLFormatter{} + default: + return &JSONFormatter{Indent: true} + } +} + +// ConsoleFormatter formats reports for terminal output +type ConsoleFormatter struct{} + +// Format writes the report to console +func (f *ConsoleFormatter) Format(report *Report, w io.Writer) error { + // Header + fmt.Fprintf(w, "\n%s\n", strings.Repeat("=", 60)) + fmt.Fprintf(w, " %s\n", report.Title) + fmt.Fprintf(w, "%s\n\n", strings.Repeat("=", 60)) + + fmt.Fprintf(w, " Generated: %s\n", report.GeneratedAt.Format("2006-01-02 15:04:05")) + fmt.Fprintf(w, " Period: %s to %s\n", + report.PeriodStart.Format("2006-01-02"), + report.PeriodEnd.Format("2006-01-02")) + fmt.Fprintf(w, " Status: %s %s\n", StatusIcon(report.Status), report.Status) + fmt.Fprintf(w, " Score: %.1f%%\n\n", report.Score) + + // Summary + fmt.Fprintf(w, " SUMMARY\n") + fmt.Fprintf(w, " %s\n", strings.Repeat("-", 40)) + fmt.Fprintf(w, " Controls: %d total, %d compliant, %d non-compliant\n", + report.Summary.TotalControls, + report.Summary.CompliantControls, + report.Summary.NonCompliantControls) + fmt.Fprintf(w, " Compliance: %.1f%%\n", report.Summary.ComplianceRate) + fmt.Fprintf(w, " Open Findings: %d (critical: %d, high: %d)\n", + report.Summary.OpenFindings, + report.Summary.CriticalFindings, + report.Summary.HighFindings) + fmt.Fprintf(w, " Risk Score: %.1f\n\n", report.Summary.RiskScore) + + // Categories + fmt.Fprintf(w, " CATEGORIES\n") + fmt.Fprintf(w, " %s\n", strings.Repeat("-", 40)) + for _, cat := range report.Categories { + fmt.Fprintf(w, " %s %-25s %.0f%%\n", StatusIcon(cat.Status), cat.Name, cat.Score) + } + fmt.Fprintln(w) + + // Findings + if len(report.Findings) > 0 { + fmt.Fprintf(w, " FINDINGS\n") + fmt.Fprintf(w, " %s\n", strings.Repeat("-", 40)) + for _, f := range report.Findings { + fmt.Fprintf(w, " %s [%s] %s\n", SeverityIcon(f.Severity), f.Severity, f.Title) + fmt.Fprintf(w, " %s\n", f.Description) + } + fmt.Fprintln(w) + } + + fmt.Fprintf(w, "%s\n", strings.Repeat("=", 60)) + return nil +} diff --git a/internal/report/report.go b/internal/report/report.go new file mode 100644 index 0000000..f09d4b7 --- /dev/null +++ b/internal/report/report.go @@ -0,0 +1,325 @@ +// Package report provides compliance report generation +package report + +import ( + "encoding/json" + "fmt" + "time" +) + +// ReportType represents the compliance framework type +type ReportType string + +const ( + ReportSOC2 ReportType = "soc2" + ReportGDPR ReportType = "gdpr" + ReportHIPAA ReportType = "hipaa" + ReportPCIDSS ReportType = "pci-dss" + ReportISO27001 ReportType = "iso27001" + ReportCustom ReportType = "custom" +) + +// ComplianceStatus represents the status of a compliance check +type ComplianceStatus string + +const ( + StatusCompliant ComplianceStatus = "compliant" + StatusNonCompliant ComplianceStatus = "non_compliant" + StatusPartial ComplianceStatus = "partial" + StatusNotApplicable ComplianceStatus = "not_applicable" + StatusUnknown ComplianceStatus = "unknown" +) + +// Report represents a compliance report +type Report struct { + ID string `json:"id"` + Type ReportType `json:"type"` + Title string `json:"title"` + Description string `json:"description"` + GeneratedAt time.Time `json:"generated_at"` + GeneratedBy string `json:"generated_by"` + PeriodStart time.Time `json:"period_start"` + PeriodEnd time.Time `json:"period_end"` + Status ComplianceStatus `json:"overall_status"` + Score float64 `json:"score"` // 0-100 + Categories []Category `json:"categories"` + Summary Summary `json:"summary"` + Findings []Finding `json:"findings"` + Evidence []Evidence `json:"evidence"` + Metadata map[string]string `json:"metadata,omitempty"` +} + +// Category represents a compliance category +type Category struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Status ComplianceStatus `json:"status"` + Score float64 `json:"score"` + Weight float64 `json:"weight"` + Controls []Control `json:"controls"` +} + +// Control represents a compliance control +type Control struct { + ID string `json:"id"` + Reference string `json:"reference"` // e.g., "SOC2 CC6.1" + Name string `json:"name"` + Description string `json:"description"` + Status ComplianceStatus `json:"status"` + Evidence []string `json:"evidence_ids,omitempty"` + Findings []string `json:"finding_ids,omitempty"` + LastChecked time.Time `json:"last_checked"` + Notes string `json:"notes,omitempty"` +} + +// Finding represents a compliance finding +type Finding struct { + ID string `json:"id"` + ControlID string `json:"control_id"` + Type FindingType `json:"type"` + Severity FindingSeverity `json:"severity"` + Title string `json:"title"` + Description string `json:"description"` + Impact string `json:"impact"` + Recommendation string `json:"recommendation"` + Status FindingStatus `json:"status"` + DetectedAt time.Time `json:"detected_at"` + ResolvedAt *time.Time `json:"resolved_at,omitempty"` + Evidence []string `json:"evidence_ids,omitempty"` +} + +// FindingType represents the type of finding +type FindingType string + +const ( + FindingGap FindingType = "gap" + FindingViolation FindingType = "violation" + FindingObservation FindingType = "observation" + FindingRecommendation FindingType = "recommendation" +) + +// FindingSeverity represents finding severity +type FindingSeverity string + +const ( + SeverityLow FindingSeverity = "low" + SeverityMedium FindingSeverity = "medium" + SeverityHigh FindingSeverity = "high" + SeverityCritical FindingSeverity = "critical" +) + +// FindingStatus represents finding status +type FindingStatus string + +const ( + FindingOpen FindingStatus = "open" + FindingAccepted FindingStatus = "accepted" + FindingResolved FindingStatus = "resolved" +) + +// Evidence represents compliance evidence +type Evidence struct { + ID string `json:"id"` + Type EvidenceType `json:"type"` + Description string `json:"description"` + Source string `json:"source"` + CollectedAt time.Time `json:"collected_at"` + Hash string `json:"hash,omitempty"` + Data interface{} `json:"data,omitempty"` +} + +// EvidenceType represents the type of evidence +type EvidenceType string + +const ( + EvidenceBackupLog EvidenceType = "backup_log" + EvidenceRestoreLog EvidenceType = "restore_log" + EvidenceDrillResult EvidenceType = "drill_result" + EvidenceEncryptionProof EvidenceType = "encryption_proof" + EvidenceRetentionProof EvidenceType = "retention_proof" + EvidenceAccessLog EvidenceType = "access_log" + EvidenceAuditLog EvidenceType = "audit_log" + EvidenceConfiguration EvidenceType = "configuration" + EvidenceScreenshot EvidenceType = "screenshot" + EvidenceOther EvidenceType = "other" +) + +// Summary provides a high-level overview +type Summary struct { + TotalControls int `json:"total_controls"` + CompliantControls int `json:"compliant_controls"` + NonCompliantControls int `json:"non_compliant_controls"` + PartialControls int `json:"partial_controls"` + NotApplicable int `json:"not_applicable"` + OpenFindings int `json:"open_findings"` + CriticalFindings int `json:"critical_findings"` + HighFindings int `json:"high_findings"` + MediumFindings int `json:"medium_findings"` + LowFindings int `json:"low_findings"` + ComplianceRate float64 `json:"compliance_rate"` + RiskScore float64 `json:"risk_score"` +} + +// ReportConfig configures report generation +type ReportConfig struct { + Type ReportType `json:"type"` + Title string `json:"title"` + Description string `json:"description"` + PeriodStart time.Time `json:"period_start"` + PeriodEnd time.Time `json:"period_end"` + IncludeDatabases []string `json:"include_databases,omitempty"` + ExcludeDatabases []string `json:"exclude_databases,omitempty"` + CatalogPath string `json:"catalog_path"` + OutputFormat OutputFormat `json:"output_format"` + OutputPath string `json:"output_path"` + IncludeEvidence bool `json:"include_evidence"` + CustomControls []Control `json:"custom_controls,omitempty"` +} + +// OutputFormat represents report output format +type OutputFormat string + +const ( + FormatJSON OutputFormat = "json" + FormatHTML OutputFormat = "html" + FormatPDF OutputFormat = "pdf" + FormatMarkdown OutputFormat = "markdown" +) + +// NewReport creates a new report +func NewReport(reportType ReportType, title string) *Report { + return &Report{ + ID: generateID(), + Type: reportType, + Title: title, + GeneratedAt: time.Now(), + Categories: make([]Category, 0), + Findings: make([]Finding, 0), + Evidence: make([]Evidence, 0), + Metadata: make(map[string]string), + } +} + +// AddCategory adds a category to the report +func (r *Report) AddCategory(cat Category) { + r.Categories = append(r.Categories, cat) +} + +// AddFinding adds a finding to the report +func (r *Report) AddFinding(f Finding) { + r.Findings = append(r.Findings, f) +} + +// AddEvidence adds evidence to the report +func (r *Report) AddEvidence(e Evidence) { + r.Evidence = append(r.Evidence, e) +} + +// Calculate computes the summary and overall status +func (r *Report) Calculate() { + var totalWeight float64 + var weightedScore float64 + + for _, cat := range r.Categories { + totalWeight += cat.Weight + weightedScore += cat.Score * cat.Weight + + for _, ctrl := range cat.Controls { + r.Summary.TotalControls++ + switch ctrl.Status { + case StatusCompliant: + r.Summary.CompliantControls++ + case StatusNonCompliant: + r.Summary.NonCompliantControls++ + case StatusPartial: + r.Summary.PartialControls++ + case StatusNotApplicable: + r.Summary.NotApplicable++ + } + } + } + + for _, f := range r.Findings { + if f.Status == FindingOpen { + r.Summary.OpenFindings++ + switch f.Severity { + case SeverityCritical: + r.Summary.CriticalFindings++ + case SeverityHigh: + r.Summary.HighFindings++ + case SeverityMedium: + r.Summary.MediumFindings++ + case SeverityLow: + r.Summary.LowFindings++ + } + } + } + + if totalWeight > 0 { + r.Score = weightedScore / totalWeight + } + + applicable := r.Summary.TotalControls - r.Summary.NotApplicable + if applicable > 0 { + r.Summary.ComplianceRate = float64(r.Summary.CompliantControls) / float64(applicable) * 100 + } + + // Calculate risk score based on findings + r.Summary.RiskScore = float64(r.Summary.CriticalFindings)*10 + + float64(r.Summary.HighFindings)*5 + + float64(r.Summary.MediumFindings)*2 + + float64(r.Summary.LowFindings)*1 + + // Determine overall status + if r.Summary.NonCompliantControls == 0 && r.Summary.CriticalFindings == 0 { + if r.Summary.PartialControls == 0 { + r.Status = StatusCompliant + } else { + r.Status = StatusPartial + } + } else { + r.Status = StatusNonCompliant + } +} + +// ToJSON converts the report to JSON +func (r *Report) ToJSON() ([]byte, error) { + return json.MarshalIndent(r, "", " ") +} + +func generateID() string { + return fmt.Sprintf("RPT-%d", time.Now().UnixNano()) +} + +// StatusIcon returns an icon for a compliance status +func StatusIcon(s ComplianceStatus) string { + switch s { + case StatusCompliant: + return "✅" + case StatusNonCompliant: + return "❌" + case StatusPartial: + return "⚠️" + case StatusNotApplicable: + return "➖" + default: + return "❓" + } +} + +// SeverityIcon returns an icon for a finding severity +func SeverityIcon(s FindingSeverity) string { + switch s { + case SeverityCritical: + return "🔴" + case SeverityHigh: + return "🟠" + case SeverityMedium: + return "🟡" + case SeverityLow: + return "🟢" + default: + return "⚪" + } +} diff --git a/internal/rto/calculator.go b/internal/rto/calculator.go new file mode 100644 index 0000000..228af68 --- /dev/null +++ b/internal/rto/calculator.go @@ -0,0 +1,481 @@ +// Package rto provides RTO/RPO calculation and analysis +package rto + +import ( + "context" + "fmt" + "sort" + "time" + + "dbbackup/internal/catalog" +) + +// Calculator calculates RTO and RPO metrics +type Calculator struct { + catalog catalog.Catalog + config Config +} + +// Config configures RTO/RPO calculations +type Config struct { + TargetRTO time.Duration `json:"target_rto"` // Target Recovery Time Objective + TargetRPO time.Duration `json:"target_rpo"` // Target Recovery Point Objective + + // Assumptions for calculation + NetworkSpeedMbps float64 `json:"network_speed_mbps"` // Network speed for cloud restores + DiskReadSpeedMBps float64 `json:"disk_read_speed_mbps"` // Disk read speed + DiskWriteSpeedMBps float64 `json:"disk_write_speed_mbps"` // Disk write speed + CloudDownloadSpeedMbps float64 `json:"cloud_download_speed_mbps"` + + // Time estimates for various operations + StartupTimeMinutes int `json:"startup_time_minutes"` // DB startup time + ValidationTimeMinutes int `json:"validation_time_minutes"` // Post-restore validation + SwitchoverTimeMinutes int `json:"switchover_time_minutes"` // Application switchover time +} + +// DefaultConfig returns sensible defaults +func DefaultConfig() Config { + return Config{ + TargetRTO: 4 * time.Hour, + TargetRPO: 1 * time.Hour, + NetworkSpeedMbps: 100, + DiskReadSpeedMBps: 100, + DiskWriteSpeedMBps: 50, + CloudDownloadSpeedMbps: 100, + StartupTimeMinutes: 2, + ValidationTimeMinutes: 5, + SwitchoverTimeMinutes: 5, + } +} + +// Analysis contains RTO/RPO analysis results +type Analysis struct { + Database string `json:"database"` + Timestamp time.Time `json:"timestamp"` + + // Current state + CurrentRPO time.Duration `json:"current_rpo"` + CurrentRTO time.Duration `json:"current_rto"` + + // Target state + TargetRPO time.Duration `json:"target_rpo"` + TargetRTO time.Duration `json:"target_rto"` + + // Compliance + RPOCompliant bool `json:"rpo_compliant"` + RTOCompliant bool `json:"rto_compliant"` + + // Details + LastBackup *time.Time `json:"last_backup,omitempty"` + NextScheduled *time.Time `json:"next_scheduled,omitempty"` + BackupInterval time.Duration `json:"backup_interval"` + + // RTO breakdown + RTOBreakdown RTOBreakdown `json:"rto_breakdown"` + + // Recommendations + Recommendations []Recommendation `json:"recommendations,omitempty"` + + // Historical + History []HistoricalPoint `json:"history,omitempty"` +} + +// RTOBreakdown shows components of RTO calculation +type RTOBreakdown struct { + DetectionTime time.Duration `json:"detection_time"` + DecisionTime time.Duration `json:"decision_time"` + DownloadTime time.Duration `json:"download_time"` + RestoreTime time.Duration `json:"restore_time"` + StartupTime time.Duration `json:"startup_time"` + ValidationTime time.Duration `json:"validation_time"` + SwitchoverTime time.Duration `json:"switchover_time"` + TotalTime time.Duration `json:"total_time"` +} + +// Recommendation suggests improvements +type Recommendation struct { + Type RecommendationType `json:"type"` + Priority Priority `json:"priority"` + Title string `json:"title"` + Description string `json:"description"` + Impact string `json:"impact"` + Effort Effort `json:"effort"` +} + +// RecommendationType categorizes recommendations +type RecommendationType string + +const ( + RecommendBackupFrequency RecommendationType = "backup_frequency" + RecommendIncrementalBackup RecommendationType = "incremental_backup" + RecommendCompression RecommendationType = "compression" + RecommendLocalCache RecommendationType = "local_cache" + RecommendParallelRestore RecommendationType = "parallel_restore" + RecommendWALArchiving RecommendationType = "wal_archiving" + RecommendReplication RecommendationType = "replication" +) + +// Priority levels +type Priority string + +const ( + PriorityCritical Priority = "critical" + PriorityHigh Priority = "high" + PriorityMedium Priority = "medium" + PriorityLow Priority = "low" +) + +// Effort levels +type Effort string + +const ( + EffortLow Effort = "low" + EffortMedium Effort = "medium" + EffortHigh Effort = "high" +) + +// HistoricalPoint tracks RTO/RPO over time +type HistoricalPoint struct { + Timestamp time.Time `json:"timestamp"` + RPO time.Duration `json:"rpo"` + RTO time.Duration `json:"rto"` +} + +// NewCalculator creates a new RTO/RPO calculator +func NewCalculator(cat catalog.Catalog, config Config) *Calculator { + return &Calculator{ + catalog: cat, + config: config, + } +} + +// Analyze performs RTO/RPO analysis for a database +func (c *Calculator) Analyze(ctx context.Context, database string) (*Analysis, error) { + analysis := &Analysis{ + Database: database, + Timestamp: time.Now(), + TargetRPO: c.config.TargetRPO, + TargetRTO: c.config.TargetRTO, + } + + // Get recent backups + entries, err := c.catalog.List(ctx, database, 100) + if err != nil { + return nil, fmt.Errorf("failed to list backups: %w", err) + } + + if len(entries) == 0 { + // No backups - worst case scenario + analysis.CurrentRPO = 0 // undefined + analysis.CurrentRTO = 0 // undefined + analysis.Recommendations = append(analysis.Recommendations, Recommendation{ + Type: RecommendBackupFrequency, + Priority: PriorityCritical, + Title: "No Backups Found", + Description: "No backups exist for this database", + Impact: "Cannot recover in case of failure", + Effort: EffortLow, + }) + return analysis, nil + } + + // Calculate current RPO (time since last backup) + lastBackup := entries[0].CreatedAt + analysis.LastBackup = &lastBackup + analysis.CurrentRPO = time.Since(lastBackup) + analysis.RPOCompliant = analysis.CurrentRPO <= c.config.TargetRPO + + // Calculate backup interval + if len(entries) >= 2 { + analysis.BackupInterval = calculateAverageInterval(entries) + } + + // Calculate RTO + latestEntry := entries[0] + analysis.RTOBreakdown = c.calculateRTOBreakdown(latestEntry) + analysis.CurrentRTO = analysis.RTOBreakdown.TotalTime + analysis.RTOCompliant = analysis.CurrentRTO <= c.config.TargetRTO + + // Generate recommendations + analysis.Recommendations = c.generateRecommendations(analysis, entries) + + // Calculate history + analysis.History = c.calculateHistory(entries) + + return analysis, nil +} + +// AnalyzeAll analyzes all databases +func (c *Calculator) AnalyzeAll(ctx context.Context) ([]*Analysis, error) { + databases, err := c.catalog.ListDatabases(ctx) + if err != nil { + return nil, fmt.Errorf("failed to list databases: %w", err) + } + + var analyses []*Analysis + for _, db := range databases { + analysis, err := c.Analyze(ctx, db) + if err != nil { + continue // Skip errors for individual databases + } + analyses = append(analyses, analysis) + } + + return analyses, nil +} + +// calculateRTOBreakdown calculates RTO components +func (c *Calculator) calculateRTOBreakdown(entry *catalog.Entry) RTOBreakdown { + breakdown := RTOBreakdown{ + // Detection time - assume monitoring catches issues quickly + DetectionTime: 5 * time.Minute, + + // Decision time - human decision making + DecisionTime: 10 * time.Minute, + + // Startup time + StartupTime: time.Duration(c.config.StartupTimeMinutes) * time.Minute, + + // Validation time + ValidationTime: time.Duration(c.config.ValidationTimeMinutes) * time.Minute, + + // Switchover time + SwitchoverTime: time.Duration(c.config.SwitchoverTimeMinutes) * time.Minute, + } + + // Calculate download time (if cloud backup) + if entry.CloudLocation != "" { + // Cloud download + bytesPerSecond := c.config.CloudDownloadSpeedMbps * 125000 // Mbps to bytes/sec + downloadSeconds := float64(entry.SizeBytes) / bytesPerSecond + breakdown.DownloadTime = time.Duration(downloadSeconds * float64(time.Second)) + } + + // Calculate restore time + // Estimate based on disk write speed + bytesPerSecond := c.config.DiskWriteSpeedMBps * 1000000 // MB/s to bytes/sec + restoreSeconds := float64(entry.SizeBytes) / bytesPerSecond + + // Add overhead for decompression if compressed + if entry.Compression != "" && entry.Compression != "none" { + restoreSeconds *= 1.3 // 30% overhead for decompression + } + + // Add overhead for decryption if encrypted + if entry.Encrypted { + restoreSeconds *= 1.1 // 10% overhead for decryption + } + + breakdown.RestoreTime = time.Duration(restoreSeconds * float64(time.Second)) + + // Calculate total + breakdown.TotalTime = breakdown.DetectionTime + + breakdown.DecisionTime + + breakdown.DownloadTime + + breakdown.RestoreTime + + breakdown.StartupTime + + breakdown.ValidationTime + + breakdown.SwitchoverTime + + return breakdown +} + +// calculateAverageInterval calculates average time between backups +func calculateAverageInterval(entries []*catalog.Entry) time.Duration { + if len(entries) < 2 { + return 0 + } + + var totalInterval time.Duration + for i := 0; i < len(entries)-1; i++ { + interval := entries[i].CreatedAt.Sub(entries[i+1].CreatedAt) + totalInterval += interval + } + + return totalInterval / time.Duration(len(entries)-1) +} + +// generateRecommendations creates recommendations based on analysis +func (c *Calculator) generateRecommendations(analysis *Analysis, entries []*catalog.Entry) []Recommendation { + var recommendations []Recommendation + + // RPO violations + if !analysis.RPOCompliant { + gap := analysis.CurrentRPO - c.config.TargetRPO + recommendations = append(recommendations, Recommendation{ + Type: RecommendBackupFrequency, + Priority: PriorityCritical, + Title: "RPO Target Not Met", + Description: fmt.Sprintf("Current RPO (%s) exceeds target (%s) by %s", + formatDuration(analysis.CurrentRPO), + formatDuration(c.config.TargetRPO), + formatDuration(gap)), + Impact: "Potential data loss exceeds acceptable threshold", + Effort: EffortLow, + }) + } + + // RTO violations + if !analysis.RTOCompliant { + recommendations = append(recommendations, Recommendation{ + Type: RecommendParallelRestore, + Priority: PriorityHigh, + Title: "RTO Target Not Met", + Description: fmt.Sprintf("Estimated recovery time (%s) exceeds target (%s)", + formatDuration(analysis.CurrentRTO), + formatDuration(c.config.TargetRTO)), + Impact: "Recovery may take longer than acceptable", + Effort: EffortMedium, + }) + } + + // Large download time + if analysis.RTOBreakdown.DownloadTime > 30*time.Minute { + recommendations = append(recommendations, Recommendation{ + Type: RecommendLocalCache, + Priority: PriorityMedium, + Title: "Consider Local Backup Cache", + Description: fmt.Sprintf("Cloud download takes %s, local cache would reduce this", + formatDuration(analysis.RTOBreakdown.DownloadTime)), + Impact: "Faster recovery from local storage", + Effort: EffortMedium, + }) + } + + // No incremental backups + hasIncremental := false + for _, e := range entries { + if e.BackupType == "incremental" { + hasIncremental = true + break + } + } + if !hasIncremental && analysis.BackupInterval > 6*time.Hour { + recommendations = append(recommendations, Recommendation{ + Type: RecommendIncrementalBackup, + Priority: PriorityMedium, + Title: "Enable Incremental Backups", + Description: "Incremental backups can reduce backup time and storage", + Impact: "Better RPO with less resource usage", + Effort: EffortLow, + }) + } + + // WAL archiving for PostgreSQL + if len(entries) > 0 && entries[0].DatabaseType == "postgresql" { + recommendations = append(recommendations, Recommendation{ + Type: RecommendWALArchiving, + Priority: PriorityMedium, + Title: "Consider WAL Archiving", + Description: "Enable WAL archiving for point-in-time recovery", + Impact: "Achieve near-zero RPO with PITR", + Effort: EffortMedium, + }) + } + + return recommendations +} + +// calculateHistory generates historical RTO/RPO points +func (c *Calculator) calculateHistory(entries []*catalog.Entry) []HistoricalPoint { + var history []HistoricalPoint + + // Sort entries by date (oldest first) + sorted := make([]*catalog.Entry, len(entries)) + copy(sorted, entries) + sort.Slice(sorted, func(i, j int) bool { + return sorted[i].CreatedAt.Before(sorted[j].CreatedAt) + }) + + for i, entry := range sorted { + point := HistoricalPoint{ + Timestamp: entry.CreatedAt, + RTO: c.calculateRTOBreakdown(entry).TotalTime, + } + + // Calculate RPO at that point (time until next backup) + if i < len(sorted)-1 { + point.RPO = sorted[i+1].CreatedAt.Sub(entry.CreatedAt) + } else { + point.RPO = time.Since(entry.CreatedAt) + } + + history = append(history, point) + } + + return history +} + +// Summary provides aggregate RTO/RPO status +type Summary struct { + TotalDatabases int `json:"total_databases"` + RPOCompliant int `json:"rpo_compliant"` + RTOCompliant int `json:"rto_compliant"` + FullyCompliant int `json:"fully_compliant"` + CriticalIssues int `json:"critical_issues"` + WorstRPO time.Duration `json:"worst_rpo"` + WorstRTO time.Duration `json:"worst_rto"` + WorstRPODatabase string `json:"worst_rpo_database"` + WorstRTODatabase string `json:"worst_rto_database"` + AverageRPO time.Duration `json:"average_rpo"` + AverageRTO time.Duration `json:"average_rto"` +} + +// Summarize creates a summary from analyses +func Summarize(analyses []*Analysis) *Summary { + summary := &Summary{} + + var totalRPO, totalRTO time.Duration + + for _, a := range analyses { + summary.TotalDatabases++ + + if a.RPOCompliant { + summary.RPOCompliant++ + } + if a.RTOCompliant { + summary.RTOCompliant++ + } + if a.RPOCompliant && a.RTOCompliant { + summary.FullyCompliant++ + } + + for _, r := range a.Recommendations { + if r.Priority == PriorityCritical { + summary.CriticalIssues++ + break + } + } + + if a.CurrentRPO > summary.WorstRPO { + summary.WorstRPO = a.CurrentRPO + summary.WorstRPODatabase = a.Database + } + if a.CurrentRTO > summary.WorstRTO { + summary.WorstRTO = a.CurrentRTO + summary.WorstRTODatabase = a.Database + } + + totalRPO += a.CurrentRPO + totalRTO += a.CurrentRTO + } + + if len(analyses) > 0 { + summary.AverageRPO = totalRPO / time.Duration(len(analyses)) + summary.AverageRTO = totalRTO / time.Duration(len(analyses)) + } + + return summary +} + +func formatDuration(d time.Duration) string { + if d < time.Minute { + return fmt.Sprintf("%.0fs", d.Seconds()) + } + if d < time.Hour { + return fmt.Sprintf("%.0fm", d.Minutes()) + } + hours := int(d.Hours()) + mins := int(d.Minutes()) - hours*60 + return fmt.Sprintf("%dh %dm", hours, mins) +}