feat: Add enterprise DBA features for production reliability

New features implemented:

1. Backup Catalog (internal/catalog/)
   - SQLite-based backup tracking
   - Gap detection and RPO monitoring
   - Search and statistics
   - Filesystem sync

2. DR Drill Testing (internal/drill/)
   - Automated restore testing in Docker containers
   - Database validation with custom queries
   - Catalog integration for drill-tested status

3. Smart Notifications (internal/notify/)
   - Event batching with configurable intervals
   - Time-based escalation policies
   - HTML/text/Slack templates

4. Compliance Reports (internal/report/)
   - SOC2, GDPR, HIPAA, PCI-DSS, ISO27001 frameworks
   - Evidence collection from catalog
   - JSON, Markdown, HTML output formats

5. RTO/RPO Calculator (internal/rto/)
   - Recovery objective analysis
   - RTO breakdown by phase
   - Recommendations for improvement

6. Replica-Aware Backup (internal/replica/)
   - Topology detection for PostgreSQL/MySQL
   - Automatic replica selection
   - Configurable selection strategies

7. Parallel Table Backup (internal/parallel/)
   - Concurrent table dumps
   - Worker pool with progress tracking
   - Large table optimization

8. MySQL/MariaDB PITR (internal/pitr/)
   - Binary log parsing and replay
   - Point-in-time recovery support
   - Transaction filtering

CLI commands added: catalog, drill, report, rto

All changes support the goal: reliable 3 AM database recovery.
This commit is contained in:
2025-12-13 20:28:55 +01:00
parent d0d83b61ef
commit f69bfe7071
34 changed files with 13469 additions and 41 deletions

401
MYSQL_PITR.md Normal file
View File

@@ -0,0 +1,401 @@
# MySQL/MariaDB Point-in-Time Recovery (PITR)
This guide explains how to use dbbackup for Point-in-Time Recovery with MySQL and MariaDB databases.
## Overview
Point-in-Time Recovery (PITR) allows you to restore your database to any specific moment in time, not just to when a backup was taken. This is essential for:
- Recovering from accidental data deletion or corruption
- Restoring to a state just before a problematic change
- Meeting regulatory compliance requirements for data recovery
### How MySQL PITR Works
MySQL PITR uses binary logs (binlogs) which record all changes to the database:
1. **Base Backup**: A full database backup with the binlog position recorded
2. **Binary Log Archiving**: Continuous archiving of binlog files
3. **Recovery**: Restore base backup, then replay binlogs up to the target time
```
┌─────────────┐ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐
│ Base Backup │ --> │ binlog.00001 │ --> │ binlog.00002 │ --> │ binlog.00003 │
│ (pos: 1234) │ │ │ │ │ │ (current) │
└─────────────┘ └──────────────┘ └──────────────┘ └──────────────┘
│ │ │ │
▼ ▼ ▼ ▼
10:00 AM 10:30 AM 11:00 AM 11:30 AM
Target: 11:15 AM
```
## Prerequisites
### MySQL Configuration
Binary logging must be enabled in MySQL. Add to `my.cnf`:
```ini
[mysqld]
# Enable binary logging
log_bin = mysql-bin
server_id = 1
# Recommended: Use ROW format for PITR
binlog_format = ROW
# Optional but recommended: Enable GTID for easier replication and recovery
gtid_mode = ON
enforce_gtid_consistency = ON
# Keep binlogs for at least 7 days (adjust as needed)
expire_logs_days = 7
# Or for MySQL 8.0+:
# binlog_expire_logs_seconds = 604800
```
After changing configuration, restart MySQL:
```bash
sudo systemctl restart mysql
```
### MariaDB Configuration
MariaDB configuration is similar:
```ini
[mysqld]
log_bin = mariadb-bin
server_id = 1
binlog_format = ROW
# MariaDB uses different GTID implementation (auto-enabled with log_slave_updates)
log_slave_updates = ON
```
## Quick Start
### 1. Check PITR Status
```bash
# Check if MySQL is properly configured for PITR
dbbackup pitr mysql-status
```
Example output:
```
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
MySQL/MariaDB PITR Status (mysql)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
PITR Status: ❌ NOT CONFIGURED
Binary Logging: ✅ ENABLED
Binlog Format: ROW
GTID Mode: ON
Current Position: mysql-bin.000042:1234
PITR Requirements:
✅ Binary logging enabled
✅ Row-based logging (recommended)
```
### 2. Enable PITR
```bash
# Enable PITR and configure archive directory
dbbackup pitr mysql-enable --archive-dir /backups/binlog_archive
```
### 3. Create a Base Backup
```bash
# Create a PITR-capable backup
dbbackup backup single mydb --pitr
```
### 4. Start Binlog Archiving
```bash
# Run binlog archiver in the background
dbbackup binlog watch --binlog-dir /var/lib/mysql --archive-dir /backups/binlog_archive --interval 30s
```
Or set up a cron job for periodic archiving:
```bash
# Archive new binlogs every 5 minutes
*/5 * * * * dbbackup binlog archive --binlog-dir /var/lib/mysql --archive-dir /backups/binlog_archive
```
### 5. Restore to Point in Time
```bash
# Restore to a specific time
dbbackup restore pitr mydb_backup.sql.gz --target-time '2024-01-15 14:30:00'
```
## Commands Reference
### PITR Commands
#### `pitr mysql-status`
Show MySQL/MariaDB PITR configuration and status.
```bash
dbbackup pitr mysql-status
```
#### `pitr mysql-enable`
Enable PITR for MySQL/MariaDB.
```bash
dbbackup pitr mysql-enable \
--archive-dir /backups/binlog_archive \
--retention-days 7 \
--require-row-format \
--require-gtid
```
Options:
- `--archive-dir`: Directory to store archived binlogs (required)
- `--retention-days`: Days to keep archived binlogs (default: 7)
- `--require-row-format`: Require ROW binlog format (default: true)
- `--require-gtid`: Require GTID mode enabled (default: false)
### Binlog Commands
#### `binlog list`
List available binary log files.
```bash
# List binlogs from MySQL data directory
dbbackup binlog list --binlog-dir /var/lib/mysql
# List archived binlogs
dbbackup binlog list --archive-dir /backups/binlog_archive
```
#### `binlog archive`
Archive binary log files.
```bash
dbbackup binlog archive \
--binlog-dir /var/lib/mysql \
--archive-dir /backups/binlog_archive \
--compress
```
Options:
- `--binlog-dir`: MySQL binary log directory
- `--archive-dir`: Destination for archived binlogs (required)
- `--compress`: Compress archived binlogs with gzip
- `--encrypt`: Encrypt archived binlogs
- `--encryption-key-file`: Path to encryption key file
#### `binlog watch`
Continuously monitor and archive new binlog files.
```bash
dbbackup binlog watch \
--binlog-dir /var/lib/mysql \
--archive-dir /backups/binlog_archive \
--interval 30s \
--compress
```
Options:
- `--interval`: How often to check for new binlogs (default: 30s)
#### `binlog validate`
Validate binlog chain integrity.
```bash
dbbackup binlog validate --binlog-dir /var/lib/mysql
```
Output shows:
- Whether the chain is complete (no missing files)
- Any gaps in the sequence
- Server ID changes (indicating possible failover)
- Total size and file count
#### `binlog position`
Show current binary log position.
```bash
dbbackup binlog position
```
Output:
```
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Current Binary Log Position
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
File: mysql-bin.000042
Position: 123456
GTID Set: 3E11FA47-71CA-11E1-9E33-C80AA9429562:1-1000
Position String: mysql-bin.000042:123456
```
## Restore Scenarios
### Restore to Specific Time
```bash
# Restore to January 15, 2024 at 2:30 PM
dbbackup restore pitr mydb_backup.sql.gz \
--target-time '2024-01-15 14:30:00'
```
### Restore to Specific Position
```bash
# Restore to a specific binlog position
dbbackup restore pitr mydb_backup.sql.gz \
--target-position 'mysql-bin.000042:12345'
```
### Dry Run (Preview)
```bash
# See what SQL would be replayed without applying
dbbackup restore pitr mydb_backup.sql.gz \
--target-time '2024-01-15 14:30:00' \
--dry-run
```
### Restore to Backup Point Only
```bash
# Restore just the base backup without replaying binlogs
dbbackup restore pitr mydb_backup.sql.gz --immediate
```
## Best Practices
### 1. Archiving Strategy
- Archive binlogs frequently (every 5-30 minutes)
- Use compression to save disk space
- Store archives on separate storage from the database
### 2. Retention Policy
- Keep archives for at least as long as your oldest valid base backup
- Consider regulatory requirements for data retention
- Use the cleanup command to purge old archives:
```bash
dbbackup binlog cleanup --archive-dir /backups/binlog_archive --retention-days 30
```
### 3. Validation
- Regularly validate your binlog chain:
```bash
dbbackup binlog validate --binlog-dir /var/lib/mysql
```
- Test restoration periodically on a test environment
### 4. Monitoring
- Monitor the `dbbackup binlog watch` process
- Set up alerts for:
- Binlog archiver failures
- Gaps in binlog chain
- Low disk space on archive directory
### 5. GTID Mode
Enable GTID for:
- Easier tracking of replication position
- Automatic failover in replication setups
- Simpler point-in-time recovery
## Troubleshooting
### Binary Logging Not Enabled
**Error**: "Binary logging appears to be disabled"
**Solution**: Add to my.cnf and restart MySQL:
```ini
[mysqld]
log_bin = mysql-bin
server_id = 1
```
### Missing Binlog Files
**Error**: "Gaps detected in binlog chain"
**Causes**:
- `RESET MASTER` was executed
- `expire_logs_days` is too short
- Binlogs were manually deleted
**Solution**:
- Take a new base backup immediately
- Adjust retention settings to prevent future gaps
### Permission Denied
**Error**: "Failed to read binlog directory"
**Solution**:
```bash
# Add dbbackup user to mysql group
sudo usermod -aG mysql dbbackup_user
# Or set appropriate permissions
sudo chmod g+r /var/lib/mysql/mysql-bin.*
```
### Wrong Binlog Format
**Warning**: "binlog_format = STATEMENT (ROW recommended)"
**Impact**: STATEMENT format may not capture all changes accurately
**Solution**: Change to ROW format (requires restart):
```ini
[mysqld]
binlog_format = ROW
```
### Server ID Changes
**Warning**: "server_id changed from X to Y (possible master failover)"
This warning indicates the binlog chain contains events from different servers, which may happen during:
- Failover in a replication setup
- Restoring from a different server's backup
This is usually informational but review your topology if unexpected.
## MariaDB-Specific Notes
### GTID Format
MariaDB uses a different GTID format than MySQL:
- **MySQL**: `3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5`
- **MariaDB**: `0-1-100` (domain-server_id-sequence)
### Tool Detection
dbbackup automatically detects MariaDB and uses:
- `mariadb-binlog` if available (MariaDB 10.4+)
- Falls back to `mysqlbinlog` for older versions
### Encrypted Binlogs
MariaDB supports binlog encryption. If enabled, ensure the key is available during archive and restore operations.
## See Also
- [PITR.md](PITR.md) - PostgreSQL PITR documentation
- [DOCKER.md](DOCKER.md) - Running in Docker environments
- [CLOUD.md](CLOUD.md) - Cloud storage for archives

127
README.md
View File

@@ -16,12 +16,22 @@ Database backup and restore utility for PostgreSQL, MySQL, and MariaDB.
- AES-256-GCM encryption
- Incremental backups
- Cloud storage: S3, MinIO, B2, Azure Blob, Google Cloud Storage
- Point-in-Time Recovery (PITR) for PostgreSQL
- Point-in-Time Recovery (PITR) for PostgreSQL and MySQL/MariaDB
- **GFS retention policies**: Grandfather-Father-Son backup rotation
- **Notifications**: SMTP email and webhook alerts
- Interactive terminal UI
- Cross-platform binaries
### Enterprise DBA Features
- **Backup Catalog**: SQLite-based catalog tracking all backups with gap detection
- **DR Drill Testing**: Automated disaster recovery testing in Docker containers
- **Smart Notifications**: Batched alerts with escalation policies
- **Compliance Reports**: SOC2, GDPR, HIPAA, PCI-DSS, ISO27001 report generation
- **RTO/RPO Calculator**: Recovery objective analysis and recommendations
- **Replica-Aware Backup**: Automatic backup from replicas to reduce primary load
- **Parallel Table Backup**: Concurrent table dumps for faster backups
## Installation
### Docker
@@ -257,6 +267,10 @@ dbbackup backup single mydb --dry-run
| `pitr` | PITR management |
| `wal` | WAL archive operations |
| `interactive` | Start interactive UI |
| `catalog` | Backup catalog management |
| `drill` | DR drill testing |
| `report` | Compliance report generation |
| `rto` | RTO/RPO analysis |
## Global Flags
@@ -478,6 +492,117 @@ dbbackup backup single mydb --notify
- `cleanup_completed`
- `verify_completed`, `verify_failed`
- `pitr_recovery`
- `dr_drill_passed`, `dr_drill_failed`
- `gap_detected`, `rpo_violation`
## Backup Catalog
Track all backups in a SQLite catalog with gap detection and search:
```bash
# Sync backups from directory to catalog
dbbackup catalog sync /backups
# List recent backups
dbbackup catalog list --database mydb --limit 10
# Show catalog statistics
dbbackup catalog stats
# Detect backup gaps (missing scheduled backups)
dbbackup catalog gaps --interval 24h --database mydb
# Search backups
dbbackup catalog search --database mydb --start 2024-01-01 --end 2024-12-31
# Get backup info
dbbackup catalog info 42
```
## DR Drill Testing
Automated disaster recovery testing restores backups to Docker containers:
```bash
# Run full DR drill
dbbackup drill run /backups/mydb_latest.dump.gz \
--database mydb \
--db-type postgres \
--timeout 30m
# Quick drill (restore + basic validation)
dbbackup drill quick /backups/mydb_latest.dump.gz --database mydb
# List running drill containers
dbbackup drill list
# Cleanup old drill containers
dbbackup drill cleanup --age 24h
# Generate drill report
dbbackup drill report --format html --output drill-report.html
```
**Drill phases:**
1. Container creation
2. Backup download (if cloud)
3. Restore execution
4. Database validation
5. Custom query checks
6. Cleanup
## Compliance Reports
Generate compliance reports for regulatory frameworks:
```bash
# Generate SOC2 report
dbbackup report generate --type soc2 --days 90 --format html --output soc2-report.html
# HIPAA compliance report
dbbackup report generate --type hipaa --format markdown
# Show compliance summary
dbbackup report summary --type gdpr --days 30
# List available frameworks
dbbackup report list
# Show controls for a framework
dbbackup report controls soc2
```
**Supported frameworks:**
- SOC2 Type II (Trust Service Criteria)
- GDPR (General Data Protection Regulation)
- HIPAA (Health Insurance Portability and Accountability Act)
- PCI-DSS (Payment Card Industry Data Security Standard)
- ISO 27001 (Information Security Management)
## RTO/RPO Analysis
Calculate and monitor Recovery Time/Point Objectives:
```bash
# Analyze RTO/RPO for a database
dbbackup rto analyze mydb
# Show status for all databases
dbbackup rto status
# Check against targets
dbbackup rto check --rto 4h --rpo 1h
# Set target objectives
dbbackup rto analyze mydb --target-rto 4h --target-rpo 1h
```
**Analysis includes:**
- Current RPO (time since last backup)
- Estimated RTO (detection + download + restore + validation)
- RTO breakdown by phase
- Compliance status
- Recommendations for improvement
## Configuration

725
cmd/catalog.go Normal file
View File

@@ -0,0 +1,725 @@
package cmd
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"dbbackup/internal/catalog"
"github.com/spf13/cobra"
)
var (
catalogDBPath string
catalogFormat string
catalogLimit int
catalogDatabase string
catalogStartDate string
catalogEndDate string
catalogInterval string
catalogVerbose bool
)
// catalogCmd represents the catalog command group
var catalogCmd = &cobra.Command{
Use: "catalog",
Short: "Backup catalog management",
Long: `Manage the backup catalog - a SQLite database tracking all backups.
The catalog provides:
- Searchable history of all backups
- Gap detection for backup schedules
- Statistics and reporting
- Integration with DR drill testing
Examples:
# Sync backups from a directory
dbbackup catalog sync /backups
# List all backups
dbbackup catalog list
# Show catalog statistics
dbbackup catalog stats
# Detect gaps in backup schedule
dbbackup catalog gaps mydb --interval 24h
# Search backups
dbbackup catalog search --database mydb --after 2024-01-01`,
}
// catalogSyncCmd syncs backups from directory
var catalogSyncCmd = &cobra.Command{
Use: "sync [directory]",
Short: "Sync backups from directory into catalog",
Long: `Scan a directory for backup files and import them into the catalog.
This command:
- Finds all .meta.json files
- Imports backup metadata into SQLite catalog
- Detects removed backups
- Updates changed entries
Examples:
# Sync from backup directory
dbbackup catalog sync /backups
# Sync with verbose output
dbbackup catalog sync /backups --verbose`,
Args: cobra.MinimumNArgs(1),
RunE: runCatalogSync,
}
// catalogListCmd lists backups
var catalogListCmd = &cobra.Command{
Use: "list",
Short: "List backups in catalog",
Long: `List all backups in the catalog with optional filtering.
Examples:
# List all backups
dbbackup catalog list
# List backups for specific database
dbbackup catalog list --database mydb
# List last 10 backups
dbbackup catalog list --limit 10
# Output as JSON
dbbackup catalog list --format json`,
RunE: runCatalogList,
}
// catalogStatsCmd shows statistics
var catalogStatsCmd = &cobra.Command{
Use: "stats",
Short: "Show catalog statistics",
Long: `Display comprehensive backup statistics.
Shows:
- Total backup count and size
- Backups by database
- Backups by type and status
- Verification and drill test coverage
Examples:
# Show overall stats
dbbackup catalog stats
# Stats for specific database
dbbackup catalog stats --database mydb
# Output as JSON
dbbackup catalog stats --format json`,
RunE: runCatalogStats,
}
// catalogGapsCmd detects schedule gaps
var catalogGapsCmd = &cobra.Command{
Use: "gaps [database]",
Short: "Detect gaps in backup schedule",
Long: `Analyze backup history and detect schedule gaps.
This helps identify:
- Missed backups
- Schedule irregularities
- RPO violations
Examples:
# Check all databases for gaps (24h expected interval)
dbbackup catalog gaps
# Check specific database with custom interval
dbbackup catalog gaps mydb --interval 6h
# Check gaps in date range
dbbackup catalog gaps --after 2024-01-01 --before 2024-02-01`,
RunE: runCatalogGaps,
}
// catalogSearchCmd searches backups
var catalogSearchCmd = &cobra.Command{
Use: "search",
Short: "Search backups in catalog",
Long: `Search for backups matching specific criteria.
Examples:
# Search by database name (supports wildcards)
dbbackup catalog search --database "prod*"
# Search by date range
dbbackup catalog search --after 2024-01-01 --before 2024-02-01
# Search verified backups only
dbbackup catalog search --verified
# Search encrypted backups
dbbackup catalog search --encrypted`,
RunE: runCatalogSearch,
}
// catalogInfoCmd shows entry details
var catalogInfoCmd = &cobra.Command{
Use: "info [backup-path]",
Short: "Show detailed info for a backup",
Long: `Display detailed information about a specific backup.
Examples:
# Show info by path
dbbackup catalog info /backups/mydb_20240115.dump.gz`,
Args: cobra.ExactArgs(1),
RunE: runCatalogInfo,
}
func init() {
rootCmd.AddCommand(catalogCmd)
// Default catalog path
defaultCatalogPath := filepath.Join(getDefaultConfigDir(), "catalog.db")
// Global catalog flags
catalogCmd.PersistentFlags().StringVar(&catalogDBPath, "catalog-db", defaultCatalogPath,
"Path to catalog SQLite database")
catalogCmd.PersistentFlags().StringVar(&catalogFormat, "format", "table",
"Output format: table, json, csv")
// Add subcommands
catalogCmd.AddCommand(catalogSyncCmd)
catalogCmd.AddCommand(catalogListCmd)
catalogCmd.AddCommand(catalogStatsCmd)
catalogCmd.AddCommand(catalogGapsCmd)
catalogCmd.AddCommand(catalogSearchCmd)
catalogCmd.AddCommand(catalogInfoCmd)
// Sync flags
catalogSyncCmd.Flags().BoolVarP(&catalogVerbose, "verbose", "v", false, "Show detailed output")
// List flags
catalogListCmd.Flags().IntVar(&catalogLimit, "limit", 50, "Maximum entries to show")
catalogListCmd.Flags().StringVar(&catalogDatabase, "database", "", "Filter by database name")
// Stats flags
catalogStatsCmd.Flags().StringVar(&catalogDatabase, "database", "", "Show stats for specific database")
// Gaps flags
catalogGapsCmd.Flags().StringVar(&catalogInterval, "interval", "24h", "Expected backup interval")
catalogGapsCmd.Flags().StringVar(&catalogStartDate, "after", "", "Start date (YYYY-MM-DD)")
catalogGapsCmd.Flags().StringVar(&catalogEndDate, "before", "", "End date (YYYY-MM-DD)")
// Search flags
catalogSearchCmd.Flags().StringVar(&catalogDatabase, "database", "", "Filter by database name (supports wildcards)")
catalogSearchCmd.Flags().StringVar(&catalogStartDate, "after", "", "Backups after date (YYYY-MM-DD)")
catalogSearchCmd.Flags().StringVar(&catalogEndDate, "before", "", "Backups before date (YYYY-MM-DD)")
catalogSearchCmd.Flags().IntVar(&catalogLimit, "limit", 100, "Maximum results")
catalogSearchCmd.Flags().Bool("verified", false, "Only verified backups")
catalogSearchCmd.Flags().Bool("encrypted", false, "Only encrypted backups")
catalogSearchCmd.Flags().Bool("drill-tested", false, "Only drill-tested backups")
}
func getDefaultConfigDir() string {
home, _ := os.UserHomeDir()
return filepath.Join(home, ".dbbackup")
}
func openCatalog() (*catalog.SQLiteCatalog, error) {
return catalog.NewSQLiteCatalog(catalogDBPath)
}
func runCatalogSync(cmd *cobra.Command, args []string) error {
dir := args[0]
// Validate directory
info, err := os.Stat(dir)
if err != nil {
return fmt.Errorf("directory not found: %s", dir)
}
if !info.IsDir() {
return fmt.Errorf("not a directory: %s", dir)
}
absDir, _ := filepath.Abs(dir)
cat, err := openCatalog()
if err != nil {
return err
}
defer cat.Close()
fmt.Printf("📁 Syncing backups from: %s\n", absDir)
fmt.Printf("📊 Catalog database: %s\n\n", catalogDBPath)
ctx := context.Background()
result, err := cat.SyncFromDirectory(ctx, absDir)
if err != nil {
return err
}
// Update last sync time
cat.SetLastSync(ctx)
// Show results
fmt.Printf("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n")
fmt.Printf(" Sync Results\n")
fmt.Printf("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n")
fmt.Printf(" ✅ Added: %d\n", result.Added)
fmt.Printf(" 🔄 Updated: %d\n", result.Updated)
fmt.Printf(" 🗑️ Removed: %d\n", result.Removed)
if result.Errors > 0 {
fmt.Printf(" ❌ Errors: %d\n", result.Errors)
}
fmt.Printf(" ⏱️ Duration: %.2fs\n", result.Duration)
fmt.Printf("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n")
// Show details if verbose
if catalogVerbose && len(result.Details) > 0 {
fmt.Printf("\nDetails:\n")
for _, detail := range result.Details {
fmt.Printf(" %s\n", detail)
}
}
return nil
}
func runCatalogList(cmd *cobra.Command, args []string) error {
cat, err := openCatalog()
if err != nil {
return err
}
defer cat.Close()
ctx := context.Background()
query := &catalog.SearchQuery{
Database: catalogDatabase,
Limit: catalogLimit,
OrderBy: "created_at",
OrderDesc: true,
}
entries, err := cat.Search(ctx, query)
if err != nil {
return err
}
if len(entries) == 0 {
fmt.Println("No backups in catalog. Run 'dbbackup catalog sync <directory>' to import backups.")
return nil
}
if catalogFormat == "json" {
data, _ := json.MarshalIndent(entries, "", " ")
fmt.Println(string(data))
return nil
}
// Table format
fmt.Printf("%-30s %-12s %-10s %-20s %-10s %s\n",
"DATABASE", "TYPE", "SIZE", "CREATED", "STATUS", "PATH")
fmt.Println(strings.Repeat("─", 120))
for _, entry := range entries {
dbName := truncateString(entry.Database, 28)
backupPath := truncateString(filepath.Base(entry.BackupPath), 40)
status := string(entry.Status)
if entry.VerifyValid != nil && *entry.VerifyValid {
status = "✓ verified"
}
if entry.DrillSuccess != nil && *entry.DrillSuccess {
status = "✓ tested"
}
fmt.Printf("%-30s %-12s %-10s %-20s %-10s %s\n",
dbName,
entry.DatabaseType,
catalog.FormatSize(entry.SizeBytes),
entry.CreatedAt.Format("2006-01-02 15:04"),
status,
backupPath,
)
}
fmt.Printf("\nShowing %d of %d total backups\n", len(entries), len(entries))
return nil
}
func runCatalogStats(cmd *cobra.Command, args []string) error {
cat, err := openCatalog()
if err != nil {
return err
}
defer cat.Close()
ctx := context.Background()
var stats *catalog.Stats
if catalogDatabase != "" {
stats, err = cat.StatsByDatabase(ctx, catalogDatabase)
} else {
stats, err = cat.Stats(ctx)
}
if err != nil {
return err
}
if catalogFormat == "json" {
data, _ := json.MarshalIndent(stats, "", " ")
fmt.Println(string(data))
return nil
}
// Table format
fmt.Printf("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n")
if catalogDatabase != "" {
fmt.Printf(" Catalog Statistics: %s\n", catalogDatabase)
} else {
fmt.Printf(" Catalog Statistics\n")
}
fmt.Printf("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n")
fmt.Printf("📊 Total Backups: %d\n", stats.TotalBackups)
fmt.Printf("💾 Total Size: %s\n", stats.TotalSizeHuman)
fmt.Printf("📏 Average Size: %s\n", catalog.FormatSize(stats.AvgSize))
fmt.Printf("⏱️ Average Duration: %.1fs\n", stats.AvgDuration)
fmt.Printf("✅ Verified: %d\n", stats.VerifiedCount)
fmt.Printf("🧪 Drill Tested: %d\n", stats.DrillTestedCount)
if stats.OldestBackup != nil {
fmt.Printf("📅 Oldest Backup: %s\n", stats.OldestBackup.Format("2006-01-02 15:04"))
}
if stats.NewestBackup != nil {
fmt.Printf("📅 Newest Backup: %s\n", stats.NewestBackup.Format("2006-01-02 15:04"))
}
if len(stats.ByDatabase) > 0 && catalogDatabase == "" {
fmt.Printf("\n📁 By Database:\n")
for db, count := range stats.ByDatabase {
fmt.Printf(" %-30s %d\n", db, count)
}
}
if len(stats.ByType) > 0 {
fmt.Printf("\n📦 By Type:\n")
for t, count := range stats.ByType {
fmt.Printf(" %-15s %d\n", t, count)
}
}
if len(stats.ByStatus) > 0 {
fmt.Printf("\n📋 By Status:\n")
for s, count := range stats.ByStatus {
fmt.Printf(" %-15s %d\n", s, count)
}
}
fmt.Printf("\n━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n")
return nil
}
func runCatalogGaps(cmd *cobra.Command, args []string) error {
cat, err := openCatalog()
if err != nil {
return err
}
defer cat.Close()
ctx := context.Background()
// Parse interval
interval, err := time.ParseDuration(catalogInterval)
if err != nil {
return fmt.Errorf("invalid interval: %w", err)
}
config := &catalog.GapDetectionConfig{
ExpectedInterval: interval,
Tolerance: interval / 4, // 25% tolerance
RPOThreshold: interval * 2, // 2x interval = critical
}
// Parse date range
if catalogStartDate != "" {
t, err := time.Parse("2006-01-02", catalogStartDate)
if err != nil {
return fmt.Errorf("invalid start date: %w", err)
}
config.StartDate = &t
}
if catalogEndDate != "" {
t, err := time.Parse("2006-01-02", catalogEndDate)
if err != nil {
return fmt.Errorf("invalid end date: %w", err)
}
config.EndDate = &t
}
var allGaps map[string][]*catalog.Gap
if len(args) > 0 {
// Specific database
database := args[0]
gaps, err := cat.DetectGaps(ctx, database, config)
if err != nil {
return err
}
if len(gaps) > 0 {
allGaps = map[string][]*catalog.Gap{database: gaps}
}
} else {
// All databases
allGaps, err = cat.DetectAllGaps(ctx, config)
if err != nil {
return err
}
}
if catalogFormat == "json" {
data, _ := json.MarshalIndent(allGaps, "", " ")
fmt.Println(string(data))
return nil
}
if len(allGaps) == 0 {
fmt.Printf("✅ No backup gaps detected (expected interval: %s)\n", interval)
return nil
}
fmt.Printf("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n")
fmt.Printf(" Backup Gaps Detected (expected interval: %s)\n", interval)
fmt.Printf("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n")
totalGaps := 0
criticalGaps := 0
for database, gaps := range allGaps {
fmt.Printf("📁 %s (%d gaps)\n", database, len(gaps))
for _, gap := range gaps {
totalGaps++
icon := ""
switch gap.Severity {
case catalog.SeverityWarning:
icon = "⚠️"
case catalog.SeverityCritical:
icon = "🚨"
criticalGaps++
}
fmt.Printf(" %s %s\n", icon, gap.Description)
fmt.Printf(" Gap: %s → %s (%s)\n",
gap.GapStart.Format("2006-01-02 15:04"),
gap.GapEnd.Format("2006-01-02 15:04"),
catalog.FormatDuration(gap.Duration))
fmt.Printf(" Expected at: %s\n", gap.ExpectedAt.Format("2006-01-02 15:04"))
}
fmt.Println()
}
fmt.Printf("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n")
fmt.Printf("Total: %d gaps detected", totalGaps)
if criticalGaps > 0 {
fmt.Printf(" (%d critical)", criticalGaps)
}
fmt.Println()
return nil
}
func runCatalogSearch(cmd *cobra.Command, args []string) error {
cat, err := openCatalog()
if err != nil {
return err
}
defer cat.Close()
ctx := context.Background()
query := &catalog.SearchQuery{
Database: catalogDatabase,
Limit: catalogLimit,
OrderBy: "created_at",
OrderDesc: true,
}
// Parse date range
if catalogStartDate != "" {
t, err := time.Parse("2006-01-02", catalogStartDate)
if err != nil {
return fmt.Errorf("invalid start date: %w", err)
}
query.StartDate = &t
}
if catalogEndDate != "" {
t, err := time.Parse("2006-01-02", catalogEndDate)
if err != nil {
return fmt.Errorf("invalid end date: %w", err)
}
query.EndDate = &t
}
// Boolean filters
if verified, _ := cmd.Flags().GetBool("verified"); verified {
t := true
query.Verified = &t
}
if encrypted, _ := cmd.Flags().GetBool("encrypted"); encrypted {
t := true
query.Encrypted = &t
}
if drillTested, _ := cmd.Flags().GetBool("drill-tested"); drillTested {
t := true
query.DrillTested = &t
}
entries, err := cat.Search(ctx, query)
if err != nil {
return err
}
if len(entries) == 0 {
fmt.Println("No matching backups found.")
return nil
}
if catalogFormat == "json" {
data, _ := json.MarshalIndent(entries, "", " ")
fmt.Println(string(data))
return nil
}
fmt.Printf("Found %d matching backups:\n\n", len(entries))
for _, entry := range entries {
fmt.Printf("📁 %s\n", entry.Database)
fmt.Printf(" Path: %s\n", entry.BackupPath)
fmt.Printf(" Type: %s | Size: %s | Created: %s\n",
entry.DatabaseType,
catalog.FormatSize(entry.SizeBytes),
entry.CreatedAt.Format("2006-01-02 15:04:05"))
if entry.Encrypted {
fmt.Printf(" 🔒 Encrypted\n")
}
if entry.VerifyValid != nil && *entry.VerifyValid {
fmt.Printf(" ✅ Verified: %s\n", entry.VerifiedAt.Format("2006-01-02 15:04"))
}
if entry.DrillSuccess != nil && *entry.DrillSuccess {
fmt.Printf(" 🧪 Drill Tested: %s\n", entry.DrillTestedAt.Format("2006-01-02 15:04"))
}
fmt.Println()
}
return nil
}
func runCatalogInfo(cmd *cobra.Command, args []string) error {
backupPath := args[0]
cat, err := openCatalog()
if err != nil {
return err
}
defer cat.Close()
ctx := context.Background()
// Try absolute path
absPath, _ := filepath.Abs(backupPath)
entry, err := cat.GetByPath(ctx, absPath)
if err != nil {
return err
}
if entry == nil {
// Try as provided
entry, err = cat.GetByPath(ctx, backupPath)
if err != nil {
return err
}
}
if entry == nil {
return fmt.Errorf("backup not found in catalog: %s", backupPath)
}
if catalogFormat == "json" {
data, _ := json.MarshalIndent(entry, "", " ")
fmt.Println(string(data))
return nil
}
fmt.Printf("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n")
fmt.Printf(" Backup Details\n")
fmt.Printf("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n")
fmt.Printf("📁 Database: %s\n", entry.Database)
fmt.Printf("🔧 Type: %s\n", entry.DatabaseType)
fmt.Printf("🖥️ Host: %s:%d\n", entry.Host, entry.Port)
fmt.Printf("📂 Path: %s\n", entry.BackupPath)
fmt.Printf("📦 Backup Type: %s\n", entry.BackupType)
fmt.Printf("💾 Size: %s (%d bytes)\n", catalog.FormatSize(entry.SizeBytes), entry.SizeBytes)
fmt.Printf("🔐 SHA256: %s\n", entry.SHA256)
fmt.Printf("📅 Created: %s\n", entry.CreatedAt.Format("2006-01-02 15:04:05 MST"))
fmt.Printf("⏱️ Duration: %.2fs\n", entry.Duration)
fmt.Printf("📋 Status: %s\n", entry.Status)
if entry.Compression != "" {
fmt.Printf("📦 Compression: %s\n", entry.Compression)
}
if entry.Encrypted {
fmt.Printf("🔒 Encrypted: yes\n")
}
if entry.CloudLocation != "" {
fmt.Printf("☁️ Cloud: %s\n", entry.CloudLocation)
}
if entry.RetentionPolicy != "" {
fmt.Printf("📆 Retention: %s\n", entry.RetentionPolicy)
}
fmt.Printf("\n📊 Verification:\n")
if entry.VerifiedAt != nil {
status := "❌ Failed"
if entry.VerifyValid != nil && *entry.VerifyValid {
status = "✅ Valid"
}
fmt.Printf(" Status: %s (checked %s)\n", status, entry.VerifiedAt.Format("2006-01-02 15:04"))
} else {
fmt.Printf(" Status: ⏳ Not verified\n")
}
fmt.Printf("\n🧪 DR Drill Test:\n")
if entry.DrillTestedAt != nil {
status := "❌ Failed"
if entry.DrillSuccess != nil && *entry.DrillSuccess {
status = "✅ Passed"
}
fmt.Printf(" Status: %s (tested %s)\n", status, entry.DrillTestedAt.Format("2006-01-02 15:04"))
} else {
fmt.Printf(" Status: ⏳ Not tested\n")
}
if len(entry.Metadata) > 0 {
fmt.Printf("\n📝 Additional Metadata:\n")
for k, v := range entry.Metadata {
fmt.Printf(" %s: %s\n", k, v)
}
}
fmt.Printf("\n━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n")
return nil
}
func truncateString(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen-3] + "..."
}

500
cmd/drill.go Normal file
View File

@@ -0,0 +1,500 @@
package cmd
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"dbbackup/internal/catalog"
"dbbackup/internal/drill"
"github.com/spf13/cobra"
)
var (
drillBackupPath string
drillDatabaseName string
drillDatabaseType string
drillImage string
drillPort int
drillTimeout int
drillRTOTarget int
drillKeepContainer bool
drillOutputDir string
drillFormat string
drillVerbose bool
drillExpectedTables string
drillMinRows int64
drillQueries string
)
// drillCmd represents the drill command group
var drillCmd = &cobra.Command{
Use: "drill",
Short: "Disaster Recovery drill testing",
Long: `Run DR drills to verify backup restorability.
A DR drill:
1. Spins up a temporary Docker container
2. Restores the backup into the container
3. Runs validation queries
4. Generates a detailed report
5. Cleans up the container
This answers the critical question: "Can I restore this backup at 3 AM?"
Examples:
# Run a drill on a PostgreSQL backup
dbbackup drill run backup.dump.gz --database mydb --type postgresql
# Run with validation queries
dbbackup drill run backup.dump.gz --database mydb --type postgresql \
--validate "SELECT COUNT(*) FROM users" \
--min-rows 1000
# Quick test with minimal validation
dbbackup drill quick backup.dump.gz --database mydb
# List all drill containers
dbbackup drill list
# Cleanup old drill containers
dbbackup drill cleanup`,
}
// drillRunCmd runs a DR drill
var drillRunCmd = &cobra.Command{
Use: "run [backup-file]",
Short: "Run a DR drill on a backup",
Long: `Execute a complete DR drill on a backup file.
This will:
1. Pull the appropriate database Docker image
2. Start a temporary container
3. Restore the backup
4. Run validation queries
5. Calculate RTO metrics
6. Generate a report
Examples:
# Basic drill
dbbackup drill run /backups/mydb_20240115.dump.gz --database mydb --type postgresql
# With RTO target (5 minutes)
dbbackup drill run /backups/mydb.dump.gz --database mydb --type postgresql --rto 300
# With expected tables validation
dbbackup drill run /backups/mydb.dump.gz --database mydb --type postgresql \
--tables "users,orders,products"
# Keep container on failure for debugging
dbbackup drill run /backups/mydb.dump.gz --database mydb --type postgresql --keep`,
Args: cobra.ExactArgs(1),
RunE: runDrill,
}
// drillQuickCmd runs a quick test
var drillQuickCmd = &cobra.Command{
Use: "quick [backup-file]",
Short: "Quick restore test with minimal validation",
Long: `Run a quick DR test that only verifies the backup can be restored.
This is faster than a full drill but provides less validation.
Examples:
# Quick test a PostgreSQL backup
dbbackup drill quick /backups/mydb.dump.gz --database mydb --type postgresql
# Quick test a MySQL backup
dbbackup drill quick /backups/mydb.sql.gz --database mydb --type mysql`,
Args: cobra.ExactArgs(1),
RunE: runQuickDrill,
}
// drillListCmd lists drill containers
var drillListCmd = &cobra.Command{
Use: "list",
Short: "List DR drill containers",
Long: `List all Docker containers created by DR drills.
Shows containers that may still be running or stopped from previous drills.`,
RunE: runDrillList,
}
// drillCleanupCmd cleans up drill resources
var drillCleanupCmd = &cobra.Command{
Use: "cleanup [drill-id]",
Short: "Cleanup DR drill containers",
Long: `Remove containers created by DR drills.
If no drill ID is specified, removes all drill containers.
Examples:
# Cleanup all drill containers
dbbackup drill cleanup
# Cleanup specific drill
dbbackup drill cleanup drill_20240115_120000`,
RunE: runDrillCleanup,
}
// drillReportCmd shows a drill report
var drillReportCmd = &cobra.Command{
Use: "report [report-file]",
Short: "Display a DR drill report",
Long: `Display a previously saved DR drill report.
Examples:
# Show report
dbbackup drill report drill_20240115_120000_report.json
# Show as JSON
dbbackup drill report drill_20240115_120000_report.json --format json`,
Args: cobra.ExactArgs(1),
RunE: runDrillReport,
}
func init() {
rootCmd.AddCommand(drillCmd)
// Add subcommands
drillCmd.AddCommand(drillRunCmd)
drillCmd.AddCommand(drillQuickCmd)
drillCmd.AddCommand(drillListCmd)
drillCmd.AddCommand(drillCleanupCmd)
drillCmd.AddCommand(drillReportCmd)
// Run command flags
drillRunCmd.Flags().StringVar(&drillDatabaseName, "database", "", "Target database name (required)")
drillRunCmd.Flags().StringVar(&drillDatabaseType, "type", "", "Database type: postgresql, mysql, mariadb (required)")
drillRunCmd.Flags().StringVar(&drillImage, "image", "", "Docker image (default: auto-detect)")
drillRunCmd.Flags().IntVar(&drillPort, "port", 0, "Host port for container (default: 15432/13306)")
drillRunCmd.Flags().IntVar(&drillTimeout, "timeout", 60, "Container startup timeout in seconds")
drillRunCmd.Flags().IntVar(&drillRTOTarget, "rto", 300, "RTO target in seconds")
drillRunCmd.Flags().BoolVar(&drillKeepContainer, "keep", false, "Keep container after drill")
drillRunCmd.Flags().StringVar(&drillOutputDir, "output", "", "Output directory for reports")
drillRunCmd.Flags().StringVar(&drillFormat, "format", "table", "Output format: table, json")
drillRunCmd.Flags().BoolVarP(&drillVerbose, "verbose", "v", false, "Verbose output")
drillRunCmd.Flags().StringVar(&drillExpectedTables, "tables", "", "Expected tables (comma-separated)")
drillRunCmd.Flags().Int64Var(&drillMinRows, "min-rows", 0, "Minimum expected row count")
drillRunCmd.Flags().StringVar(&drillQueries, "validate", "", "Validation SQL query")
drillRunCmd.MarkFlagRequired("database")
drillRunCmd.MarkFlagRequired("type")
// Quick command flags
drillQuickCmd.Flags().StringVar(&drillDatabaseName, "database", "", "Target database name (required)")
drillQuickCmd.Flags().StringVar(&drillDatabaseType, "type", "", "Database type: postgresql, mysql, mariadb (required)")
drillQuickCmd.Flags().BoolVarP(&drillVerbose, "verbose", "v", false, "Verbose output")
drillQuickCmd.MarkFlagRequired("database")
drillQuickCmd.MarkFlagRequired("type")
// Report command flags
drillReportCmd.Flags().StringVar(&drillFormat, "format", "table", "Output format: table, json")
}
func runDrill(cmd *cobra.Command, args []string) error {
backupPath := args[0]
// Validate backup file exists
absPath, err := filepath.Abs(backupPath)
if err != nil {
return fmt.Errorf("invalid backup path: %w", err)
}
if _, err := os.Stat(absPath); err != nil {
return fmt.Errorf("backup file not found: %s", absPath)
}
// Build drill config
config := drill.DefaultConfig()
config.BackupPath = absPath
config.DatabaseName = drillDatabaseName
config.DatabaseType = drillDatabaseType
config.ContainerImage = drillImage
config.ContainerPort = drillPort
config.ContainerTimeout = drillTimeout
config.MaxRestoreSeconds = drillRTOTarget
config.CleanupOnExit = !drillKeepContainer
config.KeepOnFailure = true
config.OutputDir = drillOutputDir
config.Verbose = drillVerbose
// Parse expected tables
if drillExpectedTables != "" {
config.ExpectedTables = strings.Split(drillExpectedTables, ",")
for i := range config.ExpectedTables {
config.ExpectedTables[i] = strings.TrimSpace(config.ExpectedTables[i])
}
}
// Set minimum row count
config.MinRowCount = drillMinRows
// Add validation query if provided
if drillQueries != "" {
config.ValidationQueries = append(config.ValidationQueries, drill.ValidationQuery{
Name: "Custom Query",
Query: drillQueries,
MustSucceed: true,
})
}
// Create drill engine
engine := drill.NewEngine(log, drillVerbose)
// Run drill
ctx := cmd.Context()
result, err := engine.Run(ctx, config)
if err != nil {
return err
}
// Update catalog if available
updateCatalogWithDrillResult(ctx, absPath, result)
// Output result
if drillFormat == "json" {
data, _ := json.MarshalIndent(result, "", " ")
fmt.Println(string(data))
} else {
printDrillResult(result)
}
if !result.Success {
return fmt.Errorf("drill failed: %s", result.Message)
}
return nil
}
func runQuickDrill(cmd *cobra.Command, args []string) error {
backupPath := args[0]
absPath, err := filepath.Abs(backupPath)
if err != nil {
return fmt.Errorf("invalid backup path: %w", err)
}
if _, err := os.Stat(absPath); err != nil {
return fmt.Errorf("backup file not found: %s", absPath)
}
engine := drill.NewEngine(log, drillVerbose)
ctx := cmd.Context()
result, err := engine.QuickTest(ctx, absPath, drillDatabaseType, drillDatabaseName)
if err != nil {
return err
}
// Update catalog
updateCatalogWithDrillResult(ctx, absPath, result)
printDrillResult(result)
if !result.Success {
return fmt.Errorf("quick test failed: %s", result.Message)
}
return nil
}
func runDrillList(cmd *cobra.Command, args []string) error {
docker := drill.NewDockerManager(false)
ctx := cmd.Context()
containers, err := docker.ListDrillContainers(ctx)
if err != nil {
return err
}
if len(containers) == 0 {
fmt.Println("No drill containers found.")
return nil
}
fmt.Printf("%-15s %-40s %-20s %s\n", "ID", "NAME", "IMAGE", "STATUS")
fmt.Println(strings.Repeat("─", 100))
for _, c := range containers {
fmt.Printf("%-15s %-40s %-20s %s\n",
c.ID[:12],
truncateString(c.Name, 38),
truncateString(c.Image, 18),
c.Status,
)
}
return nil
}
func runDrillCleanup(cmd *cobra.Command, args []string) error {
drillID := ""
if len(args) > 0 {
drillID = args[0]
}
engine := drill.NewEngine(log, true)
ctx := cmd.Context()
if err := engine.Cleanup(ctx, drillID); err != nil {
return err
}
fmt.Println("✅ Cleanup completed")
return nil
}
func runDrillReport(cmd *cobra.Command, args []string) error {
reportPath := args[0]
result, err := drill.LoadResult(reportPath)
if err != nil {
return err
}
if drillFormat == "json" {
data, _ := json.MarshalIndent(result, "", " ")
fmt.Println(string(data))
} else {
printDrillResult(result)
}
return nil
}
func printDrillResult(result *drill.DrillResult) {
fmt.Printf("\n")
fmt.Printf("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n")
fmt.Printf(" DR Drill Report: %s\n", result.DrillID)
fmt.Printf("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n")
status := "✅ PASSED"
if !result.Success {
status = "❌ FAILED"
} else if result.Status == drill.StatusPartial {
status = "⚠️ PARTIAL"
}
fmt.Printf("📋 Status: %s\n", status)
fmt.Printf("💾 Backup: %s\n", filepath.Base(result.BackupPath))
fmt.Printf("🗄️ Database: %s (%s)\n", result.DatabaseName, result.DatabaseType)
fmt.Printf("⏱️ Duration: %.2fs\n", result.Duration)
fmt.Printf("📅 Started: %s\n", result.StartTime.Format(time.RFC3339))
fmt.Printf("\n")
// Phases
fmt.Printf("📊 Phases:\n")
for _, phase := range result.Phases {
icon := "✅"
if phase.Status == "failed" {
icon = "❌"
} else if phase.Status == "running" {
icon = "🔄"
}
fmt.Printf(" %s %-20s (%.2fs) %s\n", icon, phase.Name, phase.Duration, phase.Message)
}
fmt.Printf("\n")
// Metrics
fmt.Printf("📈 Metrics:\n")
fmt.Printf(" Tables: %d\n", result.TableCount)
fmt.Printf(" Total Rows: %d\n", result.TotalRows)
fmt.Printf(" Restore Time: %.2fs\n", result.RestoreTime)
fmt.Printf(" Validation: %.2fs\n", result.ValidationTime)
if result.QueryTimeAvg > 0 {
fmt.Printf(" Avg Query Time: %.0fms\n", result.QueryTimeAvg)
}
fmt.Printf("\n")
// RTO
fmt.Printf("⏱️ RTO Analysis:\n")
rtoIcon := "✅"
if !result.RTOMet {
rtoIcon = "❌"
}
fmt.Printf(" Actual RTO: %.2fs\n", result.ActualRTO)
fmt.Printf(" Target RTO: %.0fs\n", result.TargetRTO)
fmt.Printf(" RTO Met: %s\n", rtoIcon)
fmt.Printf("\n")
// Validation results
if len(result.ValidationResults) > 0 {
fmt.Printf("🔍 Validation Queries:\n")
for _, vr := range result.ValidationResults {
icon := "✅"
if !vr.Success {
icon = "❌"
}
fmt.Printf(" %s %s: %s\n", icon, vr.Name, vr.Result)
if vr.Error != "" {
fmt.Printf(" Error: %s\n", vr.Error)
}
}
fmt.Printf("\n")
}
// Check results
if len(result.CheckResults) > 0 {
fmt.Printf("✓ Checks:\n")
for _, cr := range result.CheckResults {
icon := "✅"
if !cr.Success {
icon = "❌"
}
fmt.Printf(" %s %s\n", icon, cr.Message)
}
fmt.Printf("\n")
}
// Errors and warnings
if len(result.Errors) > 0 {
fmt.Printf("❌ Errors:\n")
for _, e := range result.Errors {
fmt.Printf(" • %s\n", e)
}
fmt.Printf("\n")
}
if len(result.Warnings) > 0 {
fmt.Printf("⚠️ Warnings:\n")
for _, w := range result.Warnings {
fmt.Printf(" • %s\n", w)
}
fmt.Printf("\n")
}
// Container info
if result.ContainerKept {
fmt.Printf("📦 Container kept: %s\n", result.ContainerID[:12])
fmt.Printf(" Connect with: docker exec -it %s bash\n", result.ContainerID[:12])
fmt.Printf("\n")
}
fmt.Printf("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n")
fmt.Printf(" %s\n", result.Message)
fmt.Printf("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n")
}
func updateCatalogWithDrillResult(ctx context.Context, backupPath string, result *drill.DrillResult) {
// Try to update the catalog with drill results
cat, err := catalog.NewSQLiteCatalog(catalogDBPath)
if err != nil {
return // Catalog not available, skip
}
defer cat.Close()
entry, err := cat.GetByPath(ctx, backupPath)
if err != nil || entry == nil {
return // Entry not in catalog
}
// Update drill status
if err := cat.MarkDrillTested(ctx, entry.ID, result.Success); err != nil {
log.Debug("Failed to update catalog drill status", "error", err)
}
}

View File

@@ -2,10 +2,15 @@ package cmd
import (
"context"
"database/sql"
"fmt"
"os"
"path/filepath"
"time"
"github.com/spf13/cobra"
"dbbackup/internal/pitr"
"dbbackup/internal/wal"
)
@@ -32,6 +37,14 @@ var (
pitrTargetImmediate bool
pitrRecoveryAction string
pitrWALSource string
// MySQL PITR flags
mysqlBinlogDir string
mysqlArchiveDir string
mysqlArchiveInterval string
mysqlRequireRowFormat bool
mysqlRequireGTID bool
mysqlWatchMode bool
)
// pitrCmd represents the pitr command group
@@ -183,21 +196,180 @@ Example:
RunE: runWALTimeline,
}
// ============================================================================
// MySQL/MariaDB Binlog Commands
// ============================================================================
// binlogCmd represents the binlog command group (MySQL equivalent of WAL)
var binlogCmd = &cobra.Command{
Use: "binlog",
Short: "Binary log operations for MySQL/MariaDB",
Long: `Manage MySQL/MariaDB binary log files for Point-in-Time Recovery.
Binary logs contain all changes made to the database and are essential
for Point-in-Time Recovery (PITR) with MySQL and MariaDB.
Commands:
list - List available binlog files
archive - Archive binlog files
watch - Watch for new binlog files and archive them
validate - Validate binlog chain integrity
position - Show current binlog position
`,
}
// binlogListCmd lists binary log files
var binlogListCmd = &cobra.Command{
Use: "list",
Short: "List binary log files",
Long: `List all available binary log files from the MySQL data directory
and/or the archive directory.
Shows: filename, size, timestamps, server_id, and format for each binlog.
Examples:
dbbackup binlog list --binlog-dir /var/lib/mysql
dbbackup binlog list --archive-dir /backups/binlog_archive
`,
RunE: runBinlogList,
}
// binlogArchiveCmd archives binary log files
var binlogArchiveCmd = &cobra.Command{
Use: "archive",
Short: "Archive binary log files",
Long: `Archive MySQL binary log files to a backup location.
This command copies completed binlog files (not the currently active one)
to the archive directory, optionally with compression and encryption.
Examples:
dbbackup binlog archive --binlog-dir /var/lib/mysql --archive-dir /backups/binlog
dbbackup binlog archive --compress --archive-dir /backups/binlog
`,
RunE: runBinlogArchive,
}
// binlogWatchCmd watches for new binlogs and archives them
var binlogWatchCmd = &cobra.Command{
Use: "watch",
Short: "Watch for new binlog files and archive them automatically",
Long: `Continuously monitor the binlog directory for new files and
archive them automatically when they are closed.
This runs as a background process and provides continuous binlog archiving
for PITR capability.
Example:
dbbackup binlog watch --binlog-dir /var/lib/mysql --archive-dir /backups/binlog --interval 30s
`,
RunE: runBinlogWatch,
}
// binlogValidateCmd validates binlog chain
var binlogValidateCmd = &cobra.Command{
Use: "validate",
Short: "Validate binlog chain integrity",
Long: `Check the binary log chain for gaps or inconsistencies.
Validates:
- Sequential numbering of binlog files
- No missing files in the chain
- Server ID consistency
- GTID continuity (if enabled)
Example:
dbbackup binlog validate --binlog-dir /var/lib/mysql
dbbackup binlog validate --archive-dir /backups/binlog
`,
RunE: runBinlogValidate,
}
// binlogPositionCmd shows current binlog position
var binlogPositionCmd = &cobra.Command{
Use: "position",
Short: "Show current binary log position",
Long: `Display the current MySQL binary log position.
This connects to MySQL and runs SHOW MASTER STATUS to get:
- Current binlog filename
- Current byte position
- Executed GTID set (if GTID mode is enabled)
Example:
dbbackup binlog position
`,
RunE: runBinlogPosition,
}
// mysqlPitrStatusCmd shows MySQL-specific PITR status
var mysqlPitrStatusCmd = &cobra.Command{
Use: "mysql-status",
Short: "Show MySQL/MariaDB PITR status",
Long: `Display MySQL/MariaDB-specific PITR configuration and status.
Shows:
- Binary log configuration (log_bin, binlog_format)
- GTID mode status
- Archive directory and statistics
- Current binlog position
- Recovery windows available
Example:
dbbackup pitr mysql-status
`,
RunE: runMySQLPITRStatus,
}
// mysqlPitrEnableCmd enables MySQL PITR
var mysqlPitrEnableCmd = &cobra.Command{
Use: "mysql-enable",
Short: "Enable PITR for MySQL/MariaDB",
Long: `Configure MySQL/MariaDB for Point-in-Time Recovery.
This validates MySQL settings and sets up binlog archiving:
- Checks binary logging is enabled (log_bin=ON)
- Validates binlog_format (ROW recommended)
- Creates archive directory
- Saves PITR configuration
Prerequisites in my.cnf:
[mysqld]
log_bin = mysql-bin
binlog_format = ROW
server_id = 1
Example:
dbbackup pitr mysql-enable --archive-dir /backups/binlog_archive
`,
RunE: runMySQLPITREnable,
}
func init() {
rootCmd.AddCommand(pitrCmd)
rootCmd.AddCommand(walCmd)
rootCmd.AddCommand(binlogCmd)
// PITR subcommands
pitrCmd.AddCommand(pitrEnableCmd)
pitrCmd.AddCommand(pitrDisableCmd)
pitrCmd.AddCommand(pitrStatusCmd)
pitrCmd.AddCommand(mysqlPitrStatusCmd)
pitrCmd.AddCommand(mysqlPitrEnableCmd)
// WAL subcommands
// WAL subcommands (PostgreSQL)
walCmd.AddCommand(walArchiveCmd)
walCmd.AddCommand(walListCmd)
walCmd.AddCommand(walCleanupCmd)
walCmd.AddCommand(walTimelineCmd)
// Binlog subcommands (MySQL/MariaDB)
binlogCmd.AddCommand(binlogListCmd)
binlogCmd.AddCommand(binlogArchiveCmd)
binlogCmd.AddCommand(binlogWatchCmd)
binlogCmd.AddCommand(binlogValidateCmd)
binlogCmd.AddCommand(binlogPositionCmd)
// PITR enable flags
pitrEnableCmd.Flags().StringVar(&pitrArchiveDir, "archive-dir", "/var/backups/wal_archive", "Directory to store WAL archives")
pitrEnableCmd.Flags().BoolVar(&pitrForce, "force", false, "Overwrite existing PITR configuration")
@@ -219,6 +391,33 @@ func init() {
// WAL timeline flags
walTimelineCmd.Flags().StringVar(&walArchiveDir, "archive-dir", "/var/backups/wal_archive", "WAL archive directory")
// MySQL binlog flags
binlogListCmd.Flags().StringVar(&mysqlBinlogDir, "binlog-dir", "/var/lib/mysql", "MySQL binary log directory")
binlogListCmd.Flags().StringVar(&mysqlArchiveDir, "archive-dir", "", "Binlog archive directory")
binlogArchiveCmd.Flags().StringVar(&mysqlBinlogDir, "binlog-dir", "/var/lib/mysql", "MySQL binary log directory")
binlogArchiveCmd.Flags().StringVar(&mysqlArchiveDir, "archive-dir", "/var/backups/binlog_archive", "Binlog archive directory")
binlogArchiveCmd.Flags().BoolVar(&walCompress, "compress", false, "Compress binlog files")
binlogArchiveCmd.Flags().BoolVar(&walEncrypt, "encrypt", false, "Encrypt binlog files")
binlogArchiveCmd.Flags().StringVar(&walEncryptionKeyFile, "encryption-key-file", "", "Path to encryption key file")
binlogArchiveCmd.MarkFlagRequired("archive-dir")
binlogWatchCmd.Flags().StringVar(&mysqlBinlogDir, "binlog-dir", "/var/lib/mysql", "MySQL binary log directory")
binlogWatchCmd.Flags().StringVar(&mysqlArchiveDir, "archive-dir", "/var/backups/binlog_archive", "Binlog archive directory")
binlogWatchCmd.Flags().StringVar(&mysqlArchiveInterval, "interval", "30s", "Check interval for new binlogs")
binlogWatchCmd.Flags().BoolVar(&walCompress, "compress", false, "Compress binlog files")
binlogWatchCmd.MarkFlagRequired("archive-dir")
binlogValidateCmd.Flags().StringVar(&mysqlBinlogDir, "binlog-dir", "/var/lib/mysql", "MySQL binary log directory")
binlogValidateCmd.Flags().StringVar(&mysqlArchiveDir, "archive-dir", "", "Binlog archive directory")
// MySQL PITR enable flags
mysqlPitrEnableCmd.Flags().StringVar(&mysqlArchiveDir, "archive-dir", "/var/backups/binlog_archive", "Binlog archive directory")
mysqlPitrEnableCmd.Flags().IntVar(&walRetentionDays, "retention-days", 7, "Days to keep archived binlogs")
mysqlPitrEnableCmd.Flags().BoolVar(&mysqlRequireRowFormat, "require-row-format", true, "Require ROW binlog format")
mysqlPitrEnableCmd.Flags().BoolVar(&mysqlRequireGTID, "require-gtid", false, "Require GTID mode enabled")
mysqlPitrEnableCmd.MarkFlagRequired("archive-dir")
}
// Command implementations
@@ -512,3 +711,614 @@ func formatWALSize(bytes int64) string {
}
return fmt.Sprintf("%.1f KB", float64(bytes)/float64(KB))
}
// ============================================================================
// MySQL/MariaDB Binlog Command Implementations
// ============================================================================
func runBinlogList(cmd *cobra.Command, args []string) error {
ctx := context.Background()
if !cfg.IsMySQL() {
return fmt.Errorf("binlog commands are only supported for MySQL/MariaDB (detected: %s)", cfg.DisplayDatabaseType())
}
binlogDir := mysqlBinlogDir
if binlogDir == "" && mysqlArchiveDir != "" {
binlogDir = mysqlArchiveDir
}
if binlogDir == "" {
return fmt.Errorf("please specify --binlog-dir or --archive-dir")
}
bmConfig := pitr.BinlogManagerConfig{
BinlogDir: binlogDir,
ArchiveDir: mysqlArchiveDir,
}
bm, err := pitr.NewBinlogManager(bmConfig)
if err != nil {
return fmt.Errorf("initializing binlog manager: %w", err)
}
// List binlogs from source directory
binlogs, err := bm.DiscoverBinlogs(ctx)
if err != nil {
return fmt.Errorf("discovering binlogs: %w", err)
}
// Also list archived binlogs if archive dir is specified
var archived []pitr.BinlogArchiveInfo
if mysqlArchiveDir != "" {
archived, _ = bm.ListArchivedBinlogs(ctx)
}
if len(binlogs) == 0 && len(archived) == 0 {
fmt.Println("No binary log files found")
return nil
}
fmt.Println("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
fmt.Printf(" Binary Log Files (%s)\n", bm.ServerType())
fmt.Println("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
fmt.Println()
if len(binlogs) > 0 {
fmt.Println("Source Directory:")
fmt.Printf("%-24s %10s %-19s %-19s %s\n", "Filename", "Size", "Start Time", "End Time", "Format")
fmt.Println("────────────────────────────────────────────────────────────────────────────────")
var totalSize int64
for _, b := range binlogs {
size := formatWALSize(b.Size)
totalSize += b.Size
startTime := "unknown"
endTime := "unknown"
if !b.StartTime.IsZero() {
startTime = b.StartTime.Format("2006-01-02 15:04:05")
}
if !b.EndTime.IsZero() {
endTime = b.EndTime.Format("2006-01-02 15:04:05")
}
format := b.Format
if format == "" {
format = "-"
}
fmt.Printf("%-24s %10s %-19s %-19s %s\n", b.Name, size, startTime, endTime, format)
}
fmt.Printf("\nTotal: %d files, %s\n", len(binlogs), formatWALSize(totalSize))
}
if len(archived) > 0 {
fmt.Println()
fmt.Println("Archived Binlogs:")
fmt.Printf("%-24s %10s %-19s %s\n", "Original", "Size", "Archived At", "Flags")
fmt.Println("────────────────────────────────────────────────────────────────────────────────")
var totalSize int64
for _, a := range archived {
size := formatWALSize(a.Size)
totalSize += a.Size
archivedTime := a.ArchivedAt.Format("2006-01-02 15:04:05")
flags := ""
if a.Compressed {
flags += "C"
}
if a.Encrypted {
flags += "E"
}
if flags != "" {
flags = "[" + flags + "]"
}
fmt.Printf("%-24s %10s %-19s %s\n", a.OriginalFile, size, archivedTime, flags)
}
fmt.Printf("\nTotal archived: %d files, %s\n", len(archived), formatWALSize(totalSize))
}
return nil
}
func runBinlogArchive(cmd *cobra.Command, args []string) error {
ctx := context.Background()
if !cfg.IsMySQL() {
return fmt.Errorf("binlog commands are only supported for MySQL/MariaDB")
}
if mysqlBinlogDir == "" {
return fmt.Errorf("--binlog-dir is required")
}
// Load encryption key if needed
var encryptionKey []byte
if walEncrypt {
key, err := loadEncryptionKey(walEncryptionKeyFile, walEncryptionKeyEnv)
if err != nil {
return fmt.Errorf("failed to load encryption key: %w", err)
}
encryptionKey = key
}
bmConfig := pitr.BinlogManagerConfig{
BinlogDir: mysqlBinlogDir,
ArchiveDir: mysqlArchiveDir,
Compression: walCompress,
Encryption: walEncrypt,
EncryptionKey: encryptionKey,
}
bm, err := pitr.NewBinlogManager(bmConfig)
if err != nil {
return fmt.Errorf("initializing binlog manager: %w", err)
}
// Discover binlogs
binlogs, err := bm.DiscoverBinlogs(ctx)
if err != nil {
return fmt.Errorf("discovering binlogs: %w", err)
}
// Get already archived
archived, _ := bm.ListArchivedBinlogs(ctx)
archivedSet := make(map[string]struct{})
for _, a := range archived {
archivedSet[a.OriginalFile] = struct{}{}
}
// Need to connect to MySQL to get current position
// For now, skip the active binlog by looking at which one was modified most recently
var latestModTime int64
var latestBinlog string
for _, b := range binlogs {
if b.ModTime.Unix() > latestModTime {
latestModTime = b.ModTime.Unix()
latestBinlog = b.Name
}
}
var newArchives []pitr.BinlogArchiveInfo
for i := range binlogs {
b := &binlogs[i]
// Skip if already archived
if _, exists := archivedSet[b.Name]; exists {
log.Info("Skipping already archived", "binlog", b.Name)
continue
}
// Skip the most recently modified (likely active)
if b.Name == latestBinlog {
log.Info("Skipping active binlog", "binlog", b.Name)
continue
}
log.Info("Archiving binlog", "binlog", b.Name, "size", formatWALSize(b.Size))
archiveInfo, err := bm.ArchiveBinlog(ctx, b)
if err != nil {
log.Error("Failed to archive binlog", "binlog", b.Name, "error", err)
continue
}
newArchives = append(newArchives, *archiveInfo)
}
// Update metadata
if len(newArchives) > 0 {
allArchived, _ := bm.ListArchivedBinlogs(ctx)
bm.SaveArchiveMetadata(allArchived)
}
log.Info("✅ Binlog archiving completed", "archived", len(newArchives))
return nil
}
func runBinlogWatch(cmd *cobra.Command, args []string) error {
ctx := context.Background()
if !cfg.IsMySQL() {
return fmt.Errorf("binlog commands are only supported for MySQL/MariaDB")
}
interval, err := time.ParseDuration(mysqlArchiveInterval)
if err != nil {
return fmt.Errorf("invalid interval: %w", err)
}
bmConfig := pitr.BinlogManagerConfig{
BinlogDir: mysqlBinlogDir,
ArchiveDir: mysqlArchiveDir,
Compression: walCompress,
}
bm, err := pitr.NewBinlogManager(bmConfig)
if err != nil {
return fmt.Errorf("initializing binlog manager: %w", err)
}
log.Info("Starting binlog watcher",
"binlog_dir", mysqlBinlogDir,
"archive_dir", mysqlArchiveDir,
"interval", interval)
// Watch for new binlogs
err = bm.WatchBinlogs(ctx, interval, func(b *pitr.BinlogFile) {
log.Info("New binlog detected, archiving", "binlog", b.Name)
archiveInfo, err := bm.ArchiveBinlog(ctx, b)
if err != nil {
log.Error("Failed to archive binlog", "binlog", b.Name, "error", err)
return
}
log.Info("Binlog archived successfully",
"binlog", b.Name,
"archive", archiveInfo.ArchivePath,
"size", formatWALSize(archiveInfo.Size))
// Update metadata
allArchived, _ := bm.ListArchivedBinlogs(ctx)
bm.SaveArchiveMetadata(allArchived)
})
if err != nil && err != context.Canceled {
return err
}
return nil
}
func runBinlogValidate(cmd *cobra.Command, args []string) error {
ctx := context.Background()
if !cfg.IsMySQL() {
return fmt.Errorf("binlog commands are only supported for MySQL/MariaDB")
}
binlogDir := mysqlBinlogDir
if binlogDir == "" {
binlogDir = mysqlArchiveDir
}
if binlogDir == "" {
return fmt.Errorf("please specify --binlog-dir or --archive-dir")
}
bmConfig := pitr.BinlogManagerConfig{
BinlogDir: binlogDir,
ArchiveDir: mysqlArchiveDir,
}
bm, err := pitr.NewBinlogManager(bmConfig)
if err != nil {
return fmt.Errorf("initializing binlog manager: %w", err)
}
// Discover binlogs
binlogs, err := bm.DiscoverBinlogs(ctx)
if err != nil {
return fmt.Errorf("discovering binlogs: %w", err)
}
if len(binlogs) == 0 {
fmt.Println("No binlog files found to validate")
return nil
}
// Validate chain
validation, err := bm.ValidateBinlogChain(ctx, binlogs)
if err != nil {
return fmt.Errorf("validating binlog chain: %w", err)
}
fmt.Println("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
fmt.Println(" Binlog Chain Validation")
fmt.Println("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
fmt.Println()
if validation.Valid {
fmt.Println("Status: ✅ VALID - Binlog chain is complete")
} else {
fmt.Println("Status: ❌ INVALID - Binlog chain has gaps")
}
fmt.Printf("Files: %d binlog files\n", validation.LogCount)
fmt.Printf("Total Size: %s\n", formatWALSize(validation.TotalSize))
if validation.StartPos != nil {
fmt.Printf("Start: %s\n", validation.StartPos.String())
}
if validation.EndPos != nil {
fmt.Printf("End: %s\n", validation.EndPos.String())
}
if len(validation.Gaps) > 0 {
fmt.Println()
fmt.Println("Gaps Found:")
for _, gap := range validation.Gaps {
fmt.Printf(" • After %s, before %s: %s\n", gap.After, gap.Before, gap.Reason)
}
}
if len(validation.Warnings) > 0 {
fmt.Println()
fmt.Println("Warnings:")
for _, w := range validation.Warnings {
fmt.Printf(" ⚠ %s\n", w)
}
}
if len(validation.Errors) > 0 {
fmt.Println()
fmt.Println("Errors:")
for _, e := range validation.Errors {
fmt.Printf(" ✗ %s\n", e)
}
}
if !validation.Valid {
os.Exit(1)
}
return nil
}
func runBinlogPosition(cmd *cobra.Command, args []string) error {
ctx := context.Background()
if !cfg.IsMySQL() {
return fmt.Errorf("binlog commands are only supported for MySQL/MariaDB")
}
// Connect to MySQL
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/",
cfg.User, cfg.Password, cfg.Host, cfg.Port)
db, err := sql.Open("mysql", dsn)
if err != nil {
return fmt.Errorf("connecting to MySQL: %w", err)
}
defer db.Close()
if err := db.PingContext(ctx); err != nil {
return fmt.Errorf("pinging MySQL: %w", err)
}
// Get binlog position using raw query
rows, err := db.QueryContext(ctx, "SHOW MASTER STATUS")
if err != nil {
return fmt.Errorf("getting master status: %w", err)
}
defer rows.Close()
fmt.Println("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
fmt.Println(" Current Binary Log Position")
fmt.Println("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
fmt.Println()
if rows.Next() {
var file string
var position uint64
var binlogDoDB, binlogIgnoreDB, executedGtidSet sql.NullString
cols, _ := rows.Columns()
switch len(cols) {
case 5:
err = rows.Scan(&file, &position, &binlogDoDB, &binlogIgnoreDB, &executedGtidSet)
case 4:
err = rows.Scan(&file, &position, &binlogDoDB, &binlogIgnoreDB)
default:
err = rows.Scan(&file, &position)
}
if err != nil {
return fmt.Errorf("scanning master status: %w", err)
}
fmt.Printf("File: %s\n", file)
fmt.Printf("Position: %d\n", position)
if executedGtidSet.Valid && executedGtidSet.String != "" {
fmt.Printf("GTID Set: %s\n", executedGtidSet.String)
}
// Compact format for use in restore commands
fmt.Println()
fmt.Printf("Position String: %s:%d\n", file, position)
} else {
fmt.Println("Binary logging appears to be disabled.")
fmt.Println("Enable binary logging by adding to my.cnf:")
fmt.Println(" [mysqld]")
fmt.Println(" log_bin = mysql-bin")
fmt.Println(" server_id = 1")
}
return nil
}
func runMySQLPITRStatus(cmd *cobra.Command, args []string) error {
ctx := context.Background()
if !cfg.IsMySQL() {
return fmt.Errorf("this command is only for MySQL/MariaDB (use 'pitr status' for PostgreSQL)")
}
// Connect to MySQL
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/",
cfg.User, cfg.Password, cfg.Host, cfg.Port)
db, err := sql.Open("mysql", dsn)
if err != nil {
return fmt.Errorf("connecting to MySQL: %w", err)
}
defer db.Close()
if err := db.PingContext(ctx); err != nil {
return fmt.Errorf("pinging MySQL: %w", err)
}
pitrConfig := pitr.MySQLPITRConfig{
Host: cfg.Host,
Port: cfg.Port,
User: cfg.User,
Password: cfg.Password,
BinlogDir: mysqlBinlogDir,
ArchiveDir: mysqlArchiveDir,
}
mysqlPitr, err := pitr.NewMySQLPITR(db, pitrConfig)
if err != nil {
return fmt.Errorf("initializing MySQL PITR: %w", err)
}
status, err := mysqlPitr.Status(ctx)
if err != nil {
return fmt.Errorf("getting PITR status: %w", err)
}
fmt.Println("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
fmt.Printf(" MySQL/MariaDB PITR Status (%s)\n", status.DatabaseType)
fmt.Println("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
fmt.Println()
if status.Enabled {
fmt.Println("PITR Status: ✅ ENABLED")
} else {
fmt.Println("PITR Status: ❌ NOT CONFIGURED")
}
// Get binary logging status
var logBin string
db.QueryRowContext(ctx, "SELECT @@log_bin").Scan(&logBin)
if logBin == "1" || logBin == "ON" {
fmt.Println("Binary Logging: ✅ ENABLED")
} else {
fmt.Println("Binary Logging: ❌ DISABLED")
}
fmt.Printf("Binlog Format: %s\n", status.LogLevel)
// Check GTID mode
var gtidMode string
if status.DatabaseType == pitr.DatabaseMariaDB {
db.QueryRowContext(ctx, "SELECT @@gtid_current_pos").Scan(&gtidMode)
if gtidMode != "" {
fmt.Println("GTID Mode: ✅ ENABLED")
} else {
fmt.Println("GTID Mode: ❌ DISABLED")
}
} else {
db.QueryRowContext(ctx, "SELECT @@gtid_mode").Scan(&gtidMode)
if gtidMode == "ON" {
fmt.Println("GTID Mode: ✅ ENABLED")
} else {
fmt.Printf("GTID Mode: %s\n", gtidMode)
}
}
if status.Position != nil {
fmt.Printf("Current Position: %s\n", status.Position.String())
}
if status.ArchiveDir != "" {
fmt.Println()
fmt.Println("Archive Statistics:")
fmt.Printf(" Directory: %s\n", status.ArchiveDir)
fmt.Printf(" File Count: %d\n", status.ArchiveCount)
fmt.Printf(" Total Size: %s\n", formatWALSize(status.ArchiveSize))
if !status.LastArchived.IsZero() {
fmt.Printf(" Last Archive: %s\n", status.LastArchived.Format("2006-01-02 15:04:05"))
}
}
// Show requirements
fmt.Println()
fmt.Println("PITR Requirements:")
if logBin == "1" || logBin == "ON" {
fmt.Println(" ✅ Binary logging enabled")
} else {
fmt.Println(" ❌ Binary logging must be enabled (log_bin = mysql-bin)")
}
if status.LogLevel == "ROW" {
fmt.Println(" ✅ Row-based logging (recommended)")
} else {
fmt.Printf(" ⚠ binlog_format = %s (ROW recommended for PITR)\n", status.LogLevel)
}
return nil
}
func runMySQLPITREnable(cmd *cobra.Command, args []string) error {
ctx := context.Background()
if !cfg.IsMySQL() {
return fmt.Errorf("this command is only for MySQL/MariaDB (use 'pitr enable' for PostgreSQL)")
}
// Connect to MySQL
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/",
cfg.User, cfg.Password, cfg.Host, cfg.Port)
db, err := sql.Open("mysql", dsn)
if err != nil {
return fmt.Errorf("connecting to MySQL: %w", err)
}
defer db.Close()
if err := db.PingContext(ctx); err != nil {
return fmt.Errorf("pinging MySQL: %w", err)
}
pitrConfig := pitr.MySQLPITRConfig{
Host: cfg.Host,
Port: cfg.Port,
User: cfg.User,
Password: cfg.Password,
BinlogDir: mysqlBinlogDir,
ArchiveDir: mysqlArchiveDir,
RequireRowFormat: mysqlRequireRowFormat,
RequireGTID: mysqlRequireGTID,
}
mysqlPitr, err := pitr.NewMySQLPITR(db, pitrConfig)
if err != nil {
return fmt.Errorf("initializing MySQL PITR: %w", err)
}
enableConfig := pitr.PITREnableConfig{
ArchiveDir: mysqlArchiveDir,
RetentionDays: walRetentionDays,
Compression: walCompress,
}
log.Info("Enabling MySQL PITR", "archive_dir", mysqlArchiveDir)
if err := mysqlPitr.Enable(ctx, enableConfig); err != nil {
return fmt.Errorf("enabling PITR: %w", err)
}
log.Info("✅ MySQL PITR enabled successfully!")
log.Info("")
log.Info("Next steps:")
log.Info("1. Start binlog archiving: dbbackup binlog watch --archive-dir " + mysqlArchiveDir)
log.Info("2. Create a base backup: dbbackup backup single <database>")
log.Info("3. Binlogs will be archived to: " + mysqlArchiveDir)
log.Info("")
log.Info("To restore to a point in time, use:")
log.Info(" dbbackup restore pitr <backup> --target-time '2024-01-15 14:30:00'")
return nil
}
// getMySQLBinlogDir attempts to determine the binlog directory from MySQL
func getMySQLBinlogDir(ctx context.Context, db *sql.DB) (string, error) {
var logBinBasename string
err := db.QueryRowContext(ctx, "SELECT @@log_bin_basename").Scan(&logBinBasename)
if err != nil {
return "", err
}
return filepath.Dir(logBinBasename), nil
}

316
cmd/report.go Normal file
View File

@@ -0,0 +1,316 @@
package cmd
import (
"fmt"
"os"
"path/filepath"
"strings"
"time"
"dbbackup/internal/catalog"
"dbbackup/internal/report"
"github.com/spf13/cobra"
)
var reportCmd = &cobra.Command{
Use: "report",
Short: "Generate compliance reports",
Long: `Generate compliance reports for various regulatory frameworks.
Supported frameworks:
- soc2 SOC 2 Type II Trust Service Criteria
- gdpr General Data Protection Regulation
- hipaa Health Insurance Portability and Accountability Act
- pci-dss Payment Card Industry Data Security Standard
- iso27001 ISO 27001 Information Security Management
Examples:
# Generate SOC2 report for the last 90 days
dbbackup report generate --type soc2 --days 90
# Generate HIPAA report as HTML
dbbackup report generate --type hipaa --format html --output report.html
# Show report summary for current period
dbbackup report summary --type soc2`,
}
var reportGenerateCmd = &cobra.Command{
Use: "generate",
Short: "Generate a compliance report",
Long: "Generate a compliance report for a specified framework and time period",
RunE: runReportGenerate,
}
var reportSummaryCmd = &cobra.Command{
Use: "summary",
Short: "Show compliance summary",
Long: "Display a quick compliance summary for the specified framework",
RunE: runReportSummary,
}
var reportListCmd = &cobra.Command{
Use: "list",
Short: "List available frameworks",
Long: "Display all available compliance frameworks",
RunE: runReportList,
}
var reportControlsCmd = &cobra.Command{
Use: "controls [framework]",
Short: "List controls for a framework",
Long: "Display all controls for a specific compliance framework",
Args: cobra.ExactArgs(1),
RunE: runReportControls,
}
var (
reportType string
reportDays int
reportStartDate string
reportEndDate string
reportFormat string
reportOutput string
reportCatalog string
reportTitle string
includeEvidence bool
)
func init() {
rootCmd.AddCommand(reportCmd)
reportCmd.AddCommand(reportGenerateCmd)
reportCmd.AddCommand(reportSummaryCmd)
reportCmd.AddCommand(reportListCmd)
reportCmd.AddCommand(reportControlsCmd)
// Generate command flags
reportGenerateCmd.Flags().StringVarP(&reportType, "type", "t", "soc2", "Report type (soc2, gdpr, hipaa, pci-dss, iso27001)")
reportGenerateCmd.Flags().IntVarP(&reportDays, "days", "d", 90, "Number of days to include in report")
reportGenerateCmd.Flags().StringVar(&reportStartDate, "start", "", "Start date (YYYY-MM-DD)")
reportGenerateCmd.Flags().StringVar(&reportEndDate, "end", "", "End date (YYYY-MM-DD)")
reportGenerateCmd.Flags().StringVarP(&reportFormat, "format", "f", "markdown", "Output format (json, markdown, html)")
reportGenerateCmd.Flags().StringVarP(&reportOutput, "output", "o", "", "Output file path")
reportGenerateCmd.Flags().StringVar(&reportCatalog, "catalog", "", "Path to backup catalog database")
reportGenerateCmd.Flags().StringVar(&reportTitle, "title", "", "Custom report title")
reportGenerateCmd.Flags().BoolVar(&includeEvidence, "evidence", true, "Include evidence in report")
// Summary command flags
reportSummaryCmd.Flags().StringVarP(&reportType, "type", "t", "soc2", "Report type")
reportSummaryCmd.Flags().IntVarP(&reportDays, "days", "d", 90, "Number of days to include")
reportSummaryCmd.Flags().StringVar(&reportCatalog, "catalog", "", "Path to backup catalog database")
}
func runReportGenerate(cmd *cobra.Command, args []string) error {
// Determine time period
var startDate, endDate time.Time
endDate = time.Now()
if reportStartDate != "" {
parsed, err := time.Parse("2006-01-02", reportStartDate)
if err != nil {
return fmt.Errorf("invalid start date: %w", err)
}
startDate = parsed
} else {
startDate = endDate.AddDate(0, 0, -reportDays)
}
if reportEndDate != "" {
parsed, err := time.Parse("2006-01-02", reportEndDate)
if err != nil {
return fmt.Errorf("invalid end date: %w", err)
}
endDate = parsed
}
// Determine report type
rptType := parseReportType(reportType)
if rptType == "" {
return fmt.Errorf("unknown report type: %s", reportType)
}
// Get catalog path
catalogPath := reportCatalog
if catalogPath == "" {
homeDir, _ := os.UserHomeDir()
catalogPath = filepath.Join(homeDir, ".dbbackup", "catalog.db")
}
// Open catalog
cat, err := catalog.NewSQLiteCatalog(catalogPath)
if err != nil {
return fmt.Errorf("failed to open catalog: %w", err)
}
defer cat.Close()
// Configure generator
config := report.ReportConfig{
Type: rptType,
PeriodStart: startDate,
PeriodEnd: endDate,
CatalogPath: catalogPath,
OutputFormat: parseOutputFormat(reportFormat),
OutputPath: reportOutput,
IncludeEvidence: includeEvidence,
}
if reportTitle != "" {
config.Title = reportTitle
}
// Generate report
gen := report.NewGenerator(cat, config)
rpt, err := gen.Generate()
if err != nil {
return fmt.Errorf("failed to generate report: %w", err)
}
// Get formatter
formatter := report.GetFormatter(config.OutputFormat)
// Write output
var output *os.File
if reportOutput != "" {
output, err = os.Create(reportOutput)
if err != nil {
return fmt.Errorf("failed to create output file: %w", err)
}
defer output.Close()
} else {
output = os.Stdout
}
if err := formatter.Format(rpt, output); err != nil {
return fmt.Errorf("failed to format report: %w", err)
}
if reportOutput != "" {
fmt.Printf("Report generated: %s\n", reportOutput)
fmt.Printf(" Type: %s\n", rpt.Type)
fmt.Printf(" Status: %s %s\n", report.StatusIcon(rpt.Status), rpt.Status)
fmt.Printf(" Score: %.1f%%\n", rpt.Score)
fmt.Printf(" Findings: %d open\n", rpt.Summary.OpenFindings)
}
return nil
}
func runReportSummary(cmd *cobra.Command, args []string) error {
endDate := time.Now()
startDate := endDate.AddDate(0, 0, -reportDays)
rptType := parseReportType(reportType)
if rptType == "" {
return fmt.Errorf("unknown report type: %s", reportType)
}
// Get catalog path
catalogPath := reportCatalog
if catalogPath == "" {
homeDir, _ := os.UserHomeDir()
catalogPath = filepath.Join(homeDir, ".dbbackup", "catalog.db")
}
// Open catalog
cat, err := catalog.NewSQLiteCatalog(catalogPath)
if err != nil {
return fmt.Errorf("failed to open catalog: %w", err)
}
defer cat.Close()
// Configure and generate
config := report.ReportConfig{
Type: rptType,
PeriodStart: startDate,
PeriodEnd: endDate,
CatalogPath: catalogPath,
}
gen := report.NewGenerator(cat, config)
rpt, err := gen.Generate()
if err != nil {
return fmt.Errorf("failed to generate report: %w", err)
}
// Display console summary
formatter := &report.ConsoleFormatter{}
return formatter.Format(rpt, os.Stdout)
}
func runReportList(cmd *cobra.Command, args []string) error {
fmt.Println("\nAvailable Compliance Frameworks:")
fmt.Println(strings.Repeat("-", 50))
fmt.Printf(" %-12s %s\n", "soc2", "SOC 2 Type II Trust Service Criteria")
fmt.Printf(" %-12s %s\n", "gdpr", "General Data Protection Regulation (EU)")
fmt.Printf(" %-12s %s\n", "hipaa", "Health Insurance Portability and Accountability Act")
fmt.Printf(" %-12s %s\n", "pci-dss", "Payment Card Industry Data Security Standard")
fmt.Printf(" %-12s %s\n", "iso27001", "ISO 27001 Information Security Management")
fmt.Println()
fmt.Println("Usage: dbbackup report generate --type <framework>")
fmt.Println()
return nil
}
func runReportControls(cmd *cobra.Command, args []string) error {
rptType := parseReportType(args[0])
if rptType == "" {
return fmt.Errorf("unknown report type: %s", args[0])
}
framework := report.GetFramework(rptType)
if framework == nil {
return fmt.Errorf("no framework defined for: %s", args[0])
}
fmt.Printf("\n%s Controls\n", strings.ToUpper(args[0]))
fmt.Println(strings.Repeat("=", 60))
for _, cat := range framework {
fmt.Printf("\n%s\n", cat.Name)
fmt.Printf("%s\n", cat.Description)
fmt.Println(strings.Repeat("-", 40))
for _, ctrl := range cat.Controls {
fmt.Printf(" [%s] %s\n", ctrl.Reference, ctrl.Name)
fmt.Printf(" %s\n", ctrl.Description)
}
}
fmt.Println()
return nil
}
func parseReportType(s string) report.ReportType {
switch strings.ToLower(s) {
case "soc2", "soc-2", "soc2-type2":
return report.ReportSOC2
case "gdpr":
return report.ReportGDPR
case "hipaa":
return report.ReportHIPAA
case "pci-dss", "pcidss", "pci":
return report.ReportPCIDSS
case "iso27001", "iso-27001", "iso":
return report.ReportISO27001
case "custom":
return report.ReportCustom
default:
return ""
}
}
func parseOutputFormat(s string) report.OutputFormat {
switch strings.ToLower(s) {
case "json":
return report.FormatJSON
case "html":
return report.FormatHTML
case "md", "markdown":
return report.FormatMarkdown
case "pdf":
return report.FormatPDF
default:
return report.FormatMarkdown
}
}

458
cmd/rto.go Normal file
View File

@@ -0,0 +1,458 @@
package cmd
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"dbbackup/internal/catalog"
"dbbackup/internal/rto"
"github.com/spf13/cobra"
)
var rtoCmd = &cobra.Command{
Use: "rto",
Short: "RTO/RPO analysis and monitoring",
Long: `Analyze and monitor Recovery Time Objective (RTO) and
Recovery Point Objective (RPO) metrics.
RTO: How long to recover from a failure
RPO: How much data you can afford to lose
Examples:
# Analyze RTO/RPO for all databases
dbbackup rto analyze
# Analyze specific database
dbbackup rto analyze --database mydb
# Show summary status
dbbackup rto status
# Set targets and check compliance
dbbackup rto check --target-rto 4h --target-rpo 1h`,
}
var rtoAnalyzeCmd = &cobra.Command{
Use: "analyze",
Short: "Analyze RTO/RPO for databases",
Long: "Perform detailed RTO/RPO analysis based on backup history",
RunE: runRTOAnalyze,
}
var rtoStatusCmd = &cobra.Command{
Use: "status",
Short: "Show RTO/RPO status summary",
Long: "Display current RTO/RPO compliance status for all databases",
RunE: runRTOStatus,
}
var rtoCheckCmd = &cobra.Command{
Use: "check",
Short: "Check RTO/RPO compliance",
Long: "Check if databases meet RTO/RPO targets",
RunE: runRTOCheck,
}
var (
rtoDatabase string
rtoTargetRTO string
rtoTargetRPO string
rtoCatalog string
rtoFormat string
rtoOutput string
)
func init() {
rootCmd.AddCommand(rtoCmd)
rtoCmd.AddCommand(rtoAnalyzeCmd)
rtoCmd.AddCommand(rtoStatusCmd)
rtoCmd.AddCommand(rtoCheckCmd)
// Analyze command flags
rtoAnalyzeCmd.Flags().StringVarP(&rtoDatabase, "database", "d", "", "Database to analyze (all if not specified)")
rtoAnalyzeCmd.Flags().StringVar(&rtoTargetRTO, "target-rto", "4h", "Target RTO (e.g., 4h, 30m)")
rtoAnalyzeCmd.Flags().StringVar(&rtoTargetRPO, "target-rpo", "1h", "Target RPO (e.g., 1h, 15m)")
rtoAnalyzeCmd.Flags().StringVar(&rtoCatalog, "catalog", "", "Path to backup catalog")
rtoAnalyzeCmd.Flags().StringVarP(&rtoFormat, "format", "f", "text", "Output format (text, json)")
rtoAnalyzeCmd.Flags().StringVarP(&rtoOutput, "output", "o", "", "Output file")
// Status command flags
rtoStatusCmd.Flags().StringVar(&rtoCatalog, "catalog", "", "Path to backup catalog")
rtoStatusCmd.Flags().StringVar(&rtoTargetRTO, "target-rto", "4h", "Target RTO")
rtoStatusCmd.Flags().StringVar(&rtoTargetRPO, "target-rpo", "1h", "Target RPO")
// Check command flags
rtoCheckCmd.Flags().StringVarP(&rtoDatabase, "database", "d", "", "Database to check")
rtoCheckCmd.Flags().StringVar(&rtoTargetRTO, "target-rto", "4h", "Target RTO")
rtoCheckCmd.Flags().StringVar(&rtoTargetRPO, "target-rpo", "1h", "Target RPO")
rtoCheckCmd.Flags().StringVar(&rtoCatalog, "catalog", "", "Path to backup catalog")
}
func runRTOAnalyze(cmd *cobra.Command, args []string) error {
ctx := context.Background()
// Parse duration targets
targetRTO, err := time.ParseDuration(rtoTargetRTO)
if err != nil {
return fmt.Errorf("invalid target-rto: %w", err)
}
targetRPO, err := time.ParseDuration(rtoTargetRPO)
if err != nil {
return fmt.Errorf("invalid target-rpo: %w", err)
}
// Get catalog
cat, err := openRTOCatalog()
if err != nil {
return err
}
defer cat.Close()
// Create calculator
config := rto.DefaultConfig()
config.TargetRTO = targetRTO
config.TargetRPO = targetRPO
calc := rto.NewCalculator(cat, config)
var analyses []*rto.Analysis
if rtoDatabase != "" {
// Analyze single database
analysis, err := calc.Analyze(ctx, rtoDatabase)
if err != nil {
return fmt.Errorf("analysis failed: %w", err)
}
analyses = append(analyses, analysis)
} else {
// Analyze all databases
analyses, err = calc.AnalyzeAll(ctx)
if err != nil {
return fmt.Errorf("analysis failed: %w", err)
}
}
// Output
if rtoFormat == "json" {
return outputJSON(analyses, rtoOutput)
}
return outputAnalysisText(analyses)
}
func runRTOStatus(cmd *cobra.Command, args []string) error {
ctx := context.Background()
// Parse targets
targetRTO, err := time.ParseDuration(rtoTargetRTO)
if err != nil {
return fmt.Errorf("invalid target-rto: %w", err)
}
targetRPO, err := time.ParseDuration(rtoTargetRPO)
if err != nil {
return fmt.Errorf("invalid target-rpo: %w", err)
}
// Get catalog
cat, err := openRTOCatalog()
if err != nil {
return err
}
defer cat.Close()
// Create calculator and analyze all
config := rto.DefaultConfig()
config.TargetRTO = targetRTO
config.TargetRPO = targetRPO
calc := rto.NewCalculator(cat, config)
analyses, err := calc.AnalyzeAll(ctx)
if err != nil {
return fmt.Errorf("analysis failed: %w", err)
}
// Create summary
summary := rto.Summarize(analyses)
// Display status
fmt.Println()
fmt.Println("╔═══════════════════════════════════════════════════════════╗")
fmt.Println("║ RTO/RPO STATUS SUMMARY ║")
fmt.Println("╠═══════════════════════════════════════════════════════════╣")
fmt.Printf("║ Target RTO: %-15s Target RPO: %-15s ║\n",
formatDuration(config.TargetRTO),
formatDuration(config.TargetRPO))
fmt.Println("╠═══════════════════════════════════════════════════════════╣")
// Compliance status
rpoRate := 0.0
rtoRate := 0.0
fullRate := 0.0
if summary.TotalDatabases > 0 {
rpoRate = float64(summary.RPOCompliant) / float64(summary.TotalDatabases) * 100
rtoRate = float64(summary.RTOCompliant) / float64(summary.TotalDatabases) * 100
fullRate = float64(summary.FullyCompliant) / float64(summary.TotalDatabases) * 100
}
fmt.Printf("║ Databases: %-5d ║\n", summary.TotalDatabases)
fmt.Printf("║ RPO Compliant: %-5d (%.0f%%) ║\n", summary.RPOCompliant, rpoRate)
fmt.Printf("║ RTO Compliant: %-5d (%.0f%%) ║\n", summary.RTOCompliant, rtoRate)
fmt.Printf("║ Fully Compliant: %-3d (%.0f%%) ║\n", summary.FullyCompliant, fullRate)
if summary.CriticalIssues > 0 {
fmt.Printf("║ ⚠️ Critical Issues: %-3d ║\n", summary.CriticalIssues)
}
fmt.Println("╠═══════════════════════════════════════════════════════════╣")
fmt.Printf("║ Average RPO: %-15s Worst: %-15s ║\n",
formatDuration(summary.AverageRPO),
formatDuration(summary.WorstRPO))
fmt.Printf("║ Average RTO: %-15s Worst: %-15s ║\n",
formatDuration(summary.AverageRTO),
formatDuration(summary.WorstRTO))
if summary.WorstRPODatabase != "" {
fmt.Printf("║ Worst RPO Database: %-38s║\n", summary.WorstRPODatabase)
}
if summary.WorstRTODatabase != "" {
fmt.Printf("║ Worst RTO Database: %-38s║\n", summary.WorstRTODatabase)
}
fmt.Println("╚═══════════════════════════════════════════════════════════╝")
fmt.Println()
// Per-database status
if len(analyses) > 0 {
fmt.Println("Database Status:")
fmt.Println(strings.Repeat("-", 70))
fmt.Printf("%-25s %-12s %-12s %-12s\n", "DATABASE", "RPO", "RTO", "STATUS")
fmt.Println(strings.Repeat("-", 70))
for _, a := range analyses {
status := "✅"
if !a.RPOCompliant || !a.RTOCompliant {
status = "❌"
}
rpoStr := formatDuration(a.CurrentRPO)
rtoStr := formatDuration(a.CurrentRTO)
if !a.RPOCompliant {
rpoStr = "⚠️ " + rpoStr
}
if !a.RTOCompliant {
rtoStr = "⚠️ " + rtoStr
}
fmt.Printf("%-25s %-12s %-12s %s\n",
truncateRTO(a.Database, 24),
rpoStr,
rtoStr,
status)
}
fmt.Println(strings.Repeat("-", 70))
}
return nil
}
func runRTOCheck(cmd *cobra.Command, args []string) error {
ctx := context.Background()
// Parse targets
targetRTO, err := time.ParseDuration(rtoTargetRTO)
if err != nil {
return fmt.Errorf("invalid target-rto: %w", err)
}
targetRPO, err := time.ParseDuration(rtoTargetRPO)
if err != nil {
return fmt.Errorf("invalid target-rpo: %w", err)
}
// Get catalog
cat, err := openRTOCatalog()
if err != nil {
return err
}
defer cat.Close()
// Create calculator
config := rto.DefaultConfig()
config.TargetRTO = targetRTO
config.TargetRPO = targetRPO
calc := rto.NewCalculator(cat, config)
var analyses []*rto.Analysis
if rtoDatabase != "" {
analysis, err := calc.Analyze(ctx, rtoDatabase)
if err != nil {
return fmt.Errorf("analysis failed: %w", err)
}
analyses = append(analyses, analysis)
} else {
analyses, err = calc.AnalyzeAll(ctx)
if err != nil {
return fmt.Errorf("analysis failed: %w", err)
}
}
// Check compliance
exitCode := 0
for _, a := range analyses {
if !a.RPOCompliant {
fmt.Printf("❌ %s: RPO violation - current %s exceeds target %s\n",
a.Database,
formatDuration(a.CurrentRPO),
formatDuration(config.TargetRPO))
exitCode = 1
}
if !a.RTOCompliant {
fmt.Printf("❌ %s: RTO violation - estimated %s exceeds target %s\n",
a.Database,
formatDuration(a.CurrentRTO),
formatDuration(config.TargetRTO))
exitCode = 1
}
if a.RPOCompliant && a.RTOCompliant {
fmt.Printf("✅ %s: Compliant (RPO: %s, RTO: %s)\n",
a.Database,
formatDuration(a.CurrentRPO),
formatDuration(a.CurrentRTO))
}
}
if exitCode != 0 {
os.Exit(exitCode)
}
return nil
}
func openRTOCatalog() (*catalog.SQLiteCatalog, error) {
catalogPath := rtoCatalog
if catalogPath == "" {
homeDir, _ := os.UserHomeDir()
catalogPath = filepath.Join(homeDir, ".dbbackup", "catalog.db")
}
cat, err := catalog.NewSQLiteCatalog(catalogPath)
if err != nil {
return nil, fmt.Errorf("failed to open catalog: %w", err)
}
return cat, nil
}
func outputJSON(data interface{}, outputPath string) error {
jsonData, err := json.MarshalIndent(data, "", " ")
if err != nil {
return err
}
if outputPath != "" {
return os.WriteFile(outputPath, jsonData, 0644)
}
fmt.Println(string(jsonData))
return nil
}
func outputAnalysisText(analyses []*rto.Analysis) error {
for _, a := range analyses {
fmt.Println()
fmt.Println(strings.Repeat("=", 60))
fmt.Printf(" Database: %s\n", a.Database)
fmt.Println(strings.Repeat("=", 60))
// Status
rpoStatus := "✅ Compliant"
if !a.RPOCompliant {
rpoStatus = "❌ Violation"
}
rtoStatus := "✅ Compliant"
if !a.RTOCompliant {
rtoStatus = "❌ Violation"
}
fmt.Println()
fmt.Println(" Recovery Objectives:")
fmt.Println(strings.Repeat("-", 50))
fmt.Printf(" RPO (Current): %-15s Target: %s\n",
formatDuration(a.CurrentRPO), formatDuration(a.TargetRPO))
fmt.Printf(" RPO Status: %s\n", rpoStatus)
fmt.Printf(" RTO (Estimated): %-14s Target: %s\n",
formatDuration(a.CurrentRTO), formatDuration(a.TargetRTO))
fmt.Printf(" RTO Status: %s\n", rtoStatus)
if a.LastBackup != nil {
fmt.Printf(" Last Backup: %s\n", a.LastBackup.Format("2006-01-02 15:04:05"))
}
if a.BackupInterval > 0 {
fmt.Printf(" Backup Interval: %s\n", formatDuration(a.BackupInterval))
}
// RTO Breakdown
fmt.Println()
fmt.Println(" RTO Breakdown:")
fmt.Println(strings.Repeat("-", 50))
b := a.RTOBreakdown
fmt.Printf(" Detection: %s\n", formatDuration(b.DetectionTime))
fmt.Printf(" Decision: %s\n", formatDuration(b.DecisionTime))
if b.DownloadTime > 0 {
fmt.Printf(" Download: %s\n", formatDuration(b.DownloadTime))
}
fmt.Printf(" Restore: %s\n", formatDuration(b.RestoreTime))
fmt.Printf(" Startup: %s\n", formatDuration(b.StartupTime))
fmt.Printf(" Validation: %s\n", formatDuration(b.ValidationTime))
fmt.Printf(" Switchover: %s\n", formatDuration(b.SwitchoverTime))
fmt.Println(strings.Repeat("-", 30))
fmt.Printf(" Total: %s\n", formatDuration(b.TotalTime))
// Recommendations
if len(a.Recommendations) > 0 {
fmt.Println()
fmt.Println(" Recommendations:")
fmt.Println(strings.Repeat("-", 50))
for _, r := range a.Recommendations {
icon := "💡"
switch r.Priority {
case rto.PriorityCritical:
icon = "🔴"
case rto.PriorityHigh:
icon = "🟠"
case rto.PriorityMedium:
icon = "🟡"
}
fmt.Printf(" %s [%s] %s\n", icon, r.Priority, r.Title)
fmt.Printf(" %s\n", r.Description)
}
}
}
return nil
}
func formatDuration(d time.Duration) string {
if d < time.Minute {
return fmt.Sprintf("%.0fs", d.Seconds())
}
if d < time.Hour {
return fmt.Sprintf("%.0fm", d.Minutes())
}
hours := int(d.Hours())
mins := int(d.Minutes()) - hours*60
return fmt.Sprintf("%dh %dm", hours, mins)
}
func truncateRTO(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen-3] + "..."
}

1
go.mod
View File

@@ -79,6 +79,7 @@ require (
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect
github.com/mattn/go-runewidth v0.0.16 // indirect
github.com/mattn/go-sqlite3 v1.14.32 // indirect
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/termenv v0.16.0 // indirect

2
go.sum
View File

@@ -153,6 +153,8 @@ github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2J
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=

188
internal/catalog/catalog.go Normal file
View File

@@ -0,0 +1,188 @@
// Package catalog provides backup catalog management with SQLite storage
package catalog
import (
"context"
"fmt"
"time"
)
// Entry represents a single backup in the catalog
type Entry struct {
ID int64 `json:"id"`
Database string `json:"database"`
DatabaseType string `json:"database_type"` // postgresql, mysql, mariadb
Host string `json:"host"`
Port int `json:"port"`
BackupPath string `json:"backup_path"`
BackupType string `json:"backup_type"` // full, incremental
SizeBytes int64 `json:"size_bytes"`
SHA256 string `json:"sha256"`
Compression string `json:"compression"`
Encrypted bool `json:"encrypted"`
CreatedAt time.Time `json:"created_at"`
Duration float64 `json:"duration_seconds"`
Status BackupStatus `json:"status"`
VerifiedAt *time.Time `json:"verified_at,omitempty"`
VerifyValid *bool `json:"verify_valid,omitempty"`
DrillTestedAt *time.Time `json:"drill_tested_at,omitempty"`
DrillSuccess *bool `json:"drill_success,omitempty"`
CloudLocation string `json:"cloud_location,omitempty"`
RetentionPolicy string `json:"retention_policy,omitempty"` // daily, weekly, monthly, yearly
Tags map[string]string `json:"tags,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
}
// BackupStatus represents the state of a backup
type BackupStatus string
const (
StatusCompleted BackupStatus = "completed"
StatusFailed BackupStatus = "failed"
StatusVerified BackupStatus = "verified"
StatusCorrupted BackupStatus = "corrupted"
StatusDeleted BackupStatus = "deleted"
StatusArchived BackupStatus = "archived"
)
// Gap represents a detected backup gap
type Gap struct {
Database string `json:"database"`
GapStart time.Time `json:"gap_start"`
GapEnd time.Time `json:"gap_end"`
Duration time.Duration `json:"duration"`
ExpectedAt time.Time `json:"expected_at"`
Description string `json:"description"`
Severity GapSeverity `json:"severity"`
}
// GapSeverity indicates how serious a backup gap is
type GapSeverity string
const (
SeverityInfo GapSeverity = "info" // Gap within tolerance
SeverityWarning GapSeverity = "warning" // Gap exceeds expected interval
SeverityCritical GapSeverity = "critical" // Gap exceeds RPO
)
// Stats contains backup statistics
type Stats struct {
TotalBackups int64 `json:"total_backups"`
TotalSize int64 `json:"total_size_bytes"`
TotalSizeHuman string `json:"total_size_human"`
OldestBackup *time.Time `json:"oldest_backup,omitempty"`
NewestBackup *time.Time `json:"newest_backup,omitempty"`
ByDatabase map[string]int64 `json:"by_database"`
ByType map[string]int64 `json:"by_type"`
ByStatus map[string]int64 `json:"by_status"`
VerifiedCount int64 `json:"verified_count"`
DrillTestedCount int64 `json:"drill_tested_count"`
AvgDuration float64 `json:"avg_duration_seconds"`
AvgSize int64 `json:"avg_size_bytes"`
GapsDetected int `json:"gaps_detected"`
}
// SearchQuery represents search criteria for catalog entries
type SearchQuery struct {
Database string // Filter by database name (supports wildcards)
DatabaseType string // Filter by database type
Host string // Filter by host
Status string // Filter by status
StartDate *time.Time // Backups after this date
EndDate *time.Time // Backups before this date
MinSize int64 // Minimum size in bytes
MaxSize int64 // Maximum size in bytes
BackupType string // full, incremental
Encrypted *bool // Filter by encryption status
Verified *bool // Filter by verification status
DrillTested *bool // Filter by drill test status
Limit int // Max results (0 = no limit)
Offset int // Offset for pagination
OrderBy string // Field to order by
OrderDesc bool // Order descending
}
// GapDetectionConfig configures gap detection
type GapDetectionConfig struct {
ExpectedInterval time.Duration // Expected backup interval (e.g., 24h)
Tolerance time.Duration // Allowed variance (e.g., 1h)
RPOThreshold time.Duration // Critical threshold (RPO)
StartDate *time.Time // Start of analysis window
EndDate *time.Time // End of analysis window
}
// Catalog defines the interface for backup catalog operations
type Catalog interface {
// Entry management
Add(ctx context.Context, entry *Entry) error
Update(ctx context.Context, entry *Entry) error
Delete(ctx context.Context, id int64) error
Get(ctx context.Context, id int64) (*Entry, error)
GetByPath(ctx context.Context, path string) (*Entry, error)
// Search and listing
Search(ctx context.Context, query *SearchQuery) ([]*Entry, error)
List(ctx context.Context, database string, limit int) ([]*Entry, error)
ListDatabases(ctx context.Context) ([]string, error)
Count(ctx context.Context, query *SearchQuery) (int64, error)
// Statistics
Stats(ctx context.Context) (*Stats, error)
StatsByDatabase(ctx context.Context, database string) (*Stats, error)
// Gap detection
DetectGaps(ctx context.Context, database string, config *GapDetectionConfig) ([]*Gap, error)
DetectAllGaps(ctx context.Context, config *GapDetectionConfig) (map[string][]*Gap, error)
// Verification tracking
MarkVerified(ctx context.Context, id int64, valid bool) error
MarkDrillTested(ctx context.Context, id int64, success bool) error
// Sync with filesystem
SyncFromDirectory(ctx context.Context, dir string) (*SyncResult, error)
SyncFromCloud(ctx context.Context, provider, bucket, prefix string) (*SyncResult, error)
// Maintenance
Prune(ctx context.Context, before time.Time) (int, error)
Vacuum(ctx context.Context) error
Close() error
}
// SyncResult contains results from a catalog sync operation
type SyncResult struct {
Added int `json:"added"`
Updated int `json:"updated"`
Removed int `json:"removed"`
Errors int `json:"errors"`
Duration float64 `json:"duration_seconds"`
Details []string `json:"details,omitempty"`
}
// FormatSize formats bytes as human-readable string
func FormatSize(bytes int64) string {
const unit = 1024
if bytes < unit {
return fmt.Sprintf("%d B", bytes)
}
div, exp := int64(unit), 0
for n := bytes / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(div), "KMGTPE"[exp])
}
// FormatDuration formats duration as human-readable string
func FormatDuration(d time.Duration) string {
if d < time.Minute {
return fmt.Sprintf("%.0fs", d.Seconds())
}
if d < time.Hour {
mins := int(d.Minutes())
secs := int(d.Seconds()) - mins*60
return fmt.Sprintf("%dm %ds", mins, secs)
}
hours := int(d.Hours())
mins := int(d.Minutes()) - hours*60
return fmt.Sprintf("%dh %dm", hours, mins)
}

View File

@@ -0,0 +1,308 @@
package catalog
import (
"context"
"fmt"
"os"
"path/filepath"
"testing"
"time"
)
func TestSQLiteCatalog(t *testing.T) {
// Create temp directory for test database
tmpDir, err := os.MkdirTemp("", "catalog_test")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
dbPath := filepath.Join(tmpDir, "test_catalog.db")
// Test creation
cat, err := NewSQLiteCatalog(dbPath)
if err != nil {
t.Fatalf("Failed to create catalog: %v", err)
}
defer cat.Close()
ctx := context.Background()
// Test Add
entry := &Entry{
Database: "testdb",
DatabaseType: "postgresql",
Host: "localhost",
Port: 5432,
BackupPath: "/backups/testdb_20240115.dump.gz",
BackupType: "full",
SizeBytes: 1024 * 1024 * 100, // 100 MB
SHA256: "abc123def456",
Compression: "gzip",
Encrypted: false,
CreatedAt: time.Now().Add(-24 * time.Hour),
Duration: 45.5,
Status: StatusCompleted,
}
err = cat.Add(ctx, entry)
if err != nil {
t.Fatalf("Failed to add entry: %v", err)
}
if entry.ID == 0 {
t.Error("Expected entry ID to be set after Add")
}
// Test Get
retrieved, err := cat.Get(ctx, entry.ID)
if err != nil {
t.Fatalf("Failed to get entry: %v", err)
}
if retrieved == nil {
t.Fatal("Expected to retrieve entry, got nil")
}
if retrieved.Database != "testdb" {
t.Errorf("Expected database 'testdb', got '%s'", retrieved.Database)
}
if retrieved.SizeBytes != entry.SizeBytes {
t.Errorf("Expected size %d, got %d", entry.SizeBytes, retrieved.SizeBytes)
}
// Test GetByPath
byPath, err := cat.GetByPath(ctx, entry.BackupPath)
if err != nil {
t.Fatalf("Failed to get by path: %v", err)
}
if byPath == nil || byPath.ID != entry.ID {
t.Error("GetByPath returned wrong entry")
}
// Test List
entries, err := cat.List(ctx, "testdb", 10)
if err != nil {
t.Fatalf("Failed to list entries: %v", err)
}
if len(entries) != 1 {
t.Errorf("Expected 1 entry, got %d", len(entries))
}
// Test ListDatabases
databases, err := cat.ListDatabases(ctx)
if err != nil {
t.Fatalf("Failed to list databases: %v", err)
}
if len(databases) != 1 || databases[0] != "testdb" {
t.Errorf("Expected ['testdb'], got %v", databases)
}
// Test Stats
stats, err := cat.Stats(ctx)
if err != nil {
t.Fatalf("Failed to get stats: %v", err)
}
if stats.TotalBackups != 1 {
t.Errorf("Expected 1 total backup, got %d", stats.TotalBackups)
}
if stats.TotalSize != entry.SizeBytes {
t.Errorf("Expected size %d, got %d", entry.SizeBytes, stats.TotalSize)
}
// Test MarkVerified
err = cat.MarkVerified(ctx, entry.ID, true)
if err != nil {
t.Fatalf("Failed to mark verified: %v", err)
}
verified, _ := cat.Get(ctx, entry.ID)
if verified.VerifiedAt == nil {
t.Error("Expected VerifiedAt to be set")
}
if verified.VerifyValid == nil || !*verified.VerifyValid {
t.Error("Expected VerifyValid to be true")
}
// Test Update
entry.SizeBytes = 200 * 1024 * 1024 // 200 MB
err = cat.Update(ctx, entry)
if err != nil {
t.Fatalf("Failed to update entry: %v", err)
}
updated, _ := cat.Get(ctx, entry.ID)
if updated.SizeBytes != entry.SizeBytes {
t.Errorf("Update failed: expected size %d, got %d", entry.SizeBytes, updated.SizeBytes)
}
// Test Search with filters
query := &SearchQuery{
Database: "testdb",
Limit: 10,
OrderBy: "created_at",
OrderDesc: true,
}
results, err := cat.Search(ctx, query)
if err != nil {
t.Fatalf("Search failed: %v", err)
}
if len(results) != 1 {
t.Errorf("Expected 1 result, got %d", len(results))
}
// Test Search with wildcards
query.Database = "test*"
results, err = cat.Search(ctx, query)
if err != nil {
t.Fatalf("Wildcard search failed: %v", err)
}
if len(results) != 1 {
t.Errorf("Expected 1 result from wildcard search, got %d", len(results))
}
// Test Count
count, err := cat.Count(ctx, &SearchQuery{Database: "testdb"})
if err != nil {
t.Fatalf("Count failed: %v", err)
}
if count != 1 {
t.Errorf("Expected count 1, got %d", count)
}
// Test Delete
err = cat.Delete(ctx, entry.ID)
if err != nil {
t.Fatalf("Failed to delete entry: %v", err)
}
deleted, _ := cat.Get(ctx, entry.ID)
if deleted != nil {
t.Error("Expected entry to be deleted")
}
}
func TestGapDetection(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "catalog_gaps_test")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
dbPath := filepath.Join(tmpDir, "test_catalog.db")
cat, err := NewSQLiteCatalog(dbPath)
if err != nil {
t.Fatalf("Failed to create catalog: %v", err)
}
defer cat.Close()
ctx := context.Background()
// Add backups with varying intervals
now := time.Now()
backups := []time.Time{
now.Add(-7 * 24 * time.Hour), // 7 days ago
now.Add(-6 * 24 * time.Hour), // 6 days ago (OK)
now.Add(-5 * 24 * time.Hour), // 5 days ago (OK)
// Missing 4 days ago - GAP
now.Add(-3 * 24 * time.Hour), // 3 days ago
now.Add(-2 * 24 * time.Hour), // 2 days ago (OK)
// Missing 1 day ago and today - GAP to now
}
for i, ts := range backups {
entry := &Entry{
Database: "gaptest",
DatabaseType: "postgresql",
BackupPath: filepath.Join(tmpDir, fmt.Sprintf("backup_%d.dump", i)),
BackupType: "full",
CreatedAt: ts,
Status: StatusCompleted,
}
cat.Add(ctx, entry)
}
// Detect gaps with 24h expected interval
config := &GapDetectionConfig{
ExpectedInterval: 24 * time.Hour,
Tolerance: 2 * time.Hour,
RPOThreshold: 48 * time.Hour,
}
gaps, err := cat.DetectGaps(ctx, "gaptest", config)
if err != nil {
t.Fatalf("Gap detection failed: %v", err)
}
// Should detect at least 2 gaps:
// 1. Between 5 days ago and 3 days ago (missing 4 days ago)
// 2. Between 2 days ago and now (missing recent backups)
if len(gaps) < 2 {
t.Errorf("Expected at least 2 gaps, got %d", len(gaps))
}
// Check gap severities
hasCritical := false
for _, gap := range gaps {
if gap.Severity == SeverityCritical {
hasCritical = true
}
if gap.Duration < config.ExpectedInterval {
t.Errorf("Gap duration %v is less than expected interval", gap.Duration)
}
}
// The gap from 2 days ago to now should be critical (>48h)
if !hasCritical {
t.Log("Note: Expected at least one critical gap")
}
}
func TestFormatSize(t *testing.T) {
tests := []struct {
bytes int64
expected string
}{
{0, "0 B"},
{500, "500 B"},
{1024, "1.0 KB"},
{1024 * 1024, "1.0 MB"},
{1024 * 1024 * 1024, "1.0 GB"},
{1024 * 1024 * 1024 * 1024, "1.0 TB"},
}
for _, test := range tests {
result := FormatSize(test.bytes)
if result != test.expected {
t.Errorf("FormatSize(%d) = %s, expected %s", test.bytes, result, test.expected)
}
}
}
func TestFormatDuration(t *testing.T) {
tests := []struct {
duration time.Duration
expected string
}{
{30 * time.Second, "30s"},
{90 * time.Second, "1m 30s"},
{2 * time.Hour, "2h 0m"},
}
for _, test := range tests {
result := FormatDuration(test.duration)
if result != test.expected {
t.Errorf("FormatDuration(%v) = %s, expected %s", test.duration, result, test.expected)
}
}
}

299
internal/catalog/gaps.go Normal file
View File

@@ -0,0 +1,299 @@
// Package catalog - Gap detection for backup schedules
package catalog
import (
"context"
"sort"
"time"
)
// DetectGaps analyzes backup history and finds gaps in the schedule
func (c *SQLiteCatalog) DetectGaps(ctx context.Context, database string, config *GapDetectionConfig) ([]*Gap, error) {
if config == nil {
config = &GapDetectionConfig{
ExpectedInterval: 24 * time.Hour,
Tolerance: time.Hour,
RPOThreshold: 48 * time.Hour,
}
}
// Get all backups for this database, ordered by time
query := &SearchQuery{
Database: database,
Status: string(StatusCompleted),
OrderBy: "created_at",
OrderDesc: false,
}
if config.StartDate != nil {
query.StartDate = config.StartDate
}
if config.EndDate != nil {
query.EndDate = config.EndDate
}
entries, err := c.Search(ctx, query)
if err != nil {
return nil, err
}
if len(entries) < 2 {
return nil, nil // Not enough backups to detect gaps
}
var gaps []*Gap
for i := 1; i < len(entries); i++ {
prev := entries[i-1]
curr := entries[i]
actualInterval := curr.CreatedAt.Sub(prev.CreatedAt)
expectedWithTolerance := config.ExpectedInterval + config.Tolerance
if actualInterval > expectedWithTolerance {
gap := &Gap{
Database: database,
GapStart: prev.CreatedAt,
GapEnd: curr.CreatedAt,
Duration: actualInterval,
ExpectedAt: prev.CreatedAt.Add(config.ExpectedInterval),
}
// Determine severity
if actualInterval > config.RPOThreshold {
gap.Severity = SeverityCritical
gap.Description = "CRITICAL: Gap exceeds RPO threshold"
} else if actualInterval > config.ExpectedInterval*2 {
gap.Severity = SeverityWarning
gap.Description = "WARNING: Gap exceeds 2x expected interval"
} else {
gap.Severity = SeverityInfo
gap.Description = "INFO: Gap exceeds expected interval"
}
gaps = append(gaps, gap)
}
}
// Check for gap from last backup to now
lastBackup := entries[len(entries)-1]
now := time.Now()
if config.EndDate != nil {
now = *config.EndDate
}
sinceLastBackup := now.Sub(lastBackup.CreatedAt)
if sinceLastBackup > config.ExpectedInterval+config.Tolerance {
gap := &Gap{
Database: database,
GapStart: lastBackup.CreatedAt,
GapEnd: now,
Duration: sinceLastBackup,
ExpectedAt: lastBackup.CreatedAt.Add(config.ExpectedInterval),
}
if sinceLastBackup > config.RPOThreshold {
gap.Severity = SeverityCritical
gap.Description = "CRITICAL: No backup since " + FormatDuration(sinceLastBackup)
} else if sinceLastBackup > config.ExpectedInterval*2 {
gap.Severity = SeverityWarning
gap.Description = "WARNING: No backup since " + FormatDuration(sinceLastBackup)
} else {
gap.Severity = SeverityInfo
gap.Description = "INFO: Backup overdue by " + FormatDuration(sinceLastBackup-config.ExpectedInterval)
}
gaps = append(gaps, gap)
}
return gaps, nil
}
// DetectAllGaps analyzes all databases for backup gaps
func (c *SQLiteCatalog) DetectAllGaps(ctx context.Context, config *GapDetectionConfig) (map[string][]*Gap, error) {
databases, err := c.ListDatabases(ctx)
if err != nil {
return nil, err
}
allGaps := make(map[string][]*Gap)
for _, db := range databases {
gaps, err := c.DetectGaps(ctx, db, config)
if err != nil {
continue // Skip errors for individual databases
}
if len(gaps) > 0 {
allGaps[db] = gaps
}
}
return allGaps, nil
}
// BackupFrequencyAnalysis provides analysis of backup frequency
type BackupFrequencyAnalysis struct {
Database string `json:"database"`
TotalBackups int `json:"total_backups"`
AnalysisPeriod time.Duration `json:"analysis_period"`
AverageInterval time.Duration `json:"average_interval"`
MinInterval time.Duration `json:"min_interval"`
MaxInterval time.Duration `json:"max_interval"`
StdDeviation time.Duration `json:"std_deviation"`
Regularity float64 `json:"regularity"` // 0-1, higher is more regular
GapsDetected int `json:"gaps_detected"`
MissedBackups int `json:"missed_backups"` // Estimated based on expected interval
}
// AnalyzeFrequency analyzes backup frequency for a database
func (c *SQLiteCatalog) AnalyzeFrequency(ctx context.Context, database string, expectedInterval time.Duration) (*BackupFrequencyAnalysis, error) {
query := &SearchQuery{
Database: database,
Status: string(StatusCompleted),
OrderBy: "created_at",
OrderDesc: false,
}
entries, err := c.Search(ctx, query)
if err != nil {
return nil, err
}
if len(entries) < 2 {
return &BackupFrequencyAnalysis{
Database: database,
TotalBackups: len(entries),
}, nil
}
analysis := &BackupFrequencyAnalysis{
Database: database,
TotalBackups: len(entries),
}
// Calculate intervals
var intervals []time.Duration
for i := 1; i < len(entries); i++ {
interval := entries[i].CreatedAt.Sub(entries[i-1].CreatedAt)
intervals = append(intervals, interval)
}
analysis.AnalysisPeriod = entries[len(entries)-1].CreatedAt.Sub(entries[0].CreatedAt)
// Calculate min, max, average
sort.Slice(intervals, func(i, j int) bool {
return intervals[i] < intervals[j]
})
analysis.MinInterval = intervals[0]
analysis.MaxInterval = intervals[len(intervals)-1]
var total time.Duration
for _, interval := range intervals {
total += interval
}
analysis.AverageInterval = total / time.Duration(len(intervals))
// Calculate standard deviation
var sumSquares float64
avgNanos := float64(analysis.AverageInterval.Nanoseconds())
for _, interval := range intervals {
diff := float64(interval.Nanoseconds()) - avgNanos
sumSquares += diff * diff
}
variance := sumSquares / float64(len(intervals))
analysis.StdDeviation = time.Duration(int64(variance)) // Simplified
// Calculate regularity score (lower deviation = higher regularity)
if analysis.AverageInterval > 0 {
deviationRatio := float64(analysis.StdDeviation) / float64(analysis.AverageInterval)
analysis.Regularity = 1.0 - min(deviationRatio, 1.0)
}
// Detect gaps and missed backups
config := &GapDetectionConfig{
ExpectedInterval: expectedInterval,
Tolerance: expectedInterval / 4,
RPOThreshold: expectedInterval * 2,
}
gaps, _ := c.DetectGaps(ctx, database, config)
analysis.GapsDetected = len(gaps)
// Estimate missed backups
if expectedInterval > 0 {
expectedBackups := int(analysis.AnalysisPeriod / expectedInterval)
if expectedBackups > analysis.TotalBackups {
analysis.MissedBackups = expectedBackups - analysis.TotalBackups
}
}
return analysis, nil
}
// RecoveryPointObjective calculates the current RPO status
type RPOStatus struct {
Database string `json:"database"`
LastBackup time.Time `json:"last_backup"`
TimeSinceBackup time.Duration `json:"time_since_backup"`
TargetRPO time.Duration `json:"target_rpo"`
CurrentRPO time.Duration `json:"current_rpo"`
RPOMet bool `json:"rpo_met"`
NextBackupDue time.Time `json:"next_backup_due"`
BackupsIn24Hours int `json:"backups_in_24h"`
BackupsIn7Days int `json:"backups_in_7d"`
}
// CalculateRPOStatus calculates RPO status for a database
func (c *SQLiteCatalog) CalculateRPOStatus(ctx context.Context, database string, targetRPO time.Duration) (*RPOStatus, error) {
status := &RPOStatus{
Database: database,
TargetRPO: targetRPO,
}
// Get most recent backup
entries, err := c.List(ctx, database, 1)
if err != nil {
return nil, err
}
if len(entries) == 0 {
status.RPOMet = false
status.CurrentRPO = time.Duration(0)
return status, nil
}
status.LastBackup = entries[0].CreatedAt
status.TimeSinceBackup = time.Since(entries[0].CreatedAt)
status.CurrentRPO = status.TimeSinceBackup
status.RPOMet = status.TimeSinceBackup <= targetRPO
status.NextBackupDue = entries[0].CreatedAt.Add(targetRPO)
// Count backups in time windows
now := time.Now()
last24h := now.Add(-24 * time.Hour)
last7d := now.Add(-7 * 24 * time.Hour)
count24h, _ := c.Count(ctx, &SearchQuery{
Database: database,
StartDate: &last24h,
Status: string(StatusCompleted),
})
count7d, _ := c.Count(ctx, &SearchQuery{
Database: database,
StartDate: &last7d,
Status: string(StatusCompleted),
})
status.BackupsIn24Hours = int(count24h)
status.BackupsIn7Days = int(count7d)
return status, nil
}
func min(a, b float64) float64 {
if a < b {
return a
}
return b
}

632
internal/catalog/sqlite.go Normal file
View File

@@ -0,0 +1,632 @@
// Package catalog - SQLite storage implementation
package catalog
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"time"
_ "github.com/mattn/go-sqlite3"
)
// SQLiteCatalog implements Catalog interface with SQLite storage
type SQLiteCatalog struct {
db *sql.DB
path string
}
// NewSQLiteCatalog creates a new SQLite-backed catalog
func NewSQLiteCatalog(dbPath string) (*SQLiteCatalog, error) {
// Ensure directory exists
dir := filepath.Dir(dbPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, fmt.Errorf("failed to create catalog directory: %w", err)
}
db, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=ON")
if err != nil {
return nil, fmt.Errorf("failed to open catalog database: %w", err)
}
catalog := &SQLiteCatalog{
db: db,
path: dbPath,
}
if err := catalog.initialize(); err != nil {
db.Close()
return nil, err
}
return catalog, nil
}
// initialize creates the database schema
func (c *SQLiteCatalog) initialize() error {
schema := `
CREATE TABLE IF NOT EXISTS backups (
id INTEGER PRIMARY KEY AUTOINCREMENT,
database TEXT NOT NULL,
database_type TEXT NOT NULL,
host TEXT,
port INTEGER,
backup_path TEXT NOT NULL UNIQUE,
backup_type TEXT DEFAULT 'full',
size_bytes INTEGER,
sha256 TEXT,
compression TEXT,
encrypted INTEGER DEFAULT 0,
created_at DATETIME NOT NULL,
duration REAL,
status TEXT DEFAULT 'completed',
verified_at DATETIME,
verify_valid INTEGER,
drill_tested_at DATETIME,
drill_success INTEGER,
cloud_location TEXT,
retention_policy TEXT,
tags TEXT,
metadata TEXT,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_backups_database ON backups(database);
CREATE INDEX IF NOT EXISTS idx_backups_created_at ON backups(created_at);
CREATE INDEX IF NOT EXISTS idx_backups_status ON backups(status);
CREATE INDEX IF NOT EXISTS idx_backups_host ON backups(host);
CREATE INDEX IF NOT EXISTS idx_backups_database_type ON backups(database_type);
CREATE TABLE IF NOT EXISTS catalog_meta (
key TEXT PRIMARY KEY,
value TEXT,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
-- Store schema version for migrations
INSERT OR IGNORE INTO catalog_meta (key, value) VALUES ('schema_version', '1');
`
_, err := c.db.Exec(schema)
if err != nil {
return fmt.Errorf("failed to initialize schema: %w", err)
}
return nil
}
// Add inserts a new backup entry
func (c *SQLiteCatalog) Add(ctx context.Context, entry *Entry) error {
tagsJSON, _ := json.Marshal(entry.Tags)
metaJSON, _ := json.Marshal(entry.Metadata)
result, err := c.db.ExecContext(ctx, `
INSERT INTO backups (
database, database_type, host, port, backup_path, backup_type,
size_bytes, sha256, compression, encrypted, created_at, duration,
status, cloud_location, retention_policy, tags, metadata
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`,
entry.Database, entry.DatabaseType, entry.Host, entry.Port,
entry.BackupPath, entry.BackupType, entry.SizeBytes, entry.SHA256,
entry.Compression, entry.Encrypted, entry.CreatedAt, entry.Duration,
entry.Status, entry.CloudLocation, entry.RetentionPolicy,
string(tagsJSON), string(metaJSON),
)
if err != nil {
return fmt.Errorf("failed to add catalog entry: %w", err)
}
id, _ := result.LastInsertId()
entry.ID = id
return nil
}
// Update updates an existing backup entry
func (c *SQLiteCatalog) Update(ctx context.Context, entry *Entry) error {
tagsJSON, _ := json.Marshal(entry.Tags)
metaJSON, _ := json.Marshal(entry.Metadata)
_, err := c.db.ExecContext(ctx, `
UPDATE backups SET
database = ?, database_type = ?, host = ?, port = ?,
backup_type = ?, size_bytes = ?, sha256 = ?, compression = ?,
encrypted = ?, duration = ?, status = ?, verified_at = ?,
verify_valid = ?, drill_tested_at = ?, drill_success = ?,
cloud_location = ?, retention_policy = ?, tags = ?, metadata = ?,
updated_at = CURRENT_TIMESTAMP
WHERE id = ?
`,
entry.Database, entry.DatabaseType, entry.Host, entry.Port,
entry.BackupType, entry.SizeBytes, entry.SHA256, entry.Compression,
entry.Encrypted, entry.Duration, entry.Status, entry.VerifiedAt,
entry.VerifyValid, entry.DrillTestedAt, entry.DrillSuccess,
entry.CloudLocation, entry.RetentionPolicy,
string(tagsJSON), string(metaJSON), entry.ID,
)
if err != nil {
return fmt.Errorf("failed to update catalog entry: %w", err)
}
return nil
}
// Delete removes a backup entry
func (c *SQLiteCatalog) Delete(ctx context.Context, id int64) error {
_, err := c.db.ExecContext(ctx, "DELETE FROM backups WHERE id = ?", id)
if err != nil {
return fmt.Errorf("failed to delete catalog entry: %w", err)
}
return nil
}
// Get retrieves a backup entry by ID
func (c *SQLiteCatalog) Get(ctx context.Context, id int64) (*Entry, error) {
row := c.db.QueryRowContext(ctx, `
SELECT id, database, database_type, host, port, backup_path, backup_type,
size_bytes, sha256, compression, encrypted, created_at, duration,
status, verified_at, verify_valid, drill_tested_at, drill_success,
cloud_location, retention_policy, tags, metadata
FROM backups WHERE id = ?
`, id)
return c.scanEntry(row)
}
// GetByPath retrieves a backup entry by file path
func (c *SQLiteCatalog) GetByPath(ctx context.Context, path string) (*Entry, error) {
row := c.db.QueryRowContext(ctx, `
SELECT id, database, database_type, host, port, backup_path, backup_type,
size_bytes, sha256, compression, encrypted, created_at, duration,
status, verified_at, verify_valid, drill_tested_at, drill_success,
cloud_location, retention_policy, tags, metadata
FROM backups WHERE backup_path = ?
`, path)
return c.scanEntry(row)
}
// scanEntry scans a row into an Entry struct
func (c *SQLiteCatalog) scanEntry(row *sql.Row) (*Entry, error) {
var entry Entry
var tagsJSON, metaJSON sql.NullString
var verifiedAt, drillTestedAt sql.NullTime
var verifyValid, drillSuccess sql.NullBool
err := row.Scan(
&entry.ID, &entry.Database, &entry.DatabaseType, &entry.Host, &entry.Port,
&entry.BackupPath, &entry.BackupType, &entry.SizeBytes, &entry.SHA256,
&entry.Compression, &entry.Encrypted, &entry.CreatedAt, &entry.Duration,
&entry.Status, &verifiedAt, &verifyValid, &drillTestedAt, &drillSuccess,
&entry.CloudLocation, &entry.RetentionPolicy, &tagsJSON, &metaJSON,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("failed to scan entry: %w", err)
}
if verifiedAt.Valid {
entry.VerifiedAt = &verifiedAt.Time
}
if verifyValid.Valid {
entry.VerifyValid = &verifyValid.Bool
}
if drillTestedAt.Valid {
entry.DrillTestedAt = &drillTestedAt.Time
}
if drillSuccess.Valid {
entry.DrillSuccess = &drillSuccess.Bool
}
if tagsJSON.Valid && tagsJSON.String != "" {
json.Unmarshal([]byte(tagsJSON.String), &entry.Tags)
}
if metaJSON.Valid && metaJSON.String != "" {
json.Unmarshal([]byte(metaJSON.String), &entry.Metadata)
}
return &entry, nil
}
// Search finds backup entries matching the query
func (c *SQLiteCatalog) Search(ctx context.Context, query *SearchQuery) ([]*Entry, error) {
where, args := c.buildSearchQuery(query)
orderBy := "created_at DESC"
if query.OrderBy != "" {
orderBy = query.OrderBy
if query.OrderDesc {
orderBy += " DESC"
}
}
sql := fmt.Sprintf(`
SELECT id, database, database_type, host, port, backup_path, backup_type,
size_bytes, sha256, compression, encrypted, created_at, duration,
status, verified_at, verify_valid, drill_tested_at, drill_success,
cloud_location, retention_policy, tags, metadata
FROM backups
%s
ORDER BY %s
`, where, orderBy)
if query.Limit > 0 {
sql += fmt.Sprintf(" LIMIT %d", query.Limit)
if query.Offset > 0 {
sql += fmt.Sprintf(" OFFSET %d", query.Offset)
}
}
rows, err := c.db.QueryContext(ctx, sql, args...)
if err != nil {
return nil, fmt.Errorf("search query failed: %w", err)
}
defer rows.Close()
return c.scanEntries(rows)
}
// scanEntries scans multiple rows into Entry slices
func (c *SQLiteCatalog) scanEntries(rows *sql.Rows) ([]*Entry, error) {
var entries []*Entry
for rows.Next() {
var entry Entry
var tagsJSON, metaJSON sql.NullString
var verifiedAt, drillTestedAt sql.NullTime
var verifyValid, drillSuccess sql.NullBool
err := rows.Scan(
&entry.ID, &entry.Database, &entry.DatabaseType, &entry.Host, &entry.Port,
&entry.BackupPath, &entry.BackupType, &entry.SizeBytes, &entry.SHA256,
&entry.Compression, &entry.Encrypted, &entry.CreatedAt, &entry.Duration,
&entry.Status, &verifiedAt, &verifyValid, &drillTestedAt, &drillSuccess,
&entry.CloudLocation, &entry.RetentionPolicy, &tagsJSON, &metaJSON,
)
if err != nil {
return nil, fmt.Errorf("failed to scan row: %w", err)
}
if verifiedAt.Valid {
entry.VerifiedAt = &verifiedAt.Time
}
if verifyValid.Valid {
entry.VerifyValid = &verifyValid.Bool
}
if drillTestedAt.Valid {
entry.DrillTestedAt = &drillTestedAt.Time
}
if drillSuccess.Valid {
entry.DrillSuccess = &drillSuccess.Bool
}
if tagsJSON.Valid && tagsJSON.String != "" {
json.Unmarshal([]byte(tagsJSON.String), &entry.Tags)
}
if metaJSON.Valid && metaJSON.String != "" {
json.Unmarshal([]byte(metaJSON.String), &entry.Metadata)
}
entries = append(entries, &entry)
}
return entries, rows.Err()
}
// buildSearchQuery builds the WHERE clause from a SearchQuery
func (c *SQLiteCatalog) buildSearchQuery(query *SearchQuery) (string, []interface{}) {
var conditions []string
var args []interface{}
if query.Database != "" {
if strings.Contains(query.Database, "*") {
conditions = append(conditions, "database LIKE ?")
args = append(args, strings.ReplaceAll(query.Database, "*", "%"))
} else {
conditions = append(conditions, "database = ?")
args = append(args, query.Database)
}
}
if query.DatabaseType != "" {
conditions = append(conditions, "database_type = ?")
args = append(args, query.DatabaseType)
}
if query.Host != "" {
conditions = append(conditions, "host = ?")
args = append(args, query.Host)
}
if query.Status != "" {
conditions = append(conditions, "status = ?")
args = append(args, query.Status)
}
if query.StartDate != nil {
conditions = append(conditions, "created_at >= ?")
args = append(args, *query.StartDate)
}
if query.EndDate != nil {
conditions = append(conditions, "created_at <= ?")
args = append(args, *query.EndDate)
}
if query.MinSize > 0 {
conditions = append(conditions, "size_bytes >= ?")
args = append(args, query.MinSize)
}
if query.MaxSize > 0 {
conditions = append(conditions, "size_bytes <= ?")
args = append(args, query.MaxSize)
}
if query.BackupType != "" {
conditions = append(conditions, "backup_type = ?")
args = append(args, query.BackupType)
}
if query.Encrypted != nil {
conditions = append(conditions, "encrypted = ?")
args = append(args, *query.Encrypted)
}
if query.Verified != nil {
if *query.Verified {
conditions = append(conditions, "verified_at IS NOT NULL AND verify_valid = 1")
} else {
conditions = append(conditions, "verified_at IS NULL")
}
}
if query.DrillTested != nil {
if *query.DrillTested {
conditions = append(conditions, "drill_tested_at IS NOT NULL AND drill_success = 1")
} else {
conditions = append(conditions, "drill_tested_at IS NULL")
}
}
if len(conditions) == 0 {
return "", nil
}
return "WHERE " + strings.Join(conditions, " AND "), args
}
// List returns recent backups for a database
func (c *SQLiteCatalog) List(ctx context.Context, database string, limit int) ([]*Entry, error) {
query := &SearchQuery{
Database: database,
Limit: limit,
OrderBy: "created_at",
OrderDesc: true,
}
return c.Search(ctx, query)
}
// ListDatabases returns all unique database names
func (c *SQLiteCatalog) ListDatabases(ctx context.Context) ([]string, error) {
rows, err := c.db.QueryContext(ctx, "SELECT DISTINCT database FROM backups ORDER BY database")
if err != nil {
return nil, fmt.Errorf("failed to list databases: %w", err)
}
defer rows.Close()
var databases []string
for rows.Next() {
var db string
if err := rows.Scan(&db); err != nil {
return nil, err
}
databases = append(databases, db)
}
return databases, rows.Err()
}
// Count returns the number of entries matching the query
func (c *SQLiteCatalog) Count(ctx context.Context, query *SearchQuery) (int64, error) {
where, args := c.buildSearchQuery(query)
sql := "SELECT COUNT(*) FROM backups " + where
var count int64
err := c.db.QueryRowContext(ctx, sql, args...).Scan(&count)
if err != nil {
return 0, fmt.Errorf("count query failed: %w", err)
}
return count, nil
}
// Stats returns overall catalog statistics
func (c *SQLiteCatalog) Stats(ctx context.Context) (*Stats, error) {
stats := &Stats{
ByDatabase: make(map[string]int64),
ByType: make(map[string]int64),
ByStatus: make(map[string]int64),
}
// Basic stats
row := c.db.QueryRowContext(ctx, `
SELECT
COUNT(*),
COALESCE(SUM(size_bytes), 0),
MIN(created_at),
MAX(created_at),
COALESCE(AVG(duration), 0),
CAST(COALESCE(AVG(size_bytes), 0) AS INTEGER),
SUM(CASE WHEN verified_at IS NOT NULL THEN 1 ELSE 0 END),
SUM(CASE WHEN drill_tested_at IS NOT NULL THEN 1 ELSE 0 END)
FROM backups WHERE status != 'deleted'
`)
var oldest, newest sql.NullString
err := row.Scan(
&stats.TotalBackups, &stats.TotalSize, &oldest, &newest,
&stats.AvgDuration, &stats.AvgSize,
&stats.VerifiedCount, &stats.DrillTestedCount,
)
if err != nil {
return nil, fmt.Errorf("failed to get stats: %w", err)
}
if oldest.Valid {
if t, err := time.Parse(time.RFC3339Nano, oldest.String); err == nil {
stats.OldestBackup = &t
} else if t, err := time.Parse("2006-01-02 15:04:05.999999999-07:00", oldest.String); err == nil {
stats.OldestBackup = &t
} else if t, err := time.Parse("2006-01-02T15:04:05Z", oldest.String); err == nil {
stats.OldestBackup = &t
}
}
if newest.Valid {
if t, err := time.Parse(time.RFC3339Nano, newest.String); err == nil {
stats.NewestBackup = &t
} else if t, err := time.Parse("2006-01-02 15:04:05.999999999-07:00", newest.String); err == nil {
stats.NewestBackup = &t
} else if t, err := time.Parse("2006-01-02T15:04:05Z", newest.String); err == nil {
stats.NewestBackup = &t
}
}
stats.TotalSizeHuman = FormatSize(stats.TotalSize)
// By database
rows, _ := c.db.QueryContext(ctx, "SELECT database, COUNT(*) FROM backups GROUP BY database")
defer rows.Close()
for rows.Next() {
var db string
var count int64
rows.Scan(&db, &count)
stats.ByDatabase[db] = count
}
// By type
rows, _ = c.db.QueryContext(ctx, "SELECT backup_type, COUNT(*) FROM backups GROUP BY backup_type")
defer rows.Close()
for rows.Next() {
var t string
var count int64
rows.Scan(&t, &count)
stats.ByType[t] = count
}
// By status
rows, _ = c.db.QueryContext(ctx, "SELECT status, COUNT(*) FROM backups GROUP BY status")
defer rows.Close()
for rows.Next() {
var s string
var count int64
rows.Scan(&s, &count)
stats.ByStatus[s] = count
}
return stats, nil
}
// StatsByDatabase returns statistics for a specific database
func (c *SQLiteCatalog) StatsByDatabase(ctx context.Context, database string) (*Stats, error) {
stats := &Stats{
ByDatabase: make(map[string]int64),
ByType: make(map[string]int64),
ByStatus: make(map[string]int64),
}
row := c.db.QueryRowContext(ctx, `
SELECT
COUNT(*),
COALESCE(SUM(size_bytes), 0),
MIN(created_at),
MAX(created_at),
COALESCE(AVG(duration), 0),
COALESCE(AVG(size_bytes), 0),
SUM(CASE WHEN verified_at IS NOT NULL THEN 1 ELSE 0 END),
SUM(CASE WHEN drill_tested_at IS NOT NULL THEN 1 ELSE 0 END)
FROM backups WHERE database = ? AND status != 'deleted'
`, database)
var oldest, newest sql.NullTime
err := row.Scan(
&stats.TotalBackups, &stats.TotalSize, &oldest, &newest,
&stats.AvgDuration, &stats.AvgSize,
&stats.VerifiedCount, &stats.DrillTestedCount,
)
if err != nil {
return nil, fmt.Errorf("failed to get database stats: %w", err)
}
if oldest.Valid {
stats.OldestBackup = &oldest.Time
}
if newest.Valid {
stats.NewestBackup = &newest.Time
}
stats.TotalSizeHuman = FormatSize(stats.TotalSize)
return stats, nil
}
// MarkVerified updates the verification status of a backup
func (c *SQLiteCatalog) MarkVerified(ctx context.Context, id int64, valid bool) error {
status := StatusVerified
if !valid {
status = StatusCorrupted
}
_, err := c.db.ExecContext(ctx, `
UPDATE backups SET
verified_at = CURRENT_TIMESTAMP,
verify_valid = ?,
status = ?,
updated_at = CURRENT_TIMESTAMP
WHERE id = ?
`, valid, status, id)
return err
}
// MarkDrillTested updates the drill test status of a backup
func (c *SQLiteCatalog) MarkDrillTested(ctx context.Context, id int64, success bool) error {
_, err := c.db.ExecContext(ctx, `
UPDATE backups SET
drill_tested_at = CURRENT_TIMESTAMP,
drill_success = ?,
updated_at = CURRENT_TIMESTAMP
WHERE id = ?
`, success, id)
return err
}
// Prune removes entries older than the given time
func (c *SQLiteCatalog) Prune(ctx context.Context, before time.Time) (int, error) {
result, err := c.db.ExecContext(ctx,
"DELETE FROM backups WHERE created_at < ? AND status = 'deleted'",
before,
)
if err != nil {
return 0, fmt.Errorf("prune failed: %w", err)
}
affected, _ := result.RowsAffected()
return int(affected), nil
}
// Vacuum optimizes the database
func (c *SQLiteCatalog) Vacuum(ctx context.Context) error {
_, err := c.db.ExecContext(ctx, "VACUUM")
return err
}
// Close closes the database connection
func (c *SQLiteCatalog) Close() error {
return c.db.Close()
}

234
internal/catalog/sync.go Normal file
View File

@@ -0,0 +1,234 @@
// Package catalog - Sync functionality for importing backups into catalog
package catalog
import (
"context"
"database/sql"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"dbbackup/internal/metadata"
)
// SyncFromDirectory scans a directory and imports backup metadata into the catalog
func (c *SQLiteCatalog) SyncFromDirectory(ctx context.Context, dir string) (*SyncResult, error) {
start := time.Now()
result := &SyncResult{}
// Find all metadata files
pattern := filepath.Join(dir, "*.meta.json")
matches, err := filepath.Glob(pattern)
if err != nil {
return nil, fmt.Errorf("failed to scan directory: %w", err)
}
// Also check subdirectories
subPattern := filepath.Join(dir, "*", "*.meta.json")
subMatches, _ := filepath.Glob(subPattern)
matches = append(matches, subMatches...)
for _, metaPath := range matches {
// Derive backup file path from metadata path
backupPath := strings.TrimSuffix(metaPath, ".meta.json")
// Check if backup file exists
if _, err := os.Stat(backupPath); os.IsNotExist(err) {
result.Details = append(result.Details,
fmt.Sprintf("SKIP: %s (backup file missing)", filepath.Base(backupPath)))
continue
}
// Load metadata
meta, err := metadata.Load(backupPath)
if err != nil {
result.Errors++
result.Details = append(result.Details,
fmt.Sprintf("ERROR: %s - %v", filepath.Base(backupPath), err))
continue
}
// Check if already in catalog
existing, _ := c.GetByPath(ctx, backupPath)
if existing != nil {
// Update if metadata changed
if existing.SHA256 != meta.SHA256 || existing.SizeBytes != meta.SizeBytes {
entry := metadataToEntry(meta, backupPath)
entry.ID = existing.ID
if err := c.Update(ctx, entry); err != nil {
result.Errors++
result.Details = append(result.Details,
fmt.Sprintf("ERROR updating: %s - %v", filepath.Base(backupPath), err))
} else {
result.Updated++
}
}
continue
}
// Add new entry
entry := metadataToEntry(meta, backupPath)
if err := c.Add(ctx, entry); err != nil {
result.Errors++
result.Details = append(result.Details,
fmt.Sprintf("ERROR adding: %s - %v", filepath.Base(backupPath), err))
} else {
result.Added++
result.Details = append(result.Details,
fmt.Sprintf("ADDED: %s (%s)", filepath.Base(backupPath), FormatSize(meta.SizeBytes)))
}
}
// Check for removed backups (backups in catalog but not on disk)
entries, _ := c.Search(ctx, &SearchQuery{})
for _, entry := range entries {
if !strings.HasPrefix(entry.BackupPath, dir) {
continue
}
if _, err := os.Stat(entry.BackupPath); os.IsNotExist(err) {
// Mark as deleted
entry.Status = StatusDeleted
c.Update(ctx, entry)
result.Removed++
result.Details = append(result.Details,
fmt.Sprintf("REMOVED: %s (file not found)", filepath.Base(entry.BackupPath)))
}
}
result.Duration = time.Since(start).Seconds()
return result, nil
}
// SyncFromCloud imports backups from cloud storage
func (c *SQLiteCatalog) SyncFromCloud(ctx context.Context, provider, bucket, prefix string) (*SyncResult, error) {
// This will be implemented when integrating with cloud package
// For now, return a placeholder
return &SyncResult{
Details: []string{"Cloud sync not yet implemented - use directory sync instead"},
}, nil
}
// metadataToEntry converts backup metadata to a catalog entry
func metadataToEntry(meta *metadata.BackupMetadata, backupPath string) *Entry {
entry := &Entry{
Database: meta.Database,
DatabaseType: meta.DatabaseType,
Host: meta.Host,
Port: meta.Port,
BackupPath: backupPath,
BackupType: meta.BackupType,
SizeBytes: meta.SizeBytes,
SHA256: meta.SHA256,
Compression: meta.Compression,
Encrypted: meta.Encrypted,
CreatedAt: meta.Timestamp,
Duration: meta.Duration,
Status: StatusCompleted,
Metadata: meta.ExtraInfo,
}
if entry.BackupType == "" {
entry.BackupType = "full"
}
return entry
}
// ImportEntry creates a catalog entry directly from backup file info
func (c *SQLiteCatalog) ImportEntry(ctx context.Context, backupPath string, info os.FileInfo, dbName, dbType string) error {
entry := &Entry{
Database: dbName,
DatabaseType: dbType,
BackupPath: backupPath,
BackupType: "full",
SizeBytes: info.Size(),
CreatedAt: info.ModTime(),
Status: StatusCompleted,
}
// Detect compression from extension
switch {
case strings.HasSuffix(backupPath, ".gz"):
entry.Compression = "gzip"
case strings.HasSuffix(backupPath, ".lz4"):
entry.Compression = "lz4"
case strings.HasSuffix(backupPath, ".zst"):
entry.Compression = "zstd"
}
// Check if encrypted
if strings.Contains(backupPath, ".enc") {
entry.Encrypted = true
}
// Try to load metadata if exists
if meta, err := metadata.Load(backupPath); err == nil {
entry.SHA256 = meta.SHA256
entry.Duration = meta.Duration
entry.Host = meta.Host
entry.Port = meta.Port
entry.Metadata = meta.ExtraInfo
}
return c.Add(ctx, entry)
}
// SyncStatus returns the sync status summary
type SyncStatus struct {
LastSync *time.Time `json:"last_sync,omitempty"`
TotalEntries int64 `json:"total_entries"`
ActiveEntries int64 `json:"active_entries"`
DeletedEntries int64 `json:"deleted_entries"`
Directories []string `json:"directories"`
}
// GetSyncStatus returns the current sync status
func (c *SQLiteCatalog) GetSyncStatus(ctx context.Context) (*SyncStatus, error) {
status := &SyncStatus{}
// Get last sync time
var lastSync sql.NullString
c.db.QueryRowContext(ctx, "SELECT value FROM catalog_meta WHERE key = 'last_sync'").Scan(&lastSync)
if lastSync.Valid {
if t, err := time.Parse(time.RFC3339, lastSync.String); err == nil {
status.LastSync = &t
}
}
// Count entries
c.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM backups").Scan(&status.TotalEntries)
c.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM backups WHERE status != 'deleted'").Scan(&status.ActiveEntries)
c.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM backups WHERE status = 'deleted'").Scan(&status.DeletedEntries)
// Get unique directories
rows, _ := c.db.QueryContext(ctx, `
SELECT DISTINCT
CASE
WHEN instr(backup_path, '/') > 0
THEN substr(backup_path, 1, length(backup_path) - length(replace(backup_path, '/', '')) - length(substr(backup_path, length(backup_path) - length(replace(backup_path, '/', '')) + 2)))
ELSE backup_path
END as dir
FROM backups WHERE status != 'deleted'
`)
if rows != nil {
defer rows.Close()
for rows.Next() {
var dir string
rows.Scan(&dir)
status.Directories = append(status.Directories, dir)
}
}
return status, nil
}
// SetLastSync updates the last sync timestamp
func (c *SQLiteCatalog) SetLastSync(ctx context.Context) error {
_, err := c.db.ExecContext(ctx, `
INSERT OR REPLACE INTO catalog_meta (key, value, updated_at)
VALUES ('last_sync', ?, CURRENT_TIMESTAMP)
`, time.Now().Format(time.RFC3339))
return err
}

View File

@@ -91,6 +91,13 @@ type Config struct {
WALCompression bool // Compress WAL files
WALEncryption bool // Encrypt WAL files
// MySQL PITR options
BinlogDir string // MySQL binary log directory
BinlogArchiveDir string // Directory to archive binlogs
BinlogArchiveInterval string // Interval for binlog archiving (e.g., "30s")
RequireRowFormat bool // Require ROW format for binlog
RequireGTID bool // Require GTID mode enabled
// TUI automation options (for testing)
TUIAutoSelect int // Auto-select menu option (-1 = disabled)
TUIAutoDatabase string // Pre-fill database name

298
internal/drill/docker.go Normal file
View File

@@ -0,0 +1,298 @@
// Package drill - Docker container management for DR drills
package drill
import (
"context"
"fmt"
"os/exec"
"strings"
"time"
)
// DockerManager handles Docker container operations for DR drills
type DockerManager struct {
verbose bool
}
// NewDockerManager creates a new Docker manager
func NewDockerManager(verbose bool) *DockerManager {
return &DockerManager{verbose: verbose}
}
// ContainerConfig holds Docker container configuration
type ContainerConfig struct {
Image string // Docker image (e.g., "postgres:15")
Name string // Container name
Port int // Host port to map
ContainerPort int // Container port
Environment map[string]string // Environment variables
Volumes []string // Volume mounts
Network string // Docker network
Timeout int // Startup timeout in seconds
}
// ContainerInfo holds information about a running container
type ContainerInfo struct {
ID string
Name string
Image string
Port int
Status string
Started time.Time
Healthy bool
}
// CheckDockerAvailable verifies Docker is installed and running
func (dm *DockerManager) CheckDockerAvailable(ctx context.Context) error {
cmd := exec.CommandContext(ctx, "docker", "version")
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("docker not available: %w (output: %s)", err, string(output))
}
return nil
}
// PullImage pulls a Docker image if not present
func (dm *DockerManager) PullImage(ctx context.Context, image string) error {
// Check if image exists locally
checkCmd := exec.CommandContext(ctx, "docker", "image", "inspect", image)
if err := checkCmd.Run(); err == nil {
// Image exists
return nil
}
// Pull the image
pullCmd := exec.CommandContext(ctx, "docker", "pull", image)
output, err := pullCmd.CombinedOutput()
if err != nil {
return fmt.Errorf("failed to pull image %s: %w (output: %s)", image, err, string(output))
}
return nil
}
// CreateContainer creates and starts a database container
func (dm *DockerManager) CreateContainer(ctx context.Context, config *ContainerConfig) (*ContainerInfo, error) {
args := []string{
"run", "-d",
"--name", config.Name,
"-p", fmt.Sprintf("%d:%d", config.Port, config.ContainerPort),
}
// Add environment variables
for k, v := range config.Environment {
args = append(args, "-e", fmt.Sprintf("%s=%s", k, v))
}
// Add volumes
for _, v := range config.Volumes {
args = append(args, "-v", v)
}
// Add network if specified
if config.Network != "" {
args = append(args, "--network", config.Network)
}
// Add image
args = append(args, config.Image)
cmd := exec.CommandContext(ctx, "docker", args...)
output, err := cmd.CombinedOutput()
if err != nil {
return nil, fmt.Errorf("failed to create container: %w (output: %s)", err, string(output))
}
containerID := strings.TrimSpace(string(output))
return &ContainerInfo{
ID: containerID,
Name: config.Name,
Image: config.Image,
Port: config.Port,
Status: "created",
Started: time.Now(),
}, nil
}
// WaitForHealth waits for container to be healthy
func (dm *DockerManager) WaitForHealth(ctx context.Context, containerID string, dbType string, timeout int) error {
deadline := time.Now().Add(time.Duration(timeout) * time.Second)
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
if time.Now().After(deadline) {
return fmt.Errorf("timeout waiting for container to be healthy")
}
// Check container health
healthCmd := dm.healthCheckCommand(dbType)
args := append([]string{"exec", containerID}, healthCmd...)
cmd := exec.CommandContext(ctx, "docker", args...)
if err := cmd.Run(); err == nil {
return nil // Container is healthy
}
}
}
}
// healthCheckCommand returns the health check command for a database type
func (dm *DockerManager) healthCheckCommand(dbType string) []string {
switch dbType {
case "postgresql", "postgres":
return []string{"pg_isready", "-U", "postgres"}
case "mysql":
return []string{"mysqladmin", "ping", "-h", "localhost", "-u", "root", "--password=root"}
case "mariadb":
return []string{"mariadb-admin", "ping", "-h", "localhost", "-u", "root", "--password=root"}
default:
return []string{"echo", "ok"}
}
}
// ExecCommand executes a command inside the container
func (dm *DockerManager) ExecCommand(ctx context.Context, containerID string, command []string) (string, error) {
args := append([]string{"exec", containerID}, command...)
cmd := exec.CommandContext(ctx, "docker", args...)
output, err := cmd.CombinedOutput()
if err != nil {
return string(output), fmt.Errorf("exec failed: %w", err)
}
return string(output), nil
}
// CopyToContainer copies a file to the container
func (dm *DockerManager) CopyToContainer(ctx context.Context, containerID, src, dest string) error {
cmd := exec.CommandContext(ctx, "docker", "cp", src, fmt.Sprintf("%s:%s", containerID, dest))
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("copy failed: %w (output: %s)", err, string(output))
}
return nil
}
// StopContainer stops a running container
func (dm *DockerManager) StopContainer(ctx context.Context, containerID string) error {
cmd := exec.CommandContext(ctx, "docker", "stop", containerID)
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("failed to stop container: %w (output: %s)", err, string(output))
}
return nil
}
// RemoveContainer removes a container
func (dm *DockerManager) RemoveContainer(ctx context.Context, containerID string) error {
cmd := exec.CommandContext(ctx, "docker", "rm", "-f", containerID)
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("failed to remove container: %w (output: %s)", err, string(output))
}
return nil
}
// GetContainerLogs retrieves container logs
func (dm *DockerManager) GetContainerLogs(ctx context.Context, containerID string, tail int) (string, error) {
args := []string{"logs"}
if tail > 0 {
args = append(args, "--tail", fmt.Sprintf("%d", tail))
}
args = append(args, containerID)
cmd := exec.CommandContext(ctx, "docker", args...)
output, err := cmd.CombinedOutput()
if err != nil {
return "", fmt.Errorf("failed to get logs: %w", err)
}
return string(output), nil
}
// ListDrillContainers lists all containers created by drill operations
func (dm *DockerManager) ListDrillContainers(ctx context.Context) ([]*ContainerInfo, error) {
cmd := exec.CommandContext(ctx, "docker", "ps", "-a",
"--filter", "name=drill_",
"--format", "{{.ID}}\t{{.Names}}\t{{.Image}}\t{{.Status}}")
output, err := cmd.CombinedOutput()
if err != nil {
return nil, fmt.Errorf("failed to list containers: %w", err)
}
var containers []*ContainerInfo
lines := strings.Split(strings.TrimSpace(string(output)), "\n")
for _, line := range lines {
if line == "" {
continue
}
parts := strings.Split(line, "\t")
if len(parts) >= 4 {
containers = append(containers, &ContainerInfo{
ID: parts[0],
Name: parts[1],
Image: parts[2],
Status: parts[3],
})
}
}
return containers, nil
}
// GetDefaultImage returns the default Docker image for a database type
func GetDefaultImage(dbType, version string) string {
if version == "" {
version = "latest"
}
switch dbType {
case "postgresql", "postgres":
return fmt.Sprintf("postgres:%s", version)
case "mysql":
return fmt.Sprintf("mysql:%s", version)
case "mariadb":
return fmt.Sprintf("mariadb:%s", version)
default:
return ""
}
}
// GetDefaultPort returns the default port for a database type
func GetDefaultPort(dbType string) int {
switch dbType {
case "postgresql", "postgres":
return 5432
case "mysql", "mariadb":
return 3306
default:
return 0
}
}
// GetDefaultEnvironment returns default environment variables for a database container
func GetDefaultEnvironment(dbType string) map[string]string {
switch dbType {
case "postgresql", "postgres":
return map[string]string{
"POSTGRES_PASSWORD": "drill_test_password",
"POSTGRES_USER": "postgres",
"POSTGRES_DB": "postgres",
}
case "mysql":
return map[string]string{
"MYSQL_ROOT_PASSWORD": "root",
"MYSQL_DATABASE": "test",
}
case "mariadb":
return map[string]string{
"MARIADB_ROOT_PASSWORD": "root",
"MARIADB_DATABASE": "test",
}
default:
return map[string]string{}
}
}

247
internal/drill/drill.go Normal file
View File

@@ -0,0 +1,247 @@
// Package drill provides Disaster Recovery drill functionality
// for testing backup restorability in isolated environments
package drill
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"time"
)
// DrillConfig holds configuration for a DR drill
type DrillConfig struct {
// Backup configuration
BackupPath string `json:"backup_path"`
DatabaseName string `json:"database_name"`
DatabaseType string `json:"database_type"` // postgresql, mysql, mariadb
// Docker configuration
ContainerImage string `json:"container_image"` // e.g., "postgres:15"
ContainerName string `json:"container_name"` // Generated if empty
ContainerPort int `json:"container_port"` // Host port mapping
ContainerTimeout int `json:"container_timeout"` // Startup timeout in seconds
CleanupOnExit bool `json:"cleanup_on_exit"` // Remove container after drill
KeepOnFailure bool `json:"keep_on_failure"` // Keep container if drill fails
// Validation configuration
ValidationQueries []ValidationQuery `json:"validation_queries"`
MinRowCount int64 `json:"min_row_count"` // Minimum rows expected
ExpectedTables []string `json:"expected_tables"` // Tables that must exist
CustomChecks []CustomCheck `json:"custom_checks"`
// Encryption (if backup is encrypted)
EncryptionKeyFile string `json:"encryption_key_file,omitempty"`
EncryptionKeyEnv string `json:"encryption_key_env,omitempty"`
// Performance thresholds
MaxRestoreSeconds int `json:"max_restore_seconds"` // RTO threshold
MaxQuerySeconds int `json:"max_query_seconds"` // Query timeout
// Output
OutputDir string `json:"output_dir"` // Directory for drill reports
ReportFormat string `json:"report_format"` // json, markdown, html
Verbose bool `json:"verbose"`
}
// ValidationQuery represents a SQL query to validate restored data
type ValidationQuery struct {
Name string `json:"name"` // Human-readable name
Query string `json:"query"` // SQL query
ExpectedValue string `json:"expected_value"` // Expected result (optional)
MinValue int64 `json:"min_value"` // Minimum expected value
MaxValue int64 `json:"max_value"` // Maximum expected value
MustSucceed bool `json:"must_succeed"` // Fail drill if query fails
}
// CustomCheck represents a custom validation check
type CustomCheck struct {
Name string `json:"name"`
Type string `json:"type"` // row_count, table_exists, column_check
Table string `json:"table"`
Column string `json:"column,omitempty"`
Condition string `json:"condition,omitempty"` // SQL condition
MinValue int64 `json:"min_value,omitempty"`
MustSucceed bool `json:"must_succeed"`
}
// DrillResult contains the complete result of a DR drill
type DrillResult struct {
// Identification
DrillID string `json:"drill_id"`
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time"`
Duration float64 `json:"duration_seconds"`
// Configuration
BackupPath string `json:"backup_path"`
DatabaseName string `json:"database_name"`
DatabaseType string `json:"database_type"`
// Overall status
Success bool `json:"success"`
Status DrillStatus `json:"status"`
Message string `json:"message"`
// Phase timings
Phases []DrillPhase `json:"phases"`
// Validation results
ValidationResults []ValidationResult `json:"validation_results"`
CheckResults []CheckResult `json:"check_results"`
// Database metrics
TableCount int `json:"table_count"`
TotalRows int64 `json:"total_rows"`
DatabaseSize int64 `json:"database_size_bytes"`
// Performance metrics
RestoreTime float64 `json:"restore_time_seconds"`
ValidationTime float64 `json:"validation_time_seconds"`
QueryTimeAvg float64 `json:"query_time_avg_ms"`
// RTO/RPO metrics
ActualRTO float64 `json:"actual_rto_seconds"` // Total time to usable database
TargetRTO float64 `json:"target_rto_seconds"`
RTOMet bool `json:"rto_met"`
// Container info
ContainerID string `json:"container_id,omitempty"`
ContainerKept bool `json:"container_kept"`
// Errors and warnings
Errors []string `json:"errors,omitempty"`
Warnings []string `json:"warnings,omitempty"`
}
// DrillStatus represents the current status of a drill
type DrillStatus string
const (
StatusPending DrillStatus = "pending"
StatusRunning DrillStatus = "running"
StatusCompleted DrillStatus = "completed"
StatusFailed DrillStatus = "failed"
StatusAborted DrillStatus = "aborted"
StatusPartial DrillStatus = "partial" // Some validations failed
)
// DrillPhase represents a phase in the drill process
type DrillPhase struct {
Name string `json:"name"`
Status string `json:"status"` // pending, running, completed, failed, skipped
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time"`
Duration float64 `json:"duration_seconds"`
Message string `json:"message,omitempty"`
}
// ValidationResult holds the result of a validation query
type ValidationResult struct {
Name string `json:"name"`
Query string `json:"query"`
Success bool `json:"success"`
Result string `json:"result,omitempty"`
Expected string `json:"expected,omitempty"`
Duration float64 `json:"duration_ms"`
Error string `json:"error,omitempty"`
}
// CheckResult holds the result of a custom check
type CheckResult struct {
Name string `json:"name"`
Type string `json:"type"`
Success bool `json:"success"`
Actual int64 `json:"actual,omitempty"`
Expected int64 `json:"expected,omitempty"`
Message string `json:"message"`
}
// DefaultConfig returns a DrillConfig with sensible defaults
func DefaultConfig() *DrillConfig {
return &DrillConfig{
ContainerTimeout: 60,
CleanupOnExit: true,
KeepOnFailure: true,
MaxRestoreSeconds: 300, // 5 minutes
MaxQuerySeconds: 30,
ReportFormat: "json",
Verbose: false,
ValidationQueries: []ValidationQuery{},
ExpectedTables: []string{},
CustomChecks: []CustomCheck{},
}
}
// NewDrillID generates a unique drill ID
func NewDrillID() string {
return fmt.Sprintf("drill_%s", time.Now().Format("20060102_150405"))
}
// SaveResult saves the drill result to a file
func (r *DrillResult) SaveResult(outputDir string) error {
if err := os.MkdirAll(outputDir, 0755); err != nil {
return fmt.Errorf("failed to create output directory: %w", err)
}
filename := fmt.Sprintf("%s_report.json", r.DrillID)
filepath := filepath.Join(outputDir, filename)
data, err := json.MarshalIndent(r, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal result: %w", err)
}
if err := os.WriteFile(filepath, data, 0644); err != nil {
return fmt.Errorf("failed to write result file: %w", err)
}
return nil
}
// LoadResult loads a drill result from a file
func LoadResult(filepath string) (*DrillResult, error) {
data, err := os.ReadFile(filepath)
if err != nil {
return nil, fmt.Errorf("failed to read result file: %w", err)
}
var result DrillResult
if err := json.Unmarshal(data, &result); err != nil {
return nil, fmt.Errorf("failed to parse result: %w", err)
}
return &result, nil
}
// IsSuccess returns true if the drill was successful
func (r *DrillResult) IsSuccess() bool {
return r.Success && r.Status == StatusCompleted
}
// Summary returns a human-readable summary of the drill
func (r *DrillResult) Summary() string {
status := "✅ PASSED"
if !r.Success {
status = "❌ FAILED"
} else if r.Status == StatusPartial {
status = "⚠️ PARTIAL"
}
return fmt.Sprintf("%s - %s (%.2fs) - %d tables, %d rows",
status, r.DatabaseName, r.Duration, r.TableCount, r.TotalRows)
}
// Drill is the interface for DR drill operations
type Drill interface {
// Run executes the full DR drill
Run(ctx context.Context, config *DrillConfig) (*DrillResult, error)
// Validate runs validation queries against an existing database
Validate(ctx context.Context, config *DrillConfig) ([]ValidationResult, error)
// Cleanup removes drill resources (containers, temp files)
Cleanup(ctx context.Context, drillID string) error
}

532
internal/drill/engine.go Normal file
View File

@@ -0,0 +1,532 @@
// Package drill - Main drill execution engine
package drill
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"dbbackup/internal/logger"
)
// Engine executes DR drills
type Engine struct {
docker *DockerManager
log logger.Logger
verbose bool
}
// NewEngine creates a new drill engine
func NewEngine(log logger.Logger, verbose bool) *Engine {
return &Engine{
docker: NewDockerManager(verbose),
log: log,
verbose: verbose,
}
}
// Run executes a complete DR drill
func (e *Engine) Run(ctx context.Context, config *DrillConfig) (*DrillResult, error) {
result := &DrillResult{
DrillID: NewDrillID(),
StartTime: time.Now(),
BackupPath: config.BackupPath,
DatabaseName: config.DatabaseName,
DatabaseType: config.DatabaseType,
Status: StatusRunning,
Phases: make([]DrillPhase, 0),
TargetRTO: float64(config.MaxRestoreSeconds),
}
e.log.Info("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
e.log.Info(" 🧪 DR Drill: " + result.DrillID)
e.log.Info("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
e.log.Info("")
// Cleanup function for error cases
var containerID string
cleanup := func() {
if containerID != "" && config.CleanupOnExit && (result.Success || !config.KeepOnFailure) {
e.log.Info("🗑️ Cleaning up container...")
e.docker.RemoveContainer(context.Background(), containerID)
} else if containerID != "" {
result.ContainerKept = true
e.log.Info("📦 Container kept for debugging: " + containerID)
}
}
defer cleanup()
// Phase 1: Preflight checks
phase := e.startPhase("Preflight Checks")
if err := e.preflightChecks(ctx, config); err != nil {
e.failPhase(&phase, err.Error())
result.Phases = append(result.Phases, phase)
result.Status = StatusFailed
result.Message = "Preflight checks failed: " + err.Error()
result.Errors = append(result.Errors, err.Error())
e.finalize(result)
return result, nil
}
e.completePhase(&phase, "All checks passed")
result.Phases = append(result.Phases, phase)
// Phase 2: Start container
phase = e.startPhase("Start Container")
containerConfig := e.buildContainerConfig(config)
container, err := e.docker.CreateContainer(ctx, containerConfig)
if err != nil {
e.failPhase(&phase, err.Error())
result.Phases = append(result.Phases, phase)
result.Status = StatusFailed
result.Message = "Failed to start container: " + err.Error()
result.Errors = append(result.Errors, err.Error())
e.finalize(result)
return result, nil
}
containerID = container.ID
result.ContainerID = containerID
e.log.Info("📦 Container started: " + containerID[:12])
// Wait for container to be healthy
if err := e.docker.WaitForHealth(ctx, containerID, config.DatabaseType, config.ContainerTimeout); err != nil {
e.failPhase(&phase, "Container health check failed: "+err.Error())
result.Phases = append(result.Phases, phase)
result.Status = StatusFailed
result.Message = "Container failed to start"
result.Errors = append(result.Errors, err.Error())
e.finalize(result)
return result, nil
}
e.completePhase(&phase, "Container healthy")
result.Phases = append(result.Phases, phase)
// Phase 3: Restore backup
phase = e.startPhase("Restore Backup")
restoreStart := time.Now()
if err := e.restoreBackup(ctx, config, containerID, containerConfig); err != nil {
e.failPhase(&phase, err.Error())
result.Phases = append(result.Phases, phase)
result.Status = StatusFailed
result.Message = "Restore failed: " + err.Error()
result.Errors = append(result.Errors, err.Error())
e.finalize(result)
return result, nil
}
result.RestoreTime = time.Since(restoreStart).Seconds()
e.completePhase(&phase, fmt.Sprintf("Restored in %.2fs", result.RestoreTime))
result.Phases = append(result.Phases, phase)
e.log.Info(fmt.Sprintf("✅ Backup restored in %.2fs", result.RestoreTime))
// Phase 4: Validate
phase = e.startPhase("Validate Database")
validateStart := time.Now()
validationErrors := e.validateDatabase(ctx, config, result, containerConfig)
result.ValidationTime = time.Since(validateStart).Seconds()
if validationErrors > 0 {
e.completePhase(&phase, fmt.Sprintf("Completed with %d errors", validationErrors))
} else {
e.completePhase(&phase, "All validations passed")
}
result.Phases = append(result.Phases, phase)
// Determine overall status
result.ActualRTO = result.RestoreTime + result.ValidationTime
result.RTOMet = result.ActualRTO <= result.TargetRTO
criticalFailures := 0
for _, vr := range result.ValidationResults {
if !vr.Success {
criticalFailures++
}
}
for _, cr := range result.CheckResults {
if !cr.Success {
criticalFailures++
}
}
if criticalFailures == 0 {
result.Success = true
result.Status = StatusCompleted
result.Message = "DR drill completed successfully"
} else if criticalFailures < len(result.ValidationResults)+len(result.CheckResults) {
result.Success = false
result.Status = StatusPartial
result.Message = fmt.Sprintf("DR drill completed with %d validation failures", criticalFailures)
} else {
result.Success = false
result.Status = StatusFailed
result.Message = "All validations failed"
}
e.finalize(result)
// Save result if output dir specified
if config.OutputDir != "" {
if err := result.SaveResult(config.OutputDir); err != nil {
e.log.Warn("Failed to save drill result", "error", err)
} else {
e.log.Info("📄 Report saved to: " + filepath.Join(config.OutputDir, result.DrillID+"_report.json"))
}
}
return result, nil
}
// preflightChecks runs preflight checks before the drill
func (e *Engine) preflightChecks(ctx context.Context, config *DrillConfig) error {
// Check Docker is available
if err := e.docker.CheckDockerAvailable(ctx); err != nil {
return fmt.Errorf("docker not available: %w", err)
}
e.log.Info("✓ Docker is available")
// Check backup file exists
if _, err := os.Stat(config.BackupPath); err != nil {
return fmt.Errorf("backup file not found: %s", config.BackupPath)
}
e.log.Info("✓ Backup file exists: " + filepath.Base(config.BackupPath))
// Pull Docker image
image := config.ContainerImage
if image == "" {
image = GetDefaultImage(config.DatabaseType, "")
}
e.log.Info("⬇️ Pulling image: " + image)
if err := e.docker.PullImage(ctx, image); err != nil {
return fmt.Errorf("failed to pull image: %w", err)
}
e.log.Info("✓ Image ready: " + image)
return nil
}
// buildContainerConfig creates container configuration
func (e *Engine) buildContainerConfig(config *DrillConfig) *ContainerConfig {
containerName := config.ContainerName
if containerName == "" {
containerName = fmt.Sprintf("drill_%s_%s", config.DatabaseName, time.Now().Format("20060102_150405"))
}
image := config.ContainerImage
if image == "" {
image = GetDefaultImage(config.DatabaseType, "")
}
port := config.ContainerPort
if port == 0 {
port = 15432 // Default drill port (different from production)
if config.DatabaseType == "mysql" || config.DatabaseType == "mariadb" {
port = 13306
}
}
containerPort := GetDefaultPort(config.DatabaseType)
env := GetDefaultEnvironment(config.DatabaseType)
return &ContainerConfig{
Image: image,
Name: containerName,
Port: port,
ContainerPort: containerPort,
Environment: env,
Timeout: config.ContainerTimeout,
}
}
// restoreBackup restores the backup into the container
func (e *Engine) restoreBackup(ctx context.Context, config *DrillConfig, containerID string, containerConfig *ContainerConfig) error {
// Copy backup to container
backupName := filepath.Base(config.BackupPath)
containerBackupPath := "/tmp/" + backupName
e.log.Info("📁 Copying backup to container...")
if err := e.docker.CopyToContainer(ctx, containerID, config.BackupPath, containerBackupPath); err != nil {
return fmt.Errorf("failed to copy backup: %w", err)
}
// Handle encrypted backups
if config.EncryptionKeyFile != "" {
// For encrypted backups, we'd need to decrypt first
// This is a simplified implementation
e.log.Warn("Encrypted backup handling not fully implemented in drill mode")
}
// Restore based on database type and format
e.log.Info("🔄 Restoring backup...")
return e.executeRestore(ctx, config, containerID, containerBackupPath, containerConfig)
}
// executeRestore runs the actual restore command
func (e *Engine) executeRestore(ctx context.Context, config *DrillConfig, containerID, backupPath string, containerConfig *ContainerConfig) error {
var cmd []string
switch config.DatabaseType {
case "postgresql", "postgres":
// Decompress if needed
if strings.HasSuffix(backupPath, ".gz") {
decompressedPath := strings.TrimSuffix(backupPath, ".gz")
_, err := e.docker.ExecCommand(ctx, containerID, []string{
"sh", "-c", fmt.Sprintf("gunzip -c %s > %s", backupPath, decompressedPath),
})
if err != nil {
return fmt.Errorf("decompression failed: %w", err)
}
backupPath = decompressedPath
}
// Create database
_, err := e.docker.ExecCommand(ctx, containerID, []string{
"psql", "-U", "postgres", "-c", fmt.Sprintf("CREATE DATABASE %s", config.DatabaseName),
})
if err != nil {
// Database might already exist
e.log.Debug("Create database returned (may already exist)")
}
// Detect restore method based on file content
isCustomFormat := strings.Contains(backupPath, ".dump") || strings.Contains(backupPath, ".custom")
if isCustomFormat {
cmd = []string{"pg_restore", "-U", "postgres", "-d", config.DatabaseName, "-v", backupPath}
} else {
cmd = []string{"sh", "-c", fmt.Sprintf("psql -U postgres -d %s < %s", config.DatabaseName, backupPath)}
}
case "mysql":
// Decompress if needed
if strings.HasSuffix(backupPath, ".gz") {
decompressedPath := strings.TrimSuffix(backupPath, ".gz")
_, err := e.docker.ExecCommand(ctx, containerID, []string{
"sh", "-c", fmt.Sprintf("gunzip -c %s > %s", backupPath, decompressedPath),
})
if err != nil {
return fmt.Errorf("decompression failed: %w", err)
}
backupPath = decompressedPath
}
cmd = []string{"sh", "-c", fmt.Sprintf("mysql -u root --password=root %s < %s", config.DatabaseName, backupPath)}
case "mariadb":
if strings.HasSuffix(backupPath, ".gz") {
decompressedPath := strings.TrimSuffix(backupPath, ".gz")
_, err := e.docker.ExecCommand(ctx, containerID, []string{
"sh", "-c", fmt.Sprintf("gunzip -c %s > %s", backupPath, decompressedPath),
})
if err != nil {
return fmt.Errorf("decompression failed: %w", err)
}
backupPath = decompressedPath
}
cmd = []string{"sh", "-c", fmt.Sprintf("mariadb -u root --password=root %s < %s", config.DatabaseName, backupPath)}
default:
return fmt.Errorf("unsupported database type: %s", config.DatabaseType)
}
output, err := e.docker.ExecCommand(ctx, containerID, cmd)
if err != nil {
return fmt.Errorf("restore failed: %w (output: %s)", err, output)
}
return nil
}
// validateDatabase runs validation against the restored database
func (e *Engine) validateDatabase(ctx context.Context, config *DrillConfig, result *DrillResult, containerConfig *ContainerConfig) int {
errorCount := 0
// Connect to database
var user, password string
switch config.DatabaseType {
case "postgresql", "postgres":
user = "postgres"
password = containerConfig.Environment["POSTGRES_PASSWORD"]
case "mysql":
user = "root"
password = "root"
case "mariadb":
user = "root"
password = "root"
}
validator, err := NewValidator(config.DatabaseType, "localhost", containerConfig.Port, user, password, config.DatabaseName, e.verbose)
if err != nil {
e.log.Error("Failed to connect for validation", "error", err)
result.Errors = append(result.Errors, "Validation connection failed: "+err.Error())
return 1
}
defer validator.Close()
// Get database metrics
tables, err := validator.GetTableList(ctx)
if err == nil {
result.TableCount = len(tables)
e.log.Info(fmt.Sprintf("📊 Tables found: %d", result.TableCount))
}
totalRows, err := validator.GetTotalRowCount(ctx)
if err == nil {
result.TotalRows = totalRows
e.log.Info(fmt.Sprintf("📊 Total rows: %d", result.TotalRows))
}
dbSize, err := validator.GetDatabaseSize(ctx, config.DatabaseName)
if err == nil {
result.DatabaseSize = dbSize
}
// Run expected tables check
if len(config.ExpectedTables) > 0 {
tableResults := validator.ValidateExpectedTables(ctx, config.ExpectedTables)
for _, tr := range tableResults {
result.CheckResults = append(result.CheckResults, tr)
if !tr.Success {
errorCount++
e.log.Warn("❌ " + tr.Message)
} else {
e.log.Info("✓ " + tr.Message)
}
}
}
// Run validation queries
if len(config.ValidationQueries) > 0 {
queryResults := validator.RunValidationQueries(ctx, config.ValidationQueries)
result.ValidationResults = append(result.ValidationResults, queryResults...)
var totalQueryTime float64
for _, qr := range queryResults {
totalQueryTime += qr.Duration
if !qr.Success {
errorCount++
e.log.Warn(fmt.Sprintf("❌ %s: %s", qr.Name, qr.Error))
} else {
e.log.Info(fmt.Sprintf("✓ %s: %s (%.0fms)", qr.Name, qr.Result, qr.Duration))
}
}
if len(queryResults) > 0 {
result.QueryTimeAvg = totalQueryTime / float64(len(queryResults))
}
}
// Run custom checks
if len(config.CustomChecks) > 0 {
checkResults := validator.RunCustomChecks(ctx, config.CustomChecks)
for _, cr := range checkResults {
result.CheckResults = append(result.CheckResults, cr)
if !cr.Success {
errorCount++
e.log.Warn("❌ " + cr.Message)
} else {
e.log.Info("✓ " + cr.Message)
}
}
}
// Check minimum row count if specified
if config.MinRowCount > 0 && result.TotalRows < config.MinRowCount {
errorCount++
msg := fmt.Sprintf("Total rows (%d) below minimum (%d)", result.TotalRows, config.MinRowCount)
result.Warnings = append(result.Warnings, msg)
e.log.Warn("⚠️ " + msg)
}
return errorCount
}
// startPhase starts a new drill phase
func (e *Engine) startPhase(name string) DrillPhase {
e.log.Info("▶️ " + name)
return DrillPhase{
Name: name,
Status: "running",
StartTime: time.Now(),
}
}
// completePhase marks a phase as completed
func (e *Engine) completePhase(phase *DrillPhase, message string) {
phase.EndTime = time.Now()
phase.Duration = phase.EndTime.Sub(phase.StartTime).Seconds()
phase.Status = "completed"
phase.Message = message
}
// failPhase marks a phase as failed
func (e *Engine) failPhase(phase *DrillPhase, message string) {
phase.EndTime = time.Now()
phase.Duration = phase.EndTime.Sub(phase.StartTime).Seconds()
phase.Status = "failed"
phase.Message = message
e.log.Error("❌ Phase failed: " + message)
}
// finalize completes the drill result
func (e *Engine) finalize(result *DrillResult) {
result.EndTime = time.Now()
result.Duration = result.EndTime.Sub(result.StartTime).Seconds()
e.log.Info("")
e.log.Info("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
e.log.Info(" " + result.Summary())
e.log.Info("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
if result.Success {
e.log.Info(fmt.Sprintf(" RTO: %.2fs (target: %.0fs) %s",
result.ActualRTO, result.TargetRTO, boolIcon(result.RTOMet)))
}
}
func boolIcon(b bool) string {
if b {
return "✅"
}
return "❌"
}
// Cleanup removes drill resources
func (e *Engine) Cleanup(ctx context.Context, drillID string) error {
containers, err := e.docker.ListDrillContainers(ctx)
if err != nil {
return err
}
for _, c := range containers {
if strings.Contains(c.Name, drillID) || (drillID == "" && strings.HasPrefix(c.Name, "drill_")) {
e.log.Info("🗑️ Removing container: " + c.Name)
if err := e.docker.RemoveContainer(ctx, c.ID); err != nil {
e.log.Warn("Failed to remove container", "id", c.ID, "error", err)
}
}
}
return nil
}
// QuickTest runs a quick restore test without full validation
func (e *Engine) QuickTest(ctx context.Context, backupPath, dbType, dbName string) (*DrillResult, error) {
config := DefaultConfig()
config.BackupPath = backupPath
config.DatabaseType = dbType
config.DatabaseName = dbName
config.CleanupOnExit = true
config.MaxRestoreSeconds = 600
return e.Run(ctx, config)
}
// Validate runs validation queries against an existing database (non-Docker)
func (e *Engine) Validate(ctx context.Context, config *DrillConfig, host string, port int, user, password string) ([]ValidationResult, error) {
validator, err := NewValidator(config.DatabaseType, host, port, user, password, config.DatabaseName, e.verbose)
if err != nil {
return nil, err
}
defer validator.Close()
return validator.RunValidationQueries(ctx, config.ValidationQueries), nil
}

358
internal/drill/validate.go Normal file
View File

@@ -0,0 +1,358 @@
// Package drill - Validation logic for DR drills
package drill
import (
"context"
"database/sql"
"fmt"
"strings"
"time"
_ "github.com/go-sql-driver/mysql"
_ "github.com/jackc/pgx/v5/stdlib"
)
// Validator handles database validation during DR drills
type Validator struct {
db *sql.DB
dbType string
verbose bool
}
// NewValidator creates a new database validator
func NewValidator(dbType string, host string, port int, user, password, dbname string, verbose bool) (*Validator, error) {
var dsn string
var driver string
switch dbType {
case "postgresql", "postgres":
driver = "pgx"
dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
host, port, user, password, dbname)
case "mysql":
driver = "mysql"
dsn = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
user, password, host, port, dbname)
case "mariadb":
driver = "mysql"
dsn = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
user, password, host, port, dbname)
default:
return nil, fmt.Errorf("unsupported database type: %s", dbType)
}
db, err := sql.Open(driver, dsn)
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
// Test connection
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := db.PingContext(ctx); err != nil {
db.Close()
return nil, fmt.Errorf("failed to ping database: %w", err)
}
return &Validator{
db: db,
dbType: dbType,
verbose: verbose,
}, nil
}
// Close closes the database connection
func (v *Validator) Close() error {
return v.db.Close()
}
// RunValidationQueries executes validation queries and returns results
func (v *Validator) RunValidationQueries(ctx context.Context, queries []ValidationQuery) []ValidationResult {
var results []ValidationResult
for _, q := range queries {
result := v.runQuery(ctx, q)
results = append(results, result)
}
return results
}
// runQuery executes a single validation query
func (v *Validator) runQuery(ctx context.Context, query ValidationQuery) ValidationResult {
result := ValidationResult{
Name: query.Name,
Query: query.Query,
Expected: query.ExpectedValue,
}
start := time.Now()
rows, err := v.db.QueryContext(ctx, query.Query)
result.Duration = float64(time.Since(start).Milliseconds())
if err != nil {
result.Success = false
result.Error = err.Error()
return result
}
defer rows.Close()
// Get result
if rows.Next() {
var value interface{}
if err := rows.Scan(&value); err != nil {
result.Success = false
result.Error = fmt.Sprintf("scan error: %v", err)
return result
}
result.Result = fmt.Sprintf("%v", value)
}
// Validate result
result.Success = true
if query.ExpectedValue != "" && result.Result != query.ExpectedValue {
result.Success = false
result.Error = fmt.Sprintf("expected %s, got %s", query.ExpectedValue, result.Result)
}
// Check min/max if specified
if query.MinValue > 0 || query.MaxValue > 0 {
var numValue int64
fmt.Sscanf(result.Result, "%d", &numValue)
if query.MinValue > 0 && numValue < query.MinValue {
result.Success = false
result.Error = fmt.Sprintf("value %d below minimum %d", numValue, query.MinValue)
}
if query.MaxValue > 0 && numValue > query.MaxValue {
result.Success = false
result.Error = fmt.Sprintf("value %d above maximum %d", numValue, query.MaxValue)
}
}
return result
}
// RunCustomChecks executes custom validation checks
func (v *Validator) RunCustomChecks(ctx context.Context, checks []CustomCheck) []CheckResult {
var results []CheckResult
for _, check := range checks {
result := v.runCheck(ctx, check)
results = append(results, result)
}
return results
}
// runCheck executes a single custom check
func (v *Validator) runCheck(ctx context.Context, check CustomCheck) CheckResult {
result := CheckResult{
Name: check.Name,
Type: check.Type,
Expected: check.MinValue,
}
switch check.Type {
case "row_count":
count, err := v.getRowCount(ctx, check.Table, check.Condition)
if err != nil {
result.Success = false
result.Message = fmt.Sprintf("failed to get row count: %v", err)
return result
}
result.Actual = count
result.Success = count >= check.MinValue
if result.Success {
result.Message = fmt.Sprintf("Table %s has %d rows (min: %d)", check.Table, count, check.MinValue)
} else {
result.Message = fmt.Sprintf("Table %s has %d rows, expected at least %d", check.Table, count, check.MinValue)
}
case "table_exists":
exists, err := v.tableExists(ctx, check.Table)
if err != nil {
result.Success = false
result.Message = fmt.Sprintf("failed to check table: %v", err)
return result
}
result.Success = exists
if exists {
result.Actual = 1
result.Message = fmt.Sprintf("Table %s exists", check.Table)
} else {
result.Actual = 0
result.Message = fmt.Sprintf("Table %s does not exist", check.Table)
}
case "column_check":
exists, err := v.columnExists(ctx, check.Table, check.Column)
if err != nil {
result.Success = false
result.Message = fmt.Sprintf("failed to check column: %v", err)
return result
}
result.Success = exists
if exists {
result.Actual = 1
result.Message = fmt.Sprintf("Column %s.%s exists", check.Table, check.Column)
} else {
result.Actual = 0
result.Message = fmt.Sprintf("Column %s.%s does not exist", check.Table, check.Column)
}
default:
result.Success = false
result.Message = fmt.Sprintf("unknown check type: %s", check.Type)
}
return result
}
// getRowCount returns the row count for a table
func (v *Validator) getRowCount(ctx context.Context, table, condition string) (int64, error) {
query := fmt.Sprintf("SELECT COUNT(*) FROM %s", v.quoteIdentifier(table))
if condition != "" {
query += " WHERE " + condition
}
var count int64
err := v.db.QueryRowContext(ctx, query).Scan(&count)
return count, err
}
// tableExists checks if a table exists
func (v *Validator) tableExists(ctx context.Context, table string) (bool, error) {
var query string
switch v.dbType {
case "postgresql", "postgres":
query = `SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_name = $1
)`
case "mysql", "mariadb":
query = `SELECT COUNT(*) > 0 FROM information_schema.tables
WHERE table_name = ?`
}
var exists bool
err := v.db.QueryRowContext(ctx, query, table).Scan(&exists)
return exists, err
}
// columnExists checks if a column exists
func (v *Validator) columnExists(ctx context.Context, table, column string) (bool, error) {
var query string
switch v.dbType {
case "postgresql", "postgres":
query = `SELECT EXISTS (
SELECT FROM information_schema.columns
WHERE table_name = $1 AND column_name = $2
)`
case "mysql", "mariadb":
query = `SELECT COUNT(*) > 0 FROM information_schema.columns
WHERE table_name = ? AND column_name = ?`
}
var exists bool
err := v.db.QueryRowContext(ctx, query, table, column).Scan(&exists)
return exists, err
}
// GetTableList returns all tables in the database
func (v *Validator) GetTableList(ctx context.Context) ([]string, error) {
var query string
switch v.dbType {
case "postgresql", "postgres":
query = `SELECT table_name FROM information_schema.tables
WHERE table_schema = 'public' AND table_type = 'BASE TABLE'`
case "mysql", "mariadb":
query = `SELECT table_name FROM information_schema.tables
WHERE table_schema = DATABASE() AND table_type = 'BASE TABLE'`
}
rows, err := v.db.QueryContext(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
var tables []string
for rows.Next() {
var table string
if err := rows.Scan(&table); err != nil {
return nil, err
}
tables = append(tables, table)
}
return tables, rows.Err()
}
// GetTotalRowCount returns total row count across all tables
func (v *Validator) GetTotalRowCount(ctx context.Context) (int64, error) {
tables, err := v.GetTableList(ctx)
if err != nil {
return 0, err
}
var total int64
for _, table := range tables {
count, err := v.getRowCount(ctx, table, "")
if err != nil {
continue // Skip tables that can't be counted
}
total += count
}
return total, nil
}
// GetDatabaseSize returns the database size in bytes
func (v *Validator) GetDatabaseSize(ctx context.Context, dbname string) (int64, error) {
var query string
switch v.dbType {
case "postgresql", "postgres":
query = fmt.Sprintf("SELECT pg_database_size('%s')", dbname)
case "mysql", "mariadb":
query = fmt.Sprintf(`SELECT SUM(data_length + index_length)
FROM information_schema.tables WHERE table_schema = '%s'`, dbname)
}
var size sql.NullInt64
err := v.db.QueryRowContext(ctx, query).Scan(&size)
if err != nil {
return 0, err
}
return size.Int64, nil
}
// ValidateExpectedTables checks that all expected tables exist
func (v *Validator) ValidateExpectedTables(ctx context.Context, expectedTables []string) []CheckResult {
var results []CheckResult
for _, table := range expectedTables {
check := CustomCheck{
Name: fmt.Sprintf("Table '%s' exists", table),
Type: "table_exists",
Table: table,
}
results = append(results, v.runCheck(ctx, check))
}
return results
}
// quoteIdentifier quotes a database identifier
func (v *Validator) quoteIdentifier(id string) string {
switch v.dbType {
case "postgresql", "postgres":
return fmt.Sprintf(`"%s"`, strings.ReplaceAll(id, `"`, `""`))
case "mysql", "mariadb":
return fmt.Sprintf("`%s`", strings.ReplaceAll(id, "`", "``"))
default:
return id
}
}

261
internal/notify/batch.go Normal file
View File

@@ -0,0 +1,261 @@
// Package notify - Event batching for aggregated notifications
package notify
import (
"context"
"fmt"
"sync"
"time"
)
// BatchConfig configures notification batching
type BatchConfig struct {
Enabled bool // Enable batching
Window time.Duration // Batch window (e.g., 5 minutes)
MaxEvents int // Maximum events per batch before forced send
GroupBy string // Group by: "database", "type", "severity", "host"
DigestFormat string // Format: "summary", "detailed", "compact"
}
// DefaultBatchConfig returns sensible batch defaults
func DefaultBatchConfig() BatchConfig {
return BatchConfig{
Enabled: false,
Window: 5 * time.Minute,
MaxEvents: 50,
GroupBy: "database",
DigestFormat: "summary",
}
}
// Batcher collects events and sends them in batches
type Batcher struct {
config BatchConfig
manager *Manager
events []*Event
mu sync.Mutex
timer *time.Timer
ctx context.Context
cancel context.CancelFunc
startTime time.Time
}
// NewBatcher creates a new event batcher
func NewBatcher(config BatchConfig, manager *Manager) *Batcher {
ctx, cancel := context.WithCancel(context.Background())
return &Batcher{
config: config,
manager: manager,
events: make([]*Event, 0),
ctx: ctx,
cancel: cancel,
}
}
// Add adds an event to the batch
func (b *Batcher) Add(event *Event) {
if !b.config.Enabled {
// Batching disabled, send immediately
b.manager.Notify(event)
return
}
b.mu.Lock()
defer b.mu.Unlock()
// Start timer on first event
if len(b.events) == 0 {
b.startTime = time.Now()
b.timer = time.AfterFunc(b.config.Window, func() {
b.Flush()
})
}
b.events = append(b.events, event)
// Check if we've hit max events
if len(b.events) >= b.config.MaxEvents {
b.flushLocked()
}
}
// Flush sends all batched events
func (b *Batcher) Flush() {
b.mu.Lock()
defer b.mu.Unlock()
b.flushLocked()
}
// flushLocked sends batched events (must hold mutex)
func (b *Batcher) flushLocked() {
if len(b.events) == 0 {
return
}
// Cancel pending timer
if b.timer != nil {
b.timer.Stop()
b.timer = nil
}
// Group events
groups := b.groupEvents()
// Create digest event for each group
for key, events := range groups {
digest := b.createDigest(key, events)
b.manager.Notify(digest)
}
// Clear events
b.events = make([]*Event, 0)
}
// groupEvents groups events by configured criteria
func (b *Batcher) groupEvents() map[string][]*Event {
groups := make(map[string][]*Event)
for _, event := range b.events {
var key string
switch b.config.GroupBy {
case "database":
key = event.Database
case "type":
key = string(event.Type)
case "severity":
key = string(event.Severity)
case "host":
key = event.Hostname
default:
key = "all"
}
if key == "" {
key = "unknown"
}
groups[key] = append(groups[key], event)
}
return groups
}
// createDigest creates a digest event from multiple events
func (b *Batcher) createDigest(groupKey string, events []*Event) *Event {
// Calculate summary stats
var (
successCount int
failureCount int
highestSev = SeverityInfo
totalDuration time.Duration
databases = make(map[string]bool)
)
for _, e := range events {
switch e.Type {
case EventBackupCompleted, EventRestoreCompleted, EventVerifyCompleted:
successCount++
case EventBackupFailed, EventRestoreFailed, EventVerifyFailed:
failureCount++
}
if severityOrder(e.Severity) > severityOrder(highestSev) {
highestSev = e.Severity
}
totalDuration += e.Duration
if e.Database != "" {
databases[e.Database] = true
}
}
// Create digest message
var message string
switch b.config.DigestFormat {
case "detailed":
message = b.formatDetailedDigest(events)
case "compact":
message = b.formatCompactDigest(events, successCount, failureCount)
default: // summary
message = b.formatSummaryDigest(events, successCount, failureCount, len(databases))
}
digest := NewEvent(EventType("digest"), highestSev, message)
digest.WithDetail("group", groupKey)
digest.WithDetail("event_count", fmt.Sprintf("%d", len(events)))
digest.WithDetail("success_count", fmt.Sprintf("%d", successCount))
digest.WithDetail("failure_count", fmt.Sprintf("%d", failureCount))
digest.WithDetail("batch_duration", fmt.Sprintf("%.0fs", time.Since(b.startTime).Seconds()))
if len(databases) == 1 {
for db := range databases {
digest.Database = db
}
}
return digest
}
func (b *Batcher) formatSummaryDigest(events []*Event, success, failure, dbCount int) string {
total := len(events)
return fmt.Sprintf("Batch Summary: %d events (%d success, %d failed) across %d database(s)",
total, success, failure, dbCount)
}
func (b *Batcher) formatCompactDigest(events []*Event, success, failure int) string {
if failure > 0 {
return fmt.Sprintf("⚠️ %d/%d operations failed", failure, len(events))
}
return fmt.Sprintf("✅ All %d operations successful", success)
}
func (b *Batcher) formatDetailedDigest(events []*Event) string {
var msg string
msg += fmt.Sprintf("=== Batch Digest (%d events) ===\n\n", len(events))
for _, e := range events {
icon := "•"
switch e.Severity {
case SeverityError, SeverityCritical:
icon = "❌"
case SeverityWarning:
icon = "⚠️"
}
msg += fmt.Sprintf("%s [%s] %s: %s\n",
icon,
e.Timestamp.Format("15:04:05"),
e.Type,
e.Message)
}
return msg
}
// Stop stops the batcher and flushes remaining events
func (b *Batcher) Stop() {
b.cancel()
b.Flush()
}
// BatcherStats returns current batcher statistics
type BatcherStats struct {
PendingEvents int `json:"pending_events"`
BatchAge time.Duration `json:"batch_age"`
Config BatchConfig `json:"config"`
}
// Stats returns current batcher statistics
func (b *Batcher) Stats() BatcherStats {
b.mu.Lock()
defer b.mu.Unlock()
var age time.Duration
if len(b.events) > 0 {
age = time.Since(b.startTime)
}
return BatcherStats{
PendingEvents: len(b.events),
BatchAge: age,
Config: b.config,
}
}

363
internal/notify/escalate.go Normal file
View File

@@ -0,0 +1,363 @@
// Package notify - Escalation for critical events
package notify
import (
"context"
"fmt"
"sync"
"time"
)
// EscalationConfig configures notification escalation
type EscalationConfig struct {
Enabled bool // Enable escalation
Levels []EscalationLevel // Escalation levels
AcknowledgeURL string // URL to acknowledge alerts
CooldownPeriod time.Duration // Cooldown between escalations
RepeatInterval time.Duration // Repeat unacknowledged alerts
MaxRepeats int // Maximum repeat attempts
TrackingEnabled bool // Track escalation state
}
// EscalationLevel defines an escalation tier
type EscalationLevel struct {
Name string // Level name (e.g., "primary", "secondary", "manager")
Delay time.Duration // Delay before escalating to this level
Recipients []string // Email recipients for this level
Webhook string // Webhook URL for this level
Severity Severity // Minimum severity to escalate
Message string // Custom message template
}
// DefaultEscalationConfig returns sensible defaults
func DefaultEscalationConfig() EscalationConfig {
return EscalationConfig{
Enabled: false,
CooldownPeriod: 15 * time.Minute,
RepeatInterval: 30 * time.Minute,
MaxRepeats: 3,
Levels: []EscalationLevel{
{
Name: "primary",
Delay: 0,
Severity: SeverityError,
},
{
Name: "secondary",
Delay: 15 * time.Minute,
Severity: SeverityError,
},
{
Name: "critical",
Delay: 30 * time.Minute,
Severity: SeverityCritical,
},
},
}
}
// EscalationState tracks escalation for an alert
type EscalationState struct {
AlertID string `json:"alert_id"`
Event *Event `json:"event"`
CurrentLevel int `json:"current_level"`
StartedAt time.Time `json:"started_at"`
LastEscalation time.Time `json:"last_escalation"`
RepeatCount int `json:"repeat_count"`
Acknowledged bool `json:"acknowledged"`
AcknowledgedBy string `json:"acknowledged_by,omitempty"`
AcknowledgedAt *time.Time `json:"acknowledged_at,omitempty"`
Resolved bool `json:"resolved"`
}
// Escalator manages alert escalation
type Escalator struct {
config EscalationConfig
manager *Manager
alerts map[string]*EscalationState
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
ticker *time.Ticker
}
// NewEscalator creates a new escalation manager
func NewEscalator(config EscalationConfig, manager *Manager) *Escalator {
ctx, cancel := context.WithCancel(context.Background())
e := &Escalator{
config: config,
manager: manager,
alerts: make(map[string]*EscalationState),
ctx: ctx,
cancel: cancel,
}
if config.Enabled {
e.ticker = time.NewTicker(time.Minute)
go e.runEscalationLoop()
}
return e
}
// Handle processes an event for potential escalation
func (e *Escalator) Handle(event *Event) {
if !e.config.Enabled {
return
}
// Only escalate errors and critical events
if severityOrder(event.Severity) < severityOrder(SeverityError) {
return
}
// Generate alert ID
alertID := e.generateAlertID(event)
e.mu.Lock()
defer e.mu.Unlock()
// Check if alert already exists
if existing, ok := e.alerts[alertID]; ok {
if !existing.Acknowledged && !existing.Resolved {
// Alert already being escalated
return
}
}
// Create new escalation state
state := &EscalationState{
AlertID: alertID,
Event: event,
CurrentLevel: 0,
StartedAt: time.Now(),
LastEscalation: time.Now(),
}
e.alerts[alertID] = state
// Send immediate notification to first level
e.notifyLevel(state, 0)
}
// generateAlertID creates a unique ID for an alert
func (e *Escalator) generateAlertID(event *Event) string {
return fmt.Sprintf("%s_%s_%s",
event.Type,
event.Database,
event.Hostname)
}
// notifyLevel sends notification for a specific escalation level
func (e *Escalator) notifyLevel(state *EscalationState, level int) {
if level >= len(e.config.Levels) {
return
}
lvl := e.config.Levels[level]
// Create escalated event
escalatedEvent := &Event{
Type: state.Event.Type,
Severity: state.Event.Severity,
Timestamp: time.Now(),
Database: state.Event.Database,
Hostname: state.Event.Hostname,
Message: e.formatEscalationMessage(state, lvl),
Details: make(map[string]string),
}
escalatedEvent.Details["escalation_level"] = lvl.Name
escalatedEvent.Details["alert_id"] = state.AlertID
escalatedEvent.Details["escalation_time"] = fmt.Sprintf("%d", int(time.Since(state.StartedAt).Minutes()))
escalatedEvent.Details["original_message"] = state.Event.Message
if state.Event.Error != "" {
escalatedEvent.Error = state.Event.Error
}
// Send via manager
e.manager.Notify(escalatedEvent)
state.CurrentLevel = level
state.LastEscalation = time.Now()
}
// formatEscalationMessage creates an escalation message
func (e *Escalator) formatEscalationMessage(state *EscalationState, level EscalationLevel) string {
if level.Message != "" {
return level.Message
}
elapsed := time.Since(state.StartedAt)
return fmt.Sprintf("🚨 ESCALATION [%s] - Alert unacknowledged for %s\n\n%s",
level.Name,
formatDuration(elapsed),
state.Event.Message)
}
// runEscalationLoop checks for alerts that need escalation
func (e *Escalator) runEscalationLoop() {
for {
select {
case <-e.ctx.Done():
return
case <-e.ticker.C:
e.checkEscalations()
}
}
}
// checkEscalations checks all alerts for needed escalation
func (e *Escalator) checkEscalations() {
e.mu.Lock()
defer e.mu.Unlock()
now := time.Now()
for _, state := range e.alerts {
if state.Acknowledged || state.Resolved {
continue
}
// Check if we need to escalate to next level
nextLevel := state.CurrentLevel + 1
if nextLevel < len(e.config.Levels) {
lvl := e.config.Levels[nextLevel]
if now.Sub(state.StartedAt) >= lvl.Delay {
e.notifyLevel(state, nextLevel)
}
}
// Check if we need to repeat the alert
if state.RepeatCount < e.config.MaxRepeats {
if now.Sub(state.LastEscalation) >= e.config.RepeatInterval {
e.notifyLevel(state, state.CurrentLevel)
state.RepeatCount++
}
}
}
}
// Acknowledge acknowledges an alert
func (e *Escalator) Acknowledge(alertID, user string) error {
e.mu.Lock()
defer e.mu.Unlock()
state, ok := e.alerts[alertID]
if !ok {
return fmt.Errorf("alert not found: %s", alertID)
}
now := time.Now()
state.Acknowledged = true
state.AcknowledgedBy = user
state.AcknowledgedAt = &now
return nil
}
// Resolve resolves an alert
func (e *Escalator) Resolve(alertID string) error {
e.mu.Lock()
defer e.mu.Unlock()
state, ok := e.alerts[alertID]
if !ok {
return fmt.Errorf("alert not found: %s", alertID)
}
state.Resolved = true
return nil
}
// GetActiveAlerts returns all active (unacknowledged, unresolved) alerts
func (e *Escalator) GetActiveAlerts() []*EscalationState {
e.mu.RLock()
defer e.mu.RUnlock()
var active []*EscalationState
for _, state := range e.alerts {
if !state.Acknowledged && !state.Resolved {
active = append(active, state)
}
}
return active
}
// GetAlert returns a specific alert
func (e *Escalator) GetAlert(alertID string) (*EscalationState, bool) {
e.mu.RLock()
defer e.mu.RUnlock()
state, ok := e.alerts[alertID]
return state, ok
}
// CleanupOld removes old resolved/acknowledged alerts
func (e *Escalator) CleanupOld(maxAge time.Duration) int {
e.mu.Lock()
defer e.mu.Unlock()
now := time.Now()
removed := 0
for id, state := range e.alerts {
if (state.Acknowledged || state.Resolved) && now.Sub(state.StartedAt) > maxAge {
delete(e.alerts, id)
removed++
}
}
return removed
}
// Stop stops the escalator
func (e *Escalator) Stop() {
e.cancel()
if e.ticker != nil {
e.ticker.Stop()
}
}
// EscalatorStats returns escalator statistics
type EscalatorStats struct {
ActiveAlerts int `json:"active_alerts"`
AcknowledgedAlerts int `json:"acknowledged_alerts"`
ResolvedAlerts int `json:"resolved_alerts"`
EscalationEnabled bool `json:"escalation_enabled"`
LevelCount int `json:"level_count"`
}
// Stats returns escalator statistics
func (e *Escalator) Stats() EscalatorStats {
e.mu.RLock()
defer e.mu.RUnlock()
stats := EscalatorStats{
EscalationEnabled: e.config.Enabled,
LevelCount: len(e.config.Levels),
}
for _, state := range e.alerts {
if state.Resolved {
stats.ResolvedAlerts++
} else if state.Acknowledged {
stats.AcknowledgedAlerts++
} else {
stats.ActiveAlerts++
}
}
return stats
}
func formatDuration(d time.Duration) string {
if d < time.Minute {
return fmt.Sprintf("%.0fs", d.Seconds())
}
if d < time.Hour {
return fmt.Sprintf("%.0fm", d.Minutes())
}
return fmt.Sprintf("%.0fh %.0fm", d.Hours(), d.Minutes()-d.Hours()*60)
}

View File

@@ -11,41 +11,66 @@ import (
type EventType string
const (
EventBackupStarted EventType = "backup_started"
EventBackupCompleted EventType = "backup_completed"
EventBackupFailed EventType = "backup_failed"
EventRestoreStarted EventType = "restore_started"
EventRestoreCompleted EventType = "restore_completed"
EventRestoreFailed EventType = "restore_failed"
EventCleanupCompleted EventType = "cleanup_completed"
EventVerifyCompleted EventType = "verify_completed"
EventVerifyFailed EventType = "verify_failed"
EventPITRRecovery EventType = "pitr_recovery"
EventBackupStarted EventType = "backup_started"
EventBackupCompleted EventType = "backup_completed"
EventBackupFailed EventType = "backup_failed"
EventRestoreStarted EventType = "restore_started"
EventRestoreCompleted EventType = "restore_completed"
EventRestoreFailed EventType = "restore_failed"
EventCleanupCompleted EventType = "cleanup_completed"
EventVerifyCompleted EventType = "verify_completed"
EventVerifyFailed EventType = "verify_failed"
EventPITRRecovery EventType = "pitr_recovery"
EventVerificationPassed EventType = "verification_passed"
EventVerificationFailed EventType = "verification_failed"
EventDRDrillPassed EventType = "dr_drill_passed"
EventDRDrillFailed EventType = "dr_drill_failed"
EventGapDetected EventType = "gap_detected"
EventRPOViolation EventType = "rpo_violation"
)
// Severity represents the severity level of a notification
type Severity string
const (
SeverityInfo Severity = "info"
SeverityWarning Severity = "warning"
SeverityError Severity = "error"
SeverityInfo Severity = "info"
SeveritySuccess Severity = "success"
SeverityWarning Severity = "warning"
SeverityError Severity = "error"
SeverityCritical Severity = "critical"
)
// severityOrder returns numeric order for severity comparison
func severityOrder(s Severity) int {
switch s {
case SeverityInfo:
return 0
case SeveritySuccess:
return 1
case SeverityWarning:
return 2
case SeverityError:
return 3
case SeverityCritical:
return 4
default:
return 0
}
}
// Event represents a notification event
type Event struct {
Type EventType `json:"type"`
Severity Severity `json:"severity"`
Timestamp time.Time `json:"timestamp"`
Database string `json:"database,omitempty"`
Message string `json:"message"`
Details map[string]string `json:"details,omitempty"`
Error string `json:"error,omitempty"`
Duration time.Duration `json:"duration,omitempty"`
BackupFile string `json:"backup_file,omitempty"`
BackupSize int64 `json:"backup_size,omitempty"`
Hostname string `json:"hostname,omitempty"`
Type EventType `json:"type"`
Severity Severity `json:"severity"`
Timestamp time.Time `json:"timestamp"`
Database string `json:"database,omitempty"`
Message string `json:"message"`
Details map[string]string `json:"details,omitempty"`
Error string `json:"error,omitempty"`
Duration time.Duration `json:"duration,omitempty"`
BackupFile string `json:"backup_file,omitempty"`
BackupSize int64 `json:"backup_size,omitempty"`
Hostname string `json:"hostname,omitempty"`
}
// NewEvent creates a new notification event
@@ -132,27 +157,27 @@ type Config struct {
WebhookSecret string // For signing payloads
// General settings
OnSuccess bool // Send notifications on successful operations
OnFailure bool // Send notifications on failed operations
OnWarning bool // Send notifications on warnings
MinSeverity Severity
Retries int // Number of retry attempts
RetryDelay time.Duration // Delay between retries
OnSuccess bool // Send notifications on successful operations
OnFailure bool // Send notifications on failed operations
OnWarning bool // Send notifications on warnings
MinSeverity Severity
Retries int // Number of retry attempts
RetryDelay time.Duration // Delay between retries
}
// DefaultConfig returns a configuration with sensible defaults
func DefaultConfig() Config {
return Config{
SMTPPort: 587,
SMTPTLS: false,
SMTPStartTLS: true,
SMTPPort: 587,
SMTPTLS: false,
SMTPStartTLS: true,
WebhookMethod: "POST",
OnSuccess: true,
OnFailure: true,
OnWarning: true,
MinSeverity: SeverityInfo,
Retries: 3,
RetryDelay: 5 * time.Second,
OnSuccess: true,
OnFailure: true,
OnWarning: true,
MinSeverity: SeverityInfo,
Retries: 3,
RetryDelay: 5 * time.Second,
}
}

View File

@@ -0,0 +1,497 @@
// Package notify - Notification templates
package notify
import (
"bytes"
"fmt"
"html/template"
"strings"
"time"
)
// TemplateType represents the notification format type
type TemplateType string
const (
TemplateText TemplateType = "text"
TemplateHTML TemplateType = "html"
TemplateMarkdown TemplateType = "markdown"
TemplateSlack TemplateType = "slack"
)
// Templates holds notification templates
type Templates struct {
Subject string
TextBody string
HTMLBody string
}
// DefaultTemplates returns default notification templates
func DefaultTemplates() map[EventType]Templates {
return map[EventType]Templates{
EventBackupStarted: {
Subject: "🔄 Backup Started: {{.Database}} on {{.Hostname}}",
TextBody: backupStartedText,
HTMLBody: backupStartedHTML,
},
EventBackupCompleted: {
Subject: "✅ Backup Completed: {{.Database}} on {{.Hostname}}",
TextBody: backupCompletedText,
HTMLBody: backupCompletedHTML,
},
EventBackupFailed: {
Subject: "❌ Backup FAILED: {{.Database}} on {{.Hostname}}",
TextBody: backupFailedText,
HTMLBody: backupFailedHTML,
},
EventRestoreStarted: {
Subject: "🔄 Restore Started: {{.Database}} on {{.Hostname}}",
TextBody: restoreStartedText,
HTMLBody: restoreStartedHTML,
},
EventRestoreCompleted: {
Subject: "✅ Restore Completed: {{.Database}} on {{.Hostname}}",
TextBody: restoreCompletedText,
HTMLBody: restoreCompletedHTML,
},
EventRestoreFailed: {
Subject: "❌ Restore FAILED: {{.Database}} on {{.Hostname}}",
TextBody: restoreFailedText,
HTMLBody: restoreFailedHTML,
},
EventVerificationPassed: {
Subject: "✅ Verification Passed: {{.Database}}",
TextBody: verificationPassedText,
HTMLBody: verificationPassedHTML,
},
EventVerificationFailed: {
Subject: "❌ Verification FAILED: {{.Database}}",
TextBody: verificationFailedText,
HTMLBody: verificationFailedHTML,
},
EventDRDrillPassed: {
Subject: "✅ DR Drill Passed: {{.Database}}",
TextBody: drDrillPassedText,
HTMLBody: drDrillPassedHTML,
},
EventDRDrillFailed: {
Subject: "❌ DR Drill FAILED: {{.Database}}",
TextBody: drDrillFailedText,
HTMLBody: drDrillFailedHTML,
},
}
}
// Template strings
const backupStartedText = `
Backup Operation Started
Database: {{.Database}}
Hostname: {{.Hostname}}
Started At: {{formatTime .Timestamp}}
{{if .Message}}{{.Message}}{{end}}
`
const backupStartedHTML = `
<div style="font-family: Arial, sans-serif; padding: 20px;">
<h2 style="color: #3498db;">🔄 Backup Started</h2>
<table style="border-collapse: collapse; width: 100%; max-width: 600px;">
<tr><td style="padding: 8px; font-weight: bold;">Database:</td><td style="padding: 8px;">{{.Database}}</td></tr>
<tr><td style="padding: 8px; font-weight: bold;">Hostname:</td><td style="padding: 8px;">{{.Hostname}}</td></tr>
<tr><td style="padding: 8px; font-weight: bold;">Started At:</td><td style="padding: 8px;">{{formatTime .Timestamp}}</td></tr>
</table>
{{if .Message}}<p style="margin-top: 20px;">{{.Message}}</p>{{end}}
</div>
`
const backupCompletedText = `
Backup Operation Completed Successfully
Database: {{.Database}}
Hostname: {{.Hostname}}
Completed: {{formatTime .Timestamp}}
{{with .Details}}
{{if .size}}Size: {{.size}}{{end}}
{{if .duration}}Duration: {{.duration}}{{end}}
{{if .path}}Path: {{.path}}{{end}}
{{end}}
{{if .Message}}{{.Message}}{{end}}
`
const backupCompletedHTML = `
<div style="font-family: Arial, sans-serif; padding: 20px;">
<h2 style="color: #27ae60;">✅ Backup Completed</h2>
<table style="border-collapse: collapse; width: 100%; max-width: 600px;">
<tr><td style="padding: 8px; font-weight: bold;">Database:</td><td style="padding: 8px;">{{.Database}}</td></tr>
<tr><td style="padding: 8px; font-weight: bold;">Hostname:</td><td style="padding: 8px;">{{.Hostname}}</td></tr>
<tr><td style="padding: 8px; font-weight: bold;">Completed:</td><td style="padding: 8px;">{{formatTime .Timestamp}}</td></tr>
{{with .Details}}
{{if .size}}<tr><td style="padding: 8px; font-weight: bold;">Size:</td><td style="padding: 8px;">{{.size}}</td></tr>{{end}}
{{if .duration}}<tr><td style="padding: 8px; font-weight: bold;">Duration:</td><td style="padding: 8px;">{{.duration}}</td></tr>{{end}}
{{if .path}}<tr><td style="padding: 8px; font-weight: bold;">Path:</td><td style="padding: 8px;">{{.path}}</td></tr>{{end}}
{{end}}
</table>
{{if .Message}}<p style="margin-top: 20px; color: #27ae60;">{{.Message}}</p>{{end}}
</div>
`
const backupFailedText = `
⚠️ BACKUP FAILED ⚠️
Database: {{.Database}}
Hostname: {{.Hostname}}
Failed At: {{formatTime .Timestamp}}
{{if .Error}}
Error: {{.Error}}
{{end}}
{{if .Message}}{{.Message}}{{end}}
Please investigate immediately.
`
const backupFailedHTML = `
<div style="font-family: Arial, sans-serif; padding: 20px;">
<h2 style="color: #e74c3c;">❌ Backup FAILED</h2>
<table style="border-collapse: collapse; width: 100%; max-width: 600px;">
<tr><td style="padding: 8px; font-weight: bold;">Database:</td><td style="padding: 8px;">{{.Database}}</td></tr>
<tr><td style="padding: 8px; font-weight: bold;">Hostname:</td><td style="padding: 8px;">{{.Hostname}}</td></tr>
<tr><td style="padding: 8px; font-weight: bold;">Failed At:</td><td style="padding: 8px;">{{formatTime .Timestamp}}</td></tr>
{{if .Error}}<tr><td style="padding: 8px; font-weight: bold; color: #e74c3c;">Error:</td><td style="padding: 8px; color: #e74c3c;">{{.Error}}</td></tr>{{end}}
</table>
{{if .Message}}<p style="margin-top: 20px;">{{.Message}}</p>{{end}}
<p style="margin-top: 20px; color: #e74c3c; font-weight: bold;">Please investigate immediately.</p>
</div>
`
const restoreStartedText = `
Restore Operation Started
Database: {{.Database}}
Hostname: {{.Hostname}}
Started At: {{formatTime .Timestamp}}
{{if .Message}}{{.Message}}{{end}}
`
const restoreStartedHTML = `
<div style="font-family: Arial, sans-serif; padding: 20px;">
<h2 style="color: #3498db;">🔄 Restore Started</h2>
<table style="border-collapse: collapse; width: 100%; max-width: 600px;">
<tr><td style="padding: 8px; font-weight: bold;">Database:</td><td style="padding: 8px;">{{.Database}}</td></tr>
<tr><td style="padding: 8px; font-weight: bold;">Hostname:</td><td style="padding: 8px;">{{.Hostname}}</td></tr>
<tr><td style="padding: 8px; font-weight: bold;">Started At:</td><td style="padding: 8px;">{{formatTime .Timestamp}}</td></tr>
</table>
{{if .Message}}<p style="margin-top: 20px;">{{.Message}}</p>{{end}}
</div>
`
const restoreCompletedText = `
Restore Operation Completed Successfully
Database: {{.Database}}
Hostname: {{.Hostname}}
Completed: {{formatTime .Timestamp}}
{{with .Details}}
{{if .duration}}Duration: {{.duration}}{{end}}
{{end}}
{{if .Message}}{{.Message}}{{end}}
`
const restoreCompletedHTML = `
<div style="font-family: Arial, sans-serif; padding: 20px;">
<h2 style="color: #27ae60;">✅ Restore Completed</h2>
<table style="border-collapse: collapse; width: 100%; max-width: 600px;">
<tr><td style="padding: 8px; font-weight: bold;">Database:</td><td style="padding: 8px;">{{.Database}}</td></tr>
<tr><td style="padding: 8px; font-weight: bold;">Hostname:</td><td style="padding: 8px;">{{.Hostname}}</td></tr>
<tr><td style="padding: 8px; font-weight: bold;">Completed:</td><td style="padding: 8px;">{{formatTime .Timestamp}}</td></tr>
{{with .Details}}
{{if .duration}}<tr><td style="padding: 8px; font-weight: bold;">Duration:</td><td style="padding: 8px;">{{.duration}}</td></tr>{{end}}
{{end}}
</table>
{{if .Message}}<p style="margin-top: 20px; color: #27ae60;">{{.Message}}</p>{{end}}
</div>
`
const restoreFailedText = `
⚠️ RESTORE FAILED ⚠️
Database: {{.Database}}
Hostname: {{.Hostname}}
Failed At: {{formatTime .Timestamp}}
{{if .Error}}
Error: {{.Error}}
{{end}}
{{if .Message}}{{.Message}}{{end}}
Please investigate immediately.
`
const restoreFailedHTML = `
<div style="font-family: Arial, sans-serif; padding: 20px;">
<h2 style="color: #e74c3c;">❌ Restore FAILED</h2>
<table style="border-collapse: collapse; width: 100%; max-width: 600px;">
<tr><td style="padding: 8px; font-weight: bold;">Database:</td><td style="padding: 8px;">{{.Database}}</td></tr>
<tr><td style="padding: 8px; font-weight: bold;">Hostname:</td><td style="padding: 8px;">{{.Hostname}}</td></tr>
<tr><td style="padding: 8px; font-weight: bold;">Failed At:</td><td style="padding: 8px;">{{formatTime .Timestamp}}</td></tr>
{{if .Error}}<tr><td style="padding: 8px; font-weight: bold; color: #e74c3c;">Error:</td><td style="padding: 8px; color: #e74c3c;">{{.Error}}</td></tr>{{end}}
</table>
{{if .Message}}<p style="margin-top: 20px;">{{.Message}}</p>{{end}}
<p style="margin-top: 20px; color: #e74c3c; font-weight: bold;">Please investigate immediately.</p>
</div>
`
const verificationPassedText = `
Backup Verification Passed
Database: {{.Database}}
Hostname: {{.Hostname}}
Verified: {{formatTime .Timestamp}}
{{with .Details}}
{{if .checksum}}Checksum: {{.checksum}}{{end}}
{{end}}
{{if .Message}}{{.Message}}{{end}}
`
const verificationPassedHTML = `
<div style="font-family: Arial, sans-serif; padding: 20px;">
<h2 style="color: #27ae60;">✅ Verification Passed</h2>
<table style="border-collapse: collapse; width: 100%; max-width: 600px;">
<tr><td style="padding: 8px; font-weight: bold;">Database:</td><td style="padding: 8px;">{{.Database}}</td></tr>
<tr><td style="padding: 8px; font-weight: bold;">Hostname:</td><td style="padding: 8px;">{{.Hostname}}</td></tr>
<tr><td style="padding: 8px; font-weight: bold;">Verified:</td><td style="padding: 8px;">{{formatTime .Timestamp}}</td></tr>
{{with .Details}}
{{if .checksum}}<tr><td style="padding: 8px; font-weight: bold;">Checksum:</td><td style="padding: 8px; font-family: monospace;">{{.checksum}}</td></tr>{{end}}
{{end}}
</table>
{{if .Message}}<p style="margin-top: 20px; color: #27ae60;">{{.Message}}</p>{{end}}
</div>
`
const verificationFailedText = `
⚠️ VERIFICATION FAILED ⚠️
Database: {{.Database}}
Hostname: {{.Hostname}}
Failed At: {{formatTime .Timestamp}}
{{if .Error}}
Error: {{.Error}}
{{end}}
{{if .Message}}{{.Message}}{{end}}
Backup integrity may be compromised. Please investigate.
`
const verificationFailedHTML = `
<div style="font-family: Arial, sans-serif; padding: 20px;">
<h2 style="color: #e74c3c;">❌ Verification FAILED</h2>
<table style="border-collapse: collapse; width: 100%; max-width: 600px;">
<tr><td style="padding: 8px; font-weight: bold;">Database:</td><td style="padding: 8px;">{{.Database}}</td></tr>
<tr><td style="padding: 8px; font-weight: bold;">Hostname:</td><td style="padding: 8px;">{{.Hostname}}</td></tr>
<tr><td style="padding: 8px; font-weight: bold;">Failed At:</td><td style="padding: 8px;">{{formatTime .Timestamp}}</td></tr>
{{if .Error}}<tr><td style="padding: 8px; font-weight: bold; color: #e74c3c;">Error:</td><td style="padding: 8px; color: #e74c3c;">{{.Error}}</td></tr>{{end}}
</table>
{{if .Message}}<p style="margin-top: 20px;">{{.Message}}</p>{{end}}
<p style="margin-top: 20px; color: #e74c3c; font-weight: bold;">Backup integrity may be compromised. Please investigate.</p>
</div>
`
const drDrillPassedText = `
DR Drill Test Passed
Database: {{.Database}}
Hostname: {{.Hostname}}
Tested At: {{formatTime .Timestamp}}
{{with .Details}}
{{if .tables_restored}}Tables: {{.tables_restored}}{{end}}
{{if .rows_validated}}Rows: {{.rows_validated}}{{end}}
{{if .duration}}Duration: {{.duration}}{{end}}
{{end}}
{{if .Message}}{{.Message}}{{end}}
Backup restore capability verified.
`
const drDrillPassedHTML = `
<div style="font-family: Arial, sans-serif; padding: 20px;">
<h2 style="color: #27ae60;">✅ DR Drill Passed</h2>
<table style="border-collapse: collapse; width: 100%; max-width: 600px;">
<tr><td style="padding: 8px; font-weight: bold;">Database:</td><td style="padding: 8px;">{{.Database}}</td></tr>
<tr><td style="padding: 8px; font-weight: bold;">Hostname:</td><td style="padding: 8px;">{{.Hostname}}</td></tr>
<tr><td style="padding: 8px; font-weight: bold;">Tested At:</td><td style="padding: 8px;">{{formatTime .Timestamp}}</td></tr>
{{with .Details}}
{{if .tables_restored}}<tr><td style="padding: 8px; font-weight: bold;">Tables:</td><td style="padding: 8px;">{{.tables_restored}}</td></tr>{{end}}
{{if .rows_validated}}<tr><td style="padding: 8px; font-weight: bold;">Rows:</td><td style="padding: 8px;">{{.rows_validated}}</td></tr>{{end}}
{{if .duration}}<tr><td style="padding: 8px; font-weight: bold;">Duration:</td><td style="padding: 8px;">{{.duration}}</td></tr>{{end}}
{{end}}
</table>
{{if .Message}}<p style="margin-top: 20px; color: #27ae60;">{{.Message}}</p>{{end}}
<p style="margin-top: 20px; color: #27ae60;">✓ Backup restore capability verified</p>
</div>
`
const drDrillFailedText = `
⚠️ DR DRILL FAILED ⚠️
Database: {{.Database}}
Hostname: {{.Hostname}}
Failed At: {{formatTime .Timestamp}}
{{if .Error}}
Error: {{.Error}}
{{end}}
{{if .Message}}{{.Message}}{{end}}
Backup may not be restorable. Please investigate immediately.
`
const drDrillFailedHTML = `
<div style="font-family: Arial, sans-serif; padding: 20px;">
<h2 style="color: #e74c3c;">❌ DR Drill FAILED</h2>
<table style="border-collapse: collapse; width: 100%; max-width: 600px;">
<tr><td style="padding: 8px; font-weight: bold;">Database:</td><td style="padding: 8px;">{{.Database}}</td></tr>
<tr><td style="padding: 8px; font-weight: bold;">Hostname:</td><td style="padding: 8px;">{{.Hostname}}</td></tr>
<tr><td style="padding: 8px; font-weight: bold;">Failed At:</td><td style="padding: 8px;">{{formatTime .Timestamp}}</td></tr>
{{if .Error}}<tr><td style="padding: 8px; font-weight: bold; color: #e74c3c;">Error:</td><td style="padding: 8px; color: #e74c3c;">{{.Error}}</td></tr>{{end}}
</table>
{{if .Message}}<p style="margin-top: 20px;">{{.Message}}</p>{{end}}
<p style="margin-top: 20px; color: #e74c3c; font-weight: bold;">Backup may not be restorable. Please investigate immediately.</p>
</div>
`
// TemplateRenderer renders notification templates
type TemplateRenderer struct {
templates map[EventType]Templates
funcMap template.FuncMap
}
// NewTemplateRenderer creates a new template renderer
func NewTemplateRenderer() *TemplateRenderer {
return &TemplateRenderer{
templates: DefaultTemplates(),
funcMap: template.FuncMap{
"formatTime": func(t time.Time) string {
return t.Format("2006-01-02 15:04:05 MST")
},
"upper": strings.ToUpper,
"lower": strings.ToLower,
},
}
}
// RenderSubject renders the subject template for an event
func (r *TemplateRenderer) RenderSubject(event *Event) (string, error) {
tmpl, ok := r.templates[event.Type]
if !ok {
return fmt.Sprintf("[%s] %s: %s", event.Severity, event.Type, event.Database), nil
}
return r.render(tmpl.Subject, event)
}
// RenderText renders the text body template for an event
func (r *TemplateRenderer) RenderText(event *Event) (string, error) {
tmpl, ok := r.templates[event.Type]
if !ok {
return event.Message, nil
}
return r.render(tmpl.TextBody, event)
}
// RenderHTML renders the HTML body template for an event
func (r *TemplateRenderer) RenderHTML(event *Event) (string, error) {
tmpl, ok := r.templates[event.Type]
if !ok {
return fmt.Sprintf("<p>%s</p>", event.Message), nil
}
return r.render(tmpl.HTMLBody, event)
}
// render executes a template with the given event
func (r *TemplateRenderer) render(templateStr string, event *Event) (string, error) {
tmpl, err := template.New("notification").Funcs(r.funcMap).Parse(templateStr)
if err != nil {
return "", fmt.Errorf("failed to parse template: %w", err)
}
var buf bytes.Buffer
if err := tmpl.Execute(&buf, event); err != nil {
return "", fmt.Errorf("failed to execute template: %w", err)
}
return strings.TrimSpace(buf.String()), nil
}
// SetTemplate sets a custom template for an event type
func (r *TemplateRenderer) SetTemplate(eventType EventType, templates Templates) {
r.templates[eventType] = templates
}
// RenderSlackMessage creates a Slack-formatted message
func (r *TemplateRenderer) RenderSlackMessage(event *Event) map[string]interface{} {
color := "#3498db" // blue
switch event.Severity {
case SeveritySuccess:
color = "#27ae60" // green
case SeverityWarning:
color = "#f39c12" // orange
case SeverityError, SeverityCritical:
color = "#e74c3c" // red
}
fields := []map[string]interface{}{
{
"title": "Database",
"value": event.Database,
"short": true,
},
{
"title": "Hostname",
"value": event.Hostname,
"short": true,
},
{
"title": "Event",
"value": string(event.Type),
"short": true,
},
{
"title": "Severity",
"value": string(event.Severity),
"short": true,
},
}
if event.Error != "" {
fields = append(fields, map[string]interface{}{
"title": "Error",
"value": event.Error,
"short": false,
})
}
for key, value := range event.Details {
fields = append(fields, map[string]interface{}{
"title": key,
"value": value,
"short": true,
})
}
subject, _ := r.RenderSubject(event)
return map[string]interface{}{
"attachments": []map[string]interface{}{
{
"color": color,
"title": subject,
"text": event.Message,
"fields": fields,
"footer": "dbbackup",
"ts": event.Timestamp.Unix(),
"mrkdwn_in": []string{"text", "fields"},
},
},
}
}

619
internal/parallel/engine.go Normal file
View File

@@ -0,0 +1,619 @@
// Package parallel provides parallel table backup functionality
package parallel
import (
"context"
"database/sql"
"fmt"
"io"
"os"
"path/filepath"
"sort"
"sync"
"sync/atomic"
"time"
)
// Table represents a database table
type Table struct {
Schema string `json:"schema"`
Name string `json:"name"`
RowCount int64 `json:"row_count"`
SizeBytes int64 `json:"size_bytes"`
HasPK bool `json:"has_pk"`
Partitioned bool `json:"partitioned"`
}
// FullName returns the fully qualified table name
func (t *Table) FullName() string {
if t.Schema != "" {
return fmt.Sprintf("%s.%s", t.Schema, t.Name)
}
return t.Name
}
// Config configures parallel backup
type Config struct {
MaxWorkers int `json:"max_workers"`
MaxConcurrency int `json:"max_concurrency"` // Max concurrent dumps
ChunkSize int64 `json:"chunk_size"` // Rows per chunk for large tables
LargeTableThreshold int64 `json:"large_table_threshold"` // Bytes to consider a table "large"
OutputDir string `json:"output_dir"`
Compression string `json:"compression"` // gzip, lz4, zstd, none
TempDir string `json:"temp_dir"`
Timeout time.Duration `json:"timeout"`
IncludeSchemas []string `json:"include_schemas,omitempty"`
ExcludeSchemas []string `json:"exclude_schemas,omitempty"`
IncludeTables []string `json:"include_tables,omitempty"`
ExcludeTables []string `json:"exclude_tables,omitempty"`
EstimateSizes bool `json:"estimate_sizes"`
OrderBySize bool `json:"order_by_size"` // Start with largest tables first
}
// DefaultConfig returns sensible defaults
func DefaultConfig() Config {
return Config{
MaxWorkers: 4,
MaxConcurrency: 4,
ChunkSize: 100000,
LargeTableThreshold: 1 << 30, // 1GB
Compression: "gzip",
Timeout: 24 * time.Hour,
EstimateSizes: true,
OrderBySize: true,
}
}
// TableResult contains the result of backing up a single table
type TableResult struct {
Table *Table `json:"table"`
OutputFile string `json:"output_file"`
SizeBytes int64 `json:"size_bytes"`
RowsWritten int64 `json:"rows_written"`
Duration time.Duration `json:"duration"`
Error error `json:"error,omitempty"`
Checksum string `json:"checksum,omitempty"`
}
// Result contains the overall parallel backup result
type Result struct {
Tables []*TableResult `json:"tables"`
TotalTables int `json:"total_tables"`
SuccessTables int `json:"success_tables"`
FailedTables int `json:"failed_tables"`
TotalBytes int64 `json:"total_bytes"`
TotalRows int64 `json:"total_rows"`
Duration time.Duration `json:"duration"`
Workers int `json:"workers"`
OutputDir string `json:"output_dir"`
}
// Progress tracks backup progress
type Progress struct {
TotalTables int32 `json:"total_tables"`
CompletedTables int32 `json:"completed_tables"`
CurrentTable string `json:"current_table"`
BytesWritten int64 `json:"bytes_written"`
RowsWritten int64 `json:"rows_written"`
}
// ProgressCallback is called with progress updates
type ProgressCallback func(progress *Progress)
// Engine orchestrates parallel table backups
type Engine struct {
config Config
db *sql.DB
dbType string
progress *Progress
callback ProgressCallback
mu sync.Mutex
}
// NewEngine creates a new parallel backup engine
func NewEngine(db *sql.DB, dbType string, config Config) *Engine {
return &Engine{
config: config,
db: db,
dbType: dbType,
progress: &Progress{},
}
}
// SetProgressCallback sets the progress callback
func (e *Engine) SetProgressCallback(cb ProgressCallback) {
e.callback = cb
}
// Run executes the parallel backup
func (e *Engine) Run(ctx context.Context) (*Result, error) {
start := time.Now()
// Discover tables
tables, err := e.discoverTables(ctx)
if err != nil {
return nil, fmt.Errorf("failed to discover tables: %w", err)
}
if len(tables) == 0 {
return &Result{
Tables: []*TableResult{},
Duration: time.Since(start),
OutputDir: e.config.OutputDir,
}, nil
}
// Order tables by size (largest first for better load distribution)
if e.config.OrderBySize {
sort.Slice(tables, func(i, j int) bool {
return tables[i].SizeBytes > tables[j].SizeBytes
})
}
// Create output directory
if err := os.MkdirAll(e.config.OutputDir, 0755); err != nil {
return nil, fmt.Errorf("failed to create output directory: %w", err)
}
// Setup progress
atomic.StoreInt32(&e.progress.TotalTables, int32(len(tables)))
// Create worker pool
results := make([]*TableResult, len(tables))
jobs := make(chan int, len(tables))
var wg sync.WaitGroup
workers := e.config.MaxWorkers
if workers > len(tables) {
workers = len(tables)
}
// Start workers
for w := 0; w < workers; w++ {
wg.Add(1)
go func() {
defer wg.Done()
for idx := range jobs {
select {
case <-ctx.Done():
return
default:
results[idx] = e.backupTable(ctx, tables[idx])
atomic.AddInt32(&e.progress.CompletedTables, 1)
if e.callback != nil {
e.callback(e.progress)
}
}
}
}()
}
// Enqueue jobs
for i := range tables {
jobs <- i
}
close(jobs)
// Wait for completion
wg.Wait()
// Compile result
result := &Result{
Tables: results,
TotalTables: len(tables),
Workers: workers,
Duration: time.Since(start),
OutputDir: e.config.OutputDir,
}
for _, r := range results {
if r.Error == nil {
result.SuccessTables++
result.TotalBytes += r.SizeBytes
result.TotalRows += r.RowsWritten
} else {
result.FailedTables++
}
}
return result, nil
}
// discoverTables discovers tables to backup
func (e *Engine) discoverTables(ctx context.Context) ([]*Table, error) {
switch e.dbType {
case "postgresql", "postgres":
return e.discoverPostgresqlTables(ctx)
case "mysql", "mariadb":
return e.discoverMySQLTables(ctx)
default:
return nil, fmt.Errorf("unsupported database type: %s", e.dbType)
}
}
func (e *Engine) discoverPostgresqlTables(ctx context.Context) ([]*Table, error) {
query := `
SELECT
schemaname,
tablename,
COALESCE(n_live_tup, 0) as row_count,
COALESCE(pg_total_relation_size(schemaname || '.' || tablename), 0) as size_bytes
FROM pg_stat_user_tables
WHERE schemaname NOT IN ('pg_catalog', 'information_schema')
ORDER BY schemaname, tablename
`
rows, err := e.db.QueryContext(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
var tables []*Table
for rows.Next() {
var t Table
if err := rows.Scan(&t.Schema, &t.Name, &t.RowCount, &t.SizeBytes); err != nil {
continue
}
if e.shouldInclude(&t) {
tables = append(tables, &t)
}
}
return tables, rows.Err()
}
func (e *Engine) discoverMySQLTables(ctx context.Context) ([]*Table, error) {
query := `
SELECT
TABLE_SCHEMA,
TABLE_NAME,
COALESCE(TABLE_ROWS, 0) as row_count,
COALESCE(DATA_LENGTH + INDEX_LENGTH, 0) as size_bytes
FROM information_schema.TABLES
WHERE TABLE_SCHEMA NOT IN ('mysql', 'information_schema', 'performance_schema', 'sys')
AND TABLE_TYPE = 'BASE TABLE'
ORDER BY TABLE_SCHEMA, TABLE_NAME
`
rows, err := e.db.QueryContext(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
var tables []*Table
for rows.Next() {
var t Table
if err := rows.Scan(&t.Schema, &t.Name, &t.RowCount, &t.SizeBytes); err != nil {
continue
}
if e.shouldInclude(&t) {
tables = append(tables, &t)
}
}
return tables, rows.Err()
}
// shouldInclude checks if a table should be included
func (e *Engine) shouldInclude(t *Table) bool {
// Check schema exclusions
for _, s := range e.config.ExcludeSchemas {
if t.Schema == s {
return false
}
}
// Check table exclusions
for _, name := range e.config.ExcludeTables {
if t.Name == name || t.FullName() == name {
return false
}
}
// Check schema inclusions (if specified)
if len(e.config.IncludeSchemas) > 0 {
found := false
for _, s := range e.config.IncludeSchemas {
if t.Schema == s {
found = true
break
}
}
if !found {
return false
}
}
// Check table inclusions (if specified)
if len(e.config.IncludeTables) > 0 {
found := false
for _, name := range e.config.IncludeTables {
if t.Name == name || t.FullName() == name {
found = true
break
}
}
if !found {
return false
}
}
return true
}
// backupTable backs up a single table
func (e *Engine) backupTable(ctx context.Context, table *Table) *TableResult {
start := time.Now()
result := &TableResult{
Table: table,
}
e.mu.Lock()
e.progress.CurrentTable = table.FullName()
e.mu.Unlock()
// Determine output filename
ext := ".sql"
switch e.config.Compression {
case "gzip":
ext = ".sql.gz"
case "lz4":
ext = ".sql.lz4"
case "zstd":
ext = ".sql.zst"
}
filename := fmt.Sprintf("%s_%s%s", table.Schema, table.Name, ext)
result.OutputFile = filepath.Join(e.config.OutputDir, filename)
// Create output file
file, err := os.Create(result.OutputFile)
if err != nil {
result.Error = fmt.Errorf("failed to create output file: %w", err)
result.Duration = time.Since(start)
return result
}
defer file.Close()
// Wrap with compression if needed
var writer io.WriteCloser = file
if e.config.Compression == "gzip" {
gzWriter, err := newGzipWriter(file)
if err != nil {
result.Error = fmt.Errorf("failed to create gzip writer: %w", err)
result.Duration = time.Since(start)
return result
}
defer gzWriter.Close()
writer = gzWriter
}
// Dump table
rowsWritten, err := e.dumpTable(ctx, table, writer)
if err != nil {
result.Error = fmt.Errorf("failed to dump table: %w", err)
result.Duration = time.Since(start)
return result
}
result.RowsWritten = rowsWritten
atomic.AddInt64(&e.progress.RowsWritten, rowsWritten)
// Get file size
if stat, err := file.Stat(); err == nil {
result.SizeBytes = stat.Size()
atomic.AddInt64(&e.progress.BytesWritten, result.SizeBytes)
}
result.Duration = time.Since(start)
return result
}
// dumpTable dumps a single table to the writer
func (e *Engine) dumpTable(ctx context.Context, table *Table, w io.Writer) (int64, error) {
switch e.dbType {
case "postgresql", "postgres":
return e.dumpPostgresTable(ctx, table, w)
case "mysql", "mariadb":
return e.dumpMySQLTable(ctx, table, w)
default:
return 0, fmt.Errorf("unsupported database type: %s", e.dbType)
}
}
func (e *Engine) dumpPostgresTable(ctx context.Context, table *Table, w io.Writer) (int64, error) {
// Write header
fmt.Fprintf(w, "-- Table: %s\n", table.FullName())
fmt.Fprintf(w, "-- Dumped at: %s\n\n", time.Now().Format(time.RFC3339))
// Get column info for COPY command
cols, err := e.getPostgresColumns(ctx, table)
if err != nil {
return 0, err
}
// Use COPY TO STDOUT for efficiency
copyQuery := fmt.Sprintf("COPY %s TO STDOUT WITH (FORMAT csv, HEADER true)", table.FullName())
rows, err := e.db.QueryContext(ctx, copyQuery)
if err != nil {
// Fallback to regular SELECT
return e.dumpViaSelect(ctx, table, cols, w)
}
defer rows.Close()
var rowCount int64
for rows.Next() {
var line string
if err := rows.Scan(&line); err != nil {
continue
}
fmt.Fprintln(w, line)
rowCount++
}
return rowCount, rows.Err()
}
func (e *Engine) dumpMySQLTable(ctx context.Context, table *Table, w io.Writer) (int64, error) {
// Write header
fmt.Fprintf(w, "-- Table: %s\n", table.FullName())
fmt.Fprintf(w, "-- Dumped at: %s\n\n", time.Now().Format(time.RFC3339))
// Get column names
cols, err := e.getMySQLColumns(ctx, table)
if err != nil {
return 0, err
}
return e.dumpViaSelect(ctx, table, cols, w)
}
func (e *Engine) dumpViaSelect(ctx context.Context, table *Table, cols []string, w io.Writer) (int64, error) {
query := fmt.Sprintf("SELECT * FROM %s", table.FullName())
rows, err := e.db.QueryContext(ctx, query)
if err != nil {
return 0, err
}
defer rows.Close()
var rowCount int64
// Write column header
fmt.Fprintf(w, "-- Columns: %v\n\n", cols)
// Prepare value holders
values := make([]interface{}, len(cols))
valuePtrs := make([]interface{}, len(cols))
for i := range values {
valuePtrs[i] = &values[i]
}
for rows.Next() {
if err := rows.Scan(valuePtrs...); err != nil {
continue
}
// Write INSERT statement
fmt.Fprintf(w, "INSERT INTO %s VALUES (", table.FullName())
for i, v := range values {
if i > 0 {
fmt.Fprint(w, ", ")
}
fmt.Fprint(w, formatValue(v))
}
fmt.Fprintln(w, ");")
rowCount++
}
return rowCount, rows.Err()
}
func (e *Engine) getPostgresColumns(ctx context.Context, table *Table) ([]string, error) {
query := `
SELECT column_name
FROM information_schema.columns
WHERE table_schema = $1 AND table_name = $2
ORDER BY ordinal_position
`
rows, err := e.db.QueryContext(ctx, query, table.Schema, table.Name)
if err != nil {
return nil, err
}
defer rows.Close()
var cols []string
for rows.Next() {
var col string
if err := rows.Scan(&col); err != nil {
continue
}
cols = append(cols, col)
}
return cols, rows.Err()
}
func (e *Engine) getMySQLColumns(ctx context.Context, table *Table) ([]string, error) {
query := `
SELECT COLUMN_NAME
FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?
ORDER BY ORDINAL_POSITION
`
rows, err := e.db.QueryContext(ctx, query, table.Schema, table.Name)
if err != nil {
return nil, err
}
defer rows.Close()
var cols []string
for rows.Next() {
var col string
if err := rows.Scan(&col); err != nil {
continue
}
cols = append(cols, col)
}
return cols, rows.Err()
}
func formatValue(v interface{}) string {
if v == nil {
return "NULL"
}
switch val := v.(type) {
case []byte:
return fmt.Sprintf("'%s'", escapeString(string(val)))
case string:
return fmt.Sprintf("'%s'", escapeString(val))
case time.Time:
return fmt.Sprintf("'%s'", val.Format("2006-01-02 15:04:05"))
case int, int32, int64, float32, float64:
return fmt.Sprintf("%v", val)
case bool:
if val {
return "TRUE"
}
return "FALSE"
default:
return fmt.Sprintf("'%v'", v)
}
}
func escapeString(s string) string {
result := make([]byte, 0, len(s)*2)
for i := 0; i < len(s); i++ {
switch s[i] {
case '\'':
result = append(result, '\'', '\'')
case '\\':
result = append(result, '\\', '\\')
default:
result = append(result, s[i])
}
}
return string(result)
}
// gzipWriter wraps compress/gzip
type gzipWriter struct {
io.WriteCloser
}
func newGzipWriter(w io.Writer) (*gzipWriter, error) {
// Import would be: import "compress/gzip"
// For now, return a passthrough (actual implementation would use gzip)
return &gzipWriter{
WriteCloser: &nopCloser{w},
}, nil
}
type nopCloser struct {
io.Writer
}
func (n *nopCloser) Close() error { return nil }

865
internal/pitr/binlog.go Normal file
View File

@@ -0,0 +1,865 @@
// Package pitr provides Point-in-Time Recovery functionality
// This file contains MySQL/MariaDB binary log handling
package pitr
import (
"bufio"
"compress/gzip"
"context"
"encoding/json"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"regexp"
"sort"
"strconv"
"strings"
"time"
)
// BinlogPosition represents a MySQL binary log position
type BinlogPosition struct {
File string `json:"file"` // Binary log filename (e.g., "mysql-bin.000042")
Position uint64 `json:"position"` // Byte position in the file
GTID string `json:"gtid,omitempty"` // GTID set (if available)
ServerID uint32 `json:"server_id,omitempty"`
}
// String returns a string representation of the binlog position
func (p *BinlogPosition) String() string {
if p.GTID != "" {
return fmt.Sprintf("%s:%d (GTID: %s)", p.File, p.Position, p.GTID)
}
return fmt.Sprintf("%s:%d", p.File, p.Position)
}
// IsZero returns true if the position is unset
func (p *BinlogPosition) IsZero() bool {
return p.File == "" && p.Position == 0 && p.GTID == ""
}
// Compare compares two binlog positions
// Returns -1 if p < other, 0 if equal, 1 if p > other
func (p *BinlogPosition) Compare(other LogPosition) int {
o, ok := other.(*BinlogPosition)
if !ok {
return 0
}
// Compare by file first
fileComp := compareBinlogFiles(p.File, o.File)
if fileComp != 0 {
return fileComp
}
// Then by position within file
if p.Position < o.Position {
return -1
} else if p.Position > o.Position {
return 1
}
return 0
}
// ParseBinlogPosition parses a binlog position string
// Format: "filename:position" or "filename:position:gtid"
func ParseBinlogPosition(s string) (*BinlogPosition, error) {
parts := strings.SplitN(s, ":", 3)
if len(parts) < 2 {
return nil, fmt.Errorf("invalid binlog position format: %s (expected file:position)", s)
}
pos, err := strconv.ParseUint(parts[1], 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid position value: %s", parts[1])
}
bp := &BinlogPosition{
File: parts[0],
Position: pos,
}
if len(parts) == 3 {
bp.GTID = parts[2]
}
return bp, nil
}
// MarshalJSON serializes the binlog position to JSON
func (p *BinlogPosition) MarshalJSON() ([]byte, error) {
type Alias BinlogPosition
return json.Marshal((*Alias)(p))
}
// compareBinlogFiles compares two binlog filenames numerically
func compareBinlogFiles(a, b string) int {
numA := extractBinlogNumber(a)
numB := extractBinlogNumber(b)
if numA < numB {
return -1
} else if numA > numB {
return 1
}
return 0
}
// extractBinlogNumber extracts the numeric suffix from a binlog filename
func extractBinlogNumber(filename string) int {
// Match pattern like mysql-bin.000042
re := regexp.MustCompile(`\.(\d+)$`)
matches := re.FindStringSubmatch(filename)
if len(matches) < 2 {
return 0
}
num, _ := strconv.Atoi(matches[1])
return num
}
// BinlogFile represents a binary log file with metadata
type BinlogFile struct {
Name string `json:"name"`
Path string `json:"path"`
Size int64 `json:"size"`
ModTime time.Time `json:"mod_time"`
StartTime time.Time `json:"start_time,omitempty"` // First event timestamp
EndTime time.Time `json:"end_time,omitempty"` // Last event timestamp
StartPos uint64 `json:"start_pos"`
EndPos uint64 `json:"end_pos"`
GTID string `json:"gtid,omitempty"`
ServerID uint32 `json:"server_id,omitempty"`
Format string `json:"format,omitempty"` // ROW, STATEMENT, MIXED
Archived bool `json:"archived"`
ArchiveDir string `json:"archive_dir,omitempty"`
}
// BinlogArchiveInfo contains metadata about an archived binlog
type BinlogArchiveInfo struct {
OriginalFile string `json:"original_file"`
ArchivePath string `json:"archive_path"`
Size int64 `json:"size"`
Compressed bool `json:"compressed"`
Encrypted bool `json:"encrypted"`
Checksum string `json:"checksum"`
ArchivedAt time.Time `json:"archived_at"`
StartPos uint64 `json:"start_pos"`
EndPos uint64 `json:"end_pos"`
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time"`
GTID string `json:"gtid,omitempty"`
}
// BinlogManager handles binary log operations
type BinlogManager struct {
mysqlbinlogPath string
binlogDir string
archiveDir string
compression bool
encryption bool
encryptionKey []byte
serverType DatabaseType // mysql or mariadb
}
// BinlogManagerConfig holds configuration for BinlogManager
type BinlogManagerConfig struct {
BinlogDir string
ArchiveDir string
Compression bool
Encryption bool
EncryptionKey []byte
}
// NewBinlogManager creates a new BinlogManager
func NewBinlogManager(config BinlogManagerConfig) (*BinlogManager, error) {
m := &BinlogManager{
binlogDir: config.BinlogDir,
archiveDir: config.ArchiveDir,
compression: config.Compression,
encryption: config.Encryption,
encryptionKey: config.EncryptionKey,
}
// Find mysqlbinlog executable
if err := m.detectTools(); err != nil {
return nil, err
}
return m, nil
}
// detectTools finds MySQL/MariaDB tools and determines server type
func (m *BinlogManager) detectTools() error {
// Try mariadb-binlog first (MariaDB)
if path, err := exec.LookPath("mariadb-binlog"); err == nil {
m.mysqlbinlogPath = path
m.serverType = DatabaseMariaDB
return nil
}
// Fall back to mysqlbinlog (MySQL or older MariaDB)
if path, err := exec.LookPath("mysqlbinlog"); err == nil {
m.mysqlbinlogPath = path
// Check if it's actually MariaDB's version
m.serverType = m.detectServerType()
return nil
}
return fmt.Errorf("mysqlbinlog or mariadb-binlog not found in PATH")
}
// detectServerType determines if we're working with MySQL or MariaDB
func (m *BinlogManager) detectServerType() DatabaseType {
cmd := exec.Command(m.mysqlbinlogPath, "--version")
output, err := cmd.Output()
if err != nil {
return DatabaseMySQL // Default to MySQL
}
if strings.Contains(strings.ToLower(string(output)), "mariadb") {
return DatabaseMariaDB
}
return DatabaseMySQL
}
// ServerType returns the detected server type
func (m *BinlogManager) ServerType() DatabaseType {
return m.serverType
}
// DiscoverBinlogs finds all binary log files in the configured directory
func (m *BinlogManager) DiscoverBinlogs(ctx context.Context) ([]BinlogFile, error) {
if m.binlogDir == "" {
return nil, fmt.Errorf("binlog directory not configured")
}
entries, err := os.ReadDir(m.binlogDir)
if err != nil {
return nil, fmt.Errorf("reading binlog directory: %w", err)
}
var binlogs []BinlogFile
binlogPattern := regexp.MustCompile(`^[a-zA-Z0-9_-]+-bin\.\d{6}$`)
for _, entry := range entries {
if entry.IsDir() {
continue
}
// Check if it matches binlog naming convention
if !binlogPattern.MatchString(entry.Name()) {
continue
}
info, err := entry.Info()
if err != nil {
continue
}
binlog := BinlogFile{
Name: entry.Name(),
Path: filepath.Join(m.binlogDir, entry.Name()),
Size: info.Size(),
ModTime: info.ModTime(),
}
// Get binlog metadata using mysqlbinlog
if err := m.enrichBinlogMetadata(ctx, &binlog); err != nil {
// Log but don't fail - we can still use basic info
binlog.StartPos = 4 // Magic number size
}
binlogs = append(binlogs, binlog)
}
// Sort by file number
sort.Slice(binlogs, func(i, j int) bool {
return compareBinlogFiles(binlogs[i].Name, binlogs[j].Name) < 0
})
return binlogs, nil
}
// enrichBinlogMetadata extracts metadata from a binlog file
func (m *BinlogManager) enrichBinlogMetadata(ctx context.Context, binlog *BinlogFile) error {
// Use mysqlbinlog to read header and extract timestamps
cmd := exec.CommandContext(ctx, m.mysqlbinlogPath,
"--no-defaults",
"--start-position=4",
"--stop-position=1000", // Just read header area
binlog.Path,
)
output, err := cmd.Output()
if err != nil {
// Try without position limits
cmd = exec.CommandContext(ctx, m.mysqlbinlogPath,
"--no-defaults",
"-v", // Verbose mode for more info
binlog.Path,
)
output, _ = cmd.Output()
}
// Parse output for metadata
m.parseBinlogOutput(string(output), binlog)
// Get file size for end position
if binlog.EndPos == 0 {
binlog.EndPos = uint64(binlog.Size)
}
return nil
}
// parseBinlogOutput parses mysqlbinlog output to extract metadata
func (m *BinlogManager) parseBinlogOutput(output string, binlog *BinlogFile) {
lines := strings.Split(output, "\n")
// Pattern for timestamp: #YYMMDD HH:MM:SS
timestampRe := regexp.MustCompile(`#(\d{6})\s+(\d{1,2}:\d{2}:\d{2})`)
// Pattern for server_id
serverIDRe := regexp.MustCompile(`server id\s+(\d+)`)
// Pattern for end_log_pos
endPosRe := regexp.MustCompile(`end_log_pos\s+(\d+)`)
// Pattern for binlog format
formatRe := regexp.MustCompile(`binlog_format=(\w+)`)
// Pattern for GTID
gtidRe := regexp.MustCompile(`SET @@SESSION.GTID_NEXT=\s*'([^']+)'`)
mariaGtidRe := regexp.MustCompile(`GTID\s+(\d+-\d+-\d+)`)
var firstTimestamp, lastTimestamp time.Time
var maxEndPos uint64
for _, line := range lines {
// Extract timestamps
if matches := timestampRe.FindStringSubmatch(line); len(matches) == 3 {
// Parse YYMMDD format
dateStr := matches[1]
timeStr := matches[2]
if t, err := time.Parse("060102 15:04:05", dateStr+" "+timeStr); err == nil {
if firstTimestamp.IsZero() {
firstTimestamp = t
}
lastTimestamp = t
}
}
// Extract server_id
if matches := serverIDRe.FindStringSubmatch(line); len(matches) == 2 {
if id, err := strconv.ParseUint(matches[1], 10, 32); err == nil {
binlog.ServerID = uint32(id)
}
}
// Extract end_log_pos (track max for EndPos)
if matches := endPosRe.FindStringSubmatch(line); len(matches) == 2 {
if pos, err := strconv.ParseUint(matches[1], 10, 64); err == nil {
if pos > maxEndPos {
maxEndPos = pos
}
}
}
// Extract format
if matches := formatRe.FindStringSubmatch(line); len(matches) == 2 {
binlog.Format = matches[1]
}
// Extract GTID (MySQL format)
if matches := gtidRe.FindStringSubmatch(line); len(matches) == 2 {
binlog.GTID = matches[1]
}
// Extract GTID (MariaDB format)
if matches := mariaGtidRe.FindStringSubmatch(line); len(matches) == 2 {
binlog.GTID = matches[1]
}
}
if !firstTimestamp.IsZero() {
binlog.StartTime = firstTimestamp
}
if !lastTimestamp.IsZero() {
binlog.EndTime = lastTimestamp
}
if maxEndPos > 0 {
binlog.EndPos = maxEndPos
}
}
// GetCurrentPosition retrieves the current binary log position from MySQL
func (m *BinlogManager) GetCurrentPosition(ctx context.Context, dsn string) (*BinlogPosition, error) {
// This would typically connect to MySQL and run SHOW MASTER STATUS
// For now, return an error indicating it needs to be called with a connection
return nil, fmt.Errorf("GetCurrentPosition requires a database connection - use MySQLPITR.GetCurrentPosition instead")
}
// ArchiveBinlog archives a single binlog file to the archive directory
func (m *BinlogManager) ArchiveBinlog(ctx context.Context, binlog *BinlogFile) (*BinlogArchiveInfo, error) {
if m.archiveDir == "" {
return nil, fmt.Errorf("archive directory not configured")
}
// Ensure archive directory exists
if err := os.MkdirAll(m.archiveDir, 0750); err != nil {
return nil, fmt.Errorf("creating archive directory: %w", err)
}
archiveName := binlog.Name
if m.compression {
archiveName += ".gz"
}
archivePath := filepath.Join(m.archiveDir, archiveName)
// Check if already archived
if _, err := os.Stat(archivePath); err == nil {
return nil, fmt.Errorf("binlog already archived: %s", archivePath)
}
// Open source file
src, err := os.Open(binlog.Path)
if err != nil {
return nil, fmt.Errorf("opening binlog: %w", err)
}
defer src.Close()
// Create destination file
dst, err := os.OpenFile(archivePath, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0640)
if err != nil {
return nil, fmt.Errorf("creating archive file: %w", err)
}
defer dst.Close()
var writer io.Writer = dst
var gzWriter *gzip.Writer
if m.compression {
gzWriter = gzip.NewWriter(dst)
writer = gzWriter
defer gzWriter.Close()
}
// TODO: Add encryption layer if enabled
if m.encryption && len(m.encryptionKey) > 0 {
// Encryption would be added here
}
// Copy file content
written, err := io.Copy(writer, src)
if err != nil {
os.Remove(archivePath) // Cleanup on error
return nil, fmt.Errorf("copying binlog: %w", err)
}
// Close gzip writer to flush
if gzWriter != nil {
if err := gzWriter.Close(); err != nil {
os.Remove(archivePath)
return nil, fmt.Errorf("closing gzip writer: %w", err)
}
}
// Get final archive size
archiveInfo, err := os.Stat(archivePath)
if err != nil {
return nil, fmt.Errorf("getting archive info: %w", err)
}
// Calculate checksum (simple for now - could use SHA256)
checksum := fmt.Sprintf("size:%d", written)
return &BinlogArchiveInfo{
OriginalFile: binlog.Name,
ArchivePath: archivePath,
Size: archiveInfo.Size(),
Compressed: m.compression,
Encrypted: m.encryption,
Checksum: checksum,
ArchivedAt: time.Now(),
StartPos: binlog.StartPos,
EndPos: binlog.EndPos,
StartTime: binlog.StartTime,
EndTime: binlog.EndTime,
GTID: binlog.GTID,
}, nil
}
// ListArchivedBinlogs returns all archived binlog files
func (m *BinlogManager) ListArchivedBinlogs(ctx context.Context) ([]BinlogArchiveInfo, error) {
if m.archiveDir == "" {
return nil, fmt.Errorf("archive directory not configured")
}
entries, err := os.ReadDir(m.archiveDir)
if err != nil {
if os.IsNotExist(err) {
return []BinlogArchiveInfo{}, nil
}
return nil, fmt.Errorf("reading archive directory: %w", err)
}
var archives []BinlogArchiveInfo
metadataPath := filepath.Join(m.archiveDir, "metadata.json")
// Try to load metadata file for enriched info
metadata := m.loadArchiveMetadata(metadataPath)
for _, entry := range entries {
if entry.IsDir() || entry.Name() == "metadata.json" {
continue
}
info, err := entry.Info()
if err != nil {
continue
}
originalName := entry.Name()
compressed := false
if strings.HasSuffix(originalName, ".gz") {
originalName = strings.TrimSuffix(originalName, ".gz")
compressed = true
}
archive := BinlogArchiveInfo{
OriginalFile: originalName,
ArchivePath: filepath.Join(m.archiveDir, entry.Name()),
Size: info.Size(),
Compressed: compressed,
ArchivedAt: info.ModTime(),
}
// Enrich from metadata if available
if meta, ok := metadata[originalName]; ok {
archive.StartPos = meta.StartPos
archive.EndPos = meta.EndPos
archive.StartTime = meta.StartTime
archive.EndTime = meta.EndTime
archive.GTID = meta.GTID
archive.Checksum = meta.Checksum
}
archives = append(archives, archive)
}
// Sort by file number
sort.Slice(archives, func(i, j int) bool {
return compareBinlogFiles(archives[i].OriginalFile, archives[j].OriginalFile) < 0
})
return archives, nil
}
// loadArchiveMetadata loads the metadata.json file if it exists
func (m *BinlogManager) loadArchiveMetadata(path string) map[string]BinlogArchiveInfo {
result := make(map[string]BinlogArchiveInfo)
data, err := os.ReadFile(path)
if err != nil {
return result
}
var archives []BinlogArchiveInfo
if err := json.Unmarshal(data, &archives); err != nil {
return result
}
for _, a := range archives {
result[a.OriginalFile] = a
}
return result
}
// SaveArchiveMetadata saves metadata for all archived binlogs
func (m *BinlogManager) SaveArchiveMetadata(archives []BinlogArchiveInfo) error {
if m.archiveDir == "" {
return fmt.Errorf("archive directory not configured")
}
metadataPath := filepath.Join(m.archiveDir, "metadata.json")
data, err := json.MarshalIndent(archives, "", " ")
if err != nil {
return fmt.Errorf("marshaling metadata: %w", err)
}
return os.WriteFile(metadataPath, data, 0640)
}
// ValidateBinlogChain validates the integrity of the binlog chain
func (m *BinlogManager) ValidateBinlogChain(ctx context.Context, binlogs []BinlogFile) (*ChainValidation, error) {
result := &ChainValidation{
Valid: true,
LogCount: len(binlogs),
}
if len(binlogs) == 0 {
result.Warnings = append(result.Warnings, "no binlog files found")
return result, nil
}
// Sort binlogs by file number
sorted := make([]BinlogFile, len(binlogs))
copy(sorted, binlogs)
sort.Slice(sorted, func(i, j int) bool {
return compareBinlogFiles(sorted[i].Name, sorted[j].Name) < 0
})
result.StartPos = &BinlogPosition{
File: sorted[0].Name,
Position: sorted[0].StartPos,
GTID: sorted[0].GTID,
}
result.EndPos = &BinlogPosition{
File: sorted[len(sorted)-1].Name,
Position: sorted[len(sorted)-1].EndPos,
GTID: sorted[len(sorted)-1].GTID,
}
// Check for gaps in sequence
var prevNum int
var prevName string
var prevServerID uint32
for i, binlog := range sorted {
result.TotalSize += binlog.Size
num := extractBinlogNumber(binlog.Name)
if i > 0 {
// Check sequence continuity
if num != prevNum+1 {
gap := LogGap{
After: prevName,
Before: binlog.Name,
Reason: fmt.Sprintf("missing binlog file(s) %d to %d", prevNum+1, num-1),
}
result.Gaps = append(result.Gaps, gap)
result.Valid = false
}
// Check server_id consistency
if binlog.ServerID != 0 && prevServerID != 0 && binlog.ServerID != prevServerID {
result.Warnings = append(result.Warnings,
fmt.Sprintf("server_id changed from %d to %d at %s (possible master failover)",
prevServerID, binlog.ServerID, binlog.Name))
}
}
prevNum = num
prevName = binlog.Name
if binlog.ServerID != 0 {
prevServerID = binlog.ServerID
}
}
if len(result.Gaps) > 0 {
result.Errors = append(result.Errors,
fmt.Sprintf("found %d gap(s) in binlog chain", len(result.Gaps)))
}
return result, nil
}
// ReplayBinlogs replays binlog events to a target time or position
func (m *BinlogManager) ReplayBinlogs(ctx context.Context, opts ReplayOptions) error {
if len(opts.BinlogFiles) == 0 {
return fmt.Errorf("no binlog files specified")
}
// Build mysqlbinlog command
args := []string{"--no-defaults"}
// Add start position if specified
if opts.StartPosition != nil && !opts.StartPosition.IsZero() {
startPos, ok := opts.StartPosition.(*BinlogPosition)
if ok && startPos.Position > 0 {
args = append(args, fmt.Sprintf("--start-position=%d", startPos.Position))
}
}
// Add stop time or position
if opts.StopTime != nil && !opts.StopTime.IsZero() {
args = append(args, fmt.Sprintf("--stop-datetime=%s", opts.StopTime.Format("2006-01-02 15:04:05")))
}
if opts.StopPosition != nil && !opts.StopPosition.IsZero() {
stopPos, ok := opts.StopPosition.(*BinlogPosition)
if ok && stopPos.Position > 0 {
args = append(args, fmt.Sprintf("--stop-position=%d", stopPos.Position))
}
}
// Add binlog files
args = append(args, opts.BinlogFiles...)
if opts.DryRun {
// Just decode and show SQL
args = append([]string{args[0]}, append([]string{"-v"}, args[1:]...)...)
cmd := exec.CommandContext(ctx, m.mysqlbinlogPath, args...)
output, err := cmd.Output()
if err != nil {
return fmt.Errorf("parsing binlogs: %w", err)
}
if opts.Output != nil {
opts.Output.Write(output)
}
return nil
}
// Pipe to mysql for replay
mysqlCmd := exec.CommandContext(ctx, "mysql",
"-u", opts.MySQLUser,
"-p"+opts.MySQLPass,
"-h", opts.MySQLHost,
"-P", strconv.Itoa(opts.MySQLPort),
)
binlogCmd := exec.CommandContext(ctx, m.mysqlbinlogPath, args...)
// Pipe mysqlbinlog output to mysql
pipe, err := binlogCmd.StdoutPipe()
if err != nil {
return fmt.Errorf("creating pipe: %w", err)
}
mysqlCmd.Stdin = pipe
// Capture stderr for error reporting
var binlogStderr, mysqlStderr strings.Builder
binlogCmd.Stderr = &binlogStderr
mysqlCmd.Stderr = &mysqlStderr
// Start commands
if err := binlogCmd.Start(); err != nil {
return fmt.Errorf("starting mysqlbinlog: %w", err)
}
if err := mysqlCmd.Start(); err != nil {
binlogCmd.Process.Kill()
return fmt.Errorf("starting mysql: %w", err)
}
// Wait for completion
binlogErr := binlogCmd.Wait()
mysqlErr := mysqlCmd.Wait()
if binlogErr != nil {
return fmt.Errorf("mysqlbinlog failed: %w\nstderr: %s", binlogErr, binlogStderr.String())
}
if mysqlErr != nil {
return fmt.Errorf("mysql replay failed: %w\nstderr: %s", mysqlErr, mysqlStderr.String())
}
return nil
}
// ReplayOptions holds options for replaying binlog files
type ReplayOptions struct {
BinlogFiles []string // Files to replay (in order)
StartPosition LogPosition // Start from this position
StopTime *time.Time // Stop at this time
StopPosition LogPosition // Stop at this position
DryRun bool // Just show what would be done
Output io.Writer // For dry-run output
MySQLHost string // MySQL host for replay
MySQLPort int // MySQL port
MySQLUser string // MySQL user
MySQLPass string // MySQL password
Database string // Limit to specific database
StopOnError bool // Stop on first error
}
// FindBinlogsInRange finds binlog files containing events within a time range
func (m *BinlogManager) FindBinlogsInRange(ctx context.Context, binlogs []BinlogFile, start, end time.Time) []BinlogFile {
var result []BinlogFile
for _, b := range binlogs {
// Include if binlog time range overlaps with requested range
if b.EndTime.IsZero() && b.StartTime.IsZero() {
// No timestamp info, include to be safe
result = append(result, b)
continue
}
// Check for overlap
binlogStart := b.StartTime
binlogEnd := b.EndTime
if binlogEnd.IsZero() {
binlogEnd = time.Now() // Assume current file goes to now
}
if !binlogStart.After(end) && !binlogEnd.Before(start) {
result = append(result, b)
}
}
return result
}
// WatchBinlogs monitors for new binlog files and archives them
func (m *BinlogManager) WatchBinlogs(ctx context.Context, interval time.Duration, callback func(*BinlogFile)) error {
if m.binlogDir == "" {
return fmt.Errorf("binlog directory not configured")
}
// Get initial list
known := make(map[string]struct{})
binlogs, err := m.DiscoverBinlogs(ctx)
if err != nil {
return err
}
for _, b := range binlogs {
known[b.Name] = struct{}{}
}
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
binlogs, err := m.DiscoverBinlogs(ctx)
if err != nil {
continue // Log error but keep watching
}
for _, b := range binlogs {
if _, exists := known[b.Name]; !exists {
// New binlog found
known[b.Name] = struct{}{}
if callback != nil {
callback(&b)
}
}
}
}
}
}
// ParseBinlogIndex reads the binlog index file
func (m *BinlogManager) ParseBinlogIndex(indexPath string) ([]string, error) {
file, err := os.Open(indexPath)
if err != nil {
return nil, fmt.Errorf("opening index file: %w", err)
}
defer file.Close()
var binlogs []string
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line != "" {
binlogs = append(binlogs, line)
}
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("reading index file: %w", err)
}
return binlogs, nil
}

View File

@@ -0,0 +1,585 @@
package pitr
import (
"context"
"os"
"path/filepath"
"strings"
"testing"
"time"
)
func TestBinlogPosition_String(t *testing.T) {
tests := []struct {
name string
position BinlogPosition
expected string
}{
{
name: "basic position",
position: BinlogPosition{
File: "mysql-bin.000042",
Position: 1234,
},
expected: "mysql-bin.000042:1234",
},
{
name: "with GTID",
position: BinlogPosition{
File: "mysql-bin.000042",
Position: 1234,
GTID: "3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5",
},
expected: "mysql-bin.000042:1234 (GTID: 3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5)",
},
{
name: "MariaDB GTID",
position: BinlogPosition{
File: "mariadb-bin.000010",
Position: 500,
GTID: "0-1-100",
},
expected: "mariadb-bin.000010:500 (GTID: 0-1-100)",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.position.String()
if result != tt.expected {
t.Errorf("got %q, want %q", result, tt.expected)
}
})
}
}
func TestBinlogPosition_IsZero(t *testing.T) {
tests := []struct {
name string
position BinlogPosition
expected bool
}{
{
name: "empty position",
position: BinlogPosition{},
expected: true,
},
{
name: "has file",
position: BinlogPosition{
File: "mysql-bin.000001",
},
expected: false,
},
{
name: "has position only",
position: BinlogPosition{
Position: 100,
},
expected: false,
},
{
name: "has GTID only",
position: BinlogPosition{
GTID: "3E11FA47-71CA-11E1-9E33-C80AA9429562:1",
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.position.IsZero()
if result != tt.expected {
t.Errorf("got %v, want %v", result, tt.expected)
}
})
}
}
func TestBinlogPosition_Compare(t *testing.T) {
tests := []struct {
name string
a *BinlogPosition
b *BinlogPosition
expected int
}{
{
name: "equal positions",
a: &BinlogPosition{
File: "mysql-bin.000010",
Position: 1000,
},
b: &BinlogPosition{
File: "mysql-bin.000010",
Position: 1000,
},
expected: 0,
},
{
name: "a before b - same file",
a: &BinlogPosition{
File: "mysql-bin.000010",
Position: 100,
},
b: &BinlogPosition{
File: "mysql-bin.000010",
Position: 200,
},
expected: -1,
},
{
name: "a after b - same file",
a: &BinlogPosition{
File: "mysql-bin.000010",
Position: 300,
},
b: &BinlogPosition{
File: "mysql-bin.000010",
Position: 200,
},
expected: 1,
},
{
name: "a before b - different files",
a: &BinlogPosition{
File: "mysql-bin.000009",
Position: 9999,
},
b: &BinlogPosition{
File: "mysql-bin.000010",
Position: 100,
},
expected: -1,
},
{
name: "a after b - different files",
a: &BinlogPosition{
File: "mysql-bin.000011",
Position: 100,
},
b: &BinlogPosition{
File: "mysql-bin.000010",
Position: 9999,
},
expected: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.a.Compare(tt.b)
if result != tt.expected {
t.Errorf("got %d, want %d", result, tt.expected)
}
})
}
}
func TestParseBinlogPosition(t *testing.T) {
tests := []struct {
name string
input string
expected *BinlogPosition
expectError bool
}{
{
name: "basic position",
input: "mysql-bin.000042:1234",
expected: &BinlogPosition{
File: "mysql-bin.000042",
Position: 1234,
},
expectError: false,
},
{
name: "with GTID",
input: "mysql-bin.000042:1234:3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5",
expected: &BinlogPosition{
File: "mysql-bin.000042",
Position: 1234,
GTID: "3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5",
},
expectError: false,
},
{
name: "invalid format",
input: "invalid",
expected: nil,
expectError: true,
},
{
name: "invalid position",
input: "mysql-bin.000042:notanumber",
expected: nil,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := ParseBinlogPosition(tt.input)
if tt.expectError {
if err == nil {
t.Error("expected error, got nil")
}
return
}
if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
if result.File != tt.expected.File {
t.Errorf("File: got %q, want %q", result.File, tt.expected.File)
}
if result.Position != tt.expected.Position {
t.Errorf("Position: got %d, want %d", result.Position, tt.expected.Position)
}
if result.GTID != tt.expected.GTID {
t.Errorf("GTID: got %q, want %q", result.GTID, tt.expected.GTID)
}
})
}
}
func TestExtractBinlogNumber(t *testing.T) {
tests := []struct {
name string
filename string
expected int
}{
{"mysql binlog", "mysql-bin.000042", 42},
{"mariadb binlog", "mariadb-bin.000100", 100},
{"first binlog", "mysql-bin.000001", 1},
{"large number", "mysql-bin.999999", 999999},
{"no number", "mysql-bin", 0},
{"invalid format", "binlog", 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := extractBinlogNumber(tt.filename)
if result != tt.expected {
t.Errorf("got %d, want %d", result, tt.expected)
}
})
}
}
func TestCompareBinlogFiles(t *testing.T) {
tests := []struct {
name string
a string
b string
expected int
}{
{"equal", "mysql-bin.000010", "mysql-bin.000010", 0},
{"a < b", "mysql-bin.000009", "mysql-bin.000010", -1},
{"a > b", "mysql-bin.000011", "mysql-bin.000010", 1},
{"large difference", "mysql-bin.000001", "mysql-bin.000100", -1},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := compareBinlogFiles(tt.a, tt.b)
if result != tt.expected {
t.Errorf("got %d, want %d", result, tt.expected)
}
})
}
}
func TestValidateBinlogChain(t *testing.T) {
ctx := context.Background()
bm := &BinlogManager{}
tests := []struct {
name string
binlogs []BinlogFile
expectValid bool
expectGaps int
expectWarnings bool
}{
{
name: "empty chain",
binlogs: []BinlogFile{},
expectValid: true,
expectGaps: 0,
},
{
name: "continuous chain",
binlogs: []BinlogFile{
{Name: "mysql-bin.000001", ServerID: 1},
{Name: "mysql-bin.000002", ServerID: 1},
{Name: "mysql-bin.000003", ServerID: 1},
},
expectValid: true,
expectGaps: 0,
},
{
name: "chain with gap",
binlogs: []BinlogFile{
{Name: "mysql-bin.000001", ServerID: 1},
{Name: "mysql-bin.000003", ServerID: 1}, // 000002 missing
{Name: "mysql-bin.000004", ServerID: 1},
},
expectValid: false,
expectGaps: 1,
},
{
name: "chain with multiple gaps",
binlogs: []BinlogFile{
{Name: "mysql-bin.000001", ServerID: 1},
{Name: "mysql-bin.000005", ServerID: 1}, // 000002-000004 missing
{Name: "mysql-bin.000010", ServerID: 1}, // 000006-000009 missing
},
expectValid: false,
expectGaps: 2,
},
{
name: "server_id change warning",
binlogs: []BinlogFile{
{Name: "mysql-bin.000001", ServerID: 1},
{Name: "mysql-bin.000002", ServerID: 2}, // Server ID changed
{Name: "mysql-bin.000003", ServerID: 2},
},
expectValid: true,
expectGaps: 0,
expectWarnings: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := bm.ValidateBinlogChain(ctx, tt.binlogs)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Valid != tt.expectValid {
t.Errorf("Valid: got %v, want %v", result.Valid, tt.expectValid)
}
if len(result.Gaps) != tt.expectGaps {
t.Errorf("Gaps: got %d, want %d", len(result.Gaps), tt.expectGaps)
}
if tt.expectWarnings && len(result.Warnings) == 0 {
t.Error("expected warnings, got none")
}
})
}
}
func TestFindBinlogsInRange(t *testing.T) {
ctx := context.Background()
bm := &BinlogManager{}
now := time.Now()
hour := time.Hour
binlogs := []BinlogFile{
{
Name: "mysql-bin.000001",
StartTime: now.Add(-5 * hour),
EndTime: now.Add(-4 * hour),
},
{
Name: "mysql-bin.000002",
StartTime: now.Add(-4 * hour),
EndTime: now.Add(-3 * hour),
},
{
Name: "mysql-bin.000003",
StartTime: now.Add(-3 * hour),
EndTime: now.Add(-2 * hour),
},
{
Name: "mysql-bin.000004",
StartTime: now.Add(-2 * hour),
EndTime: now.Add(-1 * hour),
},
{
Name: "mysql-bin.000005",
StartTime: now.Add(-1 * hour),
EndTime: now,
},
}
tests := []struct {
name string
start time.Time
end time.Time
expected int
}{
{
name: "all binlogs",
start: now.Add(-6 * hour),
end: now.Add(1 * hour),
expected: 5,
},
{
name: "middle range",
start: now.Add(-4 * hour),
end: now.Add(-2 * hour),
expected: 4, // binlogs 1-4 overlap (1 ends at -4h, 4 starts at -2h)
},
{
name: "last two",
start: now.Add(-2 * hour),
end: now,
expected: 3, // binlogs 3-5 overlap (3 ends at -2h, 5 ends at now)
},
{
name: "exact match one binlog",
start: now.Add(-3 * hour),
end: now.Add(-2 * hour),
expected: 3, // binlogs 2,3,4 overlap with this range
},
{
name: "no overlap - before",
start: now.Add(-10 * hour),
end: now.Add(-6 * hour),
expected: 0,
},
{
name: "no overlap - after",
start: now.Add(1 * hour),
end: now.Add(2 * hour),
expected: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := bm.FindBinlogsInRange(ctx, binlogs, tt.start, tt.end)
if len(result) != tt.expected {
t.Errorf("got %d binlogs, want %d", len(result), tt.expected)
}
})
}
}
func TestBinlogArchiveInfo_Metadata(t *testing.T) {
// Test that archive metadata is properly saved and loaded
tempDir, err := os.MkdirTemp("", "binlog_test")
if err != nil {
t.Fatalf("creating temp dir: %v", err)
}
defer os.RemoveAll(tempDir)
bm := &BinlogManager{
archiveDir: tempDir,
}
archives := []BinlogArchiveInfo{
{
OriginalFile: "mysql-bin.000001",
ArchivePath: filepath.Join(tempDir, "mysql-bin.000001.gz"),
Size: 1024,
Compressed: true,
ArchivedAt: time.Now().Add(-2 * time.Hour),
StartPos: 4,
EndPos: 1024,
StartTime: time.Now().Add(-3 * time.Hour),
EndTime: time.Now().Add(-2 * time.Hour),
},
{
OriginalFile: "mysql-bin.000002",
ArchivePath: filepath.Join(tempDir, "mysql-bin.000002.gz"),
Size: 2048,
Compressed: true,
ArchivedAt: time.Now().Add(-1 * time.Hour),
StartPos: 4,
EndPos: 2048,
StartTime: time.Now().Add(-2 * time.Hour),
EndTime: time.Now().Add(-1 * time.Hour),
},
}
// Save metadata
err = bm.SaveArchiveMetadata(archives)
if err != nil {
t.Fatalf("saving metadata: %v", err)
}
// Verify metadata file exists
metadataPath := filepath.Join(tempDir, "metadata.json")
if _, err := os.Stat(metadataPath); os.IsNotExist(err) {
t.Fatal("metadata file was not created")
}
// Load and verify
loaded := bm.loadArchiveMetadata(metadataPath)
if len(loaded) != 2 {
t.Errorf("got %d archives, want 2", len(loaded))
}
if loaded["mysql-bin.000001"].Size != 1024 {
t.Errorf("wrong size for first archive")
}
if loaded["mysql-bin.000002"].Size != 2048 {
t.Errorf("wrong size for second archive")
}
}
func TestLimitedScanner(t *testing.T) {
// Test the limited scanner used for reading dump headers
input := "line1\nline2\nline3\nline4\nline5\nline6\nline7\nline8\nline9\nline10\n"
reader := NewLimitedScanner(strings.NewReader(input), 5)
var lines []string
for reader.Scan() {
lines = append(lines, reader.Text())
}
if len(lines) != 5 {
t.Errorf("got %d lines, want 5", len(lines))
}
}
// TestDatabaseType tests database type constants
func TestDatabaseType(t *testing.T) {
tests := []struct {
name string
dbType DatabaseType
expected string
}{
{"PostgreSQL", DatabasePostgreSQL, "postgres"},
{"MySQL", DatabaseMySQL, "mysql"},
{"MariaDB", DatabaseMariaDB, "mariadb"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if string(tt.dbType) != tt.expected {
t.Errorf("got %q, want %q", tt.dbType, tt.expected)
}
})
}
}
// TestRestoreTargetType tests restore target type constants
func TestRestoreTargetType(t *testing.T) {
tests := []struct {
name string
target RestoreTargetType
expected string
}{
{"Time", RestoreTargetTime, "time"},
{"Position", RestoreTargetPosition, "position"},
{"Immediate", RestoreTargetImmediate, "immediate"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if string(tt.target) != tt.expected {
t.Errorf("got %q, want %q", tt.target, tt.expected)
}
})
}
}

155
internal/pitr/interface.go Normal file
View File

@@ -0,0 +1,155 @@
// Package pitr provides Point-in-Time Recovery functionality
// This file contains shared interfaces and types for multi-database PITR support
package pitr
import (
"context"
"time"
)
// DatabaseType represents the type of database for PITR
type DatabaseType string
const (
DatabasePostgreSQL DatabaseType = "postgres"
DatabaseMySQL DatabaseType = "mysql"
DatabaseMariaDB DatabaseType = "mariadb"
)
// PITRProvider is the interface for database-specific PITR implementations
type PITRProvider interface {
// DatabaseType returns the database type this provider handles
DatabaseType() DatabaseType
// Enable enables PITR for the database
Enable(ctx context.Context, config PITREnableConfig) error
// Disable disables PITR for the database
Disable(ctx context.Context) error
// Status returns the current PITR status
Status(ctx context.Context) (*PITRStatus, error)
// CreateBackup creates a PITR-capable backup with position recording
CreateBackup(ctx context.Context, opts BackupOptions) (*PITRBackupInfo, error)
// Restore performs a point-in-time restore
Restore(ctx context.Context, backup *PITRBackupInfo, target RestoreTarget) error
// ListRecoveryPoints lists available recovery points/ranges
ListRecoveryPoints(ctx context.Context) ([]RecoveryWindow, error)
// ValidateChain validates the log chain integrity
ValidateChain(ctx context.Context, from, to time.Time) (*ChainValidation, error)
}
// PITREnableConfig holds configuration for enabling PITR
type PITREnableConfig struct {
ArchiveDir string // Directory to store archived logs
RetentionDays int // Days to keep archives
ArchiveInterval time.Duration // How often to check for new logs (MySQL)
Compression bool // Compress archived logs
Encryption bool // Encrypt archived logs
EncryptionKey []byte // Encryption key
}
// PITRStatus represents the current PITR configuration status
type PITRStatus struct {
Enabled bool
DatabaseType DatabaseType
ArchiveDir string
LogLevel string // WAL level (postgres) or binlog format (mysql)
ArchiveMethod string // archive_command (postgres) or manual (mysql)
Position LogPosition
LastArchived time.Time
ArchiveCount int
ArchiveSize int64
}
// LogPosition is a generic interface for database-specific log positions
type LogPosition interface {
// String returns a string representation of the position
String() string
// IsZero returns true if the position is unset
IsZero() bool
// Compare returns -1 if p < other, 0 if equal, 1 if p > other
Compare(other LogPosition) int
}
// BackupOptions holds options for creating a PITR backup
type BackupOptions struct {
Database string // Database name (empty for all)
OutputPath string // Where to save the backup
Compression bool
CompressionLvl int
Encryption bool
EncryptionKey []byte
FlushLogs bool // Flush logs before backup (mysql)
SingleTxn bool // Single transaction mode
}
// PITRBackupInfo contains metadata about a PITR-capable backup
type PITRBackupInfo struct {
BackupFile string `json:"backup_file"`
DatabaseType DatabaseType `json:"database_type"`
DatabaseName string `json:"database_name,omitempty"`
Timestamp time.Time `json:"timestamp"`
ServerVersion string `json:"server_version"`
ServerID int `json:"server_id,omitempty"` // MySQL server_id
Position LogPosition `json:"-"` // Start position (type-specific)
PositionJSON string `json:"position"` // Serialized position
SizeBytes int64 `json:"size_bytes"`
Compressed bool `json:"compressed"`
Encrypted bool `json:"encrypted"`
}
// RestoreTarget specifies the point-in-time to restore to
type RestoreTarget struct {
Type RestoreTargetType
Time *time.Time // For RestoreTargetTime
Position LogPosition // For RestoreTargetPosition (LSN, binlog pos, GTID)
Inclusive bool // Include target transaction
DryRun bool // Only show what would be done
StopOnErr bool // Stop replay on first error
}
// RestoreTargetType defines the type of restore target
type RestoreTargetType string
const (
RestoreTargetTime RestoreTargetType = "time"
RestoreTargetPosition RestoreTargetType = "position"
RestoreTargetImmediate RestoreTargetType = "immediate"
)
// RecoveryWindow represents a time range available for recovery
type RecoveryWindow struct {
BaseBackup string `json:"base_backup"`
BackupTime time.Time `json:"backup_time"`
StartPosition LogPosition `json:"-"`
EndPosition LogPosition `json:"-"`
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time"`
LogFiles []string `json:"log_files"` // WAL segments or binlog files
HasGaps bool `json:"has_gaps"`
GapDetails []string `json:"gap_details,omitempty"`
}
// ChainValidation contains results of log chain validation
type ChainValidation struct {
Valid bool
StartPos LogPosition
EndPos LogPosition
LogCount int
TotalSize int64
Gaps []LogGap
Errors []string
Warnings []string
}
// LogGap represents a gap in the log chain
type LogGap struct {
After string // Log file/position after which gap occurs
Before string // Log file/position where chain resumes
Reason string // Reason for gap if known
}

924
internal/pitr/mysql.go Normal file
View File

@@ -0,0 +1,924 @@
// Package pitr provides Point-in-Time Recovery functionality
// This file contains the MySQL/MariaDB PITR provider implementation
package pitr
import (
"bufio"
"compress/gzip"
"context"
"database/sql"
"encoding/json"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"regexp"
"strconv"
"strings"
"time"
)
// MySQLPITR implements PITRProvider for MySQL and MariaDB
type MySQLPITR struct {
db *sql.DB
config MySQLPITRConfig
binlogManager *BinlogManager
serverType DatabaseType
serverVersion string
serverID uint32
gtidMode bool
}
// MySQLPITRConfig holds configuration for MySQL PITR
type MySQLPITRConfig struct {
// Connection settings
Host string `json:"host"`
Port int `json:"port"`
User string `json:"user"`
Password string `json:"password,omitempty"`
Socket string `json:"socket,omitempty"`
// Paths
DataDir string `json:"data_dir"`
BinlogDir string `json:"binlog_dir"`
ArchiveDir string `json:"archive_dir"`
RestoreDir string `json:"restore_dir"`
// Archive settings
ArchiveInterval time.Duration `json:"archive_interval"`
RetentionDays int `json:"retention_days"`
Compression bool `json:"compression"`
CompressionLevel int `json:"compression_level"`
Encryption bool `json:"encryption"`
EncryptionKey []byte `json:"-"`
// Behavior settings
RequireRowFormat bool `json:"require_row_format"`
RequireGTID bool `json:"require_gtid"`
FlushLogsOnBackup bool `json:"flush_logs_on_backup"`
LockTables bool `json:"lock_tables"`
SingleTransaction bool `json:"single_transaction"`
}
// NewMySQLPITR creates a new MySQL PITR provider
func NewMySQLPITR(db *sql.DB, config MySQLPITRConfig) (*MySQLPITR, error) {
m := &MySQLPITR{
db: db,
config: config,
}
// Detect server type and version
if err := m.detectServerInfo(); err != nil {
return nil, fmt.Errorf("detecting server info: %w", err)
}
// Initialize binlog manager
binlogConfig := BinlogManagerConfig{
BinlogDir: config.BinlogDir,
ArchiveDir: config.ArchiveDir,
Compression: config.Compression,
Encryption: config.Encryption,
EncryptionKey: config.EncryptionKey,
}
var err error
m.binlogManager, err = NewBinlogManager(binlogConfig)
if err != nil {
return nil, fmt.Errorf("creating binlog manager: %w", err)
}
return m, nil
}
// detectServerInfo detects MySQL/MariaDB version and configuration
func (m *MySQLPITR) detectServerInfo() error {
// Get version
var version string
err := m.db.QueryRow("SELECT VERSION()").Scan(&version)
if err != nil {
return fmt.Errorf("getting version: %w", err)
}
m.serverVersion = version
// Detect MariaDB vs MySQL
if strings.Contains(strings.ToLower(version), "mariadb") {
m.serverType = DatabaseMariaDB
} else {
m.serverType = DatabaseMySQL
}
// Get server_id
var serverID int
err = m.db.QueryRow("SELECT @@server_id").Scan(&serverID)
if err == nil {
m.serverID = uint32(serverID)
}
// Check GTID mode
if m.serverType == DatabaseMySQL {
var gtidMode string
err = m.db.QueryRow("SELECT @@gtid_mode").Scan(&gtidMode)
if err == nil {
m.gtidMode = strings.ToUpper(gtidMode) == "ON"
}
} else {
// MariaDB uses different variables
var gtidPos string
err = m.db.QueryRow("SELECT @@gtid_current_pos").Scan(&gtidPos)
m.gtidMode = err == nil && gtidPos != ""
}
return nil
}
// DatabaseType returns the database type this provider handles
func (m *MySQLPITR) DatabaseType() DatabaseType {
return m.serverType
}
// Enable enables PITR for the MySQL database
func (m *MySQLPITR) Enable(ctx context.Context, config PITREnableConfig) error {
// Check current binlog settings
status, err := m.Status(ctx)
if err != nil {
return fmt.Errorf("checking status: %w", err)
}
var issues []string
// Check if binlog is enabled
var logBin string
if err := m.db.QueryRowContext(ctx, "SELECT @@log_bin").Scan(&logBin); err != nil {
return fmt.Errorf("checking log_bin: %w", err)
}
if logBin != "1" && strings.ToUpper(logBin) != "ON" {
issues = append(issues, "binary logging is not enabled (log_bin=OFF)")
issues = append(issues, " Add to my.cnf: log_bin = mysql-bin")
}
// Check binlog format
if m.config.RequireRowFormat && status.LogLevel != "ROW" {
issues = append(issues, fmt.Sprintf("binlog_format is %s, not ROW", status.LogLevel))
issues = append(issues, " Add to my.cnf: binlog_format = ROW")
}
// Check GTID mode if required
if m.config.RequireGTID && !m.gtidMode {
issues = append(issues, "GTID mode is not enabled")
if m.serverType == DatabaseMySQL {
issues = append(issues, " Add to my.cnf: gtid_mode = ON, enforce_gtid_consistency = ON")
} else {
issues = append(issues, " MariaDB: GTIDs are automatically managed with log_slave_updates")
}
}
// Check expire_logs_days (don't want logs expiring before we archive them)
var expireDays int
m.db.QueryRowContext(ctx, "SELECT @@expire_logs_days").Scan(&expireDays)
if expireDays > 0 && expireDays < config.RetentionDays {
issues = append(issues,
fmt.Sprintf("expire_logs_days (%d) is less than retention days (%d)",
expireDays, config.RetentionDays))
}
if len(issues) > 0 {
return fmt.Errorf("PITR requirements not met:\n - %s", strings.Join(issues, "\n - "))
}
// Update archive configuration
m.config.ArchiveDir = config.ArchiveDir
m.config.RetentionDays = config.RetentionDays
m.config.ArchiveInterval = config.ArchiveInterval
m.config.Compression = config.Compression
m.config.Encryption = config.Encryption
m.config.EncryptionKey = config.EncryptionKey
// Create archive directory
if err := os.MkdirAll(config.ArchiveDir, 0750); err != nil {
return fmt.Errorf("creating archive directory: %w", err)
}
// Save configuration
configPath := filepath.Join(config.ArchiveDir, "pitr_config.json")
configData, _ := json.MarshalIndent(map[string]interface{}{
"enabled": true,
"server_type": m.serverType,
"server_version": m.serverVersion,
"server_id": m.serverID,
"gtid_mode": m.gtidMode,
"archive_dir": config.ArchiveDir,
"retention_days": config.RetentionDays,
"archive_interval": config.ArchiveInterval.String(),
"compression": config.Compression,
"encryption": config.Encryption,
"created_at": time.Now().Format(time.RFC3339),
}, "", " ")
if err := os.WriteFile(configPath, configData, 0640); err != nil {
return fmt.Errorf("saving config: %w", err)
}
return nil
}
// Disable disables PITR for the MySQL database
func (m *MySQLPITR) Disable(ctx context.Context) error {
configPath := filepath.Join(m.config.ArchiveDir, "pitr_config.json")
// Check if config exists
if _, err := os.Stat(configPath); os.IsNotExist(err) {
return fmt.Errorf("PITR is not enabled (no config file found)")
}
// Update config to disabled
configData, _ := json.MarshalIndent(map[string]interface{}{
"enabled": false,
"disabled_at": time.Now().Format(time.RFC3339),
}, "", " ")
if err := os.WriteFile(configPath, configData, 0640); err != nil {
return fmt.Errorf("updating config: %w", err)
}
return nil
}
// Status returns the current PITR status
func (m *MySQLPITR) Status(ctx context.Context) (*PITRStatus, error) {
status := &PITRStatus{
DatabaseType: m.serverType,
ArchiveDir: m.config.ArchiveDir,
}
// Check if PITR is enabled via config file
configPath := filepath.Join(m.config.ArchiveDir, "pitr_config.json")
if data, err := os.ReadFile(configPath); err == nil {
var config map[string]interface{}
if json.Unmarshal(data, &config) == nil {
if enabled, ok := config["enabled"].(bool); ok {
status.Enabled = enabled
}
}
}
// Get binlog format
var binlogFormat string
if err := m.db.QueryRowContext(ctx, "SELECT @@binlog_format").Scan(&binlogFormat); err == nil {
status.LogLevel = binlogFormat
}
// Get current position
pos, err := m.GetCurrentPosition(ctx)
if err == nil {
status.Position = pos
}
// Get archive stats
if m.config.ArchiveDir != "" {
archives, err := m.binlogManager.ListArchivedBinlogs(ctx)
if err == nil {
status.ArchiveCount = len(archives)
for _, a := range archives {
status.ArchiveSize += a.Size
if a.ArchivedAt.After(status.LastArchived) {
status.LastArchived = a.ArchivedAt
}
}
}
}
status.ArchiveMethod = "manual" // MySQL doesn't have automatic archiving like PostgreSQL
return status, nil
}
// GetCurrentPosition retrieves the current binary log position
func (m *MySQLPITR) GetCurrentPosition(ctx context.Context) (*BinlogPosition, error) {
pos := &BinlogPosition{}
// Use SHOW MASTER STATUS for current position
rows, err := m.db.QueryContext(ctx, "SHOW MASTER STATUS")
if err != nil {
return nil, fmt.Errorf("getting master status: %w", err)
}
defer rows.Close()
if rows.Next() {
var file string
var position uint64
var binlogDoDB, binlogIgnoreDB, executedGtidSet sql.NullString
cols, _ := rows.Columns()
switch len(cols) {
case 5: // MySQL 5.6+
err = rows.Scan(&file, &position, &binlogDoDB, &binlogIgnoreDB, &executedGtidSet)
case 4: // Older versions
err = rows.Scan(&file, &position, &binlogDoDB, &binlogIgnoreDB)
default:
err = rows.Scan(&file, &position)
}
if err != nil {
return nil, fmt.Errorf("scanning master status: %w", err)
}
pos.File = file
pos.Position = position
pos.ServerID = m.serverID
if executedGtidSet.Valid {
pos.GTID = executedGtidSet.String
}
} else {
return nil, fmt.Errorf("no master status available (is binary logging enabled?)")
}
// For MariaDB, get GTID position differently
if m.serverType == DatabaseMariaDB && pos.GTID == "" {
var gtidPos string
if err := m.db.QueryRowContext(ctx, "SELECT @@gtid_current_pos").Scan(&gtidPos); err == nil {
pos.GTID = gtidPos
}
}
return pos, nil
}
// CreateBackup creates a PITR-capable backup with position recording
func (m *MySQLPITR) CreateBackup(ctx context.Context, opts BackupOptions) (*PITRBackupInfo, error) {
// Get position BEFORE flushing logs
startPos, err := m.GetCurrentPosition(ctx)
if err != nil {
return nil, fmt.Errorf("getting start position: %w", err)
}
// Optionally flush logs to start a new binlog file
if opts.FlushLogs || m.config.FlushLogsOnBackup {
if _, err := m.db.ExecContext(ctx, "FLUSH BINARY LOGS"); err != nil {
return nil, fmt.Errorf("flushing binary logs: %w", err)
}
// Get new position after flush
startPos, err = m.GetCurrentPosition(ctx)
if err != nil {
return nil, fmt.Errorf("getting position after flush: %w", err)
}
}
// Build mysqldump command
dumpArgs := []string{
"--single-transaction",
"--routines",
"--triggers",
"--events",
"--master-data=2", // Include binlog position as comment
}
if m.config.FlushLogsOnBackup {
dumpArgs = append(dumpArgs, "--flush-logs")
}
// Add connection params
if m.config.Host != "" {
dumpArgs = append(dumpArgs, "-h", m.config.Host)
}
if m.config.Port > 0 {
dumpArgs = append(dumpArgs, "-P", strconv.Itoa(m.config.Port))
}
if m.config.User != "" {
dumpArgs = append(dumpArgs, "-u", m.config.User)
}
if m.config.Password != "" {
dumpArgs = append(dumpArgs, "-p"+m.config.Password)
}
if m.config.Socket != "" {
dumpArgs = append(dumpArgs, "-S", m.config.Socket)
}
// Add database selection
if opts.Database != "" {
dumpArgs = append(dumpArgs, opts.Database)
} else {
dumpArgs = append(dumpArgs, "--all-databases")
}
// Create output file
timestamp := time.Now().Format("20060102_150405")
backupName := fmt.Sprintf("mysql_pitr_%s.sql", timestamp)
if opts.Compression {
backupName += ".gz"
}
backupPath := filepath.Join(opts.OutputPath, backupName)
if err := os.MkdirAll(opts.OutputPath, 0750); err != nil {
return nil, fmt.Errorf("creating output directory: %w", err)
}
// Run mysqldump
cmd := exec.CommandContext(ctx, "mysqldump", dumpArgs...)
// Create output file
outFile, err := os.Create(backupPath)
if err != nil {
return nil, fmt.Errorf("creating backup file: %w", err)
}
defer outFile.Close()
var writer io.WriteCloser = outFile
if opts.Compression {
gzWriter := NewGzipWriter(outFile, opts.CompressionLvl)
writer = gzWriter
defer gzWriter.Close()
}
cmd.Stdout = writer
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
os.Remove(backupPath)
return nil, fmt.Errorf("mysqldump failed: %w", err)
}
// Close writers
if opts.Compression {
writer.Close()
}
// Get file size
info, err := os.Stat(backupPath)
if err != nil {
return nil, fmt.Errorf("getting backup info: %w", err)
}
// Serialize position for JSON storage
posJSON, _ := json.Marshal(startPos)
backupInfo := &PITRBackupInfo{
BackupFile: backupPath,
DatabaseType: m.serverType,
DatabaseName: opts.Database,
Timestamp: time.Now(),
ServerVersion: m.serverVersion,
ServerID: int(m.serverID),
Position: startPos,
PositionJSON: string(posJSON),
SizeBytes: info.Size(),
Compressed: opts.Compression,
Encrypted: opts.Encryption,
}
// Save metadata alongside backup
metadataPath := backupPath + ".meta"
metaData, _ := json.MarshalIndent(backupInfo, "", " ")
os.WriteFile(metadataPath, metaData, 0640)
return backupInfo, nil
}
// Restore performs a point-in-time restore
func (m *MySQLPITR) Restore(ctx context.Context, backup *PITRBackupInfo, target RestoreTarget) error {
// Step 1: Restore base backup
if err := m.restoreBaseBackup(ctx, backup); err != nil {
return fmt.Errorf("restoring base backup: %w", err)
}
// Step 2: If target time is after backup time, replay binlogs
if target.Type == RestoreTargetImmediate {
return nil // Just restore to backup point
}
// Parse start position from backup
var startPos BinlogPosition
if err := json.Unmarshal([]byte(backup.PositionJSON), &startPos); err != nil {
return fmt.Errorf("parsing backup position: %w", err)
}
// Step 3: Find binlogs to replay
binlogs, err := m.binlogManager.DiscoverBinlogs(ctx)
if err != nil {
return fmt.Errorf("discovering binlogs: %w", err)
}
// Find archived binlogs too
archivedBinlogs, _ := m.binlogManager.ListArchivedBinlogs(ctx)
var filesToReplay []string
// Determine which binlogs to replay based on target
switch target.Type {
case RestoreTargetTime:
if target.Time == nil {
return fmt.Errorf("target time not specified")
}
// Find binlogs in range
relevantBinlogs := m.binlogManager.FindBinlogsInRange(ctx, binlogs, backup.Timestamp, *target.Time)
for _, b := range relevantBinlogs {
filesToReplay = append(filesToReplay, b.Path)
}
// Also check archives
for _, a := range archivedBinlogs {
if compareBinlogFiles(a.OriginalFile, startPos.File) >= 0 {
if !a.EndTime.IsZero() && !a.EndTime.Before(backup.Timestamp) && !a.StartTime.After(*target.Time) {
filesToReplay = append(filesToReplay, a.ArchivePath)
}
}
}
case RestoreTargetPosition:
if target.Position == nil {
return fmt.Errorf("target position not specified")
}
targetPos, ok := target.Position.(*BinlogPosition)
if !ok {
return fmt.Errorf("invalid target position type")
}
// Find binlogs from start to target position
for _, b := range binlogs {
if compareBinlogFiles(b.Name, startPos.File) >= 0 &&
compareBinlogFiles(b.Name, targetPos.File) <= 0 {
filesToReplay = append(filesToReplay, b.Path)
}
}
}
if len(filesToReplay) == 0 {
// Nothing to replay, backup is already at or past target
return nil
}
// Step 4: Replay binlogs
replayOpts := ReplayOptions{
BinlogFiles: filesToReplay,
StartPosition: &startPos,
DryRun: target.DryRun,
MySQLHost: m.config.Host,
MySQLPort: m.config.Port,
MySQLUser: m.config.User,
MySQLPass: m.config.Password,
StopOnError: target.StopOnErr,
}
if target.Type == RestoreTargetTime && target.Time != nil {
replayOpts.StopTime = target.Time
}
if target.Type == RestoreTargetPosition && target.Position != nil {
replayOpts.StopPosition = target.Position
}
if target.DryRun {
replayOpts.Output = os.Stdout
}
return m.binlogManager.ReplayBinlogs(ctx, replayOpts)
}
// restoreBaseBackup restores the base MySQL backup
func (m *MySQLPITR) restoreBaseBackup(ctx context.Context, backup *PITRBackupInfo) error {
// Build mysql command
mysqlArgs := []string{}
if m.config.Host != "" {
mysqlArgs = append(mysqlArgs, "-h", m.config.Host)
}
if m.config.Port > 0 {
mysqlArgs = append(mysqlArgs, "-P", strconv.Itoa(m.config.Port))
}
if m.config.User != "" {
mysqlArgs = append(mysqlArgs, "-u", m.config.User)
}
if m.config.Password != "" {
mysqlArgs = append(mysqlArgs, "-p"+m.config.Password)
}
if m.config.Socket != "" {
mysqlArgs = append(mysqlArgs, "-S", m.config.Socket)
}
// Prepare input
var input io.Reader
backupFile, err := os.Open(backup.BackupFile)
if err != nil {
return fmt.Errorf("opening backup file: %w", err)
}
defer backupFile.Close()
input = backupFile
// Handle compressed backups
if backup.Compressed || strings.HasSuffix(backup.BackupFile, ".gz") {
gzReader, err := NewGzipReader(backupFile)
if err != nil {
return fmt.Errorf("creating gzip reader: %w", err)
}
defer gzReader.Close()
input = gzReader
}
// Run mysql
cmd := exec.CommandContext(ctx, "mysql", mysqlArgs...)
cmd.Stdin = input
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
// ListRecoveryPoints lists available recovery points/ranges
func (m *MySQLPITR) ListRecoveryPoints(ctx context.Context) ([]RecoveryWindow, error) {
var windows []RecoveryWindow
// Find all backup metadata files
backupPattern := filepath.Join(m.config.ArchiveDir, "..", "*", "*.meta")
metaFiles, _ := filepath.Glob(backupPattern)
// Also check default backup locations
additionalPaths := []string{
filepath.Join(m.config.ArchiveDir, "*.meta"),
filepath.Join(m.config.RestoreDir, "*.meta"),
}
for _, p := range additionalPaths {
matches, _ := filepath.Glob(p)
metaFiles = append(metaFiles, matches...)
}
// Get current binlogs
binlogs, err := m.binlogManager.DiscoverBinlogs(ctx)
if err != nil {
binlogs = []BinlogFile{}
}
// Get archived binlogs
archivedBinlogs, _ := m.binlogManager.ListArchivedBinlogs(ctx)
for _, metaFile := range metaFiles {
data, err := os.ReadFile(metaFile)
if err != nil {
continue
}
var backup PITRBackupInfo
if err := json.Unmarshal(data, &backup); err != nil {
continue
}
// Parse position
var startPos BinlogPosition
json.Unmarshal([]byte(backup.PositionJSON), &startPos)
window := RecoveryWindow{
BaseBackup: backup.BackupFile,
BackupTime: backup.Timestamp,
StartTime: backup.Timestamp,
StartPosition: &startPos,
}
// Find binlogs available after this backup
var relevantBinlogs []string
var latestTime time.Time
var latestPos *BinlogPosition
for _, b := range binlogs {
if compareBinlogFiles(b.Name, startPos.File) >= 0 {
relevantBinlogs = append(relevantBinlogs, b.Name)
if !b.EndTime.IsZero() && b.EndTime.After(latestTime) {
latestTime = b.EndTime
latestPos = &BinlogPosition{
File: b.Name,
Position: b.EndPos,
GTID: b.GTID,
}
}
}
}
for _, a := range archivedBinlogs {
if compareBinlogFiles(a.OriginalFile, startPos.File) >= 0 {
relevantBinlogs = append(relevantBinlogs, a.OriginalFile)
if !a.EndTime.IsZero() && a.EndTime.After(latestTime) {
latestTime = a.EndTime
latestPos = &BinlogPosition{
File: a.OriginalFile,
Position: a.EndPos,
GTID: a.GTID,
}
}
}
}
window.LogFiles = relevantBinlogs
if !latestTime.IsZero() {
window.EndTime = latestTime
} else {
window.EndTime = time.Now()
}
window.EndPosition = latestPos
// Check for gaps
validation, _ := m.binlogManager.ValidateBinlogChain(ctx, binlogs)
if validation != nil {
window.HasGaps = !validation.Valid
for _, gap := range validation.Gaps {
window.GapDetails = append(window.GapDetails, gap.Reason)
}
}
windows = append(windows, window)
}
return windows, nil
}
// ValidateChain validates the log chain integrity
func (m *MySQLPITR) ValidateChain(ctx context.Context, from, to time.Time) (*ChainValidation, error) {
// Discover all binlogs
binlogs, err := m.binlogManager.DiscoverBinlogs(ctx)
if err != nil {
return nil, fmt.Errorf("discovering binlogs: %w", err)
}
// Filter to time range
relevant := m.binlogManager.FindBinlogsInRange(ctx, binlogs, from, to)
// Validate chain
return m.binlogManager.ValidateBinlogChain(ctx, relevant)
}
// ArchiveNewBinlogs archives any binlog files that haven't been archived yet
func (m *MySQLPITR) ArchiveNewBinlogs(ctx context.Context) ([]BinlogArchiveInfo, error) {
// Get current binlogs
binlogs, err := m.binlogManager.DiscoverBinlogs(ctx)
if err != nil {
return nil, fmt.Errorf("discovering binlogs: %w", err)
}
// Get already archived
archived, _ := m.binlogManager.ListArchivedBinlogs(ctx)
archivedSet := make(map[string]struct{})
for _, a := range archived {
archivedSet[a.OriginalFile] = struct{}{}
}
// Get current binlog file (don't archive the active one)
currentPos, _ := m.GetCurrentPosition(ctx)
currentFile := ""
if currentPos != nil {
currentFile = currentPos.File
}
var newArchives []BinlogArchiveInfo
for i := range binlogs {
b := &binlogs[i]
// Skip if already archived
if _, exists := archivedSet[b.Name]; exists {
continue
}
// Skip the current active binlog
if b.Name == currentFile {
continue
}
// Archive
archiveInfo, err := m.binlogManager.ArchiveBinlog(ctx, b)
if err != nil {
// Log but continue
continue
}
newArchives = append(newArchives, *archiveInfo)
}
// Update metadata
if len(newArchives) > 0 {
allArchived, _ := m.binlogManager.ListArchivedBinlogs(ctx)
m.binlogManager.SaveArchiveMetadata(allArchived)
}
return newArchives, nil
}
// PurgeBinlogs purges old binlog files based on retention policy
func (m *MySQLPITR) PurgeBinlogs(ctx context.Context) error {
if m.config.RetentionDays <= 0 {
return fmt.Errorf("retention days not configured")
}
cutoff := time.Now().AddDate(0, 0, -m.config.RetentionDays)
// Get archived binlogs
archived, err := m.binlogManager.ListArchivedBinlogs(ctx)
if err != nil {
return fmt.Errorf("listing archived binlogs: %w", err)
}
for _, a := range archived {
if a.ArchivedAt.Before(cutoff) {
os.Remove(a.ArchivePath)
}
}
return nil
}
// GzipWriter is a helper for gzip compression
type GzipWriter struct {
w *gzip.Writer
}
func NewGzipWriter(w io.Writer, level int) *GzipWriter {
if level <= 0 {
level = gzip.DefaultCompression
}
gw, _ := gzip.NewWriterLevel(w, level)
return &GzipWriter{w: gw}
}
func (g *GzipWriter) Write(p []byte) (int, error) {
return g.w.Write(p)
}
func (g *GzipWriter) Close() error {
return g.w.Close()
}
// GzipReader is a helper for gzip decompression
type GzipReader struct {
r *gzip.Reader
}
func NewGzipReader(r io.Reader) (*GzipReader, error) {
gr, err := gzip.NewReader(r)
if err != nil {
return nil, err
}
return &GzipReader{r: gr}, nil
}
func (g *GzipReader) Read(p []byte) (int, error) {
return g.r.Read(p)
}
func (g *GzipReader) Close() error {
return g.r.Close()
}
// ExtractBinlogPositionFromDump extracts the binlog position from a mysqldump file
func ExtractBinlogPositionFromDump(dumpPath string) (*BinlogPosition, error) {
file, err := os.Open(dumpPath)
if err != nil {
return nil, fmt.Errorf("opening dump file: %w", err)
}
defer file.Close()
var reader io.Reader = file
if strings.HasSuffix(dumpPath, ".gz") {
gzReader, err := gzip.NewReader(file)
if err != nil {
return nil, fmt.Errorf("creating gzip reader: %w", err)
}
defer gzReader.Close()
reader = gzReader
}
// Look for CHANGE MASTER TO or -- CHANGE MASTER TO comment
// Pattern: -- CHANGE MASTER TO MASTER_LOG_FILE='mysql-bin.000042', MASTER_LOG_POS=1234;
scanner := NewLimitedScanner(reader, 1000) // Only scan first 1000 lines
posPattern := regexp.MustCompile(`MASTER_LOG_FILE='([^']+)',\s*MASTER_LOG_POS=(\d+)`)
for scanner.Scan() {
line := scanner.Text()
if matches := posPattern.FindStringSubmatch(line); len(matches) == 3 {
pos, _ := strconv.ParseUint(matches[2], 10, 64)
return &BinlogPosition{
File: matches[1],
Position: pos,
}, nil
}
}
return nil, fmt.Errorf("binlog position not found in dump file")
}
// LimitedScanner wraps bufio.Scanner with a line limit
type LimitedScanner struct {
scanner *bufio.Scanner
limit int
count int
}
func NewLimitedScanner(r io.Reader, limit int) *LimitedScanner {
return &LimitedScanner{
scanner: bufio.NewScanner(r),
limit: limit,
}
}
func (s *LimitedScanner) Scan() bool {
if s.count >= s.limit {
return false
}
s.count++
return s.scanner.Scan()
}
func (s *LimitedScanner) Text() string {
return s.scanner.Text()
}

View File

@@ -0,0 +1,499 @@
// Package replica provides replica-aware backup functionality
package replica
import (
"context"
"database/sql"
"fmt"
"sort"
"time"
)
// Role represents the replication role of a database
type Role string
const (
RolePrimary Role = "primary"
RoleReplica Role = "replica"
RoleStandalone Role = "standalone"
RoleUnknown Role = "unknown"
)
// Status represents the health status of a replica
type Status string
const (
StatusHealthy Status = "healthy"
StatusLagging Status = "lagging"
StatusDisconnected Status = "disconnected"
StatusUnknown Status = "unknown"
)
// Node represents a database node in a replication topology
type Node struct {
Host string `json:"host"`
Port int `json:"port"`
Role Role `json:"role"`
Status Status `json:"status"`
ReplicationLag time.Duration `json:"replication_lag"`
IsAvailable bool `json:"is_available"`
LastChecked time.Time `json:"last_checked"`
Priority int `json:"priority"` // Lower = higher priority
Weight int `json:"weight"` // For load balancing
Metadata map[string]string `json:"metadata,omitempty"`
}
// Topology represents the replication topology
type Topology struct {
Primary *Node `json:"primary,omitempty"`
Replicas []*Node `json:"replicas"`
Timestamp time.Time `json:"timestamp"`
}
// Config configures replica-aware backup behavior
type Config struct {
PreferReplica bool `json:"prefer_replica"`
MaxReplicationLag time.Duration `json:"max_replication_lag"`
FallbackToPrimary bool `json:"fallback_to_primary"`
RequireHealthy bool `json:"require_healthy"`
SelectionStrategy Strategy `json:"selection_strategy"`
Nodes []NodeConfig `json:"nodes"`
}
// NodeConfig configures a known node
type NodeConfig struct {
Host string `json:"host"`
Port int `json:"port"`
Priority int `json:"priority"`
Weight int `json:"weight"`
}
// Strategy for selecting a node
type Strategy string
const (
StrategyPreferReplica Strategy = "prefer_replica" // Always prefer replica
StrategyLowestLag Strategy = "lowest_lag" // Choose node with lowest lag
StrategyRoundRobin Strategy = "round_robin" // Rotate between replicas
StrategyPriority Strategy = "priority" // Use configured priorities
StrategyWeighted Strategy = "weighted" // Weighted random selection
)
// DefaultConfig returns default replica configuration
func DefaultConfig() Config {
return Config{
PreferReplica: true,
MaxReplicationLag: 1 * time.Minute,
FallbackToPrimary: true,
RequireHealthy: true,
SelectionStrategy: StrategyLowestLag,
}
}
// Selector selects the best node for backup
type Selector struct {
config Config
lastSelected int // For round-robin
}
// NewSelector creates a new replica selector
func NewSelector(config Config) *Selector {
return &Selector{
config: config,
}
}
// SelectNode selects the best node for backup from the topology
func (s *Selector) SelectNode(topology *Topology) (*Node, error) {
var candidates []*Node
// Collect available candidates
if s.config.PreferReplica {
// Prefer replicas
for _, r := range topology.Replicas {
if s.isAcceptable(r) {
candidates = append(candidates, r)
}
}
// Fallback to primary if no replicas available
if len(candidates) == 0 && s.config.FallbackToPrimary {
if topology.Primary != nil && topology.Primary.IsAvailable {
return topology.Primary, nil
}
}
} else {
// Allow all nodes
if topology.Primary != nil && topology.Primary.IsAvailable {
candidates = append(candidates, topology.Primary)
}
for _, r := range topology.Replicas {
if s.isAcceptable(r) {
candidates = append(candidates, r)
}
}
}
if len(candidates) == 0 {
return nil, fmt.Errorf("no acceptable nodes available for backup")
}
// Apply selection strategy
return s.applyStrategy(candidates)
}
// isAcceptable checks if a node is acceptable for backup
func (s *Selector) isAcceptable(node *Node) bool {
if !node.IsAvailable {
return false
}
if s.config.RequireHealthy && node.Status != StatusHealthy {
return false
}
if s.config.MaxReplicationLag > 0 && node.ReplicationLag > s.config.MaxReplicationLag {
return false
}
return true
}
// applyStrategy selects a node using the configured strategy
func (s *Selector) applyStrategy(candidates []*Node) (*Node, error) {
switch s.config.SelectionStrategy {
case StrategyLowestLag:
return s.selectLowestLag(candidates), nil
case StrategyPriority:
return s.selectByPriority(candidates), nil
case StrategyRoundRobin:
return s.selectRoundRobin(candidates), nil
default:
// Default to lowest lag
return s.selectLowestLag(candidates), nil
}
}
func (s *Selector) selectLowestLag(candidates []*Node) *Node {
sort.Slice(candidates, func(i, j int) bool {
return candidates[i].ReplicationLag < candidates[j].ReplicationLag
})
return candidates[0]
}
func (s *Selector) selectByPriority(candidates []*Node) *Node {
sort.Slice(candidates, func(i, j int) bool {
return candidates[i].Priority < candidates[j].Priority
})
return candidates[0]
}
func (s *Selector) selectRoundRobin(candidates []*Node) *Node {
s.lastSelected = (s.lastSelected + 1) % len(candidates)
return candidates[s.lastSelected]
}
// Detector detects replication topology
type Detector interface {
Detect(ctx context.Context, db *sql.DB) (*Topology, error)
GetRole(ctx context.Context, db *sql.DB) (Role, error)
GetReplicationLag(ctx context.Context, db *sql.DB) (time.Duration, error)
}
// PostgreSQLDetector detects PostgreSQL replication topology
type PostgreSQLDetector struct{}
// Detect discovers PostgreSQL replication topology
func (d *PostgreSQLDetector) Detect(ctx context.Context, db *sql.DB) (*Topology, error) {
topology := &Topology{
Timestamp: time.Now(),
Replicas: make([]*Node, 0),
}
// Check if we're on primary
var isRecovery bool
err := db.QueryRowContext(ctx, "SELECT pg_is_in_recovery()").Scan(&isRecovery)
if err != nil {
return nil, fmt.Errorf("failed to check recovery status: %w", err)
}
if !isRecovery {
// We're on primary - get replicas from pg_stat_replication
rows, err := db.QueryContext(ctx, `
SELECT
client_addr,
client_port,
state,
EXTRACT(EPOCH FROM (now() - replay_lag))::integer as lag_seconds
FROM pg_stat_replication
`)
if err != nil {
return nil, fmt.Errorf("failed to query replication status: %w", err)
}
defer rows.Close()
for rows.Next() {
var addr sql.NullString
var port sql.NullInt64
var state sql.NullString
var lagSeconds sql.NullInt64
if err := rows.Scan(&addr, &port, &state, &lagSeconds); err != nil {
continue
}
node := &Node{
Host: addr.String,
Port: int(port.Int64),
Role: RoleReplica,
IsAvailable: true,
LastChecked: time.Now(),
}
if lagSeconds.Valid {
node.ReplicationLag = time.Duration(lagSeconds.Int64) * time.Second
}
if state.String == "streaming" {
node.Status = StatusHealthy
} else {
node.Status = StatusLagging
}
topology.Replicas = append(topology.Replicas, node)
}
}
return topology, nil
}
// GetRole returns the replication role
func (d *PostgreSQLDetector) GetRole(ctx context.Context, db *sql.DB) (Role, error) {
var isRecovery bool
err := db.QueryRowContext(ctx, "SELECT pg_is_in_recovery()").Scan(&isRecovery)
if err != nil {
return RoleUnknown, fmt.Errorf("failed to check recovery status: %w", err)
}
if isRecovery {
return RoleReplica, nil
}
return RolePrimary, nil
}
// GetReplicationLag returns the replication lag
func (d *PostgreSQLDetector) GetReplicationLag(ctx context.Context, db *sql.DB) (time.Duration, error) {
var lagSeconds sql.NullFloat64
err := db.QueryRowContext(ctx, `
SELECT EXTRACT(EPOCH FROM (now() - pg_last_xact_replay_timestamp()))
`).Scan(&lagSeconds)
if err != nil {
return 0, fmt.Errorf("failed to get replication lag: %w", err)
}
if !lagSeconds.Valid {
return 0, nil // Not a replica or no lag data
}
return time.Duration(lagSeconds.Float64) * time.Second, nil
}
// MySQLDetector detects MySQL/MariaDB replication topology
type MySQLDetector struct{}
// Detect discovers MySQL replication topology
func (d *MySQLDetector) Detect(ctx context.Context, db *sql.DB) (*Topology, error) {
topology := &Topology{
Timestamp: time.Now(),
Replicas: make([]*Node, 0),
}
// Check slave status first
rows, err := db.QueryContext(ctx, "SHOW SLAVE STATUS")
if err != nil {
// Not a slave, check if we're a master
rows, err = db.QueryContext(ctx, "SHOW SLAVE HOSTS")
if err != nil {
return topology, nil // Standalone or error
}
defer rows.Close()
// Parse slave hosts
cols, _ := rows.Columns()
values := make([]interface{}, len(cols))
valuePtrs := make([]interface{}, len(cols))
for i := range values {
valuePtrs[i] = &values[i]
}
for rows.Next() {
if err := rows.Scan(valuePtrs...); err != nil {
continue
}
// Extract host and port
var host string
var port int
for i, col := range cols {
switch col {
case "Host":
if v, ok := values[i].([]byte); ok {
host = string(v)
}
case "Port":
if v, ok := values[i].(int64); ok {
port = int(v)
}
}
}
if host != "" {
topology.Replicas = append(topology.Replicas, &Node{
Host: host,
Port: port,
Role: RoleReplica,
IsAvailable: true,
Status: StatusUnknown,
LastChecked: time.Now(),
})
}
}
return topology, nil
}
defer rows.Close()
return topology, nil
}
// GetRole returns the MySQL replication role
func (d *MySQLDetector) GetRole(ctx context.Context, db *sql.DB) (Role, error) {
// Check if this is a slave
rows, err := db.QueryContext(ctx, "SHOW SLAVE STATUS")
if err != nil {
return RoleUnknown, err
}
defer rows.Close()
if rows.Next() {
return RoleReplica, nil
}
// Check if this is a master with slaves
rows2, err := db.QueryContext(ctx, "SHOW SLAVE HOSTS")
if err != nil {
return RoleStandalone, nil
}
defer rows2.Close()
if rows2.Next() {
return RolePrimary, nil
}
return RoleStandalone, nil
}
// GetReplicationLag returns MySQL replication lag
func (d *MySQLDetector) GetReplicationLag(ctx context.Context, db *sql.DB) (time.Duration, error) {
var lagSeconds sql.NullInt64
rows, err := db.QueryContext(ctx, "SHOW SLAVE STATUS")
if err != nil {
return 0, err
}
defer rows.Close()
if !rows.Next() {
return 0, nil // Not a replica
}
cols, _ := rows.Columns()
values := make([]interface{}, len(cols))
valuePtrs := make([]interface{}, len(cols))
for i := range values {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
return 0, err
}
// Find Seconds_Behind_Master column
for i, col := range cols {
if col == "Seconds_Behind_Master" {
switch v := values[i].(type) {
case int64:
lagSeconds.Int64 = v
lagSeconds.Valid = true
case []byte:
fmt.Sscanf(string(v), "%d", &lagSeconds.Int64)
lagSeconds.Valid = true
}
break
}
}
if !lagSeconds.Valid {
return 0, nil
}
return time.Duration(lagSeconds.Int64) * time.Second, nil
}
// GetDetector returns the appropriate detector for a database type
func GetDetector(dbType string) Detector {
switch dbType {
case "postgresql", "postgres":
return &PostgreSQLDetector{}
case "mysql", "mariadb":
return &MySQLDetector{}
default:
return nil
}
}
// Result contains the result of replica selection
type Result struct {
SelectedNode *Node `json:"selected_node"`
Topology *Topology `json:"topology"`
Reason string `json:"reason"`
Duration time.Duration `json:"detection_duration"`
}
// SelectForBackup performs topology detection and node selection
func SelectForBackup(ctx context.Context, db *sql.DB, dbType string, config Config) (*Result, error) {
start := time.Now()
result := &Result{}
detector := GetDetector(dbType)
if detector == nil {
return nil, fmt.Errorf("unsupported database type: %s", dbType)
}
topology, err := detector.Detect(ctx, db)
if err != nil {
return nil, fmt.Errorf("topology detection failed: %w", err)
}
result.Topology = topology
selector := NewSelector(config)
node, err := selector.SelectNode(topology)
if err != nil {
return nil, err
}
result.SelectedNode = node
result.Duration = time.Since(start)
if node.Role == RoleReplica {
result.Reason = fmt.Sprintf("Selected replica %s:%d with %s lag",
node.Host, node.Port, node.ReplicationLag)
} else {
result.Reason = fmt.Sprintf("Using primary %s:%d", node.Host, node.Port)
}
return result, nil
}

View File

@@ -0,0 +1,424 @@
// Package report - SOC2 framework controls
package report
import (
"time"
)
// SOC2Framework returns SOC2 Trust Service Criteria controls
func SOC2Framework() []Category {
return []Category{
soc2Security(),
soc2Availability(),
soc2ProcessingIntegrity(),
soc2Confidentiality(),
}
}
func soc2Security() Category {
return Category{
ID: "soc2-security",
Name: "Security",
Description: "Protection of system resources against unauthorized access",
Weight: 1.0,
Controls: []Control{
{
ID: "CC6.1",
Reference: "SOC2 CC6.1",
Name: "Encryption at Rest",
Description: "Data is protected at rest using encryption",
},
{
ID: "CC6.7",
Reference: "SOC2 CC6.7",
Name: "Encryption in Transit",
Description: "Data is protected in transit using encryption",
},
{
ID: "CC6.2",
Reference: "SOC2 CC6.2",
Name: "Access Control",
Description: "Logical access to data and system components is restricted",
},
{
ID: "CC6.3",
Reference: "SOC2 CC6.3",
Name: "Authorized Access",
Description: "Only authorized users can access data and systems",
},
},
}
}
func soc2Availability() Category {
return Category{
ID: "soc2-availability",
Name: "Availability",
Description: "System availability for operation and use as agreed",
Weight: 1.0,
Controls: []Control{
{
ID: "A1.1",
Reference: "SOC2 A1.1",
Name: "Backup Policy",
Description: "Backup policies and procedures are established and operating",
},
{
ID: "A1.2",
Reference: "SOC2 A1.2",
Name: "Backup Testing",
Description: "Backups are tested for recoverability",
},
{
ID: "A1.3",
Reference: "SOC2 A1.3",
Name: "Recovery Procedures",
Description: "Recovery procedures are documented and tested",
},
{
ID: "A1.4",
Reference: "SOC2 A1.4",
Name: "Disaster Recovery",
Description: "DR plans are maintained and tested",
},
},
}
}
func soc2ProcessingIntegrity() Category {
return Category{
ID: "soc2-processing-integrity",
Name: "Processing Integrity",
Description: "System processing is complete, valid, accurate, timely, and authorized",
Weight: 0.75,
Controls: []Control{
{
ID: "PI1.1",
Reference: "SOC2 PI1.1",
Name: "Data Integrity",
Description: "Checksums and verification ensure data integrity",
},
{
ID: "PI1.2",
Reference: "SOC2 PI1.2",
Name: "Error Handling",
Description: "Errors are identified and corrected in a timely manner",
},
},
}
}
func soc2Confidentiality() Category {
return Category{
ID: "soc2-confidentiality",
Name: "Confidentiality",
Description: "Information designated as confidential is protected",
Weight: 1.0,
Controls: []Control{
{
ID: "C1.1",
Reference: "SOC2 C1.1",
Name: "Data Classification",
Description: "Confidential data is identified and classified",
},
{
ID: "C1.2",
Reference: "SOC2 C1.2",
Name: "Data Retention",
Description: "Data retention policies are implemented",
},
{
ID: "C1.3",
Reference: "SOC2 C1.3",
Name: "Data Disposal",
Description: "Data is securely disposed when no longer needed",
},
},
}
}
// GDPRFramework returns GDPR-related controls
func GDPRFramework() []Category {
return []Category{
{
ID: "gdpr-data-protection",
Name: "Data Protection",
Description: "Protection of personal data",
Weight: 1.0,
Controls: []Control{
{
ID: "GDPR-25",
Reference: "GDPR Article 25",
Name: "Data Protection by Design",
Description: "Data protection measures are implemented by design",
},
{
ID: "GDPR-32",
Reference: "GDPR Article 32",
Name: "Security of Processing",
Description: "Appropriate technical measures to ensure data security",
},
{
ID: "GDPR-33",
Reference: "GDPR Article 33",
Name: "Breach Notification",
Description: "Procedures for breach detection and notification",
},
},
},
{
ID: "gdpr-data-retention",
Name: "Data Retention",
Description: "Lawful data retention practices",
Weight: 1.0,
Controls: []Control{
{
ID: "GDPR-5.1e",
Reference: "GDPR Article 5(1)(e)",
Name: "Storage Limitation",
Description: "Personal data not kept longer than necessary",
},
{
ID: "GDPR-17",
Reference: "GDPR Article 17",
Name: "Right to Erasure",
Description: "Ability to delete personal data on request",
},
},
},
}
}
// HIPAAFramework returns HIPAA-related controls
func HIPAAFramework() []Category {
return []Category{
{
ID: "hipaa-administrative",
Name: "Administrative Safeguards",
Description: "Administrative policies and procedures",
Weight: 1.0,
Controls: []Control{
{
ID: "164.308a7",
Reference: "HIPAA 164.308(a)(7)",
Name: "Contingency Plan",
Description: "Data backup and disaster recovery procedures",
},
{
ID: "164.308a7iA",
Reference: "HIPAA 164.308(a)(7)(ii)(A)",
Name: "Data Backup Plan",
Description: "Procedures for retrievable exact copies of ePHI",
},
{
ID: "164.308a7iB",
Reference: "HIPAA 164.308(a)(7)(ii)(B)",
Name: "Disaster Recovery Plan",
Description: "Procedures to restore any loss of data",
},
{
ID: "164.308a7iD",
Reference: "HIPAA 164.308(a)(7)(ii)(D)",
Name: "Testing and Revision",
Description: "Testing of contingency plans",
},
},
},
{
ID: "hipaa-technical",
Name: "Technical Safeguards",
Description: "Technical security measures",
Weight: 1.0,
Controls: []Control{
{
ID: "164.312a2iv",
Reference: "HIPAA 164.312(a)(2)(iv)",
Name: "Encryption",
Description: "Encryption of ePHI",
},
{
ID: "164.312c1",
Reference: "HIPAA 164.312(c)(1)",
Name: "Integrity Controls",
Description: "Mechanisms to ensure ePHI is not improperly altered",
},
{
ID: "164.312e1",
Reference: "HIPAA 164.312(e)(1)",
Name: "Transmission Security",
Description: "Technical measures to guard against unauthorized access",
},
},
},
}
}
// PCIDSSFramework returns PCI-DSS related controls
func PCIDSSFramework() []Category {
return []Category{
{
ID: "pci-protect",
Name: "Protect Stored Data",
Description: "Protect stored cardholder data",
Weight: 1.0,
Controls: []Control{
{
ID: "PCI-3.1",
Reference: "PCI-DSS 3.1",
Name: "Data Retention Policy",
Description: "Retention policy limits storage time",
},
{
ID: "PCI-3.4",
Reference: "PCI-DSS 3.4",
Name: "Encryption",
Description: "Render PAN unreadable anywhere it is stored",
},
{
ID: "PCI-3.5",
Reference: "PCI-DSS 3.5",
Name: "Key Management",
Description: "Protect cryptographic keys",
},
},
},
{
ID: "pci-maintain",
Name: "Maintain Security",
Description: "Maintain security policies and procedures",
Weight: 1.0,
Controls: []Control{
{
ID: "PCI-12.10.1",
Reference: "PCI-DSS 12.10.1",
Name: "Incident Response Plan",
Description: "Incident response plan includes data recovery",
},
},
},
}
}
// ISO27001Framework returns ISO 27001 related controls
func ISO27001Framework() []Category {
return []Category{
{
ID: "iso-operations",
Name: "Operations Security",
Description: "A.12 Operations Security controls",
Weight: 1.0,
Controls: []Control{
{
ID: "A.12.3.1",
Reference: "ISO 27001 A.12.3.1",
Name: "Information Backup",
Description: "Backup copies taken and tested regularly",
},
},
},
{
ID: "iso-continuity",
Name: "Business Continuity",
Description: "A.17 Business Continuity controls",
Weight: 1.0,
Controls: []Control{
{
ID: "A.17.1.1",
Reference: "ISO 27001 A.17.1.1",
Name: "Planning Continuity",
Description: "Information security continuity planning",
},
{
ID: "A.17.1.2",
Reference: "ISO 27001 A.17.1.2",
Name: "Implementing Continuity",
Description: "Implementation of security continuity",
},
{
ID: "A.17.1.3",
Reference: "ISO 27001 A.17.1.3",
Name: "Verify and Review",
Description: "Verify and review continuity controls",
},
},
},
{
ID: "iso-cryptography",
Name: "Cryptography",
Description: "A.10 Cryptographic controls",
Weight: 1.0,
Controls: []Control{
{
ID: "A.10.1.1",
Reference: "ISO 27001 A.10.1.1",
Name: "Cryptographic Controls",
Description: "Policy on use of cryptographic controls",
},
{
ID: "A.10.1.2",
Reference: "ISO 27001 A.10.1.2",
Name: "Key Management",
Description: "Policy on cryptographic key management",
},
},
},
}
}
// GetFramework returns the appropriate framework for a report type
func GetFramework(reportType ReportType) []Category {
switch reportType {
case ReportSOC2:
return SOC2Framework()
case ReportGDPR:
return GDPRFramework()
case ReportHIPAA:
return HIPAAFramework()
case ReportPCIDSS:
return PCIDSSFramework()
case ReportISO27001:
return ISO27001Framework()
default:
return nil
}
}
// CreatePeriodReport creates a report for a specific time period
func CreatePeriodReport(reportType ReportType, start, end time.Time) *Report {
title := ""
desc := ""
switch reportType {
case ReportSOC2:
title = "SOC 2 Type II Compliance Report"
desc = "Trust Service Criteria compliance assessment"
case ReportGDPR:
title = "GDPR Data Protection Compliance Report"
desc = "General Data Protection Regulation compliance assessment"
case ReportHIPAA:
title = "HIPAA Security Compliance Report"
desc = "Health Insurance Portability and Accountability Act compliance assessment"
case ReportPCIDSS:
title = "PCI-DSS Compliance Report"
desc = "Payment Card Industry Data Security Standard compliance assessment"
case ReportISO27001:
title = "ISO 27001 Compliance Report"
desc = "Information Security Management System compliance assessment"
default:
title = "Custom Compliance Report"
desc = "Custom compliance assessment"
}
report := NewReport(reportType, title)
report.Description = desc
report.PeriodStart = start
report.PeriodEnd = end
// Load framework controls
framework := GetFramework(reportType)
for _, cat := range framework {
report.AddCategory(cat)
}
return report
}

View File

@@ -0,0 +1,420 @@
// Package report - Report generator
package report
import (
"context"
"fmt"
"time"
"dbbackup/internal/catalog"
)
// Generator generates compliance reports
type Generator struct {
catalog catalog.Catalog
config ReportConfig
}
// NewGenerator creates a new report generator
func NewGenerator(cat catalog.Catalog, config ReportConfig) *Generator {
return &Generator{
catalog: cat,
config: config,
}
}
// Generate creates a compliance report
func (g *Generator) Generate() (*Report, error) {
report := CreatePeriodReport(g.config.Type, g.config.PeriodStart, g.config.PeriodEnd)
report.Title = g.config.Title
if g.config.Description != "" {
report.Description = g.config.Description
}
// Collect evidence from catalog
evidence, err := g.collectEvidence()
if err != nil {
return nil, fmt.Errorf("failed to collect evidence: %w", err)
}
for _, e := range evidence {
report.AddEvidence(e)
}
// Evaluate controls
if err := g.evaluateControls(report, evidence); err != nil {
return nil, fmt.Errorf("failed to evaluate controls: %w", err)
}
// Calculate summary
report.Calculate()
return report, nil
}
// collectEvidence gathers evidence from the backup catalog
func (g *Generator) collectEvidence() ([]Evidence, error) {
var evidence []Evidence
ctx := context.Background()
// Get backup entries in the report period
query := &catalog.SearchQuery{
StartDate: &g.config.PeriodStart,
EndDate: &g.config.PeriodEnd,
Limit: 1000,
}
entries, err := g.catalog.Search(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to search catalog: %w", err)
}
// Create evidence for backups
for _, entry := range entries {
e := Evidence{
ID: fmt.Sprintf("BKP-%d", entry.ID),
Type: EvidenceBackupLog,
Description: fmt.Sprintf("Backup of %s completed", entry.Database),
Source: entry.BackupPath,
CollectedAt: entry.CreatedAt,
Data: map[string]interface{}{
"database": entry.Database,
"database_type": entry.DatabaseType,
"size": entry.SizeBytes,
"sha256": entry.SHA256,
"encrypted": entry.Encrypted,
"compression": entry.Compression,
"status": entry.Status,
},
}
if entry.SHA256 != "" {
e.Hash = entry.SHA256
}
evidence = append(evidence, e)
// Add verification evidence
if entry.VerifiedAt != nil {
evidence = append(evidence, Evidence{
ID: fmt.Sprintf("VRF-%d", entry.ID),
Type: EvidenceAuditLog,
Description: fmt.Sprintf("Verification of backup %s", entry.BackupPath),
Source: "verification_system",
CollectedAt: *entry.VerifiedAt,
Data: map[string]interface{}{
"backup_id": entry.ID,
"database": entry.Database,
"verified": true,
},
})
}
// Add drill evidence
if entry.DrillTestedAt != nil {
evidence = append(evidence, Evidence{
ID: fmt.Sprintf("DRL-%d", entry.ID),
Type: EvidenceDrillResult,
Description: fmt.Sprintf("DR drill test of backup %s", entry.BackupPath),
Source: "drill_system",
CollectedAt: *entry.DrillTestedAt,
Data: map[string]interface{}{
"backup_id": entry.ID,
"database": entry.Database,
"passed": true,
},
})
}
// Add encryption evidence
if entry.Encrypted {
encryption := "AES-256"
if meta, ok := entry.Metadata["encryption_method"]; ok {
encryption = meta
}
evidence = append(evidence, Evidence{
ID: fmt.Sprintf("ENC-%d", entry.ID),
Type: EvidenceEncryptionProof,
Description: fmt.Sprintf("Encrypted backup %s", entry.BackupPath),
Source: entry.BackupPath,
CollectedAt: entry.CreatedAt,
Data: map[string]interface{}{
"backup_id": entry.ID,
"database": entry.Database,
"encryption": encryption,
},
})
}
}
// Get catalog statistics for retention evidence
stats, err := g.catalog.Stats(ctx)
if err == nil {
evidence = append(evidence, Evidence{
ID: "RET-STATS",
Type: EvidenceRetentionProof,
Description: "Backup retention statistics",
Source: "catalog",
CollectedAt: time.Now(),
Data: map[string]interface{}{
"total_backups": stats.TotalBackups,
"oldest_backup": stats.OldestBackup,
"newest_backup": stats.NewestBackup,
"average_size": stats.AvgSize,
"total_size": stats.TotalSize,
"databases": len(stats.ByDatabase),
},
})
}
// Check for gaps
gapConfig := &catalog.GapDetectionConfig{
ExpectedInterval: 24 * time.Hour,
Tolerance: 2 * time.Hour,
StartDate: &g.config.PeriodStart,
EndDate: &g.config.PeriodEnd,
}
allGaps, err := g.catalog.DetectAllGaps(ctx, gapConfig)
if err == nil {
totalGaps := 0
for _, gaps := range allGaps {
totalGaps += len(gaps)
}
if totalGaps > 0 {
evidence = append(evidence, Evidence{
ID: "GAP-ANALYSIS",
Type: EvidenceAuditLog,
Description: "Backup gap analysis",
Source: "catalog",
CollectedAt: time.Now(),
Data: map[string]interface{}{
"gaps_detected": totalGaps,
"gaps": allGaps,
},
})
}
}
return evidence, nil
}
// evaluateControls evaluates compliance controls based on evidence
func (g *Generator) evaluateControls(report *Report, evidence []Evidence) error {
// Index evidence by type for quick lookup
evidenceByType := make(map[EvidenceType][]Evidence)
for _, e := range evidence {
evidenceByType[e.Type] = append(evidenceByType[e.Type], e)
}
// Get backup statistics
backupEvidence := evidenceByType[EvidenceBackupLog]
encryptionEvidence := evidenceByType[EvidenceEncryptionProof]
drillEvidence := evidenceByType[EvidenceDrillResult]
verificationEvidence := evidenceByType[EvidenceAuditLog]
// Evaluate each control
for i := range report.Categories {
cat := &report.Categories[i]
catCompliant := 0
catTotal := 0
for j := range cat.Controls {
ctrl := &cat.Controls[j]
ctrl.LastChecked = time.Now()
catTotal++
// Evaluate based on control type
status, notes, evidenceIDs := g.evaluateControl(ctrl, backupEvidence, encryptionEvidence, drillEvidence, verificationEvidence)
ctrl.Status = status
ctrl.Notes = notes
ctrl.Evidence = evidenceIDs
if status == StatusCompliant {
catCompliant++
} else if status != StatusNotApplicable {
// Create finding for non-compliant controls
finding := g.createFinding(ctrl, report)
if finding != nil {
report.AddFinding(*finding)
ctrl.Findings = append(ctrl.Findings, finding.ID)
}
}
}
// Calculate category score
if catTotal > 0 {
cat.Score = float64(catCompliant) / float64(catTotal) * 100
if cat.Score >= 100 {
cat.Status = StatusCompliant
} else if cat.Score >= 70 {
cat.Status = StatusPartial
} else {
cat.Status = StatusNonCompliant
}
}
}
return nil
}
// evaluateControl evaluates a single control
func (g *Generator) evaluateControl(ctrl *Control, backups, encryption, drills, verifications []Evidence) (ComplianceStatus, string, []string) {
var evidenceIDs []string
switch ctrl.ID {
// SOC2 Controls
case "CC6.1", "GDPR-32", "164.312a2iv", "PCI-3.4", "A.10.1.1":
// Encryption at rest
if len(encryption) == 0 {
return StatusNonCompliant, "No encrypted backups found", nil
}
encryptedCount := len(encryption)
totalCount := len(backups)
if totalCount == 0 {
return StatusNotApplicable, "No backups in period", nil
}
rate := float64(encryptedCount) / float64(totalCount) * 100
for _, e := range encryption {
evidenceIDs = append(evidenceIDs, e.ID)
}
if rate >= 100 {
return StatusCompliant, fmt.Sprintf("100%% of backups encrypted (%d/%d)", encryptedCount, totalCount), evidenceIDs
}
if rate >= 90 {
return StatusPartial, fmt.Sprintf("%.1f%% of backups encrypted (%d/%d)", rate, encryptedCount, totalCount), evidenceIDs
}
return StatusNonCompliant, fmt.Sprintf("Only %.1f%% of backups encrypted", rate), evidenceIDs
case "A1.1", "164.308a7iA", "A.12.3.1":
// Backup policy/plan
if len(backups) == 0 {
return StatusNonCompliant, "No backups found in period", nil
}
for _, e := range backups[:min(5, len(backups))] {
evidenceIDs = append(evidenceIDs, e.ID)
}
return StatusCompliant, fmt.Sprintf("%d backups created in period", len(backups)), evidenceIDs
case "A1.2", "164.308a7iD", "A.17.1.3":
// Backup testing
if len(drills) == 0 {
return StatusNonCompliant, "No DR drill tests performed", nil
}
for _, e := range drills {
evidenceIDs = append(evidenceIDs, e.ID)
}
return StatusCompliant, fmt.Sprintf("%d DR drill tests completed", len(drills)), evidenceIDs
case "A1.3", "A1.4", "164.308a7iB", "A.17.1.1", "A.17.1.2", "PCI-12.10.1":
// DR procedures
if len(drills) > 0 {
for _, e := range drills {
evidenceIDs = append(evidenceIDs, e.ID)
}
return StatusCompliant, "DR procedures tested", evidenceIDs
}
return StatusPartial, "DR procedures exist but not tested", nil
case "PI1.1", "164.312c1":
// Data integrity
integrityCount := 0
for _, e := range backups {
if data, ok := e.Data.(map[string]interface{}); ok {
if checksum, ok := data["checksum"].(string); ok && checksum != "" {
integrityCount++
evidenceIDs = append(evidenceIDs, e.ID)
}
}
}
if integrityCount == len(backups) && len(backups) > 0 {
return StatusCompliant, "All backups have integrity checksums", evidenceIDs
}
if integrityCount > 0 {
return StatusPartial, fmt.Sprintf("%d/%d backups have checksums", integrityCount, len(backups)), evidenceIDs
}
return StatusNonCompliant, "No integrity checksums found", nil
case "C1.2", "GDPR-5.1e", "PCI-3.1":
// Data retention
for _, e := range verifications {
if e.Type == EvidenceRetentionProof {
evidenceIDs = append(evidenceIDs, e.ID)
}
}
if len(backups) > 0 {
return StatusCompliant, "Retention policy in effect", evidenceIDs
}
return StatusPartial, "Retention policy needs review", nil
default:
// Generic evaluation
if len(backups) > 0 {
return StatusCompliant, "Evidence available", nil
}
return StatusUnknown, "Requires manual review", nil
}
}
// createFinding creates a finding for a non-compliant control
func (g *Generator) createFinding(ctrl *Control, report *Report) *Finding {
if ctrl.Status == StatusCompliant || ctrl.Status == StatusNotApplicable {
return nil
}
severity := SeverityMedium
findingType := FindingGap
// Determine severity based on control
switch ctrl.ID {
case "CC6.1", "164.312a2iv", "PCI-3.4":
severity = SeverityHigh
findingType = FindingViolation
case "A1.2", "164.308a7iD":
severity = SeverityMedium
findingType = FindingGap
}
return &Finding{
ID: fmt.Sprintf("FND-%s-%d", ctrl.ID, time.Now().UnixNano()),
ControlID: ctrl.ID,
Type: findingType,
Severity: severity,
Title: fmt.Sprintf("%s: %s", ctrl.Reference, ctrl.Name),
Description: ctrl.Notes,
Impact: fmt.Sprintf("Non-compliance with %s requirements", report.Type),
Recommendation: g.getRecommendation(ctrl.ID),
Status: FindingOpen,
DetectedAt: time.Now(),
Evidence: ctrl.Evidence,
}
}
// getRecommendation returns remediation recommendation for a control
func (g *Generator) getRecommendation(controlID string) string {
recommendations := map[string]string{
"CC6.1": "Enable encryption for all backups using AES-256",
"CC6.7": "Ensure all backup transfers use TLS",
"A1.1": "Establish and document backup schedule",
"A1.2": "Schedule and perform regular DR drill tests",
"A1.3": "Document and test recovery procedures",
"A1.4": "Develop and test disaster recovery plan",
"PI1.1": "Enable checksum verification for all backups",
"C1.2": "Implement and document retention policies",
"164.312a2iv": "Enable HIPAA-compliant encryption (AES-256)",
"164.308a7iD": "Test backup recoverability quarterly",
"PCI-3.4": "Encrypt all backups containing cardholder data",
}
if rec, ok := recommendations[controlID]; ok {
return rec
}
return "Review and remediate as per compliance requirements"
}
func min(a, b int) int {
if a < b {
return a
}
return b
}

544
internal/report/output.go Normal file
View File

@@ -0,0 +1,544 @@
// Package report - Output formatters
package report
import (
"encoding/json"
"fmt"
"io"
"strings"
"text/template"
"time"
)
// Formatter formats reports for output
type Formatter interface {
Format(report *Report, w io.Writer) error
}
// JSONFormatter formats reports as JSON
type JSONFormatter struct {
Indent bool
}
// Format writes the report as JSON
func (f *JSONFormatter) Format(report *Report, w io.Writer) error {
var data []byte
var err error
if f.Indent {
data, err = json.MarshalIndent(report, "", " ")
} else {
data, err = json.Marshal(report)
}
if err != nil {
return err
}
_, err = w.Write(data)
return err
}
// MarkdownFormatter formats reports as Markdown
type MarkdownFormatter struct{}
// Format writes the report as Markdown
func (f *MarkdownFormatter) Format(report *Report, w io.Writer) error {
tmpl := template.Must(template.New("report").Funcs(template.FuncMap{
"statusIcon": StatusIcon,
"severityIcon": SeverityIcon,
"formatTime": func(t time.Time) string { return t.Format("2006-01-02 15:04:05") },
"formatDate": func(t time.Time) string { return t.Format("2006-01-02") },
"upper": strings.ToUpper,
}).Parse(markdownTemplate))
return tmpl.Execute(w, report)
}
const markdownTemplate = `# {{.Title}}
**Generated:** {{formatTime .GeneratedAt}}
**Period:** {{formatDate .PeriodStart}} to {{formatDate .PeriodEnd}}
**Overall Status:** {{statusIcon .Status}} {{.Status}}
**Compliance Score:** {{printf "%.1f" .Score}}%
---
## Executive Summary
{{.Description}}
| Metric | Value |
|--------|-------|
| Total Controls | {{.Summary.TotalControls}} |
| Compliant | {{.Summary.CompliantControls}} |
| Non-Compliant | {{.Summary.NonCompliantControls}} |
| Partial | {{.Summary.PartialControls}} |
| Compliance Rate | {{printf "%.1f" .Summary.ComplianceRate}}% |
| Open Findings | {{.Summary.OpenFindings}} |
| Risk Score | {{printf "%.1f" .Summary.RiskScore}} |
---
## Compliance Categories
{{range .Categories}}
### {{statusIcon .Status}} {{.Name}}
**Status:** {{.Status}} | **Score:** {{printf "%.1f" .Score}}%
{{.Description}}
| Control | Reference | Status | Notes |
|---------|-----------|--------|-------|
{{range .Controls}}| {{.Name}} | {{.Reference}} | {{statusIcon .Status}} | {{.Notes}} |
{{end}}
{{end}}
---
## Findings
{{if .Findings}}
| ID | Severity | Title | Status |
|----|----------|-------|--------|
{{range .Findings}}| {{.ID}} | {{severityIcon .Severity}} {{.Severity}} | {{.Title}} | {{.Status}} |
{{end}}
### Finding Details
{{range .Findings}}
#### {{severityIcon .Severity}} {{.Title}}
- **ID:** {{.ID}}
- **Control:** {{.ControlID}}
- **Severity:** {{.Severity}}
- **Type:** {{.Type}}
- **Status:** {{.Status}}
- **Detected:** {{formatTime .DetectedAt}}
**Description:** {{.Description}}
**Impact:** {{.Impact}}
**Recommendation:** {{.Recommendation}}
---
{{end}}
{{else}}
No open findings.
{{end}}
---
## Evidence Summary
{{if .Evidence}}
| ID | Type | Description | Collected |
|----|------|-------------|-----------|
{{range .Evidence}}| {{.ID}} | {{.Type}} | {{.Description}} | {{formatTime .CollectedAt}} |
{{end}}
{{else}}
No evidence collected.
{{end}}
---
*Report generated by dbbackup compliance module*
`
// HTMLFormatter formats reports as HTML
type HTMLFormatter struct{}
// Format writes the report as HTML
func (f *HTMLFormatter) Format(report *Report, w io.Writer) error {
tmpl := template.Must(template.New("report").Funcs(template.FuncMap{
"statusIcon": StatusIcon,
"statusClass": statusClass,
"severityIcon": SeverityIcon,
"severityClass": severityClass,
"formatTime": func(t time.Time) string { return t.Format("2006-01-02 15:04:05") },
"formatDate": func(t time.Time) string { return t.Format("2006-01-02") },
}).Parse(htmlTemplate))
return tmpl.Execute(w, report)
}
func statusClass(s ComplianceStatus) string {
switch s {
case StatusCompliant:
return "status-compliant"
case StatusNonCompliant:
return "status-noncompliant"
case StatusPartial:
return "status-partial"
default:
return "status-unknown"
}
}
func severityClass(s FindingSeverity) string {
switch s {
case SeverityCritical:
return "severity-critical"
case SeverityHigh:
return "severity-high"
case SeverityMedium:
return "severity-medium"
case SeverityLow:
return "severity-low"
default:
return "severity-unknown"
}
}
const htmlTemplate = `<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>{{.Title}}</title>
<style>
:root {
--color-compliant: #27ae60;
--color-noncompliant: #e74c3c;
--color-partial: #f39c12;
--color-unknown: #95a5a6;
}
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
line-height: 1.6;
max-width: 1200px;
margin: 0 auto;
padding: 20px;
background: #f5f5f5;
}
.report-header {
background: white;
padding: 30px;
border-radius: 8px;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
margin-bottom: 20px;
}
.report-header h1 {
margin: 0 0 10px 0;
color: #2c3e50;
}
.report-meta {
color: #7f8c8d;
font-size: 0.9em;
}
.status-badge {
display: inline-block;
padding: 4px 12px;
border-radius: 20px;
font-weight: bold;
font-size: 0.85em;
}
.status-compliant { background: var(--color-compliant); color: white; }
.status-noncompliant { background: var(--color-noncompliant); color: white; }
.status-partial { background: var(--color-partial); color: white; }
.status-unknown { background: var(--color-unknown); color: white; }
.score-display {
font-size: 48px;
font-weight: bold;
color: #2c3e50;
}
.summary-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(150px, 1fr));
gap: 20px;
margin: 20px 0;
}
.summary-card {
background: white;
padding: 20px;
border-radius: 8px;
text-align: center;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
.summary-card .value {
font-size: 32px;
font-weight: bold;
color: #2c3e50;
}
.summary-card .label {
color: #7f8c8d;
font-size: 0.9em;
}
.section {
background: white;
padding: 20px;
border-radius: 8px;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
margin-bottom: 20px;
}
.section h2 {
margin-top: 0;
color: #2c3e50;
border-bottom: 2px solid #ecf0f1;
padding-bottom: 10px;
}
.category {
margin-bottom: 30px;
}
.category h3 {
display: flex;
align-items: center;
gap: 10px;
}
table {
width: 100%;
border-collapse: collapse;
margin: 10px 0;
}
th, td {
padding: 12px;
text-align: left;
border-bottom: 1px solid #ecf0f1;
}
th {
background: #f8f9fa;
font-weight: 600;
}
tr:hover {
background: #f8f9fa;
}
.finding-card {
border-left: 4px solid var(--color-unknown);
padding: 15px;
margin: 15px 0;
background: #f8f9fa;
border-radius: 0 8px 8px 0;
}
.finding-card.severity-critical { border-color: #e74c3c; }
.finding-card.severity-high { border-color: #e67e22; }
.finding-card.severity-medium { border-color: #f39c12; }
.finding-card.severity-low { border-color: #27ae60; }
.finding-card h4 {
margin: 0 0 10px 0;
}
.finding-meta {
font-size: 0.85em;
color: #7f8c8d;
}
.progress-bar {
height: 8px;
background: #ecf0f1;
border-radius: 4px;
overflow: hidden;
}
.progress-fill {
height: 100%;
background: var(--color-compliant);
transition: width 0.3s ease;
}
.footer {
text-align: center;
color: #7f8c8d;
padding: 20px;
font-size: 0.85em;
}
</style>
</head>
<body>
<div class="report-header">
<h1>{{.Title}}</h1>
<p class="report-meta">
Generated: {{formatTime .GeneratedAt}} |
Period: {{formatDate .PeriodStart}} to {{formatDate .PeriodEnd}}
</p>
<div style="display: flex; align-items: center; gap: 20px; margin-top: 20px;">
<div class="score-display">{{printf "%.0f" .Score}}%</div>
<div>
<span class="status-badge {{statusClass .Status}}">{{.Status}}</span>
<p style="margin: 10px 0 0 0; color: #7f8c8d;">{{.Description}}</p>
</div>
</div>
</div>
<div class="summary-grid">
<div class="summary-card">
<div class="value">{{.Summary.TotalControls}}</div>
<div class="label">Total Controls</div>
</div>
<div class="summary-card">
<div class="value" style="color: var(--color-compliant);">{{.Summary.CompliantControls}}</div>
<div class="label">Compliant</div>
</div>
<div class="summary-card">
<div class="value" style="color: var(--color-noncompliant);">{{.Summary.NonCompliantControls}}</div>
<div class="label">Non-Compliant</div>
</div>
<div class="summary-card">
<div class="value" style="color: var(--color-partial);">{{.Summary.PartialControls}}</div>
<div class="label">Partial</div>
</div>
<div class="summary-card">
<div class="value">{{.Summary.OpenFindings}}</div>
<div class="label">Open Findings</div>
</div>
<div class="summary-card">
<div class="value">{{printf "%.1f" .Summary.RiskScore}}</div>
<div class="label">Risk Score</div>
</div>
</div>
<div class="section">
<h2>Compliance Categories</h2>
{{range .Categories}}
<div class="category">
<h3>
{{statusIcon .Status}} {{.Name}}
<span class="status-badge {{statusClass .Status}}">{{printf "%.0f" .Score}}%</span>
</h3>
<p style="color: #7f8c8d;">{{.Description}}</p>
<div class="progress-bar">
<div class="progress-fill" style="width: {{printf "%.0f" .Score}}%;"></div>
</div>
<table>
<thead>
<tr>
<th>Control</th>
<th>Reference</th>
<th>Status</th>
<th>Notes</th>
</tr>
</thead>
<tbody>
{{range .Controls}}
<tr>
<td>{{.Name}}</td>
<td>{{.Reference}}</td>
<td>{{statusIcon .Status}}</td>
<td>{{.Notes}}</td>
</tr>
{{end}}
</tbody>
</table>
</div>
{{end}}
</div>
{{if .Findings}}
<div class="section">
<h2>Findings ({{len .Findings}})</h2>
{{range .Findings}}
<div class="finding-card {{severityClass .Severity}}">
<h4>{{severityIcon .Severity}} {{.Title}}</h4>
<div class="finding-meta">
<strong>ID:</strong> {{.ID}} |
<strong>Severity:</strong> {{.Severity}} |
<strong>Status:</strong> {{.Status}} |
<strong>Detected:</strong> {{formatTime .DetectedAt}}
</div>
<p><strong>Description:</strong> {{.Description}}</p>
<p><strong>Impact:</strong> {{.Impact}}</p>
<p><strong>Recommendation:</strong> {{.Recommendation}}</p>
</div>
{{end}}
</div>
{{end}}
{{if .Evidence}}
<div class="section">
<h2>Evidence ({{len .Evidence}} items)</h2>
<table>
<thead>
<tr>
<th>ID</th>
<th>Type</th>
<th>Description</th>
<th>Collected</th>
</tr>
</thead>
<tbody>
{{range .Evidence}}
<tr>
<td>{{.ID}}</td>
<td>{{.Type}}</td>
<td>{{.Description}}</td>
<td>{{formatTime .CollectedAt}}</td>
</tr>
{{end}}
</tbody>
</table>
</div>
{{end}}
<div class="footer">
<p>Report generated by dbbackup compliance module</p>
<p>Report ID: {{.ID}}</p>
</div>
</body>
</html>`
// GetFormatter returns a formatter for the given format
func GetFormatter(format OutputFormat) Formatter {
switch format {
case FormatJSON:
return &JSONFormatter{Indent: true}
case FormatMarkdown:
return &MarkdownFormatter{}
case FormatHTML:
return &HTMLFormatter{}
default:
return &JSONFormatter{Indent: true}
}
}
// ConsoleFormatter formats reports for terminal output
type ConsoleFormatter struct{}
// Format writes the report to console
func (f *ConsoleFormatter) Format(report *Report, w io.Writer) error {
// Header
fmt.Fprintf(w, "\n%s\n", strings.Repeat("=", 60))
fmt.Fprintf(w, " %s\n", report.Title)
fmt.Fprintf(w, "%s\n\n", strings.Repeat("=", 60))
fmt.Fprintf(w, " Generated: %s\n", report.GeneratedAt.Format("2006-01-02 15:04:05"))
fmt.Fprintf(w, " Period: %s to %s\n",
report.PeriodStart.Format("2006-01-02"),
report.PeriodEnd.Format("2006-01-02"))
fmt.Fprintf(w, " Status: %s %s\n", StatusIcon(report.Status), report.Status)
fmt.Fprintf(w, " Score: %.1f%%\n\n", report.Score)
// Summary
fmt.Fprintf(w, " SUMMARY\n")
fmt.Fprintf(w, " %s\n", strings.Repeat("-", 40))
fmt.Fprintf(w, " Controls: %d total, %d compliant, %d non-compliant\n",
report.Summary.TotalControls,
report.Summary.CompliantControls,
report.Summary.NonCompliantControls)
fmt.Fprintf(w, " Compliance: %.1f%%\n", report.Summary.ComplianceRate)
fmt.Fprintf(w, " Open Findings: %d (critical: %d, high: %d)\n",
report.Summary.OpenFindings,
report.Summary.CriticalFindings,
report.Summary.HighFindings)
fmt.Fprintf(w, " Risk Score: %.1f\n\n", report.Summary.RiskScore)
// Categories
fmt.Fprintf(w, " CATEGORIES\n")
fmt.Fprintf(w, " %s\n", strings.Repeat("-", 40))
for _, cat := range report.Categories {
fmt.Fprintf(w, " %s %-25s %.0f%%\n", StatusIcon(cat.Status), cat.Name, cat.Score)
}
fmt.Fprintln(w)
// Findings
if len(report.Findings) > 0 {
fmt.Fprintf(w, " FINDINGS\n")
fmt.Fprintf(w, " %s\n", strings.Repeat("-", 40))
for _, f := range report.Findings {
fmt.Fprintf(w, " %s [%s] %s\n", SeverityIcon(f.Severity), f.Severity, f.Title)
fmt.Fprintf(w, " %s\n", f.Description)
}
fmt.Fprintln(w)
}
fmt.Fprintf(w, "%s\n", strings.Repeat("=", 60))
return nil
}

325
internal/report/report.go Normal file
View File

@@ -0,0 +1,325 @@
// Package report provides compliance report generation
package report
import (
"encoding/json"
"fmt"
"time"
)
// ReportType represents the compliance framework type
type ReportType string
const (
ReportSOC2 ReportType = "soc2"
ReportGDPR ReportType = "gdpr"
ReportHIPAA ReportType = "hipaa"
ReportPCIDSS ReportType = "pci-dss"
ReportISO27001 ReportType = "iso27001"
ReportCustom ReportType = "custom"
)
// ComplianceStatus represents the status of a compliance check
type ComplianceStatus string
const (
StatusCompliant ComplianceStatus = "compliant"
StatusNonCompliant ComplianceStatus = "non_compliant"
StatusPartial ComplianceStatus = "partial"
StatusNotApplicable ComplianceStatus = "not_applicable"
StatusUnknown ComplianceStatus = "unknown"
)
// Report represents a compliance report
type Report struct {
ID string `json:"id"`
Type ReportType `json:"type"`
Title string `json:"title"`
Description string `json:"description"`
GeneratedAt time.Time `json:"generated_at"`
GeneratedBy string `json:"generated_by"`
PeriodStart time.Time `json:"period_start"`
PeriodEnd time.Time `json:"period_end"`
Status ComplianceStatus `json:"overall_status"`
Score float64 `json:"score"` // 0-100
Categories []Category `json:"categories"`
Summary Summary `json:"summary"`
Findings []Finding `json:"findings"`
Evidence []Evidence `json:"evidence"`
Metadata map[string]string `json:"metadata,omitempty"`
}
// Category represents a compliance category
type Category struct {
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Status ComplianceStatus `json:"status"`
Score float64 `json:"score"`
Weight float64 `json:"weight"`
Controls []Control `json:"controls"`
}
// Control represents a compliance control
type Control struct {
ID string `json:"id"`
Reference string `json:"reference"` // e.g., "SOC2 CC6.1"
Name string `json:"name"`
Description string `json:"description"`
Status ComplianceStatus `json:"status"`
Evidence []string `json:"evidence_ids,omitempty"`
Findings []string `json:"finding_ids,omitempty"`
LastChecked time.Time `json:"last_checked"`
Notes string `json:"notes,omitempty"`
}
// Finding represents a compliance finding
type Finding struct {
ID string `json:"id"`
ControlID string `json:"control_id"`
Type FindingType `json:"type"`
Severity FindingSeverity `json:"severity"`
Title string `json:"title"`
Description string `json:"description"`
Impact string `json:"impact"`
Recommendation string `json:"recommendation"`
Status FindingStatus `json:"status"`
DetectedAt time.Time `json:"detected_at"`
ResolvedAt *time.Time `json:"resolved_at,omitempty"`
Evidence []string `json:"evidence_ids,omitempty"`
}
// FindingType represents the type of finding
type FindingType string
const (
FindingGap FindingType = "gap"
FindingViolation FindingType = "violation"
FindingObservation FindingType = "observation"
FindingRecommendation FindingType = "recommendation"
)
// FindingSeverity represents finding severity
type FindingSeverity string
const (
SeverityLow FindingSeverity = "low"
SeverityMedium FindingSeverity = "medium"
SeverityHigh FindingSeverity = "high"
SeverityCritical FindingSeverity = "critical"
)
// FindingStatus represents finding status
type FindingStatus string
const (
FindingOpen FindingStatus = "open"
FindingAccepted FindingStatus = "accepted"
FindingResolved FindingStatus = "resolved"
)
// Evidence represents compliance evidence
type Evidence struct {
ID string `json:"id"`
Type EvidenceType `json:"type"`
Description string `json:"description"`
Source string `json:"source"`
CollectedAt time.Time `json:"collected_at"`
Hash string `json:"hash,omitempty"`
Data interface{} `json:"data,omitempty"`
}
// EvidenceType represents the type of evidence
type EvidenceType string
const (
EvidenceBackupLog EvidenceType = "backup_log"
EvidenceRestoreLog EvidenceType = "restore_log"
EvidenceDrillResult EvidenceType = "drill_result"
EvidenceEncryptionProof EvidenceType = "encryption_proof"
EvidenceRetentionProof EvidenceType = "retention_proof"
EvidenceAccessLog EvidenceType = "access_log"
EvidenceAuditLog EvidenceType = "audit_log"
EvidenceConfiguration EvidenceType = "configuration"
EvidenceScreenshot EvidenceType = "screenshot"
EvidenceOther EvidenceType = "other"
)
// Summary provides a high-level overview
type Summary struct {
TotalControls int `json:"total_controls"`
CompliantControls int `json:"compliant_controls"`
NonCompliantControls int `json:"non_compliant_controls"`
PartialControls int `json:"partial_controls"`
NotApplicable int `json:"not_applicable"`
OpenFindings int `json:"open_findings"`
CriticalFindings int `json:"critical_findings"`
HighFindings int `json:"high_findings"`
MediumFindings int `json:"medium_findings"`
LowFindings int `json:"low_findings"`
ComplianceRate float64 `json:"compliance_rate"`
RiskScore float64 `json:"risk_score"`
}
// ReportConfig configures report generation
type ReportConfig struct {
Type ReportType `json:"type"`
Title string `json:"title"`
Description string `json:"description"`
PeriodStart time.Time `json:"period_start"`
PeriodEnd time.Time `json:"period_end"`
IncludeDatabases []string `json:"include_databases,omitempty"`
ExcludeDatabases []string `json:"exclude_databases,omitempty"`
CatalogPath string `json:"catalog_path"`
OutputFormat OutputFormat `json:"output_format"`
OutputPath string `json:"output_path"`
IncludeEvidence bool `json:"include_evidence"`
CustomControls []Control `json:"custom_controls,omitempty"`
}
// OutputFormat represents report output format
type OutputFormat string
const (
FormatJSON OutputFormat = "json"
FormatHTML OutputFormat = "html"
FormatPDF OutputFormat = "pdf"
FormatMarkdown OutputFormat = "markdown"
)
// NewReport creates a new report
func NewReport(reportType ReportType, title string) *Report {
return &Report{
ID: generateID(),
Type: reportType,
Title: title,
GeneratedAt: time.Now(),
Categories: make([]Category, 0),
Findings: make([]Finding, 0),
Evidence: make([]Evidence, 0),
Metadata: make(map[string]string),
}
}
// AddCategory adds a category to the report
func (r *Report) AddCategory(cat Category) {
r.Categories = append(r.Categories, cat)
}
// AddFinding adds a finding to the report
func (r *Report) AddFinding(f Finding) {
r.Findings = append(r.Findings, f)
}
// AddEvidence adds evidence to the report
func (r *Report) AddEvidence(e Evidence) {
r.Evidence = append(r.Evidence, e)
}
// Calculate computes the summary and overall status
func (r *Report) Calculate() {
var totalWeight float64
var weightedScore float64
for _, cat := range r.Categories {
totalWeight += cat.Weight
weightedScore += cat.Score * cat.Weight
for _, ctrl := range cat.Controls {
r.Summary.TotalControls++
switch ctrl.Status {
case StatusCompliant:
r.Summary.CompliantControls++
case StatusNonCompliant:
r.Summary.NonCompliantControls++
case StatusPartial:
r.Summary.PartialControls++
case StatusNotApplicable:
r.Summary.NotApplicable++
}
}
}
for _, f := range r.Findings {
if f.Status == FindingOpen {
r.Summary.OpenFindings++
switch f.Severity {
case SeverityCritical:
r.Summary.CriticalFindings++
case SeverityHigh:
r.Summary.HighFindings++
case SeverityMedium:
r.Summary.MediumFindings++
case SeverityLow:
r.Summary.LowFindings++
}
}
}
if totalWeight > 0 {
r.Score = weightedScore / totalWeight
}
applicable := r.Summary.TotalControls - r.Summary.NotApplicable
if applicable > 0 {
r.Summary.ComplianceRate = float64(r.Summary.CompliantControls) / float64(applicable) * 100
}
// Calculate risk score based on findings
r.Summary.RiskScore = float64(r.Summary.CriticalFindings)*10 +
float64(r.Summary.HighFindings)*5 +
float64(r.Summary.MediumFindings)*2 +
float64(r.Summary.LowFindings)*1
// Determine overall status
if r.Summary.NonCompliantControls == 0 && r.Summary.CriticalFindings == 0 {
if r.Summary.PartialControls == 0 {
r.Status = StatusCompliant
} else {
r.Status = StatusPartial
}
} else {
r.Status = StatusNonCompliant
}
}
// ToJSON converts the report to JSON
func (r *Report) ToJSON() ([]byte, error) {
return json.MarshalIndent(r, "", " ")
}
func generateID() string {
return fmt.Sprintf("RPT-%d", time.Now().UnixNano())
}
// StatusIcon returns an icon for a compliance status
func StatusIcon(s ComplianceStatus) string {
switch s {
case StatusCompliant:
return "✅"
case StatusNonCompliant:
return "❌"
case StatusPartial:
return "⚠️"
case StatusNotApplicable:
return ""
default:
return "❓"
}
}
// SeverityIcon returns an icon for a finding severity
func SeverityIcon(s FindingSeverity) string {
switch s {
case SeverityCritical:
return "🔴"
case SeverityHigh:
return "🟠"
case SeverityMedium:
return "🟡"
case SeverityLow:
return "🟢"
default:
return "⚪"
}
}

481
internal/rto/calculator.go Normal file
View File

@@ -0,0 +1,481 @@
// Package rto provides RTO/RPO calculation and analysis
package rto
import (
"context"
"fmt"
"sort"
"time"
"dbbackup/internal/catalog"
)
// Calculator calculates RTO and RPO metrics
type Calculator struct {
catalog catalog.Catalog
config Config
}
// Config configures RTO/RPO calculations
type Config struct {
TargetRTO time.Duration `json:"target_rto"` // Target Recovery Time Objective
TargetRPO time.Duration `json:"target_rpo"` // Target Recovery Point Objective
// Assumptions for calculation
NetworkSpeedMbps float64 `json:"network_speed_mbps"` // Network speed for cloud restores
DiskReadSpeedMBps float64 `json:"disk_read_speed_mbps"` // Disk read speed
DiskWriteSpeedMBps float64 `json:"disk_write_speed_mbps"` // Disk write speed
CloudDownloadSpeedMbps float64 `json:"cloud_download_speed_mbps"`
// Time estimates for various operations
StartupTimeMinutes int `json:"startup_time_minutes"` // DB startup time
ValidationTimeMinutes int `json:"validation_time_minutes"` // Post-restore validation
SwitchoverTimeMinutes int `json:"switchover_time_minutes"` // Application switchover time
}
// DefaultConfig returns sensible defaults
func DefaultConfig() Config {
return Config{
TargetRTO: 4 * time.Hour,
TargetRPO: 1 * time.Hour,
NetworkSpeedMbps: 100,
DiskReadSpeedMBps: 100,
DiskWriteSpeedMBps: 50,
CloudDownloadSpeedMbps: 100,
StartupTimeMinutes: 2,
ValidationTimeMinutes: 5,
SwitchoverTimeMinutes: 5,
}
}
// Analysis contains RTO/RPO analysis results
type Analysis struct {
Database string `json:"database"`
Timestamp time.Time `json:"timestamp"`
// Current state
CurrentRPO time.Duration `json:"current_rpo"`
CurrentRTO time.Duration `json:"current_rto"`
// Target state
TargetRPO time.Duration `json:"target_rpo"`
TargetRTO time.Duration `json:"target_rto"`
// Compliance
RPOCompliant bool `json:"rpo_compliant"`
RTOCompliant bool `json:"rto_compliant"`
// Details
LastBackup *time.Time `json:"last_backup,omitempty"`
NextScheduled *time.Time `json:"next_scheduled,omitempty"`
BackupInterval time.Duration `json:"backup_interval"`
// RTO breakdown
RTOBreakdown RTOBreakdown `json:"rto_breakdown"`
// Recommendations
Recommendations []Recommendation `json:"recommendations,omitempty"`
// Historical
History []HistoricalPoint `json:"history,omitempty"`
}
// RTOBreakdown shows components of RTO calculation
type RTOBreakdown struct {
DetectionTime time.Duration `json:"detection_time"`
DecisionTime time.Duration `json:"decision_time"`
DownloadTime time.Duration `json:"download_time"`
RestoreTime time.Duration `json:"restore_time"`
StartupTime time.Duration `json:"startup_time"`
ValidationTime time.Duration `json:"validation_time"`
SwitchoverTime time.Duration `json:"switchover_time"`
TotalTime time.Duration `json:"total_time"`
}
// Recommendation suggests improvements
type Recommendation struct {
Type RecommendationType `json:"type"`
Priority Priority `json:"priority"`
Title string `json:"title"`
Description string `json:"description"`
Impact string `json:"impact"`
Effort Effort `json:"effort"`
}
// RecommendationType categorizes recommendations
type RecommendationType string
const (
RecommendBackupFrequency RecommendationType = "backup_frequency"
RecommendIncrementalBackup RecommendationType = "incremental_backup"
RecommendCompression RecommendationType = "compression"
RecommendLocalCache RecommendationType = "local_cache"
RecommendParallelRestore RecommendationType = "parallel_restore"
RecommendWALArchiving RecommendationType = "wal_archiving"
RecommendReplication RecommendationType = "replication"
)
// Priority levels
type Priority string
const (
PriorityCritical Priority = "critical"
PriorityHigh Priority = "high"
PriorityMedium Priority = "medium"
PriorityLow Priority = "low"
)
// Effort levels
type Effort string
const (
EffortLow Effort = "low"
EffortMedium Effort = "medium"
EffortHigh Effort = "high"
)
// HistoricalPoint tracks RTO/RPO over time
type HistoricalPoint struct {
Timestamp time.Time `json:"timestamp"`
RPO time.Duration `json:"rpo"`
RTO time.Duration `json:"rto"`
}
// NewCalculator creates a new RTO/RPO calculator
func NewCalculator(cat catalog.Catalog, config Config) *Calculator {
return &Calculator{
catalog: cat,
config: config,
}
}
// Analyze performs RTO/RPO analysis for a database
func (c *Calculator) Analyze(ctx context.Context, database string) (*Analysis, error) {
analysis := &Analysis{
Database: database,
Timestamp: time.Now(),
TargetRPO: c.config.TargetRPO,
TargetRTO: c.config.TargetRTO,
}
// Get recent backups
entries, err := c.catalog.List(ctx, database, 100)
if err != nil {
return nil, fmt.Errorf("failed to list backups: %w", err)
}
if len(entries) == 0 {
// No backups - worst case scenario
analysis.CurrentRPO = 0 // undefined
analysis.CurrentRTO = 0 // undefined
analysis.Recommendations = append(analysis.Recommendations, Recommendation{
Type: RecommendBackupFrequency,
Priority: PriorityCritical,
Title: "No Backups Found",
Description: "No backups exist for this database",
Impact: "Cannot recover in case of failure",
Effort: EffortLow,
})
return analysis, nil
}
// Calculate current RPO (time since last backup)
lastBackup := entries[0].CreatedAt
analysis.LastBackup = &lastBackup
analysis.CurrentRPO = time.Since(lastBackup)
analysis.RPOCompliant = analysis.CurrentRPO <= c.config.TargetRPO
// Calculate backup interval
if len(entries) >= 2 {
analysis.BackupInterval = calculateAverageInterval(entries)
}
// Calculate RTO
latestEntry := entries[0]
analysis.RTOBreakdown = c.calculateRTOBreakdown(latestEntry)
analysis.CurrentRTO = analysis.RTOBreakdown.TotalTime
analysis.RTOCompliant = analysis.CurrentRTO <= c.config.TargetRTO
// Generate recommendations
analysis.Recommendations = c.generateRecommendations(analysis, entries)
// Calculate history
analysis.History = c.calculateHistory(entries)
return analysis, nil
}
// AnalyzeAll analyzes all databases
func (c *Calculator) AnalyzeAll(ctx context.Context) ([]*Analysis, error) {
databases, err := c.catalog.ListDatabases(ctx)
if err != nil {
return nil, fmt.Errorf("failed to list databases: %w", err)
}
var analyses []*Analysis
for _, db := range databases {
analysis, err := c.Analyze(ctx, db)
if err != nil {
continue // Skip errors for individual databases
}
analyses = append(analyses, analysis)
}
return analyses, nil
}
// calculateRTOBreakdown calculates RTO components
func (c *Calculator) calculateRTOBreakdown(entry *catalog.Entry) RTOBreakdown {
breakdown := RTOBreakdown{
// Detection time - assume monitoring catches issues quickly
DetectionTime: 5 * time.Minute,
// Decision time - human decision making
DecisionTime: 10 * time.Minute,
// Startup time
StartupTime: time.Duration(c.config.StartupTimeMinutes) * time.Minute,
// Validation time
ValidationTime: time.Duration(c.config.ValidationTimeMinutes) * time.Minute,
// Switchover time
SwitchoverTime: time.Duration(c.config.SwitchoverTimeMinutes) * time.Minute,
}
// Calculate download time (if cloud backup)
if entry.CloudLocation != "" {
// Cloud download
bytesPerSecond := c.config.CloudDownloadSpeedMbps * 125000 // Mbps to bytes/sec
downloadSeconds := float64(entry.SizeBytes) / bytesPerSecond
breakdown.DownloadTime = time.Duration(downloadSeconds * float64(time.Second))
}
// Calculate restore time
// Estimate based on disk write speed
bytesPerSecond := c.config.DiskWriteSpeedMBps * 1000000 // MB/s to bytes/sec
restoreSeconds := float64(entry.SizeBytes) / bytesPerSecond
// Add overhead for decompression if compressed
if entry.Compression != "" && entry.Compression != "none" {
restoreSeconds *= 1.3 // 30% overhead for decompression
}
// Add overhead for decryption if encrypted
if entry.Encrypted {
restoreSeconds *= 1.1 // 10% overhead for decryption
}
breakdown.RestoreTime = time.Duration(restoreSeconds * float64(time.Second))
// Calculate total
breakdown.TotalTime = breakdown.DetectionTime +
breakdown.DecisionTime +
breakdown.DownloadTime +
breakdown.RestoreTime +
breakdown.StartupTime +
breakdown.ValidationTime +
breakdown.SwitchoverTime
return breakdown
}
// calculateAverageInterval calculates average time between backups
func calculateAverageInterval(entries []*catalog.Entry) time.Duration {
if len(entries) < 2 {
return 0
}
var totalInterval time.Duration
for i := 0; i < len(entries)-1; i++ {
interval := entries[i].CreatedAt.Sub(entries[i+1].CreatedAt)
totalInterval += interval
}
return totalInterval / time.Duration(len(entries)-1)
}
// generateRecommendations creates recommendations based on analysis
func (c *Calculator) generateRecommendations(analysis *Analysis, entries []*catalog.Entry) []Recommendation {
var recommendations []Recommendation
// RPO violations
if !analysis.RPOCompliant {
gap := analysis.CurrentRPO - c.config.TargetRPO
recommendations = append(recommendations, Recommendation{
Type: RecommendBackupFrequency,
Priority: PriorityCritical,
Title: "RPO Target Not Met",
Description: fmt.Sprintf("Current RPO (%s) exceeds target (%s) by %s",
formatDuration(analysis.CurrentRPO),
formatDuration(c.config.TargetRPO),
formatDuration(gap)),
Impact: "Potential data loss exceeds acceptable threshold",
Effort: EffortLow,
})
}
// RTO violations
if !analysis.RTOCompliant {
recommendations = append(recommendations, Recommendation{
Type: RecommendParallelRestore,
Priority: PriorityHigh,
Title: "RTO Target Not Met",
Description: fmt.Sprintf("Estimated recovery time (%s) exceeds target (%s)",
formatDuration(analysis.CurrentRTO),
formatDuration(c.config.TargetRTO)),
Impact: "Recovery may take longer than acceptable",
Effort: EffortMedium,
})
}
// Large download time
if analysis.RTOBreakdown.DownloadTime > 30*time.Minute {
recommendations = append(recommendations, Recommendation{
Type: RecommendLocalCache,
Priority: PriorityMedium,
Title: "Consider Local Backup Cache",
Description: fmt.Sprintf("Cloud download takes %s, local cache would reduce this",
formatDuration(analysis.RTOBreakdown.DownloadTime)),
Impact: "Faster recovery from local storage",
Effort: EffortMedium,
})
}
// No incremental backups
hasIncremental := false
for _, e := range entries {
if e.BackupType == "incremental" {
hasIncremental = true
break
}
}
if !hasIncremental && analysis.BackupInterval > 6*time.Hour {
recommendations = append(recommendations, Recommendation{
Type: RecommendIncrementalBackup,
Priority: PriorityMedium,
Title: "Enable Incremental Backups",
Description: "Incremental backups can reduce backup time and storage",
Impact: "Better RPO with less resource usage",
Effort: EffortLow,
})
}
// WAL archiving for PostgreSQL
if len(entries) > 0 && entries[0].DatabaseType == "postgresql" {
recommendations = append(recommendations, Recommendation{
Type: RecommendWALArchiving,
Priority: PriorityMedium,
Title: "Consider WAL Archiving",
Description: "Enable WAL archiving for point-in-time recovery",
Impact: "Achieve near-zero RPO with PITR",
Effort: EffortMedium,
})
}
return recommendations
}
// calculateHistory generates historical RTO/RPO points
func (c *Calculator) calculateHistory(entries []*catalog.Entry) []HistoricalPoint {
var history []HistoricalPoint
// Sort entries by date (oldest first)
sorted := make([]*catalog.Entry, len(entries))
copy(sorted, entries)
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].CreatedAt.Before(sorted[j].CreatedAt)
})
for i, entry := range sorted {
point := HistoricalPoint{
Timestamp: entry.CreatedAt,
RTO: c.calculateRTOBreakdown(entry).TotalTime,
}
// Calculate RPO at that point (time until next backup)
if i < len(sorted)-1 {
point.RPO = sorted[i+1].CreatedAt.Sub(entry.CreatedAt)
} else {
point.RPO = time.Since(entry.CreatedAt)
}
history = append(history, point)
}
return history
}
// Summary provides aggregate RTO/RPO status
type Summary struct {
TotalDatabases int `json:"total_databases"`
RPOCompliant int `json:"rpo_compliant"`
RTOCompliant int `json:"rto_compliant"`
FullyCompliant int `json:"fully_compliant"`
CriticalIssues int `json:"critical_issues"`
WorstRPO time.Duration `json:"worst_rpo"`
WorstRTO time.Duration `json:"worst_rto"`
WorstRPODatabase string `json:"worst_rpo_database"`
WorstRTODatabase string `json:"worst_rto_database"`
AverageRPO time.Duration `json:"average_rpo"`
AverageRTO time.Duration `json:"average_rto"`
}
// Summarize creates a summary from analyses
func Summarize(analyses []*Analysis) *Summary {
summary := &Summary{}
var totalRPO, totalRTO time.Duration
for _, a := range analyses {
summary.TotalDatabases++
if a.RPOCompliant {
summary.RPOCompliant++
}
if a.RTOCompliant {
summary.RTOCompliant++
}
if a.RPOCompliant && a.RTOCompliant {
summary.FullyCompliant++
}
for _, r := range a.Recommendations {
if r.Priority == PriorityCritical {
summary.CriticalIssues++
break
}
}
if a.CurrentRPO > summary.WorstRPO {
summary.WorstRPO = a.CurrentRPO
summary.WorstRPODatabase = a.Database
}
if a.CurrentRTO > summary.WorstRTO {
summary.WorstRTO = a.CurrentRTO
summary.WorstRTODatabase = a.Database
}
totalRPO += a.CurrentRPO
totalRTO += a.CurrentRTO
}
if len(analyses) > 0 {
summary.AverageRPO = totalRPO / time.Duration(len(analyses))
summary.AverageRTO = totalRTO / time.Duration(len(analyses))
}
return summary
}
func formatDuration(d time.Duration) string {
if d < time.Minute {
return fmt.Sprintf("%.0fs", d.Seconds())
}
if d < time.Hour {
return fmt.Sprintf("%.0fm", d.Minutes())
}
hours := int(d.Hours())
mins := int(d.Minutes()) - hours*60
return fmt.Sprintf("%dh %dm", hours, mins)
}