Initial commit: Database Backup Tool v1.1.0
- PostgreSQL and MySQL support - Interactive TUI with fixed menu navigation - Line-by-line progress display - CPU-aware parallel processing - Cross-platform build support - Configuration settings menu - Silent mode for TUI operations
This commit is contained in:
708
internal/backup/engine.go
Normal file
708
internal/backup/engine.go
Normal file
@ -0,0 +1,708 @@
|
||||
package backup
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"dbbackup/internal/config"
|
||||
"dbbackup/internal/database"
|
||||
"dbbackup/internal/logger"
|
||||
"dbbackup/internal/progress"
|
||||
)
|
||||
|
||||
// Engine handles backup operations
|
||||
type Engine struct {
|
||||
cfg *config.Config
|
||||
log logger.Logger
|
||||
db database.Database
|
||||
progress progress.Indicator
|
||||
detailedReporter *progress.DetailedReporter
|
||||
silent bool // Silent mode for TUI
|
||||
}
|
||||
|
||||
// New creates a new backup engine
|
||||
func New(cfg *config.Config, log logger.Logger, db database.Database) *Engine {
|
||||
progressIndicator := progress.NewIndicator(true, "line") // Use line-by-line indicator
|
||||
detailedReporter := progress.NewDetailedReporter(progressIndicator, &loggerAdapter{logger: log})
|
||||
|
||||
return &Engine{
|
||||
cfg: cfg,
|
||||
log: log,
|
||||
db: db,
|
||||
progress: progressIndicator,
|
||||
detailedReporter: detailedReporter,
|
||||
silent: false,
|
||||
}
|
||||
}
|
||||
|
||||
// NewWithProgress creates a new backup engine with a custom progress indicator
|
||||
func NewWithProgress(cfg *config.Config, log logger.Logger, db database.Database, progressIndicator progress.Indicator) *Engine {
|
||||
detailedReporter := progress.NewDetailedReporter(progressIndicator, &loggerAdapter{logger: log})
|
||||
|
||||
return &Engine{
|
||||
cfg: cfg,
|
||||
log: log,
|
||||
db: db,
|
||||
progress: progressIndicator,
|
||||
detailedReporter: detailedReporter,
|
||||
silent: false,
|
||||
}
|
||||
}
|
||||
|
||||
// NewSilent creates a new backup engine in silent mode (for TUI)
|
||||
func NewSilent(cfg *config.Config, log logger.Logger, db database.Database, progressIndicator progress.Indicator) *Engine {
|
||||
detailedReporter := progress.NewDetailedReporter(progressIndicator, &loggerAdapter{logger: log})
|
||||
|
||||
return &Engine{
|
||||
cfg: cfg,
|
||||
log: log,
|
||||
db: db,
|
||||
progress: progressIndicator,
|
||||
detailedReporter: detailedReporter,
|
||||
silent: true, // Silent mode enabled
|
||||
}
|
||||
}
|
||||
|
||||
// loggerAdapter adapts our logger to the progress.Logger interface
|
||||
type loggerAdapter struct {
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func (la *loggerAdapter) Info(msg string, args ...any) {
|
||||
la.logger.Info(msg, args...)
|
||||
}
|
||||
|
||||
func (la *loggerAdapter) Warn(msg string, args ...any) {
|
||||
la.logger.Warn(msg, args...)
|
||||
}
|
||||
|
||||
func (la *loggerAdapter) Error(msg string, args ...any) {
|
||||
la.logger.Error(msg, args...)
|
||||
}
|
||||
|
||||
func (la *loggerAdapter) Debug(msg string, args ...any) {
|
||||
la.logger.Debug(msg, args...)
|
||||
}
|
||||
|
||||
// printf prints to stdout only if not in silent mode
|
||||
func (e *Engine) printf(format string, args ...interface{}) {
|
||||
if !e.silent {
|
||||
fmt.Printf(format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// generateOperationID creates a unique operation ID
|
||||
func generateOperationID() string {
|
||||
bytes := make([]byte, 8)
|
||||
rand.Read(bytes)
|
||||
return hex.EncodeToString(bytes)
|
||||
}
|
||||
|
||||
// BackupSingle performs a single database backup with detailed progress tracking
|
||||
func (e *Engine) BackupSingle(ctx context.Context, databaseName string) error {
|
||||
// Start detailed operation tracking
|
||||
operationID := generateOperationID()
|
||||
tracker := e.detailedReporter.StartOperation(operationID, databaseName, "backup")
|
||||
|
||||
// Add operation details
|
||||
tracker.SetDetails("database", databaseName)
|
||||
tracker.SetDetails("type", "single")
|
||||
tracker.SetDetails("compression", strconv.Itoa(e.cfg.CompressionLevel))
|
||||
tracker.SetDetails("format", "custom")
|
||||
|
||||
// Start preparing backup directory
|
||||
prepStep := tracker.AddStep("prepare", "Preparing backup directory")
|
||||
if err := os.MkdirAll(e.cfg.BackupDir, 0755); err != nil {
|
||||
prepStep.Fail(fmt.Errorf("failed to create backup directory: %w", err))
|
||||
tracker.Fail(fmt.Errorf("failed to create backup directory: %w", err))
|
||||
return fmt.Errorf("failed to create backup directory: %w", err)
|
||||
}
|
||||
prepStep.Complete("Backup directory prepared")
|
||||
tracker.UpdateProgress(10, "Backup directory prepared")
|
||||
|
||||
// Generate timestamp and filename
|
||||
timestamp := time.Now().Format("20060102_150405")
|
||||
var outputFile string
|
||||
|
||||
if e.cfg.IsPostgreSQL() {
|
||||
outputFile = filepath.Join(e.cfg.BackupDir, fmt.Sprintf("db_%s_%s.dump", databaseName, timestamp))
|
||||
} else {
|
||||
outputFile = filepath.Join(e.cfg.BackupDir, fmt.Sprintf("db_%s_%s.sql.gz", databaseName, timestamp))
|
||||
}
|
||||
|
||||
tracker.SetDetails("output_file", outputFile)
|
||||
tracker.UpdateProgress(20, "Generated backup filename")
|
||||
|
||||
// Build backup command
|
||||
cmdStep := tracker.AddStep("command", "Building backup command")
|
||||
options := database.BackupOptions{
|
||||
Compression: e.cfg.CompressionLevel,
|
||||
Parallel: e.cfg.DumpJobs,
|
||||
Format: "custom",
|
||||
Blobs: true,
|
||||
NoOwner: false,
|
||||
NoPrivileges: false,
|
||||
}
|
||||
|
||||
cmd := e.db.BuildBackupCommand(databaseName, outputFile, options)
|
||||
cmdStep.Complete("Backup command prepared")
|
||||
tracker.UpdateProgress(30, "Backup command prepared")
|
||||
|
||||
// Execute backup command with progress monitoring
|
||||
execStep := tracker.AddStep("execute", "Executing database backup")
|
||||
tracker.UpdateProgress(40, "Starting database backup...")
|
||||
|
||||
if err := e.executeCommandWithProgress(ctx, cmd, outputFile, tracker); err != nil {
|
||||
execStep.Fail(fmt.Errorf("backup execution failed: %w", err))
|
||||
tracker.Fail(fmt.Errorf("backup failed: %w", err))
|
||||
return fmt.Errorf("backup failed: %w", err)
|
||||
}
|
||||
execStep.Complete("Database backup completed")
|
||||
tracker.UpdateProgress(80, "Database backup completed")
|
||||
|
||||
// Verify backup file
|
||||
verifyStep := tracker.AddStep("verify", "Verifying backup file")
|
||||
if info, err := os.Stat(outputFile); err != nil {
|
||||
verifyStep.Fail(fmt.Errorf("backup file not created: %w", err))
|
||||
tracker.Fail(fmt.Errorf("backup file not created: %w", err))
|
||||
return fmt.Errorf("backup file not created: %w", err)
|
||||
} else {
|
||||
size := formatBytes(info.Size())
|
||||
tracker.SetDetails("file_size", size)
|
||||
tracker.SetByteProgress(info.Size(), info.Size())
|
||||
verifyStep.Complete(fmt.Sprintf("Backup file verified: %s", size))
|
||||
tracker.UpdateProgress(90, fmt.Sprintf("Backup verified: %s", size))
|
||||
}
|
||||
|
||||
// Create metadata file
|
||||
metaStep := tracker.AddStep("metadata", "Creating metadata file")
|
||||
if err := e.createMetadata(outputFile, databaseName, "single", ""); err != nil {
|
||||
e.log.Warn("Failed to create metadata file", "error", err)
|
||||
metaStep.Fail(fmt.Errorf("metadata creation failed: %w", err))
|
||||
} else {
|
||||
metaStep.Complete("Metadata file created")
|
||||
}
|
||||
|
||||
// Complete operation
|
||||
tracker.UpdateProgress(100, "Backup operation completed successfully")
|
||||
tracker.Complete(fmt.Sprintf("Single database backup completed: %s", filepath.Base(outputFile)))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BackupSample performs a sample database backup
|
||||
func (e *Engine) BackupSample(ctx context.Context, databaseName string) error {
|
||||
operation := e.log.StartOperation("Sample Database Backup")
|
||||
|
||||
// Ensure backup directory exists
|
||||
if err := os.MkdirAll(e.cfg.BackupDir, 0755); err != nil {
|
||||
operation.Fail("Failed to create backup directory")
|
||||
return fmt.Errorf("failed to create backup directory: %w", err)
|
||||
}
|
||||
|
||||
// Generate timestamp and filename
|
||||
timestamp := time.Now().Format("20060102_150405")
|
||||
outputFile := filepath.Join(e.cfg.BackupDir,
|
||||
fmt.Sprintf("sample_%s_%s%d_%s.sql", databaseName, e.cfg.SampleStrategy, e.cfg.SampleValue, timestamp))
|
||||
|
||||
operation.Update("Starting sample database backup")
|
||||
e.progress.Start(fmt.Sprintf("Creating sample backup of '%s' (%s=%d)", databaseName, e.cfg.SampleStrategy, e.cfg.SampleValue))
|
||||
|
||||
// For sample backups, we need to get the schema first, then sample data
|
||||
if err := e.createSampleBackup(ctx, databaseName, outputFile); err != nil {
|
||||
e.progress.Fail(fmt.Sprintf("Sample backup failed: %v", err))
|
||||
operation.Fail("Sample backup failed")
|
||||
return fmt.Errorf("sample backup failed: %w", err)
|
||||
}
|
||||
|
||||
// Check output file
|
||||
if info, err := os.Stat(outputFile); err != nil {
|
||||
e.progress.Fail("Sample backup file not created")
|
||||
operation.Fail("Sample backup file not found")
|
||||
return fmt.Errorf("sample backup file not created: %w", err)
|
||||
} else {
|
||||
size := formatBytes(info.Size())
|
||||
e.progress.Complete(fmt.Sprintf("Sample backup completed: %s (%s)", filepath.Base(outputFile), size))
|
||||
operation.Complete(fmt.Sprintf("Sample backup created: %s (%s)", outputFile, size))
|
||||
}
|
||||
|
||||
// Create metadata file
|
||||
if err := e.createMetadata(outputFile, databaseName, "sample", e.cfg.SampleStrategy); err != nil {
|
||||
e.log.Warn("Failed to create metadata file", "error", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BackupCluster performs a full cluster backup (PostgreSQL only)
|
||||
func (e *Engine) BackupCluster(ctx context.Context) error {
|
||||
if !e.cfg.IsPostgreSQL() {
|
||||
return fmt.Errorf("cluster backup is only supported for PostgreSQL")
|
||||
}
|
||||
|
||||
operation := e.log.StartOperation("Cluster Backup")
|
||||
|
||||
// Use a quiet progress indicator to avoid duplicate messages
|
||||
quietProgress := progress.NewQuietLineByLine()
|
||||
quietProgress.Start("Starting cluster backup (all databases)")
|
||||
|
||||
// Ensure backup directory exists
|
||||
if err := os.MkdirAll(e.cfg.BackupDir, 0755); err != nil {
|
||||
operation.Fail("Failed to create backup directory")
|
||||
quietProgress.Fail("Failed to create backup directory")
|
||||
return fmt.Errorf("failed to create backup directory: %w", err)
|
||||
}
|
||||
|
||||
// Generate timestamp and filename
|
||||
timestamp := time.Now().Format("20060102_150405")
|
||||
outputFile := filepath.Join(e.cfg.BackupDir, fmt.Sprintf("cluster_%s.tar.gz", timestamp))
|
||||
tempDir := filepath.Join(e.cfg.BackupDir, fmt.Sprintf(".cluster_%s", timestamp))
|
||||
|
||||
operation.Update("Starting cluster backup")
|
||||
|
||||
// Create temporary directory
|
||||
if err := os.MkdirAll(filepath.Join(tempDir, "dumps"), 0755); err != nil {
|
||||
operation.Fail("Failed to create temporary directory")
|
||||
quietProgress.Fail("Failed to create temporary directory")
|
||||
return fmt.Errorf("failed to create temp directory: %w", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Backup globals
|
||||
e.printf(" Backing up global objects...\n")
|
||||
if err := e.backupGlobals(ctx, tempDir); err != nil {
|
||||
quietProgress.Fail(fmt.Sprintf("Failed to backup globals: %v", err))
|
||||
operation.Fail("Global backup failed")
|
||||
return fmt.Errorf("failed to backup globals: %w", err)
|
||||
}
|
||||
|
||||
// Get list of databases
|
||||
e.printf(" Getting database list...\n")
|
||||
databases, err := e.db.ListDatabases(ctx)
|
||||
if err != nil {
|
||||
quietProgress.Fail(fmt.Sprintf("Failed to list databases: %v", err))
|
||||
operation.Fail("Database listing failed")
|
||||
return fmt.Errorf("failed to list databases: %w", err)
|
||||
}
|
||||
|
||||
// Backup each database
|
||||
e.printf(" Backing up %d databases...\n", len(databases))
|
||||
for i, dbName := range databases {
|
||||
e.printf(" Backing up database %d/%d: %s\n", i+1, len(databases), dbName)
|
||||
|
||||
dumpFile := filepath.Join(tempDir, "dumps", dbName+".dump")
|
||||
options := database.BackupOptions{
|
||||
Compression: e.cfg.CompressionLevel,
|
||||
Parallel: 1, // Individual dumps in cluster are not parallel
|
||||
Format: "custom",
|
||||
Blobs: true,
|
||||
NoOwner: false,
|
||||
NoPrivileges: false,
|
||||
}
|
||||
|
||||
cmd := e.db.BuildBackupCommand(dbName, dumpFile, options)
|
||||
if err := e.executeCommand(ctx, cmd, dumpFile); err != nil {
|
||||
e.log.Warn("Failed to backup database", "database", dbName, "error", err)
|
||||
// Continue with other databases
|
||||
}
|
||||
}
|
||||
|
||||
// Create archive
|
||||
e.printf(" Creating compressed archive...\n")
|
||||
if err := e.createArchive(tempDir, outputFile); err != nil {
|
||||
quietProgress.Fail(fmt.Sprintf("Failed to create archive: %v", err))
|
||||
operation.Fail("Archive creation failed")
|
||||
return fmt.Errorf("failed to create archive: %w", err)
|
||||
}
|
||||
|
||||
// Check output file
|
||||
if info, err := os.Stat(outputFile); err != nil {
|
||||
quietProgress.Fail("Cluster backup archive not created")
|
||||
operation.Fail("Cluster backup archive not found")
|
||||
return fmt.Errorf("cluster backup archive not created: %w", err)
|
||||
} else {
|
||||
size := formatBytes(info.Size())
|
||||
quietProgress.Complete(fmt.Sprintf("Cluster backup completed: %s (%s)", filepath.Base(outputFile), size))
|
||||
operation.Complete(fmt.Sprintf("Cluster backup created: %s (%s)", outputFile, size))
|
||||
}
|
||||
|
||||
// Create metadata file
|
||||
if err := e.createMetadata(outputFile, "cluster", "cluster", ""); err != nil {
|
||||
e.log.Warn("Failed to create metadata file", "error", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// executeCommandWithProgress executes a backup command with real-time progress monitoring
|
||||
func (e *Engine) executeCommandWithProgress(ctx context.Context, cmdArgs []string, outputFile string, tracker *progress.OperationTracker) error {
|
||||
if len(cmdArgs) == 0 {
|
||||
return fmt.Errorf("empty command")
|
||||
}
|
||||
|
||||
e.log.Debug("Executing backup command with progress", "cmd", cmdArgs[0], "args", cmdArgs[1:])
|
||||
|
||||
cmd := exec.CommandContext(ctx, cmdArgs[0], cmdArgs[1:]...)
|
||||
|
||||
// Set environment variables for database tools
|
||||
cmd.Env = os.Environ()
|
||||
if e.cfg.Password != "" {
|
||||
if e.cfg.IsPostgreSQL() {
|
||||
cmd.Env = append(cmd.Env, "PGPASSWORD="+e.cfg.Password)
|
||||
} else if e.cfg.IsMySQL() {
|
||||
cmd.Env = append(cmd.Env, "MYSQL_PWD="+e.cfg.Password)
|
||||
}
|
||||
}
|
||||
|
||||
// For MySQL, handle compression and redirection differently
|
||||
if e.cfg.IsMySQL() && e.cfg.CompressionLevel > 0 {
|
||||
return e.executeMySQLWithProgressAndCompression(ctx, cmdArgs, outputFile, tracker)
|
||||
}
|
||||
|
||||
// Get stderr pipe for progress monitoring
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
// Start the command
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start command: %w", err)
|
||||
}
|
||||
|
||||
// Monitor progress via stderr
|
||||
go e.monitorCommandProgress(stderr, tracker)
|
||||
|
||||
// Wait for command to complete
|
||||
if err := cmd.Wait(); err != nil {
|
||||
return fmt.Errorf("backup command failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// monitorCommandProgress monitors command output for progress information
|
||||
func (e *Engine) monitorCommandProgress(stderr io.ReadCloser, tracker *progress.OperationTracker) {
|
||||
defer stderr.Close()
|
||||
|
||||
scanner := bufio.NewScanner(stderr)
|
||||
progressBase := 40 // Start from 40% since command preparation is done
|
||||
progressIncrement := 0
|
||||
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
e.log.Debug("Command output", "line", line)
|
||||
|
||||
// Increment progress gradually based on output
|
||||
if progressBase < 75 {
|
||||
progressIncrement++
|
||||
if progressIncrement%5 == 0 { // Update every 5 lines
|
||||
progressBase += 2
|
||||
tracker.UpdateProgress(progressBase, "Processing data...")
|
||||
}
|
||||
}
|
||||
|
||||
// Look for specific progress indicators
|
||||
if strings.Contains(line, "COPY") {
|
||||
tracker.UpdateProgress(progressBase+5, "Copying table data...")
|
||||
} else if strings.Contains(line, "completed") {
|
||||
tracker.UpdateProgress(75, "Backup nearly complete...")
|
||||
} else if strings.Contains(line, "done") {
|
||||
tracker.UpdateProgress(78, "Finalizing backup...")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// executeMySQLWithProgressAndCompression handles MySQL backup with compression and progress
|
||||
func (e *Engine) executeMySQLWithProgressAndCompression(ctx context.Context, cmdArgs []string, outputFile string, tracker *progress.OperationTracker) error {
|
||||
// Create mysqldump command
|
||||
dumpCmd := exec.CommandContext(ctx, cmdArgs[0], cmdArgs[1:]...)
|
||||
dumpCmd.Env = os.Environ()
|
||||
if e.cfg.Password != "" {
|
||||
dumpCmd.Env = append(dumpCmd.Env, "MYSQL_PWD="+e.cfg.Password)
|
||||
}
|
||||
|
||||
// Create gzip command
|
||||
gzipCmd := exec.CommandContext(ctx, "gzip", fmt.Sprintf("-%d", e.cfg.CompressionLevel))
|
||||
|
||||
// Create output file
|
||||
outFile, err := os.Create(outputFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create output file: %w", err)
|
||||
}
|
||||
defer outFile.Close()
|
||||
|
||||
// Set up pipeline: mysqldump | gzip > outputfile
|
||||
pipe, err := dumpCmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create pipe: %w", err)
|
||||
}
|
||||
|
||||
gzipCmd.Stdin = pipe
|
||||
gzipCmd.Stdout = outFile
|
||||
|
||||
// Get stderr for progress monitoring
|
||||
stderr, err := dumpCmd.StderrPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
// Start monitoring progress
|
||||
go e.monitorCommandProgress(stderr, tracker)
|
||||
|
||||
// Start both commands
|
||||
if err := gzipCmd.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start gzip: %w", err)
|
||||
}
|
||||
|
||||
if err := dumpCmd.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start mysqldump: %w", err)
|
||||
}
|
||||
|
||||
// Wait for mysqldump to complete
|
||||
if err := dumpCmd.Wait(); err != nil {
|
||||
return fmt.Errorf("mysqldump failed: %w", err)
|
||||
}
|
||||
|
||||
// Close pipe and wait for gzip
|
||||
pipe.Close()
|
||||
if err := gzipCmd.Wait(); err != nil {
|
||||
return fmt.Errorf("gzip failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// executeMySQLWithCompression handles MySQL backup with compression
|
||||
func (e *Engine) executeMySQLWithCompression(ctx context.Context, cmdArgs []string, outputFile string) error {
|
||||
// Create mysqldump command
|
||||
dumpCmd := exec.CommandContext(ctx, cmdArgs[0], cmdArgs[1:]...)
|
||||
dumpCmd.Env = os.Environ()
|
||||
if e.cfg.Password != "" {
|
||||
dumpCmd.Env = append(dumpCmd.Env, "MYSQL_PWD="+e.cfg.Password)
|
||||
}
|
||||
|
||||
// Create gzip command
|
||||
gzipCmd := exec.CommandContext(ctx, "gzip", fmt.Sprintf("-%d", e.cfg.CompressionLevel))
|
||||
|
||||
// Create output file
|
||||
outFile, err := os.Create(outputFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create output file: %w", err)
|
||||
}
|
||||
defer outFile.Close()
|
||||
|
||||
// Set up pipeline: mysqldump | gzip > outputfile
|
||||
gzipCmd.Stdin, _ = dumpCmd.StdoutPipe()
|
||||
gzipCmd.Stdout = outFile
|
||||
|
||||
// Start both commands
|
||||
if err := gzipCmd.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start gzip: %w", err)
|
||||
}
|
||||
|
||||
if err := dumpCmd.Run(); err != nil {
|
||||
return fmt.Errorf("mysqldump failed: %w", err)
|
||||
}
|
||||
|
||||
if err := gzipCmd.Wait(); err != nil {
|
||||
return fmt.Errorf("gzip failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createSampleBackup creates a sample backup with reduced dataset
|
||||
func (e *Engine) createSampleBackup(ctx context.Context, databaseName, outputFile string) error {
|
||||
// This is a simplified implementation
|
||||
// A full implementation would:
|
||||
// 1. Export schema
|
||||
// 2. Get list of tables
|
||||
// 3. For each table, run sampling query
|
||||
// 4. Combine into single SQL file
|
||||
|
||||
// For now, we'll use a simple approach with schema-only backup first
|
||||
// Then add sample data
|
||||
|
||||
file, err := os.Create(outputFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create sample backup file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Write header
|
||||
fmt.Fprintf(file, "-- Sample Database Backup\n")
|
||||
fmt.Fprintf(file, "-- Database: %s\n", databaseName)
|
||||
fmt.Fprintf(file, "-- Strategy: %s = %d\n", e.cfg.SampleStrategy, e.cfg.SampleValue)
|
||||
fmt.Fprintf(file, "-- Created: %s\n", time.Now().Format(time.RFC3339))
|
||||
fmt.Fprintf(file, "-- WARNING: This backup may have referential integrity issues!\n\n")
|
||||
|
||||
// For PostgreSQL, we can use pg_dump --schema-only first
|
||||
if e.cfg.IsPostgreSQL() {
|
||||
// Get schema
|
||||
schemaCmd := e.db.BuildBackupCommand(databaseName, "/dev/stdout", database.BackupOptions{
|
||||
SchemaOnly: true,
|
||||
Format: "plain",
|
||||
})
|
||||
|
||||
cmd := exec.CommandContext(ctx, schemaCmd[0], schemaCmd[1:]...)
|
||||
cmd.Env = os.Environ()
|
||||
if e.cfg.Password != "" {
|
||||
cmd.Env = append(cmd.Env, "PGPASSWORD="+e.cfg.Password)
|
||||
}
|
||||
cmd.Stdout = file
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to export schema: %w", err)
|
||||
}
|
||||
|
||||
fmt.Fprintf(file, "\n-- Sample data follows\n\n")
|
||||
|
||||
// Get tables and sample data
|
||||
tables, err := e.db.ListTables(ctx, databaseName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list tables: %w", err)
|
||||
}
|
||||
|
||||
strategy := database.SampleStrategy{
|
||||
Type: e.cfg.SampleStrategy,
|
||||
Value: e.cfg.SampleValue,
|
||||
}
|
||||
|
||||
for _, table := range tables {
|
||||
fmt.Fprintf(file, "-- Data for table: %s\n", table)
|
||||
sampleQuery := e.db.BuildSampleQuery(databaseName, table, strategy)
|
||||
fmt.Fprintf(file, "\\copy (%s) TO STDOUT\n\n", sampleQuery)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// backupGlobals creates a backup of global PostgreSQL objects
|
||||
func (e *Engine) backupGlobals(ctx context.Context, tempDir string) error {
|
||||
globalsFile := filepath.Join(tempDir, "globals.sql")
|
||||
|
||||
cmd := exec.CommandContext(ctx, "pg_dumpall", "--globals-only")
|
||||
if e.cfg.Host != "localhost" {
|
||||
cmd.Args = append(cmd.Args, "-h", e.cfg.Host, "-p", fmt.Sprintf("%d", e.cfg.Port))
|
||||
}
|
||||
cmd.Args = append(cmd.Args, "-U", e.cfg.User)
|
||||
|
||||
cmd.Env = os.Environ()
|
||||
if e.cfg.Password != "" {
|
||||
cmd.Env = append(cmd.Env, "PGPASSWORD="+e.cfg.Password)
|
||||
}
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return fmt.Errorf("pg_dumpall failed: %w", err)
|
||||
}
|
||||
|
||||
return os.WriteFile(globalsFile, output, 0644)
|
||||
}
|
||||
|
||||
// createArchive creates a compressed tar archive
|
||||
func (e *Engine) createArchive(sourceDir, outputFile string) error {
|
||||
cmd := exec.Command("tar", "-czf", outputFile, "-C", sourceDir, ".")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("tar failed: %w, output: %s", err, string(output))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// createMetadata creates a metadata file for the backup
|
||||
func (e *Engine) createMetadata(backupFile, database, backupType, strategy string) error {
|
||||
metaFile := backupFile + ".info"
|
||||
|
||||
content := fmt.Sprintf(`{
|
||||
"type": "%s",
|
||||
"database": "%s",
|
||||
"timestamp": "%s",
|
||||
"host": "%s",
|
||||
"port": %d,
|
||||
"user": "%s",
|
||||
"db_type": "%s",
|
||||
"compression": %d`,
|
||||
backupType, database, time.Now().Format("20060102_150405"),
|
||||
e.cfg.Host, e.cfg.Port, e.cfg.User, e.cfg.DatabaseType, e.cfg.CompressionLevel)
|
||||
|
||||
if strategy != "" {
|
||||
content += fmt.Sprintf(`,
|
||||
"sample_strategy": "%s",
|
||||
"sample_value": %d`, e.cfg.SampleStrategy, e.cfg.SampleValue)
|
||||
}
|
||||
|
||||
if info, err := os.Stat(backupFile); err == nil {
|
||||
content += fmt.Sprintf(`,
|
||||
"size_bytes": %d`, info.Size())
|
||||
}
|
||||
|
||||
content += "\n}"
|
||||
|
||||
return os.WriteFile(metaFile, []byte(content), 0644)
|
||||
}
|
||||
|
||||
// executeCommand executes a backup command (simplified version for cluster backups)
|
||||
func (e *Engine) executeCommand(ctx context.Context, cmdArgs []string, outputFile string) error {
|
||||
if len(cmdArgs) == 0 {
|
||||
return fmt.Errorf("empty command")
|
||||
}
|
||||
|
||||
e.log.Debug("Executing backup command", "cmd", cmdArgs[0], "args", cmdArgs[1:])
|
||||
|
||||
cmd := exec.CommandContext(ctx, cmdArgs[0], cmdArgs[1:]...)
|
||||
|
||||
// Set environment variables for database tools
|
||||
cmd.Env = os.Environ()
|
||||
if e.cfg.Password != "" {
|
||||
if e.cfg.IsPostgreSQL() {
|
||||
cmd.Env = append(cmd.Env, "PGPASSWORD="+e.cfg.Password)
|
||||
} else if e.cfg.IsMySQL() {
|
||||
cmd.Env = append(cmd.Env, "MYSQL_PWD="+e.cfg.Password)
|
||||
}
|
||||
}
|
||||
|
||||
// For MySQL, handle compression differently
|
||||
if e.cfg.IsMySQL() && e.cfg.CompressionLevel > 0 {
|
||||
return e.executeMySQLWithCompression(ctx, cmdArgs, outputFile)
|
||||
}
|
||||
|
||||
// Run the command
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
e.log.Error("Backup command failed", "error", err, "output", string(output))
|
||||
return fmt.Errorf("backup command failed: %w, output: %s", err, string(output))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// formatBytes formats byte count in human-readable format
|
||||
func formatBytes(bytes int64) string {
|
||||
const unit = 1024
|
||||
if bytes < unit {
|
||||
return fmt.Sprintf("%d B", bytes)
|
||||
}
|
||||
div, exp := int64(unit), 0
|
||||
for n := bytes / unit; n >= unit; n /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(div), "KMGTPE"[exp])
|
||||
}
|
||||
319
internal/config/config.go
Normal file
319
internal/config/config.go
Normal file
@ -0,0 +1,319 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
|
||||
"dbbackup/internal/cpu"
|
||||
)
|
||||
|
||||
// Config holds all configuration options
|
||||
type Config struct {
|
||||
// Version information
|
||||
Version string
|
||||
BuildTime string
|
||||
GitCommit string
|
||||
|
||||
// Database connection
|
||||
Host string
|
||||
Port int
|
||||
User string
|
||||
Database string
|
||||
Password string
|
||||
DatabaseType string // "postgres" or "mysql"
|
||||
SSLMode string
|
||||
Insecure bool
|
||||
|
||||
// Backup options
|
||||
BackupDir string
|
||||
CompressionLevel int
|
||||
Jobs int
|
||||
DumpJobs int
|
||||
MaxCores int
|
||||
AutoDetectCores bool
|
||||
CPUWorkloadType string // "cpu-intensive", "io-intensive", "balanced"
|
||||
|
||||
// CPU detection
|
||||
CPUDetector *cpu.Detector
|
||||
CPUInfo *cpu.CPUInfo
|
||||
|
||||
// Sample backup options
|
||||
SampleStrategy string // "ratio", "percent", "count"
|
||||
SampleValue int
|
||||
|
||||
// Output options
|
||||
NoColor bool
|
||||
Debug bool
|
||||
LogLevel string
|
||||
LogFormat string
|
||||
OutputLength int
|
||||
|
||||
// Single database backup/restore
|
||||
SingleDBName string
|
||||
RestoreDBName string
|
||||
}
|
||||
|
||||
// New creates a new configuration with default values
|
||||
func New() *Config {
|
||||
// Get default backup directory
|
||||
backupDir := getEnvString("BACKUP_DIR", getDefaultBackupDir())
|
||||
|
||||
// Initialize CPU detector
|
||||
cpuDetector := cpu.NewDetector()
|
||||
cpuInfo, _ := cpuDetector.DetectCPU()
|
||||
|
||||
return &Config{
|
||||
// Database defaults
|
||||
Host: getEnvString("PG_HOST", "localhost"),
|
||||
Port: getEnvInt("PG_PORT", 5432),
|
||||
User: getEnvString("PG_USER", getCurrentUser()),
|
||||
Database: getEnvString("PG_DATABASE", "postgres"),
|
||||
Password: getEnvString("PGPASSWORD", ""),
|
||||
DatabaseType: getEnvString("DB_TYPE", "postgres"),
|
||||
SSLMode: getEnvString("PG_SSLMODE", "prefer"),
|
||||
Insecure: getEnvBool("INSECURE", false),
|
||||
|
||||
// Backup defaults
|
||||
BackupDir: backupDir,
|
||||
CompressionLevel: getEnvInt("COMPRESS_LEVEL", 6),
|
||||
Jobs: getEnvInt("JOBS", getDefaultJobs(cpuInfo)),
|
||||
DumpJobs: getEnvInt("DUMP_JOBS", getDefaultDumpJobs(cpuInfo)),
|
||||
MaxCores: getEnvInt("MAX_CORES", getDefaultMaxCores(cpuInfo)),
|
||||
AutoDetectCores: getEnvBool("AUTO_DETECT_CORES", true),
|
||||
CPUWorkloadType: getEnvString("CPU_WORKLOAD_TYPE", "balanced"),
|
||||
|
||||
// CPU detection
|
||||
CPUDetector: cpuDetector,
|
||||
CPUInfo: cpuInfo,
|
||||
|
||||
// Sample backup defaults
|
||||
SampleStrategy: getEnvString("SAMPLE_STRATEGY", "ratio"),
|
||||
SampleValue: getEnvInt("SAMPLE_VALUE", 10),
|
||||
|
||||
// Output defaults
|
||||
NoColor: getEnvBool("NO_COLOR", false),
|
||||
Debug: getEnvBool("DEBUG", false),
|
||||
LogLevel: getEnvString("LOG_LEVEL", "info"),
|
||||
LogFormat: getEnvString("LOG_FORMAT", "text"),
|
||||
OutputLength: getEnvInt("OUTPUT_LENGTH", 0),
|
||||
|
||||
// Single database options
|
||||
SingleDBName: getEnvString("SINGLE_DB_NAME", ""),
|
||||
RestoreDBName: getEnvString("RESTORE_DB_NAME", ""),
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateFromEnvironment updates configuration from environment variables
|
||||
func (c *Config) UpdateFromEnvironment() {
|
||||
if password := os.Getenv("PGPASSWORD"); password != "" {
|
||||
c.Password = password
|
||||
}
|
||||
if password := os.Getenv("MYSQL_PWD"); password != "" && c.DatabaseType == "mysql" {
|
||||
c.Password = password
|
||||
}
|
||||
}
|
||||
|
||||
// Validate validates the configuration
|
||||
func (c *Config) Validate() error {
|
||||
if c.DatabaseType != "postgres" && c.DatabaseType != "mysql" {
|
||||
return &ConfigError{Field: "database-type", Value: c.DatabaseType, Message: "must be 'postgres' or 'mysql'"}
|
||||
}
|
||||
|
||||
if c.CompressionLevel < 0 || c.CompressionLevel > 9 {
|
||||
return &ConfigError{Field: "compression", Value: string(rune(c.CompressionLevel)), Message: "must be between 0-9"}
|
||||
}
|
||||
|
||||
if c.Jobs < 1 {
|
||||
return &ConfigError{Field: "jobs", Value: string(rune(c.Jobs)), Message: "must be at least 1"}
|
||||
}
|
||||
|
||||
if c.DumpJobs < 1 {
|
||||
return &ConfigError{Field: "dump-jobs", Value: string(rune(c.DumpJobs)), Message: "must be at least 1"}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsPostgreSQL returns true if database type is PostgreSQL
|
||||
func (c *Config) IsPostgreSQL() bool {
|
||||
return c.DatabaseType == "postgres"
|
||||
}
|
||||
|
||||
// IsMySQL returns true if database type is MySQL
|
||||
func (c *Config) IsMySQL() bool {
|
||||
return c.DatabaseType == "mysql"
|
||||
}
|
||||
|
||||
// GetDefaultPort returns the default port for the database type
|
||||
func (c *Config) GetDefaultPort() int {
|
||||
if c.IsMySQL() {
|
||||
return 3306
|
||||
}
|
||||
return 5432
|
||||
}
|
||||
|
||||
// OptimizeForCPU optimizes job settings based on detected CPU
|
||||
func (c *Config) OptimizeForCPU() error {
|
||||
if c.CPUDetector == nil {
|
||||
c.CPUDetector = cpu.NewDetector()
|
||||
}
|
||||
|
||||
if c.CPUInfo == nil {
|
||||
info, err := c.CPUDetector.DetectCPU()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.CPUInfo = info
|
||||
}
|
||||
|
||||
if c.AutoDetectCores {
|
||||
// Optimize jobs based on workload type
|
||||
if jobs, err := c.CPUDetector.CalculateOptimalJobs(c.CPUWorkloadType, c.MaxCores); err == nil {
|
||||
c.Jobs = jobs
|
||||
}
|
||||
|
||||
// Optimize dump jobs (more conservative for database dumps)
|
||||
if dumpJobs, err := c.CPUDetector.CalculateOptimalJobs("cpu-intensive", c.MaxCores/2); err == nil {
|
||||
c.DumpJobs = dumpJobs
|
||||
if c.DumpJobs > 8 {
|
||||
c.DumpJobs = 8 // Conservative limit for dumps
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCPUInfo returns CPU information, detecting if necessary
|
||||
func (c *Config) GetCPUInfo() (*cpu.CPUInfo, error) {
|
||||
if c.CPUInfo != nil {
|
||||
return c.CPUInfo, nil
|
||||
}
|
||||
|
||||
if c.CPUDetector == nil {
|
||||
c.CPUDetector = cpu.NewDetector()
|
||||
}
|
||||
|
||||
info, err := c.CPUDetector.DetectCPU()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.CPUInfo = info
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// ConfigError represents a configuration validation error
|
||||
type ConfigError struct {
|
||||
Field string
|
||||
Value string
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *ConfigError) Error() string {
|
||||
return "config error in field '" + e.Field + "' with value '" + e.Value + "': " + e.Message
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func getEnvString(key, defaultValue string) string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func getEnvInt(key string, defaultValue int) int {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
if i, err := strconv.Atoi(value); err == nil {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func getEnvBool(key string, defaultValue bool) bool {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
if b, err := strconv.ParseBool(value); err == nil {
|
||||
return b
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func getCurrentUser() string {
|
||||
if user := os.Getenv("USER"); user != "" {
|
||||
return user
|
||||
}
|
||||
if user := os.Getenv("USERNAME"); user != "" {
|
||||
return user
|
||||
}
|
||||
return "postgres"
|
||||
}
|
||||
|
||||
func getDefaultBackupDir() string {
|
||||
// Try to create a sensible default backup directory
|
||||
homeDir, _ := os.UserHomeDir()
|
||||
if homeDir != "" {
|
||||
return filepath.Join(homeDir, "db_backups")
|
||||
}
|
||||
|
||||
// Fallback based on OS
|
||||
if runtime.GOOS == "windows" {
|
||||
return "C:\\db_backups"
|
||||
}
|
||||
|
||||
// For PostgreSQL user on Linux/Unix
|
||||
if getCurrentUser() == "postgres" {
|
||||
return "/var/lib/pgsql/pg_backups"
|
||||
}
|
||||
|
||||
return "/tmp/db_backups"
|
||||
}
|
||||
|
||||
// CPU-related helper functions
|
||||
func getDefaultJobs(cpuInfo *cpu.CPUInfo) int {
|
||||
if cpuInfo == nil {
|
||||
return 1
|
||||
}
|
||||
// Default to logical cores for restore operations
|
||||
jobs := cpuInfo.LogicalCores
|
||||
if jobs < 1 {
|
||||
jobs = 1
|
||||
}
|
||||
if jobs > 16 {
|
||||
jobs = 16 // Safety limit
|
||||
}
|
||||
return jobs
|
||||
}
|
||||
|
||||
func getDefaultDumpJobs(cpuInfo *cpu.CPUInfo) int {
|
||||
if cpuInfo == nil {
|
||||
return 1
|
||||
}
|
||||
// Use physical cores for dump operations (CPU intensive)
|
||||
jobs := cpuInfo.PhysicalCores
|
||||
if jobs < 1 {
|
||||
jobs = 1
|
||||
}
|
||||
if jobs > 8 {
|
||||
jobs = 8 // Conservative limit for dumps
|
||||
}
|
||||
return jobs
|
||||
}
|
||||
|
||||
func getDefaultMaxCores(cpuInfo *cpu.CPUInfo) int {
|
||||
if cpuInfo == nil {
|
||||
return 16
|
||||
}
|
||||
// Set max cores to 2x logical cores, with reasonable upper limit
|
||||
maxCores := cpuInfo.LogicalCores * 2
|
||||
if maxCores < 4 {
|
||||
maxCores = 4
|
||||
}
|
||||
if maxCores > 64 {
|
||||
maxCores = 64
|
||||
}
|
||||
return maxCores
|
||||
}
|
||||
346
internal/cpu/detection.go
Normal file
346
internal/cpu/detection.go
Normal file
@ -0,0 +1,346 @@
|
||||
package cpu
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"os"
|
||||
"os/exec"
|
||||
"bufio"
|
||||
)
|
||||
|
||||
// CPUInfo holds information about the system CPU
|
||||
type CPUInfo struct {
|
||||
LogicalCores int `json:"logical_cores"`
|
||||
PhysicalCores int `json:"physical_cores"`
|
||||
Architecture string `json:"architecture"`
|
||||
ModelName string `json:"model_name"`
|
||||
MaxFrequency float64 `json:"max_frequency_mhz"`
|
||||
CacheSize string `json:"cache_size"`
|
||||
Vendor string `json:"vendor"`
|
||||
Features []string `json:"features"`
|
||||
}
|
||||
|
||||
// Detector provides CPU detection functionality
|
||||
type Detector struct {
|
||||
info *CPUInfo
|
||||
}
|
||||
|
||||
// NewDetector creates a new CPU detector
|
||||
func NewDetector() *Detector {
|
||||
return &Detector{}
|
||||
}
|
||||
|
||||
// DetectCPU detects CPU information for the current system
|
||||
func (d *Detector) DetectCPU() (*CPUInfo, error) {
|
||||
if d.info != nil {
|
||||
return d.info, nil
|
||||
}
|
||||
|
||||
info := &CPUInfo{
|
||||
LogicalCores: runtime.NumCPU(),
|
||||
Architecture: runtime.GOARCH,
|
||||
}
|
||||
|
||||
// Platform-specific detection
|
||||
switch runtime.GOOS {
|
||||
case "linux":
|
||||
if err := d.detectLinux(info); err != nil {
|
||||
return info, fmt.Errorf("linux CPU detection failed: %w", err)
|
||||
}
|
||||
case "darwin":
|
||||
if err := d.detectDarwin(info); err != nil {
|
||||
return info, fmt.Errorf("darwin CPU detection failed: %w", err)
|
||||
}
|
||||
case "windows":
|
||||
if err := d.detectWindows(info); err != nil {
|
||||
return info, fmt.Errorf("windows CPU detection failed: %w", err)
|
||||
}
|
||||
default:
|
||||
// Fallback for unsupported platforms
|
||||
info.PhysicalCores = info.LogicalCores
|
||||
info.ModelName = "Unknown"
|
||||
info.Vendor = "Unknown"
|
||||
}
|
||||
|
||||
d.info = info
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// detectLinux detects CPU information on Linux systems
|
||||
func (d *Detector) detectLinux(info *CPUInfo) error {
|
||||
file, err := os.Open("/proc/cpuinfo")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
physicalCoreCount := make(map[string]bool)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.TrimSpace(line) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.SplitN(line, ":", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
key := strings.TrimSpace(parts[0])
|
||||
value := strings.TrimSpace(parts[1])
|
||||
|
||||
switch key {
|
||||
case "model name":
|
||||
if info.ModelName == "" {
|
||||
info.ModelName = value
|
||||
}
|
||||
case "vendor_id":
|
||||
if info.Vendor == "" {
|
||||
info.Vendor = value
|
||||
}
|
||||
case "cpu MHz":
|
||||
if freq, err := strconv.ParseFloat(value, 64); err == nil && info.MaxFrequency < freq {
|
||||
info.MaxFrequency = freq
|
||||
}
|
||||
case "cache size":
|
||||
if info.CacheSize == "" {
|
||||
info.CacheSize = value
|
||||
}
|
||||
case "flags", "Features":
|
||||
if len(info.Features) == 0 {
|
||||
info.Features = strings.Fields(value)
|
||||
}
|
||||
case "physical id":
|
||||
physicalCoreCount[value] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate physical cores
|
||||
if len(physicalCoreCount) > 0 {
|
||||
info.PhysicalCores = len(physicalCoreCount)
|
||||
} else {
|
||||
// Fallback: assume hyperthreading if logical > 1
|
||||
info.PhysicalCores = info.LogicalCores
|
||||
if info.LogicalCores > 1 {
|
||||
info.PhysicalCores = info.LogicalCores / 2
|
||||
}
|
||||
}
|
||||
|
||||
// Try to get more accurate physical core count from lscpu
|
||||
if cmd := exec.Command("lscpu"); cmd != nil {
|
||||
if output, err := cmd.Output(); err == nil {
|
||||
d.parseLscpu(string(output), info)
|
||||
}
|
||||
}
|
||||
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
// parseLscpu parses lscpu output for more accurate CPU information
|
||||
func (d *Detector) parseLscpu(output string, info *CPUInfo) {
|
||||
lines := strings.Split(output, "\n")
|
||||
for _, line := range lines {
|
||||
if strings.TrimSpace(line) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.SplitN(line, ":", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
key := strings.TrimSpace(parts[0])
|
||||
value := strings.TrimSpace(parts[1])
|
||||
|
||||
switch key {
|
||||
case "Core(s) per socket":
|
||||
if cores, err := strconv.Atoi(value); err == nil {
|
||||
if sockets := d.getSocketCount(output); sockets > 0 {
|
||||
info.PhysicalCores = cores * sockets
|
||||
}
|
||||
}
|
||||
case "Model name":
|
||||
info.ModelName = value
|
||||
case "Vendor ID":
|
||||
info.Vendor = value
|
||||
case "CPU max MHz":
|
||||
if freq, err := strconv.ParseFloat(value, 64); err == nil {
|
||||
info.MaxFrequency = freq
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getSocketCount extracts socket count from lscpu output
|
||||
func (d *Detector) getSocketCount(output string) int {
|
||||
lines := strings.Split(output, "\n")
|
||||
for _, line := range lines {
|
||||
if strings.Contains(line, "Socket(s):") {
|
||||
parts := strings.SplitN(line, ":", 2)
|
||||
if len(parts) == 2 {
|
||||
if sockets, err := strconv.Atoi(strings.TrimSpace(parts[1])); err == nil {
|
||||
return sockets
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return 1 // Default to 1 socket
|
||||
}
|
||||
|
||||
// detectDarwin detects CPU information on macOS systems
|
||||
func (d *Detector) detectDarwin(info *CPUInfo) error {
|
||||
// Get CPU brand
|
||||
if output, err := exec.Command("sysctl", "-n", "machdep.cpu.brand_string").Output(); err == nil {
|
||||
info.ModelName = strings.TrimSpace(string(output))
|
||||
}
|
||||
|
||||
// Get physical cores
|
||||
if output, err := exec.Command("sysctl", "-n", "hw.physicalcpu").Output(); err == nil {
|
||||
if cores, err := strconv.Atoi(strings.TrimSpace(string(output))); err == nil {
|
||||
info.PhysicalCores = cores
|
||||
}
|
||||
}
|
||||
|
||||
// Get max frequency
|
||||
if output, err := exec.Command("sysctl", "-n", "hw.cpufrequency_max").Output(); err == nil {
|
||||
if freq, err := strconv.ParseFloat(strings.TrimSpace(string(output)), 64); err == nil {
|
||||
info.MaxFrequency = freq / 1000000 // Convert Hz to MHz
|
||||
}
|
||||
}
|
||||
|
||||
// Get vendor
|
||||
if output, err := exec.Command("sysctl", "-n", "machdep.cpu.vendor").Output(); err == nil {
|
||||
info.Vendor = strings.TrimSpace(string(output))
|
||||
}
|
||||
|
||||
// Get cache size
|
||||
if output, err := exec.Command("sysctl", "-n", "hw.l3cachesize").Output(); err == nil {
|
||||
if cache, err := strconv.Atoi(strings.TrimSpace(string(output))); err == nil {
|
||||
info.CacheSize = fmt.Sprintf("%d KB", cache/1024)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// detectWindows detects CPU information on Windows systems
|
||||
func (d *Detector) detectWindows(info *CPUInfo) error {
|
||||
// Use wmic to get CPU information
|
||||
cmd := exec.Command("wmic", "cpu", "get", "Name,NumberOfCores,NumberOfLogicalProcessors,MaxClockSpeed", "/format:list")
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
lines := strings.Split(string(output), "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.SplitN(line, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
key := strings.TrimSpace(parts[0])
|
||||
value := strings.TrimSpace(parts[1])
|
||||
|
||||
switch key {
|
||||
case "Name":
|
||||
if value != "" {
|
||||
info.ModelName = value
|
||||
}
|
||||
case "NumberOfCores":
|
||||
if cores, err := strconv.Atoi(value); err == nil {
|
||||
info.PhysicalCores = cores
|
||||
}
|
||||
case "MaxClockSpeed":
|
||||
if freq, err := strconv.ParseFloat(value, 64); err == nil {
|
||||
info.MaxFrequency = freq
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get vendor information
|
||||
cmd = exec.Command("wmic", "cpu", "get", "Manufacturer", "/format:list")
|
||||
if output, err := cmd.Output(); err == nil {
|
||||
lines := strings.Split(string(output), "\n")
|
||||
for _, line := range lines {
|
||||
if strings.HasPrefix(line, "Manufacturer=") {
|
||||
info.Vendor = strings.TrimSpace(strings.SplitN(line, "=", 2)[1])
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CalculateOptimalJobs calculates optimal job count based on CPU info and workload type
|
||||
func (d *Detector) CalculateOptimalJobs(workloadType string, maxJobs int) (int, error) {
|
||||
info, err := d.DetectCPU()
|
||||
if err != nil {
|
||||
return 1, err
|
||||
}
|
||||
|
||||
var optimal int
|
||||
|
||||
switch workloadType {
|
||||
case "cpu-intensive":
|
||||
// For CPU-intensive tasks, use physical cores
|
||||
optimal = info.PhysicalCores
|
||||
case "io-intensive":
|
||||
// For I/O intensive tasks, can use more jobs than cores
|
||||
optimal = info.LogicalCores * 2
|
||||
case "balanced":
|
||||
// Balanced workload, use logical cores
|
||||
optimal = info.LogicalCores
|
||||
default:
|
||||
optimal = info.LogicalCores
|
||||
}
|
||||
|
||||
// Apply safety limits
|
||||
if optimal < 1 {
|
||||
optimal = 1
|
||||
}
|
||||
if maxJobs > 0 && optimal > maxJobs {
|
||||
optimal = maxJobs
|
||||
}
|
||||
|
||||
return optimal, nil
|
||||
}
|
||||
|
||||
// GetCPUInfo returns the detected CPU information
|
||||
func (d *Detector) GetCPUInfo() *CPUInfo {
|
||||
return d.info
|
||||
}
|
||||
|
||||
// FormatCPUInfo returns a formatted string representation of CPU info
|
||||
func (info *CPUInfo) FormatCPUInfo() string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString(fmt.Sprintf("Architecture: %s\n", info.Architecture))
|
||||
sb.WriteString(fmt.Sprintf("Logical Cores: %d\n", info.LogicalCores))
|
||||
sb.WriteString(fmt.Sprintf("Physical Cores: %d\n", info.PhysicalCores))
|
||||
|
||||
if info.ModelName != "" {
|
||||
sb.WriteString(fmt.Sprintf("Model: %s\n", info.ModelName))
|
||||
}
|
||||
if info.Vendor != "" {
|
||||
sb.WriteString(fmt.Sprintf("Vendor: %s\n", info.Vendor))
|
||||
}
|
||||
if info.MaxFrequency > 0 {
|
||||
sb.WriteString(fmt.Sprintf("Max Frequency: %.2f MHz\n", info.MaxFrequency))
|
||||
}
|
||||
if info.CacheSize != "" {
|
||||
sb.WriteString(fmt.Sprintf("Cache Size: %s\n", info.CacheSize))
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
133
internal/database/interface.go
Normal file
133
internal/database/interface.go
Normal file
@ -0,0 +1,133 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"dbbackup/internal/config"
|
||||
"dbbackup/internal/logger"
|
||||
|
||||
_ "github.com/lib/pq" // PostgreSQL driver
|
||||
_ "github.com/go-sql-driver/mysql" // MySQL driver
|
||||
)
|
||||
|
||||
// Database represents a database connection and operations
|
||||
type Database interface {
|
||||
// Connection management
|
||||
Connect(ctx context.Context) error
|
||||
Close() error
|
||||
Ping(ctx context.Context) error
|
||||
|
||||
// Database discovery
|
||||
ListDatabases(ctx context.Context) ([]string, error)
|
||||
ListTables(ctx context.Context, database string) ([]string, error)
|
||||
|
||||
// Database operations
|
||||
CreateDatabase(ctx context.Context, name string) error
|
||||
DropDatabase(ctx context.Context, name string) error
|
||||
DatabaseExists(ctx context.Context, name string) (bool, error)
|
||||
|
||||
// Information
|
||||
GetVersion(ctx context.Context) (string, error)
|
||||
GetDatabaseSize(ctx context.Context, database string) (int64, error)
|
||||
GetTableRowCount(ctx context.Context, database, table string) (int64, error)
|
||||
|
||||
// Backup/Restore command building
|
||||
BuildBackupCommand(database, outputFile string, options BackupOptions) []string
|
||||
BuildRestoreCommand(database, inputFile string, options RestoreOptions) []string
|
||||
BuildSampleQuery(database, table string, strategy SampleStrategy) string
|
||||
|
||||
// Validation
|
||||
ValidateBackupTools() error
|
||||
}
|
||||
|
||||
// BackupOptions holds options for backup operations
|
||||
type BackupOptions struct {
|
||||
Compression int
|
||||
Parallel int
|
||||
Format string // "custom", "plain", "directory"
|
||||
Blobs bool
|
||||
SchemaOnly bool
|
||||
DataOnly bool
|
||||
NoOwner bool
|
||||
NoPrivileges bool
|
||||
Clean bool
|
||||
IfExists bool
|
||||
Role string
|
||||
}
|
||||
|
||||
// RestoreOptions holds options for restore operations
|
||||
type RestoreOptions struct {
|
||||
Parallel int
|
||||
Clean bool
|
||||
IfExists bool
|
||||
NoOwner bool
|
||||
NoPrivileges bool
|
||||
SingleTransaction bool
|
||||
}
|
||||
|
||||
// SampleStrategy defines how to sample data
|
||||
type SampleStrategy struct {
|
||||
Type string // "ratio", "percent", "count"
|
||||
Value int
|
||||
}
|
||||
|
||||
// DatabaseInfo holds database metadata
|
||||
type DatabaseInfo struct {
|
||||
Name string
|
||||
Size int64
|
||||
Owner string
|
||||
Encoding string
|
||||
Collation string
|
||||
Tables []TableInfo
|
||||
}
|
||||
|
||||
// TableInfo holds table metadata
|
||||
type TableInfo struct {
|
||||
Schema string
|
||||
Name string
|
||||
RowCount int64
|
||||
Size int64
|
||||
}
|
||||
|
||||
// New creates a new database instance based on configuration
|
||||
func New(cfg *config.Config, log logger.Logger) (Database, error) {
|
||||
if cfg.IsPostgreSQL() {
|
||||
return NewPostgreSQL(cfg, log), nil
|
||||
} else if cfg.IsMySQL() {
|
||||
return NewMySQL(cfg, log), nil
|
||||
}
|
||||
return nil, fmt.Errorf("unsupported database type: %s", cfg.DatabaseType)
|
||||
}
|
||||
|
||||
// Common database implementation
|
||||
type baseDatabase struct {
|
||||
cfg *config.Config
|
||||
log logger.Logger
|
||||
db *sql.DB
|
||||
dsn string
|
||||
}
|
||||
|
||||
func (b *baseDatabase) Close() error {
|
||||
if b.db != nil {
|
||||
return b.db.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *baseDatabase) Ping(ctx context.Context) error {
|
||||
if b.db == nil {
|
||||
return fmt.Errorf("database not connected")
|
||||
}
|
||||
return b.db.PingContext(ctx)
|
||||
}
|
||||
|
||||
// buildTimeout creates a context with timeout for database operations
|
||||
func buildTimeout(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
|
||||
if timeout <= 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
return context.WithTimeout(ctx, timeout)
|
||||
}
|
||||
410
internal/database/mysql.go
Normal file
410
internal/database/mysql.go
Normal file
@ -0,0 +1,410 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"dbbackup/internal/config"
|
||||
"dbbackup/internal/logger"
|
||||
)
|
||||
|
||||
// MySQL implements Database interface for MySQL
|
||||
type MySQL struct {
|
||||
baseDatabase
|
||||
}
|
||||
|
||||
// NewMySQL creates a new MySQL database instance
|
||||
func NewMySQL(cfg *config.Config, log logger.Logger) *MySQL {
|
||||
return &MySQL{
|
||||
baseDatabase: baseDatabase{
|
||||
cfg: cfg,
|
||||
log: log,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Connect establishes a connection to MySQL
|
||||
func (m *MySQL) Connect(ctx context.Context) error {
|
||||
// Build MySQL DSN
|
||||
dsn := m.buildDSN()
|
||||
m.dsn = dsn
|
||||
|
||||
m.log.Debug("Connecting to MySQL", "dsn", sanitizeMySQLDSN(dsn))
|
||||
|
||||
db, err := sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open MySQL connection: %w", err)
|
||||
}
|
||||
|
||||
// Configure connection pool
|
||||
db.SetMaxOpenConns(10)
|
||||
db.SetMaxIdleConns(5)
|
||||
db.SetConnMaxLifetime(0)
|
||||
|
||||
// Test connection
|
||||
timeoutCtx, cancel := buildTimeout(ctx, 0)
|
||||
defer cancel()
|
||||
|
||||
if err := db.PingContext(timeoutCtx); err != nil {
|
||||
db.Close()
|
||||
return fmt.Errorf("failed to ping MySQL: %w", err)
|
||||
}
|
||||
|
||||
m.db = db
|
||||
m.log.Info("Connected to MySQL successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListDatabases returns list of databases (excluding system databases)
|
||||
func (m *MySQL) ListDatabases(ctx context.Context) ([]string, error) {
|
||||
if m.db == nil {
|
||||
return nil, fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
query := `SHOW DATABASES`
|
||||
|
||||
rows, err := m.db.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query databases: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var databases []string
|
||||
systemDbs := map[string]bool{
|
||||
"information_schema": true,
|
||||
"performance_schema": true,
|
||||
"mysql": true,
|
||||
"sys": true,
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var name string
|
||||
if err := rows.Scan(&name); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan database name: %w", err)
|
||||
}
|
||||
|
||||
// Skip system databases
|
||||
if !systemDbs[name] {
|
||||
databases = append(databases, name)
|
||||
}
|
||||
}
|
||||
|
||||
return databases, rows.Err()
|
||||
}
|
||||
|
||||
// ListTables returns list of tables in a database
|
||||
func (m *MySQL) ListTables(ctx context.Context, database string) ([]string, error) {
|
||||
if m.db == nil {
|
||||
return nil, fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
query := `SELECT table_name FROM information_schema.tables
|
||||
WHERE table_schema = ? AND table_type = 'BASE TABLE'
|
||||
ORDER BY table_name`
|
||||
|
||||
rows, err := m.db.QueryContext(ctx, query, database)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query tables: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tables []string
|
||||
for rows.Next() {
|
||||
var name string
|
||||
if err := rows.Scan(&name); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan table name: %w", err)
|
||||
}
|
||||
tables = append(tables, name)
|
||||
}
|
||||
|
||||
return tables, rows.Err()
|
||||
}
|
||||
|
||||
// CreateDatabase creates a new database
|
||||
func (m *MySQL) CreateDatabase(ctx context.Context, name string) error {
|
||||
if m.db == nil {
|
||||
return fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `%s`", name)
|
||||
_, err := m.db.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create database %s: %w", name, err)
|
||||
}
|
||||
|
||||
m.log.Info("Created database", "name", name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DropDatabase drops a database
|
||||
func (m *MySQL) DropDatabase(ctx context.Context, name string) error {
|
||||
if m.db == nil {
|
||||
return fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", name)
|
||||
_, err := m.db.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to drop database %s: %w", name, err)
|
||||
}
|
||||
|
||||
m.log.Info("Dropped database", "name", name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DatabaseExists checks if a database exists
|
||||
func (m *MySQL) DatabaseExists(ctx context.Context, name string) (bool, error) {
|
||||
if m.db == nil {
|
||||
return false, fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
query := `SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?`
|
||||
var dbName string
|
||||
err := m.db.QueryRowContext(ctx, query, name).Scan(&dbName)
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check database existence: %w", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// GetVersion returns MySQL version
|
||||
func (m *MySQL) GetVersion(ctx context.Context) (string, error) {
|
||||
if m.db == nil {
|
||||
return "", fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
var version string
|
||||
err := m.db.QueryRowContext(ctx, "SELECT VERSION()").Scan(&version)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get version: %w", err)
|
||||
}
|
||||
|
||||
return version, nil
|
||||
}
|
||||
|
||||
// GetDatabaseSize returns database size in bytes
|
||||
func (m *MySQL) GetDatabaseSize(ctx context.Context, database string) (int64, error) {
|
||||
if m.db == nil {
|
||||
return 0, fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
query := `SELECT COALESCE(SUM(data_length + index_length), 0) as size_bytes
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = ?`
|
||||
|
||||
var size int64
|
||||
err := m.db.QueryRowContext(ctx, query, database).Scan(&size)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get database size: %w", err)
|
||||
}
|
||||
|
||||
return size, nil
|
||||
}
|
||||
|
||||
// GetTableRowCount returns row count for a table
|
||||
func (m *MySQL) GetTableRowCount(ctx context.Context, database, table string) (int64, error) {
|
||||
if m.db == nil {
|
||||
return 0, fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
// First try information_schema for approximate count (faster)
|
||||
query := `SELECT table_rows FROM information_schema.tables
|
||||
WHERE table_schema = ? AND table_name = ?`
|
||||
|
||||
var count int64
|
||||
err := m.db.QueryRowContext(ctx, query, database, table).Scan(&count)
|
||||
if err != nil || count == 0 {
|
||||
// Fallback to exact count
|
||||
exactQuery := fmt.Sprintf("SELECT COUNT(*) FROM `%s`.`%s`", database, table)
|
||||
err = m.db.QueryRowContext(ctx, exactQuery).Scan(&count)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get table row count: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// BuildBackupCommand builds mysqldump command
|
||||
func (m *MySQL) BuildBackupCommand(database, outputFile string, options BackupOptions) []string {
|
||||
cmd := []string{"mysqldump"}
|
||||
|
||||
// Connection parameters
|
||||
cmd = append(cmd, "-h", m.cfg.Host)
|
||||
cmd = append(cmd, "-P", strconv.Itoa(m.cfg.Port))
|
||||
cmd = append(cmd, "-u", m.cfg.User)
|
||||
|
||||
if m.cfg.Password != "" {
|
||||
cmd = append(cmd, "-p"+m.cfg.Password)
|
||||
}
|
||||
|
||||
// SSL options
|
||||
if m.cfg.Insecure {
|
||||
cmd = append(cmd, "--skip-ssl")
|
||||
} else if m.cfg.SSLMode != "" {
|
||||
// MySQL SSL modes: DISABLED, PREFERRED, REQUIRED, VERIFY_CA, VERIFY_IDENTITY
|
||||
switch strings.ToLower(m.cfg.SSLMode) {
|
||||
case "disable", "disabled":
|
||||
cmd = append(cmd, "--skip-ssl")
|
||||
case "require", "required":
|
||||
cmd = append(cmd, "--ssl-mode=REQUIRED")
|
||||
case "verify-ca":
|
||||
cmd = append(cmd, "--ssl-mode=VERIFY_CA")
|
||||
case "verify-full", "verify-identity":
|
||||
cmd = append(cmd, "--ssl-mode=VERIFY_IDENTITY")
|
||||
default:
|
||||
cmd = append(cmd, "--ssl-mode=PREFERRED")
|
||||
}
|
||||
}
|
||||
|
||||
// Backup options
|
||||
cmd = append(cmd, "--single-transaction") // Consistent backup
|
||||
cmd = append(cmd, "--routines") // Include stored procedures/functions
|
||||
cmd = append(cmd, "--triggers") // Include triggers
|
||||
cmd = append(cmd, "--events") // Include events
|
||||
|
||||
if options.SchemaOnly {
|
||||
cmd = append(cmd, "--no-data")
|
||||
} else if options.DataOnly {
|
||||
cmd = append(cmd, "--no-create-info")
|
||||
}
|
||||
|
||||
if options.NoOwner || options.NoPrivileges {
|
||||
cmd = append(cmd, "--skip-add-drop-table")
|
||||
}
|
||||
|
||||
// Compression (handled externally for MySQL)
|
||||
// Output redirection will be handled by caller
|
||||
|
||||
// Database
|
||||
cmd = append(cmd, database)
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
// BuildRestoreCommand builds mysql restore command
|
||||
func (m *MySQL) BuildRestoreCommand(database, inputFile string, options RestoreOptions) []string {
|
||||
cmd := []string{"mysql"}
|
||||
|
||||
// Connection parameters
|
||||
cmd = append(cmd, "-h", m.cfg.Host)
|
||||
cmd = append(cmd, "-P", strconv.Itoa(m.cfg.Port))
|
||||
cmd = append(cmd, "-u", m.cfg.User)
|
||||
|
||||
if m.cfg.Password != "" {
|
||||
cmd = append(cmd, "-p"+m.cfg.Password)
|
||||
}
|
||||
|
||||
// SSL options
|
||||
if m.cfg.Insecure {
|
||||
cmd = append(cmd, "--skip-ssl")
|
||||
}
|
||||
|
||||
// Options
|
||||
if options.SingleTransaction {
|
||||
cmd = append(cmd, "--single-transaction")
|
||||
}
|
||||
|
||||
// Database
|
||||
cmd = append(cmd, database)
|
||||
|
||||
// Input file (will be handled via stdin redirection)
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
// BuildSampleQuery builds SQL query for sampling data
|
||||
func (m *MySQL) BuildSampleQuery(database, table string, strategy SampleStrategy) string {
|
||||
switch strategy.Type {
|
||||
case "ratio":
|
||||
// Every Nth record using row_number (MySQL 8.0+) or modulo
|
||||
return fmt.Sprintf("SELECT * FROM (SELECT *, (@row_number:=@row_number + 1) AS rn FROM %s.%s CROSS JOIN (SELECT @row_number:=0) AS t) AS numbered WHERE rn %% %d = 1",
|
||||
database, table, strategy.Value)
|
||||
case "percent":
|
||||
// Percentage sampling using RAND()
|
||||
return fmt.Sprintf("SELECT * FROM %s.%s WHERE RAND() <= %f",
|
||||
database, table, float64(strategy.Value)/100.0)
|
||||
case "count":
|
||||
// First N records
|
||||
return fmt.Sprintf("SELECT * FROM %s.%s LIMIT %d", database, table, strategy.Value)
|
||||
default:
|
||||
return fmt.Sprintf("SELECT * FROM %s.%s LIMIT 1000", database, table)
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateBackupTools checks if required MySQL tools are available
|
||||
func (m *MySQL) ValidateBackupTools() error {
|
||||
tools := []string{"mysqldump", "mysql"}
|
||||
|
||||
for _, tool := range tools {
|
||||
if _, err := exec.LookPath(tool); err != nil {
|
||||
return fmt.Errorf("required tool not found: %s", tool)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildDSN constructs MySQL connection string
|
||||
func (m *MySQL) buildDSN() string {
|
||||
dsn := ""
|
||||
|
||||
if m.cfg.User != "" {
|
||||
dsn += m.cfg.User
|
||||
}
|
||||
|
||||
if m.cfg.Password != "" {
|
||||
dsn += ":" + m.cfg.Password
|
||||
}
|
||||
|
||||
dsn += "@"
|
||||
|
||||
if m.cfg.Host != "" && m.cfg.Host != "localhost" {
|
||||
dsn += "tcp(" + m.cfg.Host + ":" + strconv.Itoa(m.cfg.Port) + ")"
|
||||
}
|
||||
|
||||
dsn += "/" + m.cfg.Database
|
||||
|
||||
// Add connection parameters
|
||||
params := []string{}
|
||||
|
||||
if m.cfg.Insecure {
|
||||
params = append(params, "tls=skip-verify")
|
||||
} else if m.cfg.SSLMode != "" {
|
||||
switch strings.ToLower(m.cfg.SSLMode) {
|
||||
case "disable", "disabled":
|
||||
params = append(params, "tls=false")
|
||||
case "require", "required":
|
||||
params = append(params, "tls=true")
|
||||
}
|
||||
}
|
||||
|
||||
// Add charset
|
||||
params = append(params, "charset=utf8mb4")
|
||||
params = append(params, "parseTime=true")
|
||||
|
||||
if len(params) > 0 {
|
||||
dsn += "?" + strings.Join(params, "&")
|
||||
}
|
||||
|
||||
return dsn
|
||||
}
|
||||
|
||||
// sanitizeMySQLDSN removes password from DSN for logging
|
||||
func sanitizeMySQLDSN(dsn string) string {
|
||||
// Find password part and replace it
|
||||
if idx := strings.Index(dsn, ":"); idx != -1 {
|
||||
if endIdx := strings.Index(dsn[idx:], "@"); endIdx != -1 {
|
||||
return dsn[:idx] + ":***" + dsn[idx+endIdx:]
|
||||
}
|
||||
}
|
||||
return dsn
|
||||
}
|
||||
427
internal/database/postgresql.go
Normal file
427
internal/database/postgresql.go
Normal file
@ -0,0 +1,427 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"dbbackup/internal/config"
|
||||
"dbbackup/internal/logger"
|
||||
)
|
||||
|
||||
// PostgreSQL implements Database interface for PostgreSQL
|
||||
type PostgreSQL struct {
|
||||
baseDatabase
|
||||
}
|
||||
|
||||
// NewPostgreSQL creates a new PostgreSQL database instance
|
||||
func NewPostgreSQL(cfg *config.Config, log logger.Logger) *PostgreSQL {
|
||||
return &PostgreSQL{
|
||||
baseDatabase: baseDatabase{
|
||||
cfg: cfg,
|
||||
log: log,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Connect establishes a connection to PostgreSQL
|
||||
func (p *PostgreSQL) Connect(ctx context.Context) error {
|
||||
// Build PostgreSQL DSN
|
||||
dsn := p.buildDSN()
|
||||
p.dsn = dsn
|
||||
|
||||
p.log.Debug("Connecting to PostgreSQL", "dsn", sanitizeDSN(dsn))
|
||||
|
||||
db, err := sql.Open("postgres", dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open PostgreSQL connection: %w", err)
|
||||
}
|
||||
|
||||
// Configure connection pool
|
||||
db.SetMaxOpenConns(10)
|
||||
db.SetMaxIdleConns(5)
|
||||
db.SetConnMaxLifetime(0)
|
||||
|
||||
// Test connection
|
||||
timeoutCtx, cancel := buildTimeout(ctx, 0)
|
||||
defer cancel()
|
||||
|
||||
if err := db.PingContext(timeoutCtx); err != nil {
|
||||
db.Close()
|
||||
return fmt.Errorf("failed to ping PostgreSQL: %w", err)
|
||||
}
|
||||
|
||||
p.db = db
|
||||
p.log.Info("Connected to PostgreSQL successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListDatabases returns list of non-template databases
|
||||
func (p *PostgreSQL) ListDatabases(ctx context.Context) ([]string, error) {
|
||||
if p.db == nil {
|
||||
return nil, fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
query := `SELECT datname FROM pg_database
|
||||
WHERE datistemplate = false
|
||||
ORDER BY datname`
|
||||
|
||||
rows, err := p.db.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query databases: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var databases []string
|
||||
for rows.Next() {
|
||||
var name string
|
||||
if err := rows.Scan(&name); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan database name: %w", err)
|
||||
}
|
||||
databases = append(databases, name)
|
||||
}
|
||||
|
||||
return databases, rows.Err()
|
||||
}
|
||||
|
||||
// ListTables returns list of tables in a database
|
||||
func (p *PostgreSQL) ListTables(ctx context.Context, database string) ([]string, error) {
|
||||
if p.db == nil {
|
||||
return nil, fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
query := `SELECT schemaname||'.'||tablename as full_name
|
||||
FROM pg_tables
|
||||
WHERE schemaname NOT IN ('information_schema', 'pg_catalog', 'pg_toast')
|
||||
ORDER BY schemaname, tablename`
|
||||
|
||||
rows, err := p.db.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query tables: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tables []string
|
||||
for rows.Next() {
|
||||
var name string
|
||||
if err := rows.Scan(&name); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan table name: %w", err)
|
||||
}
|
||||
tables = append(tables, name)
|
||||
}
|
||||
|
||||
return tables, rows.Err()
|
||||
}
|
||||
|
||||
// CreateDatabase creates a new database
|
||||
func (p *PostgreSQL) CreateDatabase(ctx context.Context, name string) error {
|
||||
if p.db == nil {
|
||||
return fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
// PostgreSQL doesn't support CREATE DATABASE in transactions or prepared statements
|
||||
query := fmt.Sprintf("CREATE DATABASE %s", name)
|
||||
_, err := p.db.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create database %s: %w", name, err)
|
||||
}
|
||||
|
||||
p.log.Info("Created database", "name", name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DropDatabase drops a database
|
||||
func (p *PostgreSQL) DropDatabase(ctx context.Context, name string) error {
|
||||
if p.db == nil {
|
||||
return fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
// Force drop connections and drop database
|
||||
query := fmt.Sprintf("DROP DATABASE IF EXISTS %s", name)
|
||||
_, err := p.db.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to drop database %s: %w", name, err)
|
||||
}
|
||||
|
||||
p.log.Info("Dropped database", "name", name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DatabaseExists checks if a database exists
|
||||
func (p *PostgreSQL) DatabaseExists(ctx context.Context, name string) (bool, error) {
|
||||
if p.db == nil {
|
||||
return false, fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
query := `SELECT EXISTS(SELECT 1 FROM pg_database WHERE datname = $1)`
|
||||
var exists bool
|
||||
err := p.db.QueryRowContext(ctx, query, name).Scan(&exists)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check database existence: %w", err)
|
||||
}
|
||||
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
// GetVersion returns PostgreSQL version
|
||||
func (p *PostgreSQL) GetVersion(ctx context.Context) (string, error) {
|
||||
if p.db == nil {
|
||||
return "", fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
var version string
|
||||
err := p.db.QueryRowContext(ctx, "SELECT version()").Scan(&version)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get version: %w", err)
|
||||
}
|
||||
|
||||
return version, nil
|
||||
}
|
||||
|
||||
// GetDatabaseSize returns database size in bytes
|
||||
func (p *PostgreSQL) GetDatabaseSize(ctx context.Context, database string) (int64, error) {
|
||||
if p.db == nil {
|
||||
return 0, fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
query := `SELECT pg_database_size($1)`
|
||||
var size int64
|
||||
err := p.db.QueryRowContext(ctx, query, database).Scan(&size)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get database size: %w", err)
|
||||
}
|
||||
|
||||
return size, nil
|
||||
}
|
||||
|
||||
// GetTableRowCount returns approximate row count for a table
|
||||
func (p *PostgreSQL) GetTableRowCount(ctx context.Context, database, table string) (int64, error) {
|
||||
if p.db == nil {
|
||||
return 0, fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
// Use pg_stat_user_tables for approximate count (faster)
|
||||
parts := strings.Split(table, ".")
|
||||
if len(parts) != 2 {
|
||||
return 0, fmt.Errorf("table name must be in format schema.table")
|
||||
}
|
||||
|
||||
query := `SELECT COALESCE(n_tup_ins, 0) FROM pg_stat_user_tables
|
||||
WHERE schemaname = $1 AND relname = $2`
|
||||
|
||||
var count int64
|
||||
err := p.db.QueryRowContext(ctx, query, parts[0], parts[1]).Scan(&count)
|
||||
if err != nil {
|
||||
// Fallback to exact count if stats not available
|
||||
exactQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s", table)
|
||||
err = p.db.QueryRowContext(ctx, exactQuery).Scan(&count)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get table row count: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// BuildBackupCommand builds pg_dump command
|
||||
func (p *PostgreSQL) BuildBackupCommand(database, outputFile string, options BackupOptions) []string {
|
||||
cmd := []string{"pg_dump"}
|
||||
|
||||
// Connection parameters
|
||||
if p.cfg.Host != "localhost" {
|
||||
cmd = append(cmd, "-h", p.cfg.Host)
|
||||
cmd = append(cmd, "-p", strconv.Itoa(p.cfg.Port))
|
||||
cmd = append(cmd, "--no-password")
|
||||
}
|
||||
cmd = append(cmd, "-U", p.cfg.User)
|
||||
|
||||
// Format and compression
|
||||
if options.Format != "" {
|
||||
cmd = append(cmd, "--format="+options.Format)
|
||||
} else {
|
||||
cmd = append(cmd, "--format=custom")
|
||||
}
|
||||
|
||||
if options.Compression > 0 {
|
||||
cmd = append(cmd, "--compress="+strconv.Itoa(options.Compression))
|
||||
}
|
||||
|
||||
// Parallel jobs (only for directory format)
|
||||
if options.Parallel > 1 && options.Format == "directory" {
|
||||
cmd = append(cmd, "--jobs="+strconv.Itoa(options.Parallel))
|
||||
}
|
||||
|
||||
// Options
|
||||
if options.Blobs {
|
||||
cmd = append(cmd, "--blobs")
|
||||
}
|
||||
if options.SchemaOnly {
|
||||
cmd = append(cmd, "--schema-only")
|
||||
}
|
||||
if options.DataOnly {
|
||||
cmd = append(cmd, "--data-only")
|
||||
}
|
||||
if options.NoOwner {
|
||||
cmd = append(cmd, "--no-owner")
|
||||
}
|
||||
if options.NoPrivileges {
|
||||
cmd = append(cmd, "--no-privileges")
|
||||
}
|
||||
if options.Role != "" {
|
||||
cmd = append(cmd, "--role="+options.Role)
|
||||
}
|
||||
|
||||
// Database and output
|
||||
cmd = append(cmd, "--dbname="+database)
|
||||
cmd = append(cmd, "--file="+outputFile)
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
// BuildRestoreCommand builds pg_restore command
|
||||
func (p *PostgreSQL) BuildRestoreCommand(database, inputFile string, options RestoreOptions) []string {
|
||||
cmd := []string{"pg_restore"}
|
||||
|
||||
// Connection parameters
|
||||
if p.cfg.Host != "localhost" {
|
||||
cmd = append(cmd, "-h", p.cfg.Host)
|
||||
cmd = append(cmd, "-p", strconv.Itoa(p.cfg.Port))
|
||||
cmd = append(cmd, "--no-password")
|
||||
}
|
||||
cmd = append(cmd, "-U", p.cfg.User)
|
||||
|
||||
// Parallel jobs
|
||||
if options.Parallel > 1 {
|
||||
cmd = append(cmd, "--jobs="+strconv.Itoa(options.Parallel))
|
||||
}
|
||||
|
||||
// Options
|
||||
if options.Clean {
|
||||
cmd = append(cmd, "--clean")
|
||||
}
|
||||
if options.IfExists {
|
||||
cmd = append(cmd, "--if-exists")
|
||||
}
|
||||
if options.NoOwner {
|
||||
cmd = append(cmd, "--no-owner")
|
||||
}
|
||||
if options.NoPrivileges {
|
||||
cmd = append(cmd, "--no-privileges")
|
||||
}
|
||||
if options.SingleTransaction {
|
||||
cmd = append(cmd, "--single-transaction")
|
||||
}
|
||||
|
||||
// Database and input
|
||||
cmd = append(cmd, "--dbname="+database)
|
||||
cmd = append(cmd, inputFile)
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
// BuildSampleQuery builds SQL query for sampling data
|
||||
func (p *PostgreSQL) BuildSampleQuery(database, table string, strategy SampleStrategy) string {
|
||||
switch strategy.Type {
|
||||
case "ratio":
|
||||
// Every Nth record using row_number
|
||||
return fmt.Sprintf("SELECT * FROM (SELECT *, row_number() OVER () as rn FROM %s) t WHERE rn %% %d = 1",
|
||||
table, strategy.Value)
|
||||
case "percent":
|
||||
// Percentage sampling using TABLESAMPLE (PostgreSQL 9.5+)
|
||||
return fmt.Sprintf("SELECT * FROM %s TABLESAMPLE BERNOULLI(%d)", table, strategy.Value)
|
||||
case "count":
|
||||
// First N records
|
||||
return fmt.Sprintf("SELECT * FROM %s LIMIT %d", table, strategy.Value)
|
||||
default:
|
||||
return fmt.Sprintf("SELECT * FROM %s LIMIT 1000", table)
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateBackupTools checks if required PostgreSQL tools are available
|
||||
func (p *PostgreSQL) ValidateBackupTools() error {
|
||||
tools := []string{"pg_dump", "pg_restore", "pg_dumpall", "psql"}
|
||||
|
||||
for _, tool := range tools {
|
||||
if _, err := exec.LookPath(tool); err != nil {
|
||||
return fmt.Errorf("required tool not found: %s", tool)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildDSN constructs PostgreSQL connection string
|
||||
func (p *PostgreSQL) buildDSN() string {
|
||||
dsn := fmt.Sprintf("user=%s dbname=%s", p.cfg.User, p.cfg.Database)
|
||||
|
||||
if p.cfg.Password != "" {
|
||||
dsn += " password=" + p.cfg.Password
|
||||
}
|
||||
|
||||
// For localhost connections, try socket first for peer auth
|
||||
if p.cfg.Host == "localhost" && p.cfg.Password == "" {
|
||||
// Try Unix socket connection for peer authentication
|
||||
// Common PostgreSQL socket locations
|
||||
socketDirs := []string{
|
||||
"/var/run/postgresql",
|
||||
"/tmp",
|
||||
"/var/lib/pgsql",
|
||||
}
|
||||
|
||||
for _, dir := range socketDirs {
|
||||
socketPath := fmt.Sprintf("%s/.s.PGSQL.%d", dir, p.cfg.Port)
|
||||
if _, err := os.Stat(socketPath); err == nil {
|
||||
dsn += " host=" + dir
|
||||
p.log.Debug("Using PostgreSQL socket", "path", socketPath)
|
||||
break
|
||||
}
|
||||
}
|
||||
} else if p.cfg.Host != "localhost" || p.cfg.Password != "" {
|
||||
// Use TCP connection
|
||||
dsn += " host=" + p.cfg.Host
|
||||
dsn += " port=" + strconv.Itoa(p.cfg.Port)
|
||||
}
|
||||
|
||||
if p.cfg.SSLMode != "" && !p.cfg.Insecure {
|
||||
// Map SSL modes to supported values for lib/pq
|
||||
switch strings.ToLower(p.cfg.SSLMode) {
|
||||
case "prefer", "preferred":
|
||||
dsn += " sslmode=require" // lib/pq default, closest to prefer
|
||||
case "require", "required":
|
||||
dsn += " sslmode=require"
|
||||
case "verify-ca":
|
||||
dsn += " sslmode=verify-ca"
|
||||
case "verify-full", "verify-identity":
|
||||
dsn += " sslmode=verify-full"
|
||||
case "disable", "disabled":
|
||||
dsn += " sslmode=disable"
|
||||
default:
|
||||
dsn += " sslmode=require" // Safe default
|
||||
}
|
||||
} else if p.cfg.Insecure {
|
||||
dsn += " sslmode=disable"
|
||||
}
|
||||
|
||||
return dsn
|
||||
}
|
||||
|
||||
// sanitizeDSN removes password from DSN for logging
|
||||
func sanitizeDSN(dsn string) string {
|
||||
// Simple password removal for logging
|
||||
parts := strings.Split(dsn, " ")
|
||||
var sanitized []string
|
||||
|
||||
for _, part := range parts {
|
||||
if strings.HasPrefix(part, "password=") {
|
||||
sanitized = append(sanitized, "password=***")
|
||||
} else {
|
||||
sanitized = append(sanitized, part)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(sanitized, " ")
|
||||
}
|
||||
185
internal/logger/logger.go
Normal file
185
internal/logger/logger.go
Normal file
@ -0,0 +1,185 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Logger defines the interface for logging
|
||||
type Logger interface {
|
||||
Debug(msg string, args ...any)
|
||||
Info(msg string, args ...any)
|
||||
Warn(msg string, args ...any)
|
||||
Error(msg string, args ...any)
|
||||
Time(msg string, args ...any)
|
||||
|
||||
// Progress logging for operations
|
||||
StartOperation(name string) OperationLogger
|
||||
}
|
||||
|
||||
// OperationLogger tracks timing for operations
|
||||
type OperationLogger interface {
|
||||
Update(msg string, args ...any)
|
||||
Complete(msg string, args ...any)
|
||||
Fail(msg string, args ...any)
|
||||
}
|
||||
|
||||
// logger implements Logger interface using slog
|
||||
type logger struct {
|
||||
slog *slog.Logger
|
||||
level slog.Level
|
||||
format string
|
||||
}
|
||||
|
||||
// operationLogger tracks a single operation
|
||||
type operationLogger struct {
|
||||
name string
|
||||
startTime time.Time
|
||||
parent *logger
|
||||
}
|
||||
|
||||
// New creates a new logger
|
||||
func New(level, format string) Logger {
|
||||
var slogLevel slog.Level
|
||||
switch strings.ToLower(level) {
|
||||
case "debug":
|
||||
slogLevel = slog.LevelDebug
|
||||
case "info":
|
||||
slogLevel = slog.LevelInfo
|
||||
case "warn", "warning":
|
||||
slogLevel = slog.LevelWarn
|
||||
case "error":
|
||||
slogLevel = slog.LevelError
|
||||
default:
|
||||
slogLevel = slog.LevelInfo
|
||||
}
|
||||
|
||||
var handler slog.Handler
|
||||
opts := &slog.HandlerOptions{
|
||||
Level: slogLevel,
|
||||
}
|
||||
|
||||
switch strings.ToLower(format) {
|
||||
case "json":
|
||||
handler = slog.NewJSONHandler(os.Stdout, opts)
|
||||
default:
|
||||
handler = slog.NewTextHandler(os.Stdout, opts)
|
||||
}
|
||||
|
||||
return &logger{
|
||||
slog: slog.New(handler),
|
||||
level: slogLevel,
|
||||
format: format,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *logger) Debug(msg string, args ...any) {
|
||||
l.slog.Debug(msg, args...)
|
||||
}
|
||||
|
||||
func (l *logger) Info(msg string, args ...any) {
|
||||
l.slog.Info(msg, args...)
|
||||
}
|
||||
|
||||
func (l *logger) Warn(msg string, args ...any) {
|
||||
l.slog.Warn(msg, args...)
|
||||
}
|
||||
|
||||
func (l *logger) Error(msg string, args ...any) {
|
||||
l.slog.Error(msg, args...)
|
||||
}
|
||||
|
||||
func (l *logger) Time(msg string, args ...any) {
|
||||
// Time logs are always at info level with special formatting
|
||||
l.slog.Info("[TIME] "+msg, args...)
|
||||
}
|
||||
|
||||
func (l *logger) StartOperation(name string) OperationLogger {
|
||||
return &operationLogger{
|
||||
name: name,
|
||||
startTime: time.Now(),
|
||||
parent: l,
|
||||
}
|
||||
}
|
||||
|
||||
func (ol *operationLogger) Update(msg string, args ...any) {
|
||||
elapsed := time.Since(ol.startTime)
|
||||
ol.parent.Info(fmt.Sprintf("[%s] %s", ol.name, msg),
|
||||
append(args, "elapsed", elapsed.String())...)
|
||||
}
|
||||
|
||||
func (ol *operationLogger) Complete(msg string, args ...any) {
|
||||
elapsed := time.Since(ol.startTime)
|
||||
ol.parent.Info(fmt.Sprintf("[%s] COMPLETED: %s", ol.name, msg),
|
||||
append(args, "duration", formatDuration(elapsed))...)
|
||||
}
|
||||
|
||||
func (ol *operationLogger) Fail(msg string, args ...any) {
|
||||
elapsed := time.Since(ol.startTime)
|
||||
ol.parent.Error(fmt.Sprintf("[%s] FAILED: %s", ol.name, msg),
|
||||
append(args, "duration", formatDuration(elapsed))...)
|
||||
}
|
||||
|
||||
// formatDuration formats duration in human-readable format
|
||||
func formatDuration(d time.Duration) string {
|
||||
if d < time.Minute {
|
||||
return fmt.Sprintf("%.1fs", d.Seconds())
|
||||
} else if d < time.Hour {
|
||||
minutes := int(d.Minutes())
|
||||
seconds := int(d.Seconds()) % 60
|
||||
return fmt.Sprintf("%dm %ds", minutes, seconds)
|
||||
} else {
|
||||
hours := int(d.Hours())
|
||||
minutes := int(d.Minutes()) % 60
|
||||
seconds := int(d.Seconds()) % 60
|
||||
return fmt.Sprintf("%dh %dm %ds", hours, minutes, seconds)
|
||||
}
|
||||
}
|
||||
|
||||
// FileLogger creates a logger that writes to both stdout and a file
|
||||
func FileLogger(level, format, filename string) (Logger, error) {
|
||||
var slogLevel slog.Level
|
||||
switch strings.ToLower(level) {
|
||||
case "debug":
|
||||
slogLevel = slog.LevelDebug
|
||||
case "info":
|
||||
slogLevel = slog.LevelInfo
|
||||
case "warn", "warning":
|
||||
slogLevel = slog.LevelWarn
|
||||
case "error":
|
||||
slogLevel = slog.LevelError
|
||||
default:
|
||||
slogLevel = slog.LevelInfo
|
||||
}
|
||||
|
||||
// Open log file
|
||||
file, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open log file: %w", err)
|
||||
}
|
||||
|
||||
// Create multi-writer (stdout + file)
|
||||
multiWriter := io.MultiWriter(os.Stdout, file)
|
||||
|
||||
var handler slog.Handler
|
||||
opts := &slog.HandlerOptions{
|
||||
Level: slogLevel,
|
||||
}
|
||||
|
||||
switch strings.ToLower(format) {
|
||||
case "json":
|
||||
handler = slog.NewJSONHandler(multiWriter, opts)
|
||||
default:
|
||||
handler = slog.NewTextHandler(multiWriter, opts)
|
||||
}
|
||||
|
||||
return &logger{
|
||||
slog: slog.New(handler),
|
||||
level: slogLevel,
|
||||
format: format,
|
||||
}, nil
|
||||
}
|
||||
427
internal/progress/detailed.go
Normal file
427
internal/progress/detailed.go
Normal file
@ -0,0 +1,427 @@
|
||||
package progress
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DetailedReporter provides comprehensive progress reporting with timestamps and status
|
||||
type DetailedReporter struct {
|
||||
mu sync.RWMutex
|
||||
operations []OperationStatus
|
||||
startTime time.Time
|
||||
indicator Indicator
|
||||
logger Logger
|
||||
}
|
||||
|
||||
// OperationStatus represents the status of a backup/restore operation
|
||||
type OperationStatus struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"` // "backup", "restore", "verify"
|
||||
Status string `json:"status"` // "running", "completed", "failed"
|
||||
StartTime time.Time `json:"start_time"`
|
||||
EndTime *time.Time `json:"end_time,omitempty"`
|
||||
Duration time.Duration `json:"duration"`
|
||||
Progress int `json:"progress"` // 0-100
|
||||
Message string `json:"message"`
|
||||
Details map[string]string `json:"details"`
|
||||
Steps []StepStatus `json:"steps"`
|
||||
BytesTotal int64 `json:"bytes_total"`
|
||||
BytesDone int64 `json:"bytes_done"`
|
||||
FilesTotal int `json:"files_total"`
|
||||
FilesDone int `json:"files_done"`
|
||||
Errors []string `json:"errors,omitempty"`
|
||||
}
|
||||
|
||||
// StepStatus represents individual steps within an operation
|
||||
type StepStatus struct {
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
StartTime time.Time `json:"start_time"`
|
||||
EndTime *time.Time `json:"end_time,omitempty"`
|
||||
Duration time.Duration `json:"duration"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// Logger interface for detailed reporting
|
||||
type Logger interface {
|
||||
Info(msg string, args ...any)
|
||||
Warn(msg string, args ...any)
|
||||
Error(msg string, args ...any)
|
||||
Debug(msg string, args ...any)
|
||||
}
|
||||
|
||||
// NewDetailedReporter creates a new detailed progress reporter
|
||||
func NewDetailedReporter(indicator Indicator, logger Logger) *DetailedReporter {
|
||||
return &DetailedReporter{
|
||||
operations: make([]OperationStatus, 0),
|
||||
indicator: indicator,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// StartOperation begins tracking a new operation
|
||||
func (dr *DetailedReporter) StartOperation(id, name, opType string) *OperationTracker {
|
||||
dr.mu.Lock()
|
||||
defer dr.mu.Unlock()
|
||||
|
||||
operation := OperationStatus{
|
||||
ID: id,
|
||||
Name: name,
|
||||
Type: opType,
|
||||
Status: "running",
|
||||
StartTime: time.Now(),
|
||||
Progress: 0,
|
||||
Details: make(map[string]string),
|
||||
Steps: make([]StepStatus, 0),
|
||||
}
|
||||
|
||||
dr.operations = append(dr.operations, operation)
|
||||
|
||||
if dr.startTime.IsZero() {
|
||||
dr.startTime = time.Now()
|
||||
}
|
||||
|
||||
// Start visual indicator
|
||||
if dr.indicator != nil {
|
||||
dr.indicator.Start(fmt.Sprintf("Starting %s: %s", opType, name))
|
||||
}
|
||||
|
||||
// Log operation start
|
||||
dr.logger.Info("Operation started",
|
||||
"id", id,
|
||||
"name", name,
|
||||
"type", opType,
|
||||
"timestamp", operation.StartTime.Format(time.RFC3339))
|
||||
|
||||
return &OperationTracker{
|
||||
reporter: dr,
|
||||
operationID: id,
|
||||
}
|
||||
}
|
||||
|
||||
// OperationTracker provides methods to update operation progress
|
||||
type OperationTracker struct {
|
||||
reporter *DetailedReporter
|
||||
operationID string
|
||||
}
|
||||
|
||||
// UpdateProgress updates the progress of the operation
|
||||
func (ot *OperationTracker) UpdateProgress(progress int, message string) {
|
||||
ot.reporter.mu.Lock()
|
||||
defer ot.reporter.mu.Unlock()
|
||||
|
||||
for i := range ot.reporter.operations {
|
||||
if ot.reporter.operations[i].ID == ot.operationID {
|
||||
ot.reporter.operations[i].Progress = progress
|
||||
ot.reporter.operations[i].Message = message
|
||||
|
||||
// Update visual indicator
|
||||
if ot.reporter.indicator != nil {
|
||||
progressMsg := fmt.Sprintf("[%d%%] %s", progress, message)
|
||||
ot.reporter.indicator.Update(progressMsg)
|
||||
}
|
||||
|
||||
// Log progress update
|
||||
ot.reporter.logger.Debug("Progress update",
|
||||
"operation_id", ot.operationID,
|
||||
"progress", progress,
|
||||
"message", message,
|
||||
"timestamp", time.Now().Format(time.RFC3339))
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AddStep adds a new step to the operation
|
||||
func (ot *OperationTracker) AddStep(name, message string) *StepTracker {
|
||||
ot.reporter.mu.Lock()
|
||||
defer ot.reporter.mu.Unlock()
|
||||
|
||||
step := StepStatus{
|
||||
Name: name,
|
||||
Status: "running",
|
||||
StartTime: time.Now(),
|
||||
Message: message,
|
||||
}
|
||||
|
||||
for i := range ot.reporter.operations {
|
||||
if ot.reporter.operations[i].ID == ot.operationID {
|
||||
ot.reporter.operations[i].Steps = append(ot.reporter.operations[i].Steps, step)
|
||||
|
||||
// Log step start
|
||||
ot.reporter.logger.Info("Step started",
|
||||
"operation_id", ot.operationID,
|
||||
"step", name,
|
||||
"message", message,
|
||||
"timestamp", step.StartTime.Format(time.RFC3339))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return &StepTracker{
|
||||
reporter: ot.reporter,
|
||||
operationID: ot.operationID,
|
||||
stepName: name,
|
||||
}
|
||||
}
|
||||
|
||||
// SetDetails adds metadata to the operation
|
||||
func (ot *OperationTracker) SetDetails(key, value string) {
|
||||
ot.reporter.mu.Lock()
|
||||
defer ot.reporter.mu.Unlock()
|
||||
|
||||
for i := range ot.reporter.operations {
|
||||
if ot.reporter.operations[i].ID == ot.operationID {
|
||||
ot.reporter.operations[i].Details[key] = value
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetFileProgress updates file-based progress
|
||||
func (ot *OperationTracker) SetFileProgress(filesDone, filesTotal int) {
|
||||
ot.reporter.mu.Lock()
|
||||
defer ot.reporter.mu.Unlock()
|
||||
|
||||
for i := range ot.reporter.operations {
|
||||
if ot.reporter.operations[i].ID == ot.operationID {
|
||||
ot.reporter.operations[i].FilesDone = filesDone
|
||||
ot.reporter.operations[i].FilesTotal = filesTotal
|
||||
|
||||
if filesTotal > 0 {
|
||||
progress := (filesDone * 100) / filesTotal
|
||||
ot.reporter.operations[i].Progress = progress
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetByteProgress updates byte-based progress
|
||||
func (ot *OperationTracker) SetByteProgress(bytesDone, bytesTotal int64) {
|
||||
ot.reporter.mu.Lock()
|
||||
defer ot.reporter.mu.Unlock()
|
||||
|
||||
for i := range ot.reporter.operations {
|
||||
if ot.reporter.operations[i].ID == ot.operationID {
|
||||
ot.reporter.operations[i].BytesDone = bytesDone
|
||||
ot.reporter.operations[i].BytesTotal = bytesTotal
|
||||
|
||||
if bytesTotal > 0 {
|
||||
progress := int((bytesDone * 100) / bytesTotal)
|
||||
ot.reporter.operations[i].Progress = progress
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Complete marks the operation as completed
|
||||
func (ot *OperationTracker) Complete(message string) {
|
||||
ot.reporter.mu.Lock()
|
||||
defer ot.reporter.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for i := range ot.reporter.operations {
|
||||
if ot.reporter.operations[i].ID == ot.operationID {
|
||||
ot.reporter.operations[i].Status = "completed"
|
||||
ot.reporter.operations[i].Progress = 100
|
||||
ot.reporter.operations[i].EndTime = &now
|
||||
ot.reporter.operations[i].Duration = now.Sub(ot.reporter.operations[i].StartTime)
|
||||
ot.reporter.operations[i].Message = message
|
||||
|
||||
// Complete visual indicator
|
||||
if ot.reporter.indicator != nil {
|
||||
ot.reporter.indicator.Complete(fmt.Sprintf("✅ %s", message))
|
||||
}
|
||||
|
||||
// Log completion with duration
|
||||
ot.reporter.logger.Info("Operation completed",
|
||||
"operation_id", ot.operationID,
|
||||
"message", message,
|
||||
"duration", ot.reporter.operations[i].Duration.String(),
|
||||
"timestamp", now.Format(time.RFC3339))
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fail marks the operation as failed
|
||||
func (ot *OperationTracker) Fail(err error) {
|
||||
ot.reporter.mu.Lock()
|
||||
defer ot.reporter.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for i := range ot.reporter.operations {
|
||||
if ot.reporter.operations[i].ID == ot.operationID {
|
||||
ot.reporter.operations[i].Status = "failed"
|
||||
ot.reporter.operations[i].EndTime = &now
|
||||
ot.reporter.operations[i].Duration = now.Sub(ot.reporter.operations[i].StartTime)
|
||||
ot.reporter.operations[i].Message = err.Error()
|
||||
ot.reporter.operations[i].Errors = append(ot.reporter.operations[i].Errors, err.Error())
|
||||
|
||||
// Fail visual indicator
|
||||
if ot.reporter.indicator != nil {
|
||||
ot.reporter.indicator.Fail(fmt.Sprintf("❌ %s", err.Error()))
|
||||
}
|
||||
|
||||
// Log failure
|
||||
ot.reporter.logger.Error("Operation failed",
|
||||
"operation_id", ot.operationID,
|
||||
"error", err.Error(),
|
||||
"duration", ot.reporter.operations[i].Duration.String(),
|
||||
"timestamp", now.Format(time.RFC3339))
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// StepTracker manages individual step progress
|
||||
type StepTracker struct {
|
||||
reporter *DetailedReporter
|
||||
operationID string
|
||||
stepName string
|
||||
}
|
||||
|
||||
// Complete marks the step as completed
|
||||
func (st *StepTracker) Complete(message string) {
|
||||
st.reporter.mu.Lock()
|
||||
defer st.reporter.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for i := range st.reporter.operations {
|
||||
if st.reporter.operations[i].ID == st.operationID {
|
||||
for j := range st.reporter.operations[i].Steps {
|
||||
if st.reporter.operations[i].Steps[j].Name == st.stepName {
|
||||
st.reporter.operations[i].Steps[j].Status = "completed"
|
||||
st.reporter.operations[i].Steps[j].EndTime = &now
|
||||
st.reporter.operations[i].Steps[j].Duration = now.Sub(st.reporter.operations[i].Steps[j].StartTime)
|
||||
st.reporter.operations[i].Steps[j].Message = message
|
||||
|
||||
// Log step completion
|
||||
st.reporter.logger.Info("Step completed",
|
||||
"operation_id", st.operationID,
|
||||
"step", st.stepName,
|
||||
"message", message,
|
||||
"duration", st.reporter.operations[i].Steps[j].Duration.String(),
|
||||
"timestamp", now.Format(time.RFC3339))
|
||||
break
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fail marks the step as failed
|
||||
func (st *StepTracker) Fail(err error) {
|
||||
st.reporter.mu.Lock()
|
||||
defer st.reporter.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for i := range st.reporter.operations {
|
||||
if st.reporter.operations[i].ID == st.operationID {
|
||||
for j := range st.reporter.operations[i].Steps {
|
||||
if st.reporter.operations[i].Steps[j].Name == st.stepName {
|
||||
st.reporter.operations[i].Steps[j].Status = "failed"
|
||||
st.reporter.operations[i].Steps[j].EndTime = &now
|
||||
st.reporter.operations[i].Steps[j].Duration = now.Sub(st.reporter.operations[i].Steps[j].StartTime)
|
||||
st.reporter.operations[i].Steps[j].Message = err.Error()
|
||||
|
||||
// Log step failure
|
||||
st.reporter.logger.Error("Step failed",
|
||||
"operation_id", st.operationID,
|
||||
"step", st.stepName,
|
||||
"error", err.Error(),
|
||||
"duration", st.reporter.operations[i].Steps[j].Duration.String(),
|
||||
"timestamp", now.Format(time.RFC3339))
|
||||
break
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetOperationStatus returns the current status of an operation
|
||||
func (dr *DetailedReporter) GetOperationStatus(id string) *OperationStatus {
|
||||
dr.mu.RLock()
|
||||
defer dr.mu.RUnlock()
|
||||
|
||||
for _, op := range dr.operations {
|
||||
if op.ID == id {
|
||||
return &op
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllOperations returns all tracked operations
|
||||
func (dr *DetailedReporter) GetAllOperations() []OperationStatus {
|
||||
dr.mu.RLock()
|
||||
defer dr.mu.RUnlock()
|
||||
|
||||
return append([]OperationStatus(nil), dr.operations...)
|
||||
}
|
||||
|
||||
// GetSummary returns a summary of all operations
|
||||
func (dr *DetailedReporter) GetSummary() OperationSummary {
|
||||
dr.mu.RLock()
|
||||
defer dr.mu.RUnlock()
|
||||
|
||||
summary := OperationSummary{
|
||||
TotalOperations: len(dr.operations),
|
||||
CompletedOperations: 0,
|
||||
FailedOperations: 0,
|
||||
RunningOperations: 0,
|
||||
TotalDuration: time.Since(dr.startTime),
|
||||
}
|
||||
|
||||
for _, op := range dr.operations {
|
||||
switch op.Status {
|
||||
case "completed":
|
||||
summary.CompletedOperations++
|
||||
case "failed":
|
||||
summary.FailedOperations++
|
||||
case "running":
|
||||
summary.RunningOperations++
|
||||
}
|
||||
}
|
||||
|
||||
return summary
|
||||
}
|
||||
|
||||
// OperationSummary provides overall statistics
|
||||
type OperationSummary struct {
|
||||
TotalOperations int `json:"total_operations"`
|
||||
CompletedOperations int `json:"completed_operations"`
|
||||
FailedOperations int `json:"failed_operations"`
|
||||
RunningOperations int `json:"running_operations"`
|
||||
TotalDuration time.Duration `json:"total_duration"`
|
||||
}
|
||||
|
||||
// FormatSummary returns a formatted string representation of the summary
|
||||
func (os *OperationSummary) FormatSummary() string {
|
||||
return fmt.Sprintf(
|
||||
"📊 Operations Summary:\n"+
|
||||
" Total: %d | Completed: %d | Failed: %d | Running: %d\n"+
|
||||
" Total Duration: %s",
|
||||
os.TotalOperations,
|
||||
os.CompletedOperations,
|
||||
os.FailedOperations,
|
||||
os.RunningOperations,
|
||||
formatDuration(os.TotalDuration))
|
||||
}
|
||||
|
||||
// formatDuration formats a duration in a human-readable way
|
||||
func formatDuration(d time.Duration) string {
|
||||
if d < time.Minute {
|
||||
return fmt.Sprintf("%.1fs", d.Seconds())
|
||||
} else if d < time.Hour {
|
||||
return fmt.Sprintf("%.1fm", d.Minutes())
|
||||
}
|
||||
return fmt.Sprintf("%.1fh", d.Hours())
|
||||
}
|
||||
398
internal/progress/progress.go
Normal file
398
internal/progress/progress.go
Normal file
@ -0,0 +1,398 @@
|
||||
package progress
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Indicator represents a progress indicator interface
|
||||
type Indicator interface {
|
||||
Start(message string)
|
||||
Update(message string)
|
||||
Complete(message string)
|
||||
Fail(message string)
|
||||
Stop()
|
||||
}
|
||||
|
||||
// Spinner creates a spinning progress indicator
|
||||
type Spinner struct {
|
||||
writer io.Writer
|
||||
message string
|
||||
active bool
|
||||
frames []string
|
||||
interval time.Duration
|
||||
stopCh chan bool
|
||||
}
|
||||
|
||||
// NewSpinner creates a new spinner progress indicator
|
||||
func NewSpinner() *Spinner {
|
||||
return &Spinner{
|
||||
writer: os.Stdout,
|
||||
frames: []string{"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"},
|
||||
interval: 80 * time.Millisecond,
|
||||
stopCh: make(chan bool, 1),
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins the spinner with a message
|
||||
func (s *Spinner) Start(message string) {
|
||||
s.message = message
|
||||
s.active = true
|
||||
|
||||
go func() {
|
||||
i := 0
|
||||
lastMessage := ""
|
||||
for {
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
return
|
||||
default:
|
||||
if s.active {
|
||||
currentFrame := fmt.Sprintf("%s %s", s.frames[i%len(s.frames)], s.message)
|
||||
if s.message != lastMessage {
|
||||
// Print new line for new messages
|
||||
fmt.Fprintf(s.writer, "\n%s", currentFrame)
|
||||
lastMessage = s.message
|
||||
} else {
|
||||
// Update in place for same message
|
||||
fmt.Fprintf(s.writer, "\r%s", currentFrame)
|
||||
}
|
||||
i++
|
||||
time.Sleep(s.interval)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Update changes the spinner message
|
||||
func (s *Spinner) Update(message string) {
|
||||
s.message = message
|
||||
}
|
||||
|
||||
// Complete stops the spinner with a success message
|
||||
func (s *Spinner) Complete(message string) {
|
||||
s.Stop()
|
||||
fmt.Fprintf(s.writer, "\n✅ %s\n", message)
|
||||
}
|
||||
|
||||
// Fail stops the spinner with a failure message
|
||||
func (s *Spinner) Fail(message string) {
|
||||
s.Stop()
|
||||
fmt.Fprintf(s.writer, "\n❌ %s\n", message)
|
||||
}
|
||||
|
||||
// Stop stops the spinner
|
||||
func (s *Spinner) Stop() {
|
||||
if s.active {
|
||||
s.active = false
|
||||
s.stopCh <- true
|
||||
fmt.Fprint(s.writer, "\n") // New line instead of clearing
|
||||
}
|
||||
}
|
||||
|
||||
// Dots creates a dots progress indicator
|
||||
type Dots struct {
|
||||
writer io.Writer
|
||||
message string
|
||||
active bool
|
||||
stopCh chan bool
|
||||
}
|
||||
|
||||
// NewDots creates a new dots progress indicator
|
||||
func NewDots() *Dots {
|
||||
return &Dots{
|
||||
writer: os.Stdout,
|
||||
stopCh: make(chan bool, 1),
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins the dots indicator
|
||||
func (d *Dots) Start(message string) {
|
||||
d.message = message
|
||||
d.active = true
|
||||
|
||||
fmt.Fprint(d.writer, message)
|
||||
|
||||
go func() {
|
||||
count := 0
|
||||
for {
|
||||
select {
|
||||
case <-d.stopCh:
|
||||
return
|
||||
default:
|
||||
if d.active {
|
||||
fmt.Fprint(d.writer, ".")
|
||||
count++
|
||||
if count%3 == 0 {
|
||||
// Reset dots
|
||||
fmt.Fprint(d.writer, "\r"+d.message)
|
||||
}
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Update changes the dots message
|
||||
func (d *Dots) Update(message string) {
|
||||
d.message = message
|
||||
if d.active {
|
||||
fmt.Fprintf(d.writer, "\n%s", message)
|
||||
}
|
||||
}
|
||||
|
||||
// Complete stops the dots with a success message
|
||||
func (d *Dots) Complete(message string) {
|
||||
d.Stop()
|
||||
fmt.Fprintf(d.writer, " ✅ %s\n", message)
|
||||
}
|
||||
|
||||
// Fail stops the dots with a failure message
|
||||
func (d *Dots) Fail(message string) {
|
||||
d.Stop()
|
||||
fmt.Fprintf(d.writer, " ❌ %s\n", message)
|
||||
}
|
||||
|
||||
// Stop stops the dots indicator
|
||||
func (d *Dots) Stop() {
|
||||
if d.active {
|
||||
d.active = false
|
||||
d.stopCh <- true
|
||||
}
|
||||
}
|
||||
|
||||
// ProgressBar creates a visual progress bar
|
||||
type ProgressBar struct {
|
||||
writer io.Writer
|
||||
message string
|
||||
total int
|
||||
current int
|
||||
width int
|
||||
active bool
|
||||
stopCh chan bool
|
||||
}
|
||||
|
||||
// NewProgressBar creates a new progress bar
|
||||
func NewProgressBar(total int) *ProgressBar {
|
||||
return &ProgressBar{
|
||||
writer: os.Stdout,
|
||||
total: total,
|
||||
width: 40,
|
||||
stopCh: make(chan bool, 1),
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins the progress bar
|
||||
func (p *ProgressBar) Start(message string) {
|
||||
p.message = message
|
||||
p.active = true
|
||||
p.current = 0
|
||||
p.render()
|
||||
}
|
||||
|
||||
// Update advances the progress bar
|
||||
func (p *ProgressBar) Update(message string) {
|
||||
if p.current < p.total {
|
||||
p.current++
|
||||
}
|
||||
p.message = message
|
||||
p.render()
|
||||
}
|
||||
|
||||
// SetProgress sets specific progress value
|
||||
func (p *ProgressBar) SetProgress(current int, message string) {
|
||||
p.current = current
|
||||
p.message = message
|
||||
p.render()
|
||||
}
|
||||
|
||||
// Complete finishes the progress bar
|
||||
func (p *ProgressBar) Complete(message string) {
|
||||
p.current = p.total
|
||||
p.message = message
|
||||
p.render()
|
||||
fmt.Fprintf(p.writer, " ✅ %s\n", message)
|
||||
p.Stop()
|
||||
}
|
||||
|
||||
// Fail stops the progress bar with failure
|
||||
func (p *ProgressBar) Fail(message string) {
|
||||
p.render()
|
||||
fmt.Fprintf(p.writer, " ❌ %s\n", message)
|
||||
p.Stop()
|
||||
}
|
||||
|
||||
// Stop stops the progress bar
|
||||
func (p *ProgressBar) Stop() {
|
||||
p.active = false
|
||||
}
|
||||
|
||||
// render draws the progress bar
|
||||
func (p *ProgressBar) render() {
|
||||
if !p.active {
|
||||
return
|
||||
}
|
||||
|
||||
percent := float64(p.current) / float64(p.total)
|
||||
filled := int(percent * float64(p.width))
|
||||
|
||||
bar := strings.Repeat("█", filled) + strings.Repeat("░", p.width-filled)
|
||||
|
||||
fmt.Fprintf(p.writer, "\n%s [%s] %d%%", p.message, bar, int(percent*100))
|
||||
}
|
||||
|
||||
// Static creates a simple static progress indicator
|
||||
type Static struct {
|
||||
writer io.Writer
|
||||
}
|
||||
|
||||
// NewStatic creates a new static progress indicator
|
||||
func NewStatic() *Static {
|
||||
return &Static{
|
||||
writer: os.Stdout,
|
||||
}
|
||||
}
|
||||
|
||||
// Start shows the initial message
|
||||
func (s *Static) Start(message string) {
|
||||
fmt.Fprintf(s.writer, "→ %s", message)
|
||||
}
|
||||
|
||||
// Update shows an update message
|
||||
func (s *Static) Update(message string) {
|
||||
fmt.Fprintf(s.writer, " - %s", message)
|
||||
}
|
||||
|
||||
// Complete shows completion message
|
||||
func (s *Static) Complete(message string) {
|
||||
fmt.Fprintf(s.writer, " ✅ %s\n", message)
|
||||
}
|
||||
|
||||
// Fail shows failure message
|
||||
func (s *Static) Fail(message string) {
|
||||
fmt.Fprintf(s.writer, " ❌ %s\n", message)
|
||||
}
|
||||
|
||||
// Stop does nothing for static indicator
|
||||
func (s *Static) Stop() {
|
||||
// No-op for static indicator
|
||||
}
|
||||
|
||||
// LineByLine creates a line-by-line progress indicator
|
||||
type LineByLine struct {
|
||||
writer io.Writer
|
||||
silent bool
|
||||
}
|
||||
|
||||
// NewLineByLine creates a new line-by-line progress indicator
|
||||
func NewLineByLine() *LineByLine {
|
||||
return &LineByLine{
|
||||
writer: os.Stdout,
|
||||
silent: false,
|
||||
}
|
||||
}
|
||||
|
||||
// Light creates a minimal progress indicator with just essential status
|
||||
type Light struct {
|
||||
writer io.Writer
|
||||
silent bool
|
||||
}
|
||||
|
||||
// NewLight creates a new light progress indicator
|
||||
func NewLight() *Light {
|
||||
return &Light{
|
||||
writer: os.Stdout,
|
||||
silent: false,
|
||||
}
|
||||
}
|
||||
|
||||
// NewQuietLineByLine creates a quiet line-by-line progress indicator
|
||||
func NewQuietLineByLine() *LineByLine {
|
||||
return &LineByLine{
|
||||
writer: os.Stdout,
|
||||
silent: true,
|
||||
}
|
||||
}
|
||||
|
||||
// Start shows the initial message
|
||||
func (l *LineByLine) Start(message string) {
|
||||
fmt.Fprintf(l.writer, "\n🔄 %s\n", message)
|
||||
}
|
||||
|
||||
// Update shows an update message
|
||||
func (l *LineByLine) Update(message string) {
|
||||
if !l.silent {
|
||||
fmt.Fprintf(l.writer, " %s\n", message)
|
||||
}
|
||||
}
|
||||
|
||||
// Complete shows completion message
|
||||
func (l *LineByLine) Complete(message string) {
|
||||
fmt.Fprintf(l.writer, "✅ %s\n\n", message)
|
||||
}
|
||||
|
||||
// Fail shows failure message
|
||||
func (l *LineByLine) Fail(message string) {
|
||||
fmt.Fprintf(l.writer, "❌ %s\n\n", message)
|
||||
}
|
||||
|
||||
// Stop does nothing for line-by-line (no cleanup needed)
|
||||
func (l *LineByLine) Stop() {
|
||||
// No cleanup needed for line-by-line
|
||||
}
|
||||
|
||||
// Light indicator methods - minimal output
|
||||
func (l *Light) Start(message string) {
|
||||
if !l.silent {
|
||||
fmt.Fprintf(l.writer, "▶ %s\n", message)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Light) Update(message string) {
|
||||
if !l.silent {
|
||||
fmt.Fprintf(l.writer, " %s\n", message)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Light) Complete(message string) {
|
||||
if !l.silent {
|
||||
fmt.Fprintf(l.writer, "✓ %s\n", message)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Light) Fail(message string) {
|
||||
if !l.silent {
|
||||
fmt.Fprintf(l.writer, "✗ %s\n", message)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Light) Stop() {
|
||||
// No cleanup needed for light indicator
|
||||
}
|
||||
|
||||
// NewIndicator creates an appropriate progress indicator based on environment
|
||||
func NewIndicator(interactive bool, indicatorType string) Indicator {
|
||||
if !interactive {
|
||||
return NewLineByLine() // Use line-by-line for non-interactive mode
|
||||
}
|
||||
|
||||
switch indicatorType {
|
||||
case "spinner":
|
||||
return NewSpinner()
|
||||
case "dots":
|
||||
return NewDots()
|
||||
case "bar":
|
||||
return NewProgressBar(100) // Default to 100 steps
|
||||
case "line":
|
||||
return NewLineByLine()
|
||||
case "light":
|
||||
return NewLight()
|
||||
default:
|
||||
return NewLineByLine() // Default to line-by-line for better compatibility
|
||||
}
|
||||
}
|
||||
658
internal/tui/menu.go
Normal file
658
internal/tui/menu.go
Normal file
@ -0,0 +1,658 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/bubbles/spinner"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
|
||||
"dbbackup/internal/config"
|
||||
"dbbackup/internal/database"
|
||||
"dbbackup/internal/logger"
|
||||
"dbbackup/internal/progress"
|
||||
)
|
||||
|
||||
// Style definitions
|
||||
var (
|
||||
titleStyle = lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(lipgloss.Color("#FAFAFA")).
|
||||
Background(lipgloss.Color("#7D56F4")).
|
||||
Padding(0, 1)
|
||||
|
||||
menuStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#626262"))
|
||||
|
||||
selectedStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#FF75B7")).
|
||||
Bold(true)
|
||||
|
||||
infoStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#626262"))
|
||||
|
||||
successStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#04B575")).
|
||||
Bold(true)
|
||||
|
||||
errorStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#FF6B6B")).
|
||||
Bold(true)
|
||||
|
||||
progressStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#FFD93D")).
|
||||
Bold(true)
|
||||
|
||||
stepStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#6BCF7F")).
|
||||
MarginLeft(2)
|
||||
|
||||
detailStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#A8A8A8")).
|
||||
MarginLeft(4).
|
||||
Italic(true)
|
||||
)
|
||||
|
||||
// MenuModel represents the enhanced menu state with progress tracking
|
||||
type MenuModel struct {
|
||||
choices []string
|
||||
cursor int
|
||||
config *config.Config
|
||||
logger logger.Logger
|
||||
quitting bool
|
||||
message string
|
||||
|
||||
// Progress tracking
|
||||
showProgress bool
|
||||
showCompletion bool
|
||||
completionMessage string
|
||||
completionDismissed bool // Track if user manually dismissed completion
|
||||
currentOperation *progress.OperationStatus
|
||||
allOperations []progress.OperationStatus
|
||||
lastUpdate time.Time
|
||||
spinner spinner.Model
|
||||
|
||||
// Background operations
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
// TUI Progress Reporter
|
||||
progressReporter *TUIProgressReporter
|
||||
}
|
||||
|
||||
// completionMsg carries completion status
|
||||
type completionMsg struct {
|
||||
success bool
|
||||
message string
|
||||
}
|
||||
|
||||
// operationUpdateMsg carries operation updates
|
||||
type operationUpdateMsg struct {
|
||||
operations []progress.OperationStatus
|
||||
}
|
||||
|
||||
// operationCompleteMsg signals operation completion
|
||||
type operationCompleteMsg struct {
|
||||
operation *progress.OperationStatus
|
||||
success bool
|
||||
}
|
||||
|
||||
// Initialize the menu model
|
||||
func NewMenuModel(cfg *config.Config, log logger.Logger) MenuModel {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
s := spinner.New()
|
||||
s.Spinner = spinner.Dot
|
||||
s.Style = lipgloss.NewStyle().Foreground(lipgloss.Color("#FFD93D"))
|
||||
|
||||
// Create TUI progress reporter
|
||||
progressReporter := NewTUIProgressReporter()
|
||||
|
||||
model := MenuModel{
|
||||
choices: []string{
|
||||
"Single Database Backup",
|
||||
"Sample Database Backup (with ratio)",
|
||||
"Cluster Backup (all databases)",
|
||||
"View Active Operations",
|
||||
"Show Operation History",
|
||||
"Database Status & Health Check",
|
||||
"Configuration Settings",
|
||||
"Clear Operation History",
|
||||
"Quit",
|
||||
},
|
||||
config: cfg,
|
||||
logger: log,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
spinner: s,
|
||||
lastUpdate: time.Now(),
|
||||
progressReporter: progressReporter,
|
||||
}
|
||||
|
||||
// Set up progress callback
|
||||
progressReporter.AddCallback(func(operations []progress.OperationStatus) {
|
||||
// This will be called when operations update
|
||||
// The TUI will pick up these updates in the pollOperations method
|
||||
})
|
||||
|
||||
return model
|
||||
}
|
||||
|
||||
// Init initializes the model
|
||||
func (m MenuModel) Init() tea.Cmd {
|
||||
return tea.Batch(
|
||||
m.spinner.Tick,
|
||||
m.pollOperations(),
|
||||
)
|
||||
}
|
||||
|
||||
// pollOperations periodically checks for operation updates
|
||||
func (m MenuModel) pollOperations() tea.Cmd {
|
||||
return tea.Tick(time.Millisecond*500, func(t time.Time) tea.Msg {
|
||||
// Get operations from our TUI progress reporter
|
||||
operations := m.progressReporter.GetOperations()
|
||||
return operationUpdateMsg{operations: operations}
|
||||
})
|
||||
}
|
||||
|
||||
// Update handles messages
|
||||
func (m MenuModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
switch msg.String() {
|
||||
case "ctrl+c", "q":
|
||||
if m.cancel != nil {
|
||||
m.cancel()
|
||||
}
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
|
||||
case "up", "k":
|
||||
// Clear completion status and allow navigation
|
||||
if m.showCompletion {
|
||||
m.showCompletion = false
|
||||
m.completionMessage = ""
|
||||
m.message = ""
|
||||
m.completionDismissed = true // Mark as manually dismissed
|
||||
}
|
||||
if m.cursor > 0 {
|
||||
m.cursor--
|
||||
}
|
||||
|
||||
case "down", "j":
|
||||
// Clear completion status and allow navigation
|
||||
if m.showCompletion {
|
||||
m.showCompletion = false
|
||||
m.completionMessage = ""
|
||||
m.message = ""
|
||||
m.completionDismissed = true // Mark as manually dismissed
|
||||
}
|
||||
if m.cursor < len(m.choices)-1 {
|
||||
m.cursor++
|
||||
}
|
||||
|
||||
case "enter", " ":
|
||||
// Clear completion status and allow selection
|
||||
if m.showCompletion {
|
||||
m.showCompletion = false
|
||||
m.completionMessage = ""
|
||||
m.message = ""
|
||||
m.completionDismissed = true // Mark as manually dismissed
|
||||
return m, m.pollOperations()
|
||||
}
|
||||
|
||||
switch m.cursor {
|
||||
case 0: // Single Database Backup
|
||||
return m.handleSingleBackup()
|
||||
case 1: // Sample Database Backup
|
||||
return m.handleSampleBackup()
|
||||
case 2: // Cluster Backup
|
||||
return m.handleClusterBackup()
|
||||
case 3: // View Active Operations
|
||||
return m.handleViewOperations()
|
||||
case 4: // Show Operation History
|
||||
return m.handleOperationHistory()
|
||||
case 5: // Database Status
|
||||
return m.handleStatus()
|
||||
case 6: // Settings
|
||||
return m.handleSettings()
|
||||
case 7: // Clear History
|
||||
return m.handleClearHistory()
|
||||
case 8: // Quit
|
||||
if m.cancel != nil {
|
||||
m.cancel()
|
||||
}
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
|
||||
case "esc":
|
||||
// Clear completion status on escape
|
||||
if m.showCompletion {
|
||||
m.showCompletion = false
|
||||
m.completionMessage = ""
|
||||
m.message = ""
|
||||
m.completionDismissed = true // Mark as manually dismissed
|
||||
}
|
||||
}
|
||||
|
||||
case operationUpdateMsg:
|
||||
m.allOperations = msg.operations
|
||||
if len(msg.operations) > 0 {
|
||||
latest := msg.operations[len(msg.operations)-1]
|
||||
if latest.Status == "running" {
|
||||
m.currentOperation = &latest
|
||||
m.showProgress = true
|
||||
m.showCompletion = false
|
||||
m.completionDismissed = false // Reset dismissal flag for new operation
|
||||
} else if m.currentOperation != nil && latest.ID == m.currentOperation.ID {
|
||||
m.currentOperation = &latest
|
||||
m.showProgress = false
|
||||
// Only show completion status if user hasn't manually dismissed it
|
||||
if !m.completionDismissed {
|
||||
if latest.Status == "completed" {
|
||||
m.showCompletion = true
|
||||
m.completionMessage = fmt.Sprintf("✅ %s", latest.Message)
|
||||
} else if latest.Status == "failed" {
|
||||
m.showCompletion = true
|
||||
m.completionMessage = fmt.Sprintf("❌ %s", latest.Message)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return m, m.pollOperations()
|
||||
|
||||
case completionMsg:
|
||||
m.showProgress = false
|
||||
m.showCompletion = true
|
||||
if msg.success {
|
||||
m.completionMessage = fmt.Sprintf("✅ %s", msg.message)
|
||||
} else {
|
||||
m.completionMessage = fmt.Sprintf("❌ %s", msg.message)
|
||||
}
|
||||
return m, m.pollOperations()
|
||||
|
||||
case operationCompleteMsg:
|
||||
m.currentOperation = msg.operation
|
||||
m.showProgress = false
|
||||
if msg.success {
|
||||
m.message = fmt.Sprintf("✅ Operation completed: %s", msg.operation.Message)
|
||||
} else {
|
||||
m.message = fmt.Sprintf("❌ Operation failed: %s", msg.operation.Message)
|
||||
}
|
||||
return m, m.pollOperations()
|
||||
|
||||
case spinner.TickMsg:
|
||||
var cmd tea.Cmd
|
||||
m.spinner, cmd = m.spinner.Update(msg)
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// View renders the enhanced menu with progress tracking
|
||||
func (m MenuModel) View() string {
|
||||
if m.quitting {
|
||||
return "Thanks for using DB Backup Tool!\n"
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
|
||||
// Header
|
||||
header := titleStyle.Render("🗄️ Database Backup Tool - Interactive Menu")
|
||||
b.WriteString(fmt.Sprintf("\n%s\n\n", header))
|
||||
|
||||
// Database info
|
||||
dbInfo := infoStyle.Render(fmt.Sprintf("Database: %s@%s:%d (%s)",
|
||||
m.config.User, m.config.Host, m.config.Port, m.config.DatabaseType))
|
||||
b.WriteString(fmt.Sprintf("%s\n\n", dbInfo))
|
||||
|
||||
// Menu items
|
||||
for i, choice := range m.choices {
|
||||
cursor := " "
|
||||
if m.cursor == i {
|
||||
cursor = ">"
|
||||
b.WriteString(selectedStyle.Render(fmt.Sprintf("%s %s", cursor, choice)))
|
||||
} else {
|
||||
b.WriteString(menuStyle.Render(fmt.Sprintf("%s %s", cursor, choice)))
|
||||
}
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
// Current operation progress
|
||||
if m.showProgress && m.currentOperation != nil {
|
||||
b.WriteString("\n")
|
||||
b.WriteString(m.renderOperationProgress(m.currentOperation))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
// Completion status (persistent until key press)
|
||||
if m.showCompletion {
|
||||
b.WriteString("\n")
|
||||
b.WriteString(successStyle.Render(m.completionMessage))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(infoStyle.Render("💡 Press any key to continue..."))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
// Message area
|
||||
if m.message != "" && !m.showCompletion {
|
||||
b.WriteString("\n")
|
||||
b.WriteString(m.message)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
// Operations summary
|
||||
if len(m.allOperations) > 0 {
|
||||
b.WriteString("\n")
|
||||
b.WriteString(m.renderOperationsSummary())
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
// Footer
|
||||
var footer string
|
||||
if m.showCompletion {
|
||||
footer = infoStyle.Render("\n⌨️ Press Enter, ↑/↓ arrows, or Esc to continue...")
|
||||
} else {
|
||||
footer = infoStyle.Render("\n⌨️ Press ↑/↓ to navigate • Enter to select • q to quit")
|
||||
}
|
||||
b.WriteString(footer)
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// renderOperationProgress renders detailed progress for the current operation
|
||||
func (m MenuModel) renderOperationProgress(op *progress.OperationStatus) string {
|
||||
var b strings.Builder
|
||||
|
||||
// Operation header with spinner
|
||||
spinnerView := ""
|
||||
if op.Status == "running" {
|
||||
spinnerView = m.spinner.View() + " "
|
||||
}
|
||||
|
||||
status := "🔄"
|
||||
if op.Status == "completed" {
|
||||
status = "✅"
|
||||
} else if op.Status == "failed" {
|
||||
status = "❌"
|
||||
}
|
||||
|
||||
b.WriteString(progressStyle.Render(fmt.Sprintf("%s%s %s [%d%%]",
|
||||
spinnerView, status, strings.Title(op.Type), op.Progress)))
|
||||
b.WriteString("\n")
|
||||
|
||||
// Progress bar
|
||||
barWidth := 40
|
||||
filledWidth := (op.Progress * barWidth) / 100
|
||||
if filledWidth > barWidth {
|
||||
filledWidth = barWidth
|
||||
}
|
||||
bar := strings.Repeat("█", filledWidth) + strings.Repeat("░", barWidth-filledWidth)
|
||||
b.WriteString(detailStyle.Render(fmt.Sprintf("[%s] %s", bar, op.Message)))
|
||||
b.WriteString("\n")
|
||||
|
||||
// Time and details
|
||||
elapsed := time.Since(op.StartTime)
|
||||
timeInfo := fmt.Sprintf("Elapsed: %s", formatDuration(elapsed))
|
||||
if op.EndTime != nil {
|
||||
timeInfo = fmt.Sprintf("Duration: %s", op.Duration.String())
|
||||
}
|
||||
b.WriteString(detailStyle.Render(timeInfo))
|
||||
b.WriteString("\n")
|
||||
|
||||
// File/byte progress
|
||||
if op.FilesTotal > 0 {
|
||||
b.WriteString(detailStyle.Render(fmt.Sprintf("Files: %d/%d", op.FilesDone, op.FilesTotal)))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
if op.BytesTotal > 0 {
|
||||
b.WriteString(detailStyle.Render(fmt.Sprintf("Data: %s/%s",
|
||||
formatBytes(op.BytesDone), formatBytes(op.BytesTotal))))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
// Current steps
|
||||
if len(op.Steps) > 0 {
|
||||
b.WriteString(stepStyle.Render("Steps:"))
|
||||
b.WriteString("\n")
|
||||
for _, step := range op.Steps {
|
||||
stepStatus := "⏳"
|
||||
if step.Status == "completed" {
|
||||
stepStatus = "✅"
|
||||
} else if step.Status == "failed" {
|
||||
stepStatus = "❌"
|
||||
}
|
||||
b.WriteString(detailStyle.Render(fmt.Sprintf(" %s %s", stepStatus, step.Name)))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// renderOperationsSummary renders a summary of all operations
|
||||
func (m MenuModel) renderOperationsSummary() string {
|
||||
if len(m.allOperations) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
completed := 0
|
||||
failed := 0
|
||||
running := 0
|
||||
|
||||
for _, op := range m.allOperations {
|
||||
switch op.Status {
|
||||
case "completed":
|
||||
completed++
|
||||
case "failed":
|
||||
failed++
|
||||
case "running":
|
||||
running++
|
||||
}
|
||||
}
|
||||
|
||||
summary := fmt.Sprintf("📊 Operations: %d total | %d completed | %d failed | %d running",
|
||||
len(m.allOperations), completed, failed, running)
|
||||
|
||||
return infoStyle.Render(summary)
|
||||
}
|
||||
|
||||
// Enhanced backup handlers with progress tracking
|
||||
|
||||
// Handle single database backup with progress
|
||||
func (m MenuModel) handleSingleBackup() (tea.Model, tea.Cmd) {
|
||||
if m.config.Database == "" {
|
||||
m.message = errorStyle.Render("❌ No database specified. Use --database flag or set in config.")
|
||||
return m, nil
|
||||
}
|
||||
|
||||
m.message = progressStyle.Render(fmt.Sprintf("🔄 Starting single backup for: %s", m.config.Database))
|
||||
m.showProgress = true
|
||||
m.showCompletion = false
|
||||
|
||||
// Start backup and return polling command
|
||||
go func() {
|
||||
err := RunBackupInTUI(m.ctx, m.config, m.logger, "single", m.config.Database, m.progressReporter)
|
||||
// The completion will be handled by the progress reporter callback system
|
||||
_ = err // Handle error in the progress reporter
|
||||
}()
|
||||
|
||||
return m, m.pollOperations()
|
||||
}
|
||||
|
||||
// Handle sample backup with progress
|
||||
func (m MenuModel) handleSampleBackup() (tea.Model, tea.Cmd) {
|
||||
m.message = progressStyle.Render("🔄 Starting sample backup...")
|
||||
m.showProgress = true
|
||||
m.showCompletion = false
|
||||
m.completionDismissed = false // Reset for new operation
|
||||
|
||||
// Start backup and return polling command
|
||||
go func() {
|
||||
err := RunBackupInTUI(m.ctx, m.config, m.logger, "sample", "", m.progressReporter)
|
||||
// The completion will be handled by the progress reporter callback system
|
||||
_ = err // Handle error in the progress reporter
|
||||
}()
|
||||
|
||||
return m, m.pollOperations()
|
||||
}
|
||||
|
||||
// Handle cluster backup with progress
|
||||
func (m MenuModel) handleClusterBackup() (tea.Model, tea.Cmd) {
|
||||
m.message = progressStyle.Render("🔄 Starting cluster backup (all databases)...")
|
||||
m.showProgress = true
|
||||
m.showCompletion = false
|
||||
m.completionDismissed = false // Reset for new operation
|
||||
|
||||
// Start backup and return polling command
|
||||
go func() {
|
||||
err := RunBackupInTUI(m.ctx, m.config, m.logger, "cluster", "", m.progressReporter)
|
||||
// The completion will be handled by the progress reporter callback system
|
||||
_ = err // Handle error in the progress reporter
|
||||
}()
|
||||
|
||||
return m, m.pollOperations()
|
||||
}
|
||||
|
||||
// Handle viewing active operations
|
||||
func (m MenuModel) handleViewOperations() (tea.Model, tea.Cmd) {
|
||||
if len(m.allOperations) == 0 {
|
||||
m.message = infoStyle.Render("ℹ️ No operations currently running or completed")
|
||||
return m, nil
|
||||
}
|
||||
|
||||
var activeOps []progress.OperationStatus
|
||||
for _, op := range m.allOperations {
|
||||
if op.Status == "running" {
|
||||
activeOps = append(activeOps, op)
|
||||
}
|
||||
}
|
||||
|
||||
if len(activeOps) == 0 {
|
||||
m.message = infoStyle.Render("ℹ️ No operations currently running")
|
||||
} else {
|
||||
m.message = progressStyle.Render(fmt.Sprintf("🔄 %d active operations", len(activeOps)))
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Handle showing operation history
|
||||
func (m MenuModel) handleOperationHistory() (tea.Model, tea.Cmd) {
|
||||
if len(m.allOperations) == 0 {
|
||||
m.message = infoStyle.Render("ℹ️ No operation history available")
|
||||
return m, nil
|
||||
}
|
||||
|
||||
var history strings.Builder
|
||||
history.WriteString("📋 Operation History:\n")
|
||||
|
||||
for i, op := range m.allOperations {
|
||||
if i >= 5 { // Show last 5 operations
|
||||
break
|
||||
}
|
||||
|
||||
status := "🔄"
|
||||
if op.Status == "completed" {
|
||||
status = "✅"
|
||||
} else if op.Status == "failed" {
|
||||
status = "❌"
|
||||
}
|
||||
|
||||
history.WriteString(fmt.Sprintf("%s %s - %s (%s)\n",
|
||||
status, op.Name, op.Type, op.StartTime.Format("15:04:05")))
|
||||
}
|
||||
|
||||
m.message = history.String()
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Handle status check
|
||||
func (m MenuModel) handleStatus() (tea.Model, tea.Cmd) {
|
||||
db, err := database.New(m.config, m.logger)
|
||||
if err != nil {
|
||||
m.message = errorStyle.Render(fmt.Sprintf("❌ Connection failed: %v", err))
|
||||
return m, nil
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Connect(m.ctx)
|
||||
if err != nil {
|
||||
m.message = errorStyle.Render(fmt.Sprintf("❌ Connection failed: %v", err))
|
||||
return m, nil
|
||||
}
|
||||
|
||||
err = db.Ping(m.ctx)
|
||||
if err != nil {
|
||||
m.message = errorStyle.Render(fmt.Sprintf("❌ Ping failed: %v", err))
|
||||
return m, nil
|
||||
}
|
||||
|
||||
version, err := db.GetVersion(m.ctx)
|
||||
if err != nil {
|
||||
m.message = errorStyle.Render(fmt.Sprintf("❌ Failed to get version: %v", err))
|
||||
return m, nil
|
||||
}
|
||||
|
||||
m.message = successStyle.Render(fmt.Sprintf("✅ Connected successfully!\nVersion: %s", version))
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Handle settings display
|
||||
func (m MenuModel) handleSettings() (tea.Model, tea.Cmd) {
|
||||
// Create and switch to settings model
|
||||
settingsModel := NewSettingsModel(m.config, m.logger, m)
|
||||
return settingsModel, settingsModel.Init()
|
||||
}
|
||||
|
||||
// Handle clearing operation history
|
||||
func (m MenuModel) handleClearHistory() (tea.Model, tea.Cmd) {
|
||||
m.allOperations = []progress.OperationStatus{}
|
||||
m.currentOperation = nil
|
||||
m.showProgress = false
|
||||
m.message = successStyle.Render("✅ Operation history cleared")
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Utility functions
|
||||
|
||||
// formatDuration formats a duration in a human-readable way
|
||||
func formatDuration(d time.Duration) string {
|
||||
if d < time.Minute {
|
||||
return fmt.Sprintf("%.1fs", d.Seconds())
|
||||
} else if d < time.Hour {
|
||||
return fmt.Sprintf("%.1fm", d.Minutes())
|
||||
}
|
||||
return fmt.Sprintf("%.1fh", d.Hours())
|
||||
}
|
||||
|
||||
// formatBytes formats byte count in human-readable format
|
||||
func formatBytes(bytes int64) string {
|
||||
const unit = 1024
|
||||
if bytes < unit {
|
||||
return fmt.Sprintf("%d B", bytes)
|
||||
}
|
||||
div, exp := int64(unit), 0
|
||||
for n := bytes / unit; n >= unit; n /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(div), "KMGTPE"[exp])
|
||||
}
|
||||
|
||||
// RunInteractiveMenu starts the enhanced TUI with progress tracking
|
||||
func RunInteractiveMenu(cfg *config.Config, log logger.Logger) error {
|
||||
m := NewMenuModel(cfg, log)
|
||||
p := tea.NewProgram(m, tea.WithAltScreen())
|
||||
|
||||
if _, err := p.Run(); err != nil {
|
||||
return fmt.Errorf("error running interactive menu: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
212
internal/tui/progress.go
Normal file
212
internal/tui/progress.go
Normal file
@ -0,0 +1,212 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"dbbackup/internal/backup"
|
||||
"dbbackup/internal/config"
|
||||
"dbbackup/internal/database"
|
||||
"dbbackup/internal/logger"
|
||||
"dbbackup/internal/progress"
|
||||
)
|
||||
|
||||
// TUIProgressReporter is a progress reporter that integrates with the TUI
|
||||
type TUIProgressReporter struct {
|
||||
mu sync.RWMutex
|
||||
operations map[string]*progress.OperationStatus
|
||||
callbacks []func([]progress.OperationStatus)
|
||||
}
|
||||
|
||||
// NewTUIProgressReporter creates a new TUI-compatible progress reporter
|
||||
func NewTUIProgressReporter() *TUIProgressReporter {
|
||||
return &TUIProgressReporter{
|
||||
operations: make(map[string]*progress.OperationStatus),
|
||||
callbacks: make([]func([]progress.OperationStatus), 0),
|
||||
}
|
||||
}
|
||||
|
||||
// AddCallback adds a callback function to be called when operations update
|
||||
func (t *TUIProgressReporter) AddCallback(callback func([]progress.OperationStatus)) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
t.callbacks = append(t.callbacks, callback)
|
||||
}
|
||||
|
||||
// notifyCallbacks calls all registered callbacks with current operations
|
||||
func (t *TUIProgressReporter) notifyCallbacks() {
|
||||
operations := make([]progress.OperationStatus, 0, len(t.operations))
|
||||
for _, op := range t.operations {
|
||||
operations = append(operations, *op)
|
||||
}
|
||||
|
||||
for _, callback := range t.callbacks {
|
||||
go callback(operations)
|
||||
}
|
||||
}
|
||||
|
||||
// StartOperation starts tracking a new operation
|
||||
func (t *TUIProgressReporter) StartOperation(id, name, opType string) *TUIOperationTracker {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
operation := &progress.OperationStatus{
|
||||
ID: id,
|
||||
Name: name,
|
||||
Type: opType,
|
||||
Status: "running",
|
||||
StartTime: time.Now(),
|
||||
Progress: 0,
|
||||
Message: fmt.Sprintf("Starting %s: %s", opType, name),
|
||||
Details: make(map[string]string),
|
||||
Steps: make([]progress.StepStatus, 0),
|
||||
}
|
||||
|
||||
t.operations[id] = operation
|
||||
t.notifyCallbacks()
|
||||
|
||||
return &TUIOperationTracker{
|
||||
reporter: t,
|
||||
operationID: id,
|
||||
}
|
||||
}
|
||||
|
||||
// TUIOperationTracker tracks progress for TUI display
|
||||
type TUIOperationTracker struct {
|
||||
reporter *TUIProgressReporter
|
||||
operationID string
|
||||
}
|
||||
|
||||
// UpdateProgress updates the operation progress
|
||||
func (t *TUIOperationTracker) UpdateProgress(progress int, message string) {
|
||||
t.reporter.mu.Lock()
|
||||
defer t.reporter.mu.Unlock()
|
||||
|
||||
if op, exists := t.reporter.operations[t.operationID]; exists {
|
||||
op.Progress = progress
|
||||
op.Message = message
|
||||
t.reporter.notifyCallbacks()
|
||||
}
|
||||
}
|
||||
|
||||
// Complete marks the operation as completed
|
||||
func (t *TUIOperationTracker) Complete(message string) {
|
||||
t.reporter.mu.Lock()
|
||||
defer t.reporter.mu.Unlock()
|
||||
|
||||
if op, exists := t.reporter.operations[t.operationID]; exists {
|
||||
now := time.Now()
|
||||
op.Status = "completed"
|
||||
op.Progress = 100
|
||||
op.Message = message
|
||||
op.EndTime = &now
|
||||
op.Duration = now.Sub(op.StartTime)
|
||||
t.reporter.notifyCallbacks()
|
||||
}
|
||||
}
|
||||
|
||||
// Fail marks the operation as failed
|
||||
func (t *TUIOperationTracker) Fail(message string) {
|
||||
t.reporter.mu.Lock()
|
||||
defer t.reporter.mu.Unlock()
|
||||
|
||||
if op, exists := t.reporter.operations[t.operationID]; exists {
|
||||
now := time.Now()
|
||||
op.Status = "failed"
|
||||
op.Message = message
|
||||
op.EndTime = &now
|
||||
op.Duration = now.Sub(op.StartTime)
|
||||
t.reporter.notifyCallbacks()
|
||||
}
|
||||
}
|
||||
|
||||
// GetOperations returns all current operations
|
||||
func (t *TUIProgressReporter) GetOperations() []progress.OperationStatus {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
|
||||
operations := make([]progress.OperationStatus, 0, len(t.operations))
|
||||
for _, op := range t.operations {
|
||||
operations = append(operations, *op)
|
||||
}
|
||||
return operations
|
||||
}
|
||||
|
||||
// SilentLogger implements logger.Logger but doesn't output anything
|
||||
type SilentLogger struct{}
|
||||
|
||||
func (s *SilentLogger) Info(msg string, args ...any) {}
|
||||
func (s *SilentLogger) Warn(msg string, args ...any) {}
|
||||
func (s *SilentLogger) Error(msg string, args ...any) {}
|
||||
func (s *SilentLogger) Debug(msg string, args ...any) {}
|
||||
func (s *SilentLogger) Time(msg string, args ...any) {}
|
||||
func (s *SilentLogger) StartOperation(name string) logger.OperationLogger {
|
||||
return &SilentOperation{}
|
||||
}
|
||||
|
||||
// SilentOperation implements logger.OperationLogger but doesn't output anything
|
||||
type SilentOperation struct{}
|
||||
|
||||
func (s *SilentOperation) Update(message string, args ...any) {}
|
||||
func (s *SilentOperation) Complete(message string, args ...any) {}
|
||||
func (s *SilentOperation) Fail(message string, args ...any) {}
|
||||
|
||||
// SilentProgressIndicator implements progress.Indicator but doesn't output anything
|
||||
type SilentProgressIndicator struct{}
|
||||
|
||||
func (s *SilentProgressIndicator) Start(message string) {}
|
||||
func (s *SilentProgressIndicator) Update(message string) {}
|
||||
func (s *SilentProgressIndicator) Complete(message string) {}
|
||||
func (s *SilentProgressIndicator) Fail(message string) {}
|
||||
func (s *SilentProgressIndicator) Stop() {}
|
||||
|
||||
// RunBackupInTUI runs a backup operation with TUI-compatible progress reporting
|
||||
func RunBackupInTUI(ctx context.Context, cfg *config.Config, log logger.Logger,
|
||||
backupType string, databaseName string, reporter *TUIProgressReporter) error {
|
||||
|
||||
// Create database connection
|
||||
db, err := database.New(cfg, &SilentLogger{}) // Use silent logger
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create database connection: %w", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Connect(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
|
||||
// Create backup engine with silent progress indicator and logger
|
||||
silentProgress := &SilentProgressIndicator{}
|
||||
engine := backup.NewSilent(cfg, &SilentLogger{}, db, silentProgress)
|
||||
|
||||
// Start operation tracking
|
||||
operationID := fmt.Sprintf("%s_%d", backupType, time.Now().Unix())
|
||||
tracker := reporter.StartOperation(operationID, databaseName, backupType)
|
||||
|
||||
// Run the appropriate backup type
|
||||
switch backupType {
|
||||
case "single":
|
||||
tracker.UpdateProgress(10, "Preparing single database backup...")
|
||||
err = engine.BackupSingle(ctx, databaseName)
|
||||
case "cluster":
|
||||
tracker.UpdateProgress(10, "Preparing cluster backup...")
|
||||
err = engine.BackupCluster(ctx)
|
||||
case "sample":
|
||||
tracker.UpdateProgress(10, "Preparing sample backup...")
|
||||
err = engine.BackupSample(ctx, databaseName)
|
||||
default:
|
||||
err = fmt.Errorf("unknown backup type: %s", backupType)
|
||||
}
|
||||
|
||||
// Update final status
|
||||
if err != nil {
|
||||
tracker.Fail(fmt.Sprintf("Backup failed: %v", err))
|
||||
return err
|
||||
} else {
|
||||
tracker.Complete(fmt.Sprintf("%s backup completed successfully", backupType))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
465
internal/tui/settings.go
Normal file
465
internal/tui/settings.go
Normal file
@ -0,0 +1,465 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
|
||||
"dbbackup/internal/config"
|
||||
"dbbackup/internal/logger"
|
||||
)
|
||||
|
||||
// SettingsModel represents the settings configuration state
|
||||
type SettingsModel struct {
|
||||
config *config.Config
|
||||
logger logger.Logger
|
||||
cursor int
|
||||
editing bool
|
||||
editingField string
|
||||
editingValue string
|
||||
settings []SettingItem
|
||||
quitting bool
|
||||
message string
|
||||
parent tea.Model
|
||||
}
|
||||
|
||||
// SettingItem represents a configurable setting
|
||||
type SettingItem struct {
|
||||
Key string
|
||||
DisplayName string
|
||||
Value func(*config.Config) string
|
||||
Update func(*config.Config, string) error
|
||||
Type string // "string", "int", "bool", "path"
|
||||
Description string
|
||||
}
|
||||
|
||||
// Initialize settings model
|
||||
func NewSettingsModel(cfg *config.Config, log logger.Logger, parent tea.Model) SettingsModel {
|
||||
settings := []SettingItem{
|
||||
{
|
||||
Key: "backup_dir",
|
||||
DisplayName: "Backup Directory",
|
||||
Value: func(c *config.Config) string { return c.BackupDir },
|
||||
Update: func(c *config.Config, v string) error {
|
||||
if v == "" {
|
||||
return fmt.Errorf("backup directory cannot be empty")
|
||||
}
|
||||
c.BackupDir = filepath.Clean(v)
|
||||
return nil
|
||||
},
|
||||
Type: "path",
|
||||
Description: "Directory where backup files will be stored",
|
||||
},
|
||||
{
|
||||
Key: "compression_level",
|
||||
DisplayName: "Compression Level",
|
||||
Value: func(c *config.Config) string { return fmt.Sprintf("%d", c.CompressionLevel) },
|
||||
Update: func(c *config.Config, v string) error {
|
||||
val, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("compression level must be a number")
|
||||
}
|
||||
if val < 0 || val > 9 {
|
||||
return fmt.Errorf("compression level must be between 0-9")
|
||||
}
|
||||
c.CompressionLevel = val
|
||||
return nil
|
||||
},
|
||||
Type: "int",
|
||||
Description: "Compression level (0=fastest, 9=smallest)",
|
||||
},
|
||||
{
|
||||
Key: "jobs",
|
||||
DisplayName: "Parallel Jobs",
|
||||
Value: func(c *config.Config) string { return fmt.Sprintf("%d", c.Jobs) },
|
||||
Update: func(c *config.Config, v string) error {
|
||||
val, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("jobs must be a number")
|
||||
}
|
||||
if val < 1 {
|
||||
return fmt.Errorf("jobs must be at least 1")
|
||||
}
|
||||
c.Jobs = val
|
||||
return nil
|
||||
},
|
||||
Type: "int",
|
||||
Description: "Number of parallel jobs for backup operations",
|
||||
},
|
||||
{
|
||||
Key: "dump_jobs",
|
||||
DisplayName: "Dump Jobs",
|
||||
Value: func(c *config.Config) string { return fmt.Sprintf("%d", c.DumpJobs) },
|
||||
Update: func(c *config.Config, v string) error {
|
||||
val, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dump jobs must be a number")
|
||||
}
|
||||
if val < 1 {
|
||||
return fmt.Errorf("dump jobs must be at least 1")
|
||||
}
|
||||
c.DumpJobs = val
|
||||
return nil
|
||||
},
|
||||
Type: "int",
|
||||
Description: "Number of parallel jobs for database dumps",
|
||||
},
|
||||
{
|
||||
Key: "host",
|
||||
DisplayName: "Database Host",
|
||||
Value: func(c *config.Config) string { return c.Host },
|
||||
Update: func(c *config.Config, v string) error {
|
||||
if v == "" {
|
||||
return fmt.Errorf("host cannot be empty")
|
||||
}
|
||||
c.Host = v
|
||||
return nil
|
||||
},
|
||||
Type: "string",
|
||||
Description: "Database server hostname or IP address",
|
||||
},
|
||||
{
|
||||
Key: "port",
|
||||
DisplayName: "Database Port",
|
||||
Value: func(c *config.Config) string { return fmt.Sprintf("%d", c.Port) },
|
||||
Update: func(c *config.Config, v string) error {
|
||||
val, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("port must be a number")
|
||||
}
|
||||
if val < 1 || val > 65535 {
|
||||
return fmt.Errorf("port must be between 1-65535")
|
||||
}
|
||||
c.Port = val
|
||||
return nil
|
||||
},
|
||||
Type: "int",
|
||||
Description: "Database server port number",
|
||||
},
|
||||
{
|
||||
Key: "user",
|
||||
DisplayName: "Database User",
|
||||
Value: func(c *config.Config) string { return c.User },
|
||||
Update: func(c *config.Config, v string) error {
|
||||
if v == "" {
|
||||
return fmt.Errorf("user cannot be empty")
|
||||
}
|
||||
c.User = v
|
||||
return nil
|
||||
},
|
||||
Type: "string",
|
||||
Description: "Database username for connections",
|
||||
},
|
||||
{
|
||||
Key: "database",
|
||||
DisplayName: "Default Database",
|
||||
Value: func(c *config.Config) string { return c.Database },
|
||||
Update: func(c *config.Config, v string) error {
|
||||
c.Database = v // Can be empty for cluster operations
|
||||
return nil
|
||||
},
|
||||
Type: "string",
|
||||
Description: "Default database name (optional)",
|
||||
},
|
||||
{
|
||||
Key: "ssl_mode",
|
||||
DisplayName: "SSL Mode",
|
||||
Value: func(c *config.Config) string { return c.SSLMode },
|
||||
Update: func(c *config.Config, v string) error {
|
||||
validModes := []string{"disable", "allow", "prefer", "require", "verify-ca", "verify-full"}
|
||||
for _, mode := range validModes {
|
||||
if v == mode {
|
||||
c.SSLMode = v
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("invalid SSL mode. Valid options: %s", strings.Join(validModes, ", "))
|
||||
},
|
||||
Type: "string",
|
||||
Description: "SSL connection mode (disable, allow, prefer, require, verify-ca, verify-full)",
|
||||
},
|
||||
{
|
||||
Key: "auto_detect_cores",
|
||||
DisplayName: "Auto Detect CPU Cores",
|
||||
Value: func(c *config.Config) string {
|
||||
if c.AutoDetectCores { return "true" } else { return "false" }
|
||||
},
|
||||
Update: func(c *config.Config, v string) error {
|
||||
val, err := strconv.ParseBool(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("must be true or false")
|
||||
}
|
||||
c.AutoDetectCores = val
|
||||
return nil
|
||||
},
|
||||
Type: "bool",
|
||||
Description: "Automatically detect and optimize for CPU cores",
|
||||
},
|
||||
}
|
||||
|
||||
return SettingsModel{
|
||||
config: cfg,
|
||||
logger: log,
|
||||
settings: settings,
|
||||
parent: parent,
|
||||
}
|
||||
}
|
||||
|
||||
// Init initializes the settings model
|
||||
func (m SettingsModel) Init() tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update handles messages
|
||||
func (m SettingsModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
if m.editing {
|
||||
return m.handleEditingInput(msg)
|
||||
}
|
||||
|
||||
switch msg.String() {
|
||||
case "ctrl+c", "q", "esc":
|
||||
m.quitting = true
|
||||
return m.parent, nil
|
||||
|
||||
case "up", "k":
|
||||
if m.cursor > 0 {
|
||||
m.cursor--
|
||||
}
|
||||
|
||||
case "down", "j":
|
||||
if m.cursor < len(m.settings)-1 {
|
||||
m.cursor++
|
||||
}
|
||||
|
||||
case "enter", " ":
|
||||
return m.startEditing()
|
||||
|
||||
case "r":
|
||||
return m.resetToDefaults()
|
||||
|
||||
case "s":
|
||||
return m.saveSettings()
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// handleEditingInput handles input when editing a setting
|
||||
func (m SettingsModel) handleEditingInput(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||
switch msg.String() {
|
||||
case "ctrl+c":
|
||||
m.quitting = true
|
||||
return m.parent, nil
|
||||
|
||||
case "esc":
|
||||
m.editing = false
|
||||
m.editingField = ""
|
||||
m.editingValue = ""
|
||||
m.message = ""
|
||||
return m, nil
|
||||
|
||||
case "enter":
|
||||
return m.saveEditedValue()
|
||||
|
||||
case "backspace":
|
||||
if len(m.editingValue) > 0 {
|
||||
m.editingValue = m.editingValue[:len(m.editingValue)-1]
|
||||
}
|
||||
|
||||
default:
|
||||
// Add character to editing value
|
||||
if len(msg.String()) == 1 {
|
||||
m.editingValue += msg.String()
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// startEditing begins editing a setting
|
||||
func (m SettingsModel) startEditing() (tea.Model, tea.Cmd) {
|
||||
if m.cursor >= len(m.settings) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
setting := m.settings[m.cursor]
|
||||
m.editing = true
|
||||
m.editingField = setting.Key
|
||||
m.editingValue = setting.Value(m.config)
|
||||
m.message = ""
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// saveEditedValue saves the currently edited value
|
||||
func (m SettingsModel) saveEditedValue() (tea.Model, tea.Cmd) {
|
||||
if m.editingField == "" {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Find the setting being edited
|
||||
var setting *SettingItem
|
||||
for i := range m.settings {
|
||||
if m.settings[i].Key == m.editingField {
|
||||
setting = &m.settings[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if setting == nil {
|
||||
m.message = errorStyle.Render("❌ Setting not found")
|
||||
m.editing = false
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Update the configuration
|
||||
if err := setting.Update(m.config, m.editingValue); err != nil {
|
||||
m.message = errorStyle.Render(fmt.Sprintf("❌ %s", err.Error()))
|
||||
return m, nil
|
||||
}
|
||||
|
||||
m.message = successStyle.Render(fmt.Sprintf("✅ Updated %s", setting.DisplayName))
|
||||
m.editing = false
|
||||
m.editingField = ""
|
||||
m.editingValue = ""
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// resetToDefaults resets configuration to default values
|
||||
func (m SettingsModel) resetToDefaults() (tea.Model, tea.Cmd) {
|
||||
newConfig := config.New()
|
||||
|
||||
// Copy important connection details
|
||||
newConfig.Host = m.config.Host
|
||||
newConfig.Port = m.config.Port
|
||||
newConfig.User = m.config.User
|
||||
newConfig.Database = m.config.Database
|
||||
newConfig.DatabaseType = m.config.DatabaseType
|
||||
|
||||
*m.config = *newConfig
|
||||
m.message = successStyle.Render("✅ Settings reset to defaults")
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// saveSettings validates and saves current settings
|
||||
func (m SettingsModel) saveSettings() (tea.Model, tea.Cmd) {
|
||||
if err := m.config.Validate(); err != nil {
|
||||
m.message = errorStyle.Render(fmt.Sprintf("❌ Validation failed: %s", err.Error()))
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Optimize CPU settings if auto-detect is enabled
|
||||
if m.config.AutoDetectCores {
|
||||
if err := m.config.OptimizeForCPU(); err != nil {
|
||||
m.message = errorStyle.Render(fmt.Sprintf("❌ CPU optimization failed: %s", err.Error()))
|
||||
return m, nil
|
||||
}
|
||||
}
|
||||
|
||||
m.message = successStyle.Render("✅ Settings validated and saved")
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// View renders the settings interface
|
||||
func (m SettingsModel) View() string {
|
||||
if m.quitting {
|
||||
return "Returning to main menu...\n"
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
|
||||
// Header
|
||||
header := titleStyle.Render("⚙️ Configuration Settings")
|
||||
b.WriteString(fmt.Sprintf("\n%s\n\n", header))
|
||||
|
||||
// Settings list
|
||||
for i, setting := range m.settings {
|
||||
cursor := " "
|
||||
value := setting.Value(m.config)
|
||||
|
||||
if m.cursor == i {
|
||||
cursor = ">"
|
||||
if m.editing && m.editingField == setting.Key {
|
||||
// Show editing interface
|
||||
editValue := m.editingValue
|
||||
if setting.Type == "bool" {
|
||||
editValue += " (true/false)"
|
||||
}
|
||||
line := fmt.Sprintf("%s %s: %s", cursor, setting.DisplayName, editValue)
|
||||
b.WriteString(selectedStyle.Render(line))
|
||||
b.WriteString(" ✏️")
|
||||
} else {
|
||||
line := fmt.Sprintf("%s %s: %s", cursor, setting.DisplayName, value)
|
||||
b.WriteString(selectedStyle.Render(line))
|
||||
}
|
||||
} else {
|
||||
line := fmt.Sprintf("%s %s: %s", cursor, setting.DisplayName, value)
|
||||
b.WriteString(menuStyle.Render(line))
|
||||
}
|
||||
b.WriteString("\n")
|
||||
|
||||
// Show description for selected item
|
||||
if m.cursor == i && !m.editing {
|
||||
desc := detailStyle.Render(fmt.Sprintf(" %s", setting.Description))
|
||||
b.WriteString(desc)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
// Message area
|
||||
if m.message != "" {
|
||||
b.WriteString("\n")
|
||||
b.WriteString(m.message)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
// Current configuration summary
|
||||
if !m.editing {
|
||||
b.WriteString("\n")
|
||||
b.WriteString(infoStyle.Render("📋 Current Configuration:"))
|
||||
b.WriteString("\n")
|
||||
|
||||
summary := []string{
|
||||
fmt.Sprintf("Database: %s@%s:%d", m.config.User, m.config.Host, m.config.Port),
|
||||
fmt.Sprintf("Backup Dir: %s", m.config.BackupDir),
|
||||
fmt.Sprintf("Compression: Level %d", m.config.CompressionLevel),
|
||||
fmt.Sprintf("Jobs: %d parallel, %d dump", m.config.Jobs, m.config.DumpJobs),
|
||||
}
|
||||
|
||||
for _, line := range summary {
|
||||
b.WriteString(detailStyle.Render(fmt.Sprintf(" %s", line)))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
// Footer with instructions
|
||||
var footer string
|
||||
if m.editing {
|
||||
footer = infoStyle.Render("\n⌨️ Type new value • Enter to save • Esc to cancel")
|
||||
} else {
|
||||
footer = infoStyle.Render("\n⌨️ ↑/↓ navigate • Enter to edit • 's' save • 'r' reset • 'q' back to menu")
|
||||
}
|
||||
b.WriteString(footer)
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// RunSettingsMenu starts the settings configuration interface
|
||||
func RunSettingsMenu(cfg *config.Config, log logger.Logger, parent tea.Model) error {
|
||||
m := NewSettingsModel(cfg, log, parent)
|
||||
p := tea.NewProgram(m, tea.WithAltScreen())
|
||||
|
||||
if _, err := p.Run(); err != nil {
|
||||
return fmt.Errorf("error running settings menu: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user