Initial commit: Database Backup Tool v1.1.0

- PostgreSQL and MySQL support
- Interactive TUI with fixed menu navigation
- Line-by-line progress display
- CPU-aware parallel processing
- Cross-platform build support
- Configuration settings menu
- Silent mode for TUI operations
This commit is contained in:
2025-10-22 19:27:38 +00:00
commit 9b3c3f2b1b
39 changed files with 6498 additions and 0 deletions

708
internal/backup/engine.go Normal file
View File

@ -0,0 +1,708 @@
package backup
import (
"bufio"
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"strconv"
"strings"
"time"
"dbbackup/internal/config"
"dbbackup/internal/database"
"dbbackup/internal/logger"
"dbbackup/internal/progress"
)
// Engine handles backup operations
type Engine struct {
cfg *config.Config
log logger.Logger
db database.Database
progress progress.Indicator
detailedReporter *progress.DetailedReporter
silent bool // Silent mode for TUI
}
// New creates a new backup engine
func New(cfg *config.Config, log logger.Logger, db database.Database) *Engine {
progressIndicator := progress.NewIndicator(true, "line") // Use line-by-line indicator
detailedReporter := progress.NewDetailedReporter(progressIndicator, &loggerAdapter{logger: log})
return &Engine{
cfg: cfg,
log: log,
db: db,
progress: progressIndicator,
detailedReporter: detailedReporter,
silent: false,
}
}
// NewWithProgress creates a new backup engine with a custom progress indicator
func NewWithProgress(cfg *config.Config, log logger.Logger, db database.Database, progressIndicator progress.Indicator) *Engine {
detailedReporter := progress.NewDetailedReporter(progressIndicator, &loggerAdapter{logger: log})
return &Engine{
cfg: cfg,
log: log,
db: db,
progress: progressIndicator,
detailedReporter: detailedReporter,
silent: false,
}
}
// NewSilent creates a new backup engine in silent mode (for TUI)
func NewSilent(cfg *config.Config, log logger.Logger, db database.Database, progressIndicator progress.Indicator) *Engine {
detailedReporter := progress.NewDetailedReporter(progressIndicator, &loggerAdapter{logger: log})
return &Engine{
cfg: cfg,
log: log,
db: db,
progress: progressIndicator,
detailedReporter: detailedReporter,
silent: true, // Silent mode enabled
}
}
// loggerAdapter adapts our logger to the progress.Logger interface
type loggerAdapter struct {
logger logger.Logger
}
func (la *loggerAdapter) Info(msg string, args ...any) {
la.logger.Info(msg, args...)
}
func (la *loggerAdapter) Warn(msg string, args ...any) {
la.logger.Warn(msg, args...)
}
func (la *loggerAdapter) Error(msg string, args ...any) {
la.logger.Error(msg, args...)
}
func (la *loggerAdapter) Debug(msg string, args ...any) {
la.logger.Debug(msg, args...)
}
// printf prints to stdout only if not in silent mode
func (e *Engine) printf(format string, args ...interface{}) {
if !e.silent {
fmt.Printf(format, args...)
}
}
// generateOperationID creates a unique operation ID
func generateOperationID() string {
bytes := make([]byte, 8)
rand.Read(bytes)
return hex.EncodeToString(bytes)
}
// BackupSingle performs a single database backup with detailed progress tracking
func (e *Engine) BackupSingle(ctx context.Context, databaseName string) error {
// Start detailed operation tracking
operationID := generateOperationID()
tracker := e.detailedReporter.StartOperation(operationID, databaseName, "backup")
// Add operation details
tracker.SetDetails("database", databaseName)
tracker.SetDetails("type", "single")
tracker.SetDetails("compression", strconv.Itoa(e.cfg.CompressionLevel))
tracker.SetDetails("format", "custom")
// Start preparing backup directory
prepStep := tracker.AddStep("prepare", "Preparing backup directory")
if err := os.MkdirAll(e.cfg.BackupDir, 0755); err != nil {
prepStep.Fail(fmt.Errorf("failed to create backup directory: %w", err))
tracker.Fail(fmt.Errorf("failed to create backup directory: %w", err))
return fmt.Errorf("failed to create backup directory: %w", err)
}
prepStep.Complete("Backup directory prepared")
tracker.UpdateProgress(10, "Backup directory prepared")
// Generate timestamp and filename
timestamp := time.Now().Format("20060102_150405")
var outputFile string
if e.cfg.IsPostgreSQL() {
outputFile = filepath.Join(e.cfg.BackupDir, fmt.Sprintf("db_%s_%s.dump", databaseName, timestamp))
} else {
outputFile = filepath.Join(e.cfg.BackupDir, fmt.Sprintf("db_%s_%s.sql.gz", databaseName, timestamp))
}
tracker.SetDetails("output_file", outputFile)
tracker.UpdateProgress(20, "Generated backup filename")
// Build backup command
cmdStep := tracker.AddStep("command", "Building backup command")
options := database.BackupOptions{
Compression: e.cfg.CompressionLevel,
Parallel: e.cfg.DumpJobs,
Format: "custom",
Blobs: true,
NoOwner: false,
NoPrivileges: false,
}
cmd := e.db.BuildBackupCommand(databaseName, outputFile, options)
cmdStep.Complete("Backup command prepared")
tracker.UpdateProgress(30, "Backup command prepared")
// Execute backup command with progress monitoring
execStep := tracker.AddStep("execute", "Executing database backup")
tracker.UpdateProgress(40, "Starting database backup...")
if err := e.executeCommandWithProgress(ctx, cmd, outputFile, tracker); err != nil {
execStep.Fail(fmt.Errorf("backup execution failed: %w", err))
tracker.Fail(fmt.Errorf("backup failed: %w", err))
return fmt.Errorf("backup failed: %w", err)
}
execStep.Complete("Database backup completed")
tracker.UpdateProgress(80, "Database backup completed")
// Verify backup file
verifyStep := tracker.AddStep("verify", "Verifying backup file")
if info, err := os.Stat(outputFile); err != nil {
verifyStep.Fail(fmt.Errorf("backup file not created: %w", err))
tracker.Fail(fmt.Errorf("backup file not created: %w", err))
return fmt.Errorf("backup file not created: %w", err)
} else {
size := formatBytes(info.Size())
tracker.SetDetails("file_size", size)
tracker.SetByteProgress(info.Size(), info.Size())
verifyStep.Complete(fmt.Sprintf("Backup file verified: %s", size))
tracker.UpdateProgress(90, fmt.Sprintf("Backup verified: %s", size))
}
// Create metadata file
metaStep := tracker.AddStep("metadata", "Creating metadata file")
if err := e.createMetadata(outputFile, databaseName, "single", ""); err != nil {
e.log.Warn("Failed to create metadata file", "error", err)
metaStep.Fail(fmt.Errorf("metadata creation failed: %w", err))
} else {
metaStep.Complete("Metadata file created")
}
// Complete operation
tracker.UpdateProgress(100, "Backup operation completed successfully")
tracker.Complete(fmt.Sprintf("Single database backup completed: %s", filepath.Base(outputFile)))
return nil
}
// BackupSample performs a sample database backup
func (e *Engine) BackupSample(ctx context.Context, databaseName string) error {
operation := e.log.StartOperation("Sample Database Backup")
// Ensure backup directory exists
if err := os.MkdirAll(e.cfg.BackupDir, 0755); err != nil {
operation.Fail("Failed to create backup directory")
return fmt.Errorf("failed to create backup directory: %w", err)
}
// Generate timestamp and filename
timestamp := time.Now().Format("20060102_150405")
outputFile := filepath.Join(e.cfg.BackupDir,
fmt.Sprintf("sample_%s_%s%d_%s.sql", databaseName, e.cfg.SampleStrategy, e.cfg.SampleValue, timestamp))
operation.Update("Starting sample database backup")
e.progress.Start(fmt.Sprintf("Creating sample backup of '%s' (%s=%d)", databaseName, e.cfg.SampleStrategy, e.cfg.SampleValue))
// For sample backups, we need to get the schema first, then sample data
if err := e.createSampleBackup(ctx, databaseName, outputFile); err != nil {
e.progress.Fail(fmt.Sprintf("Sample backup failed: %v", err))
operation.Fail("Sample backup failed")
return fmt.Errorf("sample backup failed: %w", err)
}
// Check output file
if info, err := os.Stat(outputFile); err != nil {
e.progress.Fail("Sample backup file not created")
operation.Fail("Sample backup file not found")
return fmt.Errorf("sample backup file not created: %w", err)
} else {
size := formatBytes(info.Size())
e.progress.Complete(fmt.Sprintf("Sample backup completed: %s (%s)", filepath.Base(outputFile), size))
operation.Complete(fmt.Sprintf("Sample backup created: %s (%s)", outputFile, size))
}
// Create metadata file
if err := e.createMetadata(outputFile, databaseName, "sample", e.cfg.SampleStrategy); err != nil {
e.log.Warn("Failed to create metadata file", "error", err)
}
return nil
}
// BackupCluster performs a full cluster backup (PostgreSQL only)
func (e *Engine) BackupCluster(ctx context.Context) error {
if !e.cfg.IsPostgreSQL() {
return fmt.Errorf("cluster backup is only supported for PostgreSQL")
}
operation := e.log.StartOperation("Cluster Backup")
// Use a quiet progress indicator to avoid duplicate messages
quietProgress := progress.NewQuietLineByLine()
quietProgress.Start("Starting cluster backup (all databases)")
// Ensure backup directory exists
if err := os.MkdirAll(e.cfg.BackupDir, 0755); err != nil {
operation.Fail("Failed to create backup directory")
quietProgress.Fail("Failed to create backup directory")
return fmt.Errorf("failed to create backup directory: %w", err)
}
// Generate timestamp and filename
timestamp := time.Now().Format("20060102_150405")
outputFile := filepath.Join(e.cfg.BackupDir, fmt.Sprintf("cluster_%s.tar.gz", timestamp))
tempDir := filepath.Join(e.cfg.BackupDir, fmt.Sprintf(".cluster_%s", timestamp))
operation.Update("Starting cluster backup")
// Create temporary directory
if err := os.MkdirAll(filepath.Join(tempDir, "dumps"), 0755); err != nil {
operation.Fail("Failed to create temporary directory")
quietProgress.Fail("Failed to create temporary directory")
return fmt.Errorf("failed to create temp directory: %w", err)
}
defer os.RemoveAll(tempDir)
// Backup globals
e.printf(" Backing up global objects...\n")
if err := e.backupGlobals(ctx, tempDir); err != nil {
quietProgress.Fail(fmt.Sprintf("Failed to backup globals: %v", err))
operation.Fail("Global backup failed")
return fmt.Errorf("failed to backup globals: %w", err)
}
// Get list of databases
e.printf(" Getting database list...\n")
databases, err := e.db.ListDatabases(ctx)
if err != nil {
quietProgress.Fail(fmt.Sprintf("Failed to list databases: %v", err))
operation.Fail("Database listing failed")
return fmt.Errorf("failed to list databases: %w", err)
}
// Backup each database
e.printf(" Backing up %d databases...\n", len(databases))
for i, dbName := range databases {
e.printf(" Backing up database %d/%d: %s\n", i+1, len(databases), dbName)
dumpFile := filepath.Join(tempDir, "dumps", dbName+".dump")
options := database.BackupOptions{
Compression: e.cfg.CompressionLevel,
Parallel: 1, // Individual dumps in cluster are not parallel
Format: "custom",
Blobs: true,
NoOwner: false,
NoPrivileges: false,
}
cmd := e.db.BuildBackupCommand(dbName, dumpFile, options)
if err := e.executeCommand(ctx, cmd, dumpFile); err != nil {
e.log.Warn("Failed to backup database", "database", dbName, "error", err)
// Continue with other databases
}
}
// Create archive
e.printf(" Creating compressed archive...\n")
if err := e.createArchive(tempDir, outputFile); err != nil {
quietProgress.Fail(fmt.Sprintf("Failed to create archive: %v", err))
operation.Fail("Archive creation failed")
return fmt.Errorf("failed to create archive: %w", err)
}
// Check output file
if info, err := os.Stat(outputFile); err != nil {
quietProgress.Fail("Cluster backup archive not created")
operation.Fail("Cluster backup archive not found")
return fmt.Errorf("cluster backup archive not created: %w", err)
} else {
size := formatBytes(info.Size())
quietProgress.Complete(fmt.Sprintf("Cluster backup completed: %s (%s)", filepath.Base(outputFile), size))
operation.Complete(fmt.Sprintf("Cluster backup created: %s (%s)", outputFile, size))
}
// Create metadata file
if err := e.createMetadata(outputFile, "cluster", "cluster", ""); err != nil {
e.log.Warn("Failed to create metadata file", "error", err)
}
return nil
}
// executeCommandWithProgress executes a backup command with real-time progress monitoring
func (e *Engine) executeCommandWithProgress(ctx context.Context, cmdArgs []string, outputFile string, tracker *progress.OperationTracker) error {
if len(cmdArgs) == 0 {
return fmt.Errorf("empty command")
}
e.log.Debug("Executing backup command with progress", "cmd", cmdArgs[0], "args", cmdArgs[1:])
cmd := exec.CommandContext(ctx, cmdArgs[0], cmdArgs[1:]...)
// Set environment variables for database tools
cmd.Env = os.Environ()
if e.cfg.Password != "" {
if e.cfg.IsPostgreSQL() {
cmd.Env = append(cmd.Env, "PGPASSWORD="+e.cfg.Password)
} else if e.cfg.IsMySQL() {
cmd.Env = append(cmd.Env, "MYSQL_PWD="+e.cfg.Password)
}
}
// For MySQL, handle compression and redirection differently
if e.cfg.IsMySQL() && e.cfg.CompressionLevel > 0 {
return e.executeMySQLWithProgressAndCompression(ctx, cmdArgs, outputFile, tracker)
}
// Get stderr pipe for progress monitoring
stderr, err := cmd.StderrPipe()
if err != nil {
return fmt.Errorf("failed to get stderr pipe: %w", err)
}
// Start the command
if err := cmd.Start(); err != nil {
return fmt.Errorf("failed to start command: %w", err)
}
// Monitor progress via stderr
go e.monitorCommandProgress(stderr, tracker)
// Wait for command to complete
if err := cmd.Wait(); err != nil {
return fmt.Errorf("backup command failed: %w", err)
}
return nil
}
// monitorCommandProgress monitors command output for progress information
func (e *Engine) monitorCommandProgress(stderr io.ReadCloser, tracker *progress.OperationTracker) {
defer stderr.Close()
scanner := bufio.NewScanner(stderr)
progressBase := 40 // Start from 40% since command preparation is done
progressIncrement := 0
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
e.log.Debug("Command output", "line", line)
// Increment progress gradually based on output
if progressBase < 75 {
progressIncrement++
if progressIncrement%5 == 0 { // Update every 5 lines
progressBase += 2
tracker.UpdateProgress(progressBase, "Processing data...")
}
}
// Look for specific progress indicators
if strings.Contains(line, "COPY") {
tracker.UpdateProgress(progressBase+5, "Copying table data...")
} else if strings.Contains(line, "completed") {
tracker.UpdateProgress(75, "Backup nearly complete...")
} else if strings.Contains(line, "done") {
tracker.UpdateProgress(78, "Finalizing backup...")
}
}
}
// executeMySQLWithProgressAndCompression handles MySQL backup with compression and progress
func (e *Engine) executeMySQLWithProgressAndCompression(ctx context.Context, cmdArgs []string, outputFile string, tracker *progress.OperationTracker) error {
// Create mysqldump command
dumpCmd := exec.CommandContext(ctx, cmdArgs[0], cmdArgs[1:]...)
dumpCmd.Env = os.Environ()
if e.cfg.Password != "" {
dumpCmd.Env = append(dumpCmd.Env, "MYSQL_PWD="+e.cfg.Password)
}
// Create gzip command
gzipCmd := exec.CommandContext(ctx, "gzip", fmt.Sprintf("-%d", e.cfg.CompressionLevel))
// Create output file
outFile, err := os.Create(outputFile)
if err != nil {
return fmt.Errorf("failed to create output file: %w", err)
}
defer outFile.Close()
// Set up pipeline: mysqldump | gzip > outputfile
pipe, err := dumpCmd.StdoutPipe()
if err != nil {
return fmt.Errorf("failed to create pipe: %w", err)
}
gzipCmd.Stdin = pipe
gzipCmd.Stdout = outFile
// Get stderr for progress monitoring
stderr, err := dumpCmd.StderrPipe()
if err != nil {
return fmt.Errorf("failed to get stderr pipe: %w", err)
}
// Start monitoring progress
go e.monitorCommandProgress(stderr, tracker)
// Start both commands
if err := gzipCmd.Start(); err != nil {
return fmt.Errorf("failed to start gzip: %w", err)
}
if err := dumpCmd.Start(); err != nil {
return fmt.Errorf("failed to start mysqldump: %w", err)
}
// Wait for mysqldump to complete
if err := dumpCmd.Wait(); err != nil {
return fmt.Errorf("mysqldump failed: %w", err)
}
// Close pipe and wait for gzip
pipe.Close()
if err := gzipCmd.Wait(); err != nil {
return fmt.Errorf("gzip failed: %w", err)
}
return nil
}
// executeMySQLWithCompression handles MySQL backup with compression
func (e *Engine) executeMySQLWithCompression(ctx context.Context, cmdArgs []string, outputFile string) error {
// Create mysqldump command
dumpCmd := exec.CommandContext(ctx, cmdArgs[0], cmdArgs[1:]...)
dumpCmd.Env = os.Environ()
if e.cfg.Password != "" {
dumpCmd.Env = append(dumpCmd.Env, "MYSQL_PWD="+e.cfg.Password)
}
// Create gzip command
gzipCmd := exec.CommandContext(ctx, "gzip", fmt.Sprintf("-%d", e.cfg.CompressionLevel))
// Create output file
outFile, err := os.Create(outputFile)
if err != nil {
return fmt.Errorf("failed to create output file: %w", err)
}
defer outFile.Close()
// Set up pipeline: mysqldump | gzip > outputfile
gzipCmd.Stdin, _ = dumpCmd.StdoutPipe()
gzipCmd.Stdout = outFile
// Start both commands
if err := gzipCmd.Start(); err != nil {
return fmt.Errorf("failed to start gzip: %w", err)
}
if err := dumpCmd.Run(); err != nil {
return fmt.Errorf("mysqldump failed: %w", err)
}
if err := gzipCmd.Wait(); err != nil {
return fmt.Errorf("gzip failed: %w", err)
}
return nil
}
// createSampleBackup creates a sample backup with reduced dataset
func (e *Engine) createSampleBackup(ctx context.Context, databaseName, outputFile string) error {
// This is a simplified implementation
// A full implementation would:
// 1. Export schema
// 2. Get list of tables
// 3. For each table, run sampling query
// 4. Combine into single SQL file
// For now, we'll use a simple approach with schema-only backup first
// Then add sample data
file, err := os.Create(outputFile)
if err != nil {
return fmt.Errorf("failed to create sample backup file: %w", err)
}
defer file.Close()
// Write header
fmt.Fprintf(file, "-- Sample Database Backup\n")
fmt.Fprintf(file, "-- Database: %s\n", databaseName)
fmt.Fprintf(file, "-- Strategy: %s = %d\n", e.cfg.SampleStrategy, e.cfg.SampleValue)
fmt.Fprintf(file, "-- Created: %s\n", time.Now().Format(time.RFC3339))
fmt.Fprintf(file, "-- WARNING: This backup may have referential integrity issues!\n\n")
// For PostgreSQL, we can use pg_dump --schema-only first
if e.cfg.IsPostgreSQL() {
// Get schema
schemaCmd := e.db.BuildBackupCommand(databaseName, "/dev/stdout", database.BackupOptions{
SchemaOnly: true,
Format: "plain",
})
cmd := exec.CommandContext(ctx, schemaCmd[0], schemaCmd[1:]...)
cmd.Env = os.Environ()
if e.cfg.Password != "" {
cmd.Env = append(cmd.Env, "PGPASSWORD="+e.cfg.Password)
}
cmd.Stdout = file
if err := cmd.Run(); err != nil {
return fmt.Errorf("failed to export schema: %w", err)
}
fmt.Fprintf(file, "\n-- Sample data follows\n\n")
// Get tables and sample data
tables, err := e.db.ListTables(ctx, databaseName)
if err != nil {
return fmt.Errorf("failed to list tables: %w", err)
}
strategy := database.SampleStrategy{
Type: e.cfg.SampleStrategy,
Value: e.cfg.SampleValue,
}
for _, table := range tables {
fmt.Fprintf(file, "-- Data for table: %s\n", table)
sampleQuery := e.db.BuildSampleQuery(databaseName, table, strategy)
fmt.Fprintf(file, "\\copy (%s) TO STDOUT\n\n", sampleQuery)
}
}
return nil
}
// backupGlobals creates a backup of global PostgreSQL objects
func (e *Engine) backupGlobals(ctx context.Context, tempDir string) error {
globalsFile := filepath.Join(tempDir, "globals.sql")
cmd := exec.CommandContext(ctx, "pg_dumpall", "--globals-only")
if e.cfg.Host != "localhost" {
cmd.Args = append(cmd.Args, "-h", e.cfg.Host, "-p", fmt.Sprintf("%d", e.cfg.Port))
}
cmd.Args = append(cmd.Args, "-U", e.cfg.User)
cmd.Env = os.Environ()
if e.cfg.Password != "" {
cmd.Env = append(cmd.Env, "PGPASSWORD="+e.cfg.Password)
}
output, err := cmd.Output()
if err != nil {
return fmt.Errorf("pg_dumpall failed: %w", err)
}
return os.WriteFile(globalsFile, output, 0644)
}
// createArchive creates a compressed tar archive
func (e *Engine) createArchive(sourceDir, outputFile string) error {
cmd := exec.Command("tar", "-czf", outputFile, "-C", sourceDir, ".")
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("tar failed: %w, output: %s", err, string(output))
}
return nil
}
// createMetadata creates a metadata file for the backup
func (e *Engine) createMetadata(backupFile, database, backupType, strategy string) error {
metaFile := backupFile + ".info"
content := fmt.Sprintf(`{
"type": "%s",
"database": "%s",
"timestamp": "%s",
"host": "%s",
"port": %d,
"user": "%s",
"db_type": "%s",
"compression": %d`,
backupType, database, time.Now().Format("20060102_150405"),
e.cfg.Host, e.cfg.Port, e.cfg.User, e.cfg.DatabaseType, e.cfg.CompressionLevel)
if strategy != "" {
content += fmt.Sprintf(`,
"sample_strategy": "%s",
"sample_value": %d`, e.cfg.SampleStrategy, e.cfg.SampleValue)
}
if info, err := os.Stat(backupFile); err == nil {
content += fmt.Sprintf(`,
"size_bytes": %d`, info.Size())
}
content += "\n}"
return os.WriteFile(metaFile, []byte(content), 0644)
}
// executeCommand executes a backup command (simplified version for cluster backups)
func (e *Engine) executeCommand(ctx context.Context, cmdArgs []string, outputFile string) error {
if len(cmdArgs) == 0 {
return fmt.Errorf("empty command")
}
e.log.Debug("Executing backup command", "cmd", cmdArgs[0], "args", cmdArgs[1:])
cmd := exec.CommandContext(ctx, cmdArgs[0], cmdArgs[1:]...)
// Set environment variables for database tools
cmd.Env = os.Environ()
if e.cfg.Password != "" {
if e.cfg.IsPostgreSQL() {
cmd.Env = append(cmd.Env, "PGPASSWORD="+e.cfg.Password)
} else if e.cfg.IsMySQL() {
cmd.Env = append(cmd.Env, "MYSQL_PWD="+e.cfg.Password)
}
}
// For MySQL, handle compression differently
if e.cfg.IsMySQL() && e.cfg.CompressionLevel > 0 {
return e.executeMySQLWithCompression(ctx, cmdArgs, outputFile)
}
// Run the command
output, err := cmd.CombinedOutput()
if err != nil {
e.log.Error("Backup command failed", "error", err, "output", string(output))
return fmt.Errorf("backup command failed: %w, output: %s", err, string(output))
}
return nil
}
// formatBytes formats byte count in human-readable format
func formatBytes(bytes int64) string {
const unit = 1024
if bytes < unit {
return fmt.Sprintf("%d B", bytes)
}
div, exp := int64(unit), 0
for n := bytes / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(div), "KMGTPE"[exp])
}

319
internal/config/config.go Normal file
View File

@ -0,0 +1,319 @@
package config
import (
"os"
"path/filepath"
"runtime"
"strconv"
"dbbackup/internal/cpu"
)
// Config holds all configuration options
type Config struct {
// Version information
Version string
BuildTime string
GitCommit string
// Database connection
Host string
Port int
User string
Database string
Password string
DatabaseType string // "postgres" or "mysql"
SSLMode string
Insecure bool
// Backup options
BackupDir string
CompressionLevel int
Jobs int
DumpJobs int
MaxCores int
AutoDetectCores bool
CPUWorkloadType string // "cpu-intensive", "io-intensive", "balanced"
// CPU detection
CPUDetector *cpu.Detector
CPUInfo *cpu.CPUInfo
// Sample backup options
SampleStrategy string // "ratio", "percent", "count"
SampleValue int
// Output options
NoColor bool
Debug bool
LogLevel string
LogFormat string
OutputLength int
// Single database backup/restore
SingleDBName string
RestoreDBName string
}
// New creates a new configuration with default values
func New() *Config {
// Get default backup directory
backupDir := getEnvString("BACKUP_DIR", getDefaultBackupDir())
// Initialize CPU detector
cpuDetector := cpu.NewDetector()
cpuInfo, _ := cpuDetector.DetectCPU()
return &Config{
// Database defaults
Host: getEnvString("PG_HOST", "localhost"),
Port: getEnvInt("PG_PORT", 5432),
User: getEnvString("PG_USER", getCurrentUser()),
Database: getEnvString("PG_DATABASE", "postgres"),
Password: getEnvString("PGPASSWORD", ""),
DatabaseType: getEnvString("DB_TYPE", "postgres"),
SSLMode: getEnvString("PG_SSLMODE", "prefer"),
Insecure: getEnvBool("INSECURE", false),
// Backup defaults
BackupDir: backupDir,
CompressionLevel: getEnvInt("COMPRESS_LEVEL", 6),
Jobs: getEnvInt("JOBS", getDefaultJobs(cpuInfo)),
DumpJobs: getEnvInt("DUMP_JOBS", getDefaultDumpJobs(cpuInfo)),
MaxCores: getEnvInt("MAX_CORES", getDefaultMaxCores(cpuInfo)),
AutoDetectCores: getEnvBool("AUTO_DETECT_CORES", true),
CPUWorkloadType: getEnvString("CPU_WORKLOAD_TYPE", "balanced"),
// CPU detection
CPUDetector: cpuDetector,
CPUInfo: cpuInfo,
// Sample backup defaults
SampleStrategy: getEnvString("SAMPLE_STRATEGY", "ratio"),
SampleValue: getEnvInt("SAMPLE_VALUE", 10),
// Output defaults
NoColor: getEnvBool("NO_COLOR", false),
Debug: getEnvBool("DEBUG", false),
LogLevel: getEnvString("LOG_LEVEL", "info"),
LogFormat: getEnvString("LOG_FORMAT", "text"),
OutputLength: getEnvInt("OUTPUT_LENGTH", 0),
// Single database options
SingleDBName: getEnvString("SINGLE_DB_NAME", ""),
RestoreDBName: getEnvString("RESTORE_DB_NAME", ""),
}
}
// UpdateFromEnvironment updates configuration from environment variables
func (c *Config) UpdateFromEnvironment() {
if password := os.Getenv("PGPASSWORD"); password != "" {
c.Password = password
}
if password := os.Getenv("MYSQL_PWD"); password != "" && c.DatabaseType == "mysql" {
c.Password = password
}
}
// Validate validates the configuration
func (c *Config) Validate() error {
if c.DatabaseType != "postgres" && c.DatabaseType != "mysql" {
return &ConfigError{Field: "database-type", Value: c.DatabaseType, Message: "must be 'postgres' or 'mysql'"}
}
if c.CompressionLevel < 0 || c.CompressionLevel > 9 {
return &ConfigError{Field: "compression", Value: string(rune(c.CompressionLevel)), Message: "must be between 0-9"}
}
if c.Jobs < 1 {
return &ConfigError{Field: "jobs", Value: string(rune(c.Jobs)), Message: "must be at least 1"}
}
if c.DumpJobs < 1 {
return &ConfigError{Field: "dump-jobs", Value: string(rune(c.DumpJobs)), Message: "must be at least 1"}
}
return nil
}
// IsPostgreSQL returns true if database type is PostgreSQL
func (c *Config) IsPostgreSQL() bool {
return c.DatabaseType == "postgres"
}
// IsMySQL returns true if database type is MySQL
func (c *Config) IsMySQL() bool {
return c.DatabaseType == "mysql"
}
// GetDefaultPort returns the default port for the database type
func (c *Config) GetDefaultPort() int {
if c.IsMySQL() {
return 3306
}
return 5432
}
// OptimizeForCPU optimizes job settings based on detected CPU
func (c *Config) OptimizeForCPU() error {
if c.CPUDetector == nil {
c.CPUDetector = cpu.NewDetector()
}
if c.CPUInfo == nil {
info, err := c.CPUDetector.DetectCPU()
if err != nil {
return err
}
c.CPUInfo = info
}
if c.AutoDetectCores {
// Optimize jobs based on workload type
if jobs, err := c.CPUDetector.CalculateOptimalJobs(c.CPUWorkloadType, c.MaxCores); err == nil {
c.Jobs = jobs
}
// Optimize dump jobs (more conservative for database dumps)
if dumpJobs, err := c.CPUDetector.CalculateOptimalJobs("cpu-intensive", c.MaxCores/2); err == nil {
c.DumpJobs = dumpJobs
if c.DumpJobs > 8 {
c.DumpJobs = 8 // Conservative limit for dumps
}
}
}
return nil
}
// GetCPUInfo returns CPU information, detecting if necessary
func (c *Config) GetCPUInfo() (*cpu.CPUInfo, error) {
if c.CPUInfo != nil {
return c.CPUInfo, nil
}
if c.CPUDetector == nil {
c.CPUDetector = cpu.NewDetector()
}
info, err := c.CPUDetector.DetectCPU()
if err != nil {
return nil, err
}
c.CPUInfo = info
return info, nil
}
// ConfigError represents a configuration validation error
type ConfigError struct {
Field string
Value string
Message string
}
func (e *ConfigError) Error() string {
return "config error in field '" + e.Field + "' with value '" + e.Value + "': " + e.Message
}
// Helper functions
func getEnvString(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}
func getEnvInt(key string, defaultValue int) int {
if value := os.Getenv(key); value != "" {
if i, err := strconv.Atoi(value); err == nil {
return i
}
}
return defaultValue
}
func getEnvBool(key string, defaultValue bool) bool {
if value := os.Getenv(key); value != "" {
if b, err := strconv.ParseBool(value); err == nil {
return b
}
}
return defaultValue
}
func getCurrentUser() string {
if user := os.Getenv("USER"); user != "" {
return user
}
if user := os.Getenv("USERNAME"); user != "" {
return user
}
return "postgres"
}
func getDefaultBackupDir() string {
// Try to create a sensible default backup directory
homeDir, _ := os.UserHomeDir()
if homeDir != "" {
return filepath.Join(homeDir, "db_backups")
}
// Fallback based on OS
if runtime.GOOS == "windows" {
return "C:\\db_backups"
}
// For PostgreSQL user on Linux/Unix
if getCurrentUser() == "postgres" {
return "/var/lib/pgsql/pg_backups"
}
return "/tmp/db_backups"
}
// CPU-related helper functions
func getDefaultJobs(cpuInfo *cpu.CPUInfo) int {
if cpuInfo == nil {
return 1
}
// Default to logical cores for restore operations
jobs := cpuInfo.LogicalCores
if jobs < 1 {
jobs = 1
}
if jobs > 16 {
jobs = 16 // Safety limit
}
return jobs
}
func getDefaultDumpJobs(cpuInfo *cpu.CPUInfo) int {
if cpuInfo == nil {
return 1
}
// Use physical cores for dump operations (CPU intensive)
jobs := cpuInfo.PhysicalCores
if jobs < 1 {
jobs = 1
}
if jobs > 8 {
jobs = 8 // Conservative limit for dumps
}
return jobs
}
func getDefaultMaxCores(cpuInfo *cpu.CPUInfo) int {
if cpuInfo == nil {
return 16
}
// Set max cores to 2x logical cores, with reasonable upper limit
maxCores := cpuInfo.LogicalCores * 2
if maxCores < 4 {
maxCores = 4
}
if maxCores > 64 {
maxCores = 64
}
return maxCores
}

346
internal/cpu/detection.go Normal file
View File

@ -0,0 +1,346 @@
package cpu
import (
"fmt"
"runtime"
"strconv"
"strings"
"os"
"os/exec"
"bufio"
)
// CPUInfo holds information about the system CPU
type CPUInfo struct {
LogicalCores int `json:"logical_cores"`
PhysicalCores int `json:"physical_cores"`
Architecture string `json:"architecture"`
ModelName string `json:"model_name"`
MaxFrequency float64 `json:"max_frequency_mhz"`
CacheSize string `json:"cache_size"`
Vendor string `json:"vendor"`
Features []string `json:"features"`
}
// Detector provides CPU detection functionality
type Detector struct {
info *CPUInfo
}
// NewDetector creates a new CPU detector
func NewDetector() *Detector {
return &Detector{}
}
// DetectCPU detects CPU information for the current system
func (d *Detector) DetectCPU() (*CPUInfo, error) {
if d.info != nil {
return d.info, nil
}
info := &CPUInfo{
LogicalCores: runtime.NumCPU(),
Architecture: runtime.GOARCH,
}
// Platform-specific detection
switch runtime.GOOS {
case "linux":
if err := d.detectLinux(info); err != nil {
return info, fmt.Errorf("linux CPU detection failed: %w", err)
}
case "darwin":
if err := d.detectDarwin(info); err != nil {
return info, fmt.Errorf("darwin CPU detection failed: %w", err)
}
case "windows":
if err := d.detectWindows(info); err != nil {
return info, fmt.Errorf("windows CPU detection failed: %w", err)
}
default:
// Fallback for unsupported platforms
info.PhysicalCores = info.LogicalCores
info.ModelName = "Unknown"
info.Vendor = "Unknown"
}
d.info = info
return info, nil
}
// detectLinux detects CPU information on Linux systems
func (d *Detector) detectLinux(info *CPUInfo) error {
file, err := os.Open("/proc/cpuinfo")
if err != nil {
return err
}
defer file.Close()
scanner := bufio.NewScanner(file)
physicalCoreCount := make(map[string]bool)
for scanner.Scan() {
line := scanner.Text()
if strings.TrimSpace(line) == "" {
continue
}
parts := strings.SplitN(line, ":", 2)
if len(parts) != 2 {
continue
}
key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])
switch key {
case "model name":
if info.ModelName == "" {
info.ModelName = value
}
case "vendor_id":
if info.Vendor == "" {
info.Vendor = value
}
case "cpu MHz":
if freq, err := strconv.ParseFloat(value, 64); err == nil && info.MaxFrequency < freq {
info.MaxFrequency = freq
}
case "cache size":
if info.CacheSize == "" {
info.CacheSize = value
}
case "flags", "Features":
if len(info.Features) == 0 {
info.Features = strings.Fields(value)
}
case "physical id":
physicalCoreCount[value] = true
}
}
// Calculate physical cores
if len(physicalCoreCount) > 0 {
info.PhysicalCores = len(physicalCoreCount)
} else {
// Fallback: assume hyperthreading if logical > 1
info.PhysicalCores = info.LogicalCores
if info.LogicalCores > 1 {
info.PhysicalCores = info.LogicalCores / 2
}
}
// Try to get more accurate physical core count from lscpu
if cmd := exec.Command("lscpu"); cmd != nil {
if output, err := cmd.Output(); err == nil {
d.parseLscpu(string(output), info)
}
}
return scanner.Err()
}
// parseLscpu parses lscpu output for more accurate CPU information
func (d *Detector) parseLscpu(output string, info *CPUInfo) {
lines := strings.Split(output, "\n")
for _, line := range lines {
if strings.TrimSpace(line) == "" {
continue
}
parts := strings.SplitN(line, ":", 2)
if len(parts) != 2 {
continue
}
key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])
switch key {
case "Core(s) per socket":
if cores, err := strconv.Atoi(value); err == nil {
if sockets := d.getSocketCount(output); sockets > 0 {
info.PhysicalCores = cores * sockets
}
}
case "Model name":
info.ModelName = value
case "Vendor ID":
info.Vendor = value
case "CPU max MHz":
if freq, err := strconv.ParseFloat(value, 64); err == nil {
info.MaxFrequency = freq
}
}
}
}
// getSocketCount extracts socket count from lscpu output
func (d *Detector) getSocketCount(output string) int {
lines := strings.Split(output, "\n")
for _, line := range lines {
if strings.Contains(line, "Socket(s):") {
parts := strings.SplitN(line, ":", 2)
if len(parts) == 2 {
if sockets, err := strconv.Atoi(strings.TrimSpace(parts[1])); err == nil {
return sockets
}
}
}
}
return 1 // Default to 1 socket
}
// detectDarwin detects CPU information on macOS systems
func (d *Detector) detectDarwin(info *CPUInfo) error {
// Get CPU brand
if output, err := exec.Command("sysctl", "-n", "machdep.cpu.brand_string").Output(); err == nil {
info.ModelName = strings.TrimSpace(string(output))
}
// Get physical cores
if output, err := exec.Command("sysctl", "-n", "hw.physicalcpu").Output(); err == nil {
if cores, err := strconv.Atoi(strings.TrimSpace(string(output))); err == nil {
info.PhysicalCores = cores
}
}
// Get max frequency
if output, err := exec.Command("sysctl", "-n", "hw.cpufrequency_max").Output(); err == nil {
if freq, err := strconv.ParseFloat(strings.TrimSpace(string(output)), 64); err == nil {
info.MaxFrequency = freq / 1000000 // Convert Hz to MHz
}
}
// Get vendor
if output, err := exec.Command("sysctl", "-n", "machdep.cpu.vendor").Output(); err == nil {
info.Vendor = strings.TrimSpace(string(output))
}
// Get cache size
if output, err := exec.Command("sysctl", "-n", "hw.l3cachesize").Output(); err == nil {
if cache, err := strconv.Atoi(strings.TrimSpace(string(output))); err == nil {
info.CacheSize = fmt.Sprintf("%d KB", cache/1024)
}
}
return nil
}
// detectWindows detects CPU information on Windows systems
func (d *Detector) detectWindows(info *CPUInfo) error {
// Use wmic to get CPU information
cmd := exec.Command("wmic", "cpu", "get", "Name,NumberOfCores,NumberOfLogicalProcessors,MaxClockSpeed", "/format:list")
output, err := cmd.Output()
if err != nil {
return err
}
lines := strings.Split(string(output), "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" {
continue
}
parts := strings.SplitN(line, "=", 2)
if len(parts) != 2 {
continue
}
key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])
switch key {
case "Name":
if value != "" {
info.ModelName = value
}
case "NumberOfCores":
if cores, err := strconv.Atoi(value); err == nil {
info.PhysicalCores = cores
}
case "MaxClockSpeed":
if freq, err := strconv.ParseFloat(value, 64); err == nil {
info.MaxFrequency = freq
}
}
}
// Get vendor information
cmd = exec.Command("wmic", "cpu", "get", "Manufacturer", "/format:list")
if output, err := cmd.Output(); err == nil {
lines := strings.Split(string(output), "\n")
for _, line := range lines {
if strings.HasPrefix(line, "Manufacturer=") {
info.Vendor = strings.TrimSpace(strings.SplitN(line, "=", 2)[1])
break
}
}
}
return nil
}
// CalculateOptimalJobs calculates optimal job count based on CPU info and workload type
func (d *Detector) CalculateOptimalJobs(workloadType string, maxJobs int) (int, error) {
info, err := d.DetectCPU()
if err != nil {
return 1, err
}
var optimal int
switch workloadType {
case "cpu-intensive":
// For CPU-intensive tasks, use physical cores
optimal = info.PhysicalCores
case "io-intensive":
// For I/O intensive tasks, can use more jobs than cores
optimal = info.LogicalCores * 2
case "balanced":
// Balanced workload, use logical cores
optimal = info.LogicalCores
default:
optimal = info.LogicalCores
}
// Apply safety limits
if optimal < 1 {
optimal = 1
}
if maxJobs > 0 && optimal > maxJobs {
optimal = maxJobs
}
return optimal, nil
}
// GetCPUInfo returns the detected CPU information
func (d *Detector) GetCPUInfo() *CPUInfo {
return d.info
}
// FormatCPUInfo returns a formatted string representation of CPU info
func (info *CPUInfo) FormatCPUInfo() string {
var sb strings.Builder
sb.WriteString(fmt.Sprintf("Architecture: %s\n", info.Architecture))
sb.WriteString(fmt.Sprintf("Logical Cores: %d\n", info.LogicalCores))
sb.WriteString(fmt.Sprintf("Physical Cores: %d\n", info.PhysicalCores))
if info.ModelName != "" {
sb.WriteString(fmt.Sprintf("Model: %s\n", info.ModelName))
}
if info.Vendor != "" {
sb.WriteString(fmt.Sprintf("Vendor: %s\n", info.Vendor))
}
if info.MaxFrequency > 0 {
sb.WriteString(fmt.Sprintf("Max Frequency: %.2f MHz\n", info.MaxFrequency))
}
if info.CacheSize != "" {
sb.WriteString(fmt.Sprintf("Cache Size: %s\n", info.CacheSize))
}
return sb.String()
}

View File

@ -0,0 +1,133 @@
package database
import (
"context"
"database/sql"
"fmt"
"time"
"dbbackup/internal/config"
"dbbackup/internal/logger"
_ "github.com/lib/pq" // PostgreSQL driver
_ "github.com/go-sql-driver/mysql" // MySQL driver
)
// Database represents a database connection and operations
type Database interface {
// Connection management
Connect(ctx context.Context) error
Close() error
Ping(ctx context.Context) error
// Database discovery
ListDatabases(ctx context.Context) ([]string, error)
ListTables(ctx context.Context, database string) ([]string, error)
// Database operations
CreateDatabase(ctx context.Context, name string) error
DropDatabase(ctx context.Context, name string) error
DatabaseExists(ctx context.Context, name string) (bool, error)
// Information
GetVersion(ctx context.Context) (string, error)
GetDatabaseSize(ctx context.Context, database string) (int64, error)
GetTableRowCount(ctx context.Context, database, table string) (int64, error)
// Backup/Restore command building
BuildBackupCommand(database, outputFile string, options BackupOptions) []string
BuildRestoreCommand(database, inputFile string, options RestoreOptions) []string
BuildSampleQuery(database, table string, strategy SampleStrategy) string
// Validation
ValidateBackupTools() error
}
// BackupOptions holds options for backup operations
type BackupOptions struct {
Compression int
Parallel int
Format string // "custom", "plain", "directory"
Blobs bool
SchemaOnly bool
DataOnly bool
NoOwner bool
NoPrivileges bool
Clean bool
IfExists bool
Role string
}
// RestoreOptions holds options for restore operations
type RestoreOptions struct {
Parallel int
Clean bool
IfExists bool
NoOwner bool
NoPrivileges bool
SingleTransaction bool
}
// SampleStrategy defines how to sample data
type SampleStrategy struct {
Type string // "ratio", "percent", "count"
Value int
}
// DatabaseInfo holds database metadata
type DatabaseInfo struct {
Name string
Size int64
Owner string
Encoding string
Collation string
Tables []TableInfo
}
// TableInfo holds table metadata
type TableInfo struct {
Schema string
Name string
RowCount int64
Size int64
}
// New creates a new database instance based on configuration
func New(cfg *config.Config, log logger.Logger) (Database, error) {
if cfg.IsPostgreSQL() {
return NewPostgreSQL(cfg, log), nil
} else if cfg.IsMySQL() {
return NewMySQL(cfg, log), nil
}
return nil, fmt.Errorf("unsupported database type: %s", cfg.DatabaseType)
}
// Common database implementation
type baseDatabase struct {
cfg *config.Config
log logger.Logger
db *sql.DB
dsn string
}
func (b *baseDatabase) Close() error {
if b.db != nil {
return b.db.Close()
}
return nil
}
func (b *baseDatabase) Ping(ctx context.Context) error {
if b.db == nil {
return fmt.Errorf("database not connected")
}
return b.db.PingContext(ctx)
}
// buildTimeout creates a context with timeout for database operations
func buildTimeout(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
if timeout <= 0 {
timeout = 30 * time.Second
}
return context.WithTimeout(ctx, timeout)
}

410
internal/database/mysql.go Normal file
View File

@ -0,0 +1,410 @@
package database
import (
"context"
"database/sql"
"fmt"
"os/exec"
"strconv"
"strings"
"dbbackup/internal/config"
"dbbackup/internal/logger"
)
// MySQL implements Database interface for MySQL
type MySQL struct {
baseDatabase
}
// NewMySQL creates a new MySQL database instance
func NewMySQL(cfg *config.Config, log logger.Logger) *MySQL {
return &MySQL{
baseDatabase: baseDatabase{
cfg: cfg,
log: log,
},
}
}
// Connect establishes a connection to MySQL
func (m *MySQL) Connect(ctx context.Context) error {
// Build MySQL DSN
dsn := m.buildDSN()
m.dsn = dsn
m.log.Debug("Connecting to MySQL", "dsn", sanitizeMySQLDSN(dsn))
db, err := sql.Open("mysql", dsn)
if err != nil {
return fmt.Errorf("failed to open MySQL connection: %w", err)
}
// Configure connection pool
db.SetMaxOpenConns(10)
db.SetMaxIdleConns(5)
db.SetConnMaxLifetime(0)
// Test connection
timeoutCtx, cancel := buildTimeout(ctx, 0)
defer cancel()
if err := db.PingContext(timeoutCtx); err != nil {
db.Close()
return fmt.Errorf("failed to ping MySQL: %w", err)
}
m.db = db
m.log.Info("Connected to MySQL successfully")
return nil
}
// ListDatabases returns list of databases (excluding system databases)
func (m *MySQL) ListDatabases(ctx context.Context) ([]string, error) {
if m.db == nil {
return nil, fmt.Errorf("not connected to database")
}
query := `SHOW DATABASES`
rows, err := m.db.QueryContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to query databases: %w", err)
}
defer rows.Close()
var databases []string
systemDbs := map[string]bool{
"information_schema": true,
"performance_schema": true,
"mysql": true,
"sys": true,
}
for rows.Next() {
var name string
if err := rows.Scan(&name); err != nil {
return nil, fmt.Errorf("failed to scan database name: %w", err)
}
// Skip system databases
if !systemDbs[name] {
databases = append(databases, name)
}
}
return databases, rows.Err()
}
// ListTables returns list of tables in a database
func (m *MySQL) ListTables(ctx context.Context, database string) ([]string, error) {
if m.db == nil {
return nil, fmt.Errorf("not connected to database")
}
query := `SELECT table_name FROM information_schema.tables
WHERE table_schema = ? AND table_type = 'BASE TABLE'
ORDER BY table_name`
rows, err := m.db.QueryContext(ctx, query, database)
if err != nil {
return nil, fmt.Errorf("failed to query tables: %w", err)
}
defer rows.Close()
var tables []string
for rows.Next() {
var name string
if err := rows.Scan(&name); err != nil {
return nil, fmt.Errorf("failed to scan table name: %w", err)
}
tables = append(tables, name)
}
return tables, rows.Err()
}
// CreateDatabase creates a new database
func (m *MySQL) CreateDatabase(ctx context.Context, name string) error {
if m.db == nil {
return fmt.Errorf("not connected to database")
}
query := fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `%s`", name)
_, err := m.db.ExecContext(ctx, query)
if err != nil {
return fmt.Errorf("failed to create database %s: %w", name, err)
}
m.log.Info("Created database", "name", name)
return nil
}
// DropDatabase drops a database
func (m *MySQL) DropDatabase(ctx context.Context, name string) error {
if m.db == nil {
return fmt.Errorf("not connected to database")
}
query := fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", name)
_, err := m.db.ExecContext(ctx, query)
if err != nil {
return fmt.Errorf("failed to drop database %s: %w", name, err)
}
m.log.Info("Dropped database", "name", name)
return nil
}
// DatabaseExists checks if a database exists
func (m *MySQL) DatabaseExists(ctx context.Context, name string) (bool, error) {
if m.db == nil {
return false, fmt.Errorf("not connected to database")
}
query := `SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?`
var dbName string
err := m.db.QueryRowContext(ctx, query, name).Scan(&dbName)
if err == sql.ErrNoRows {
return false, nil
}
if err != nil {
return false, fmt.Errorf("failed to check database existence: %w", err)
}
return true, nil
}
// GetVersion returns MySQL version
func (m *MySQL) GetVersion(ctx context.Context) (string, error) {
if m.db == nil {
return "", fmt.Errorf("not connected to database")
}
var version string
err := m.db.QueryRowContext(ctx, "SELECT VERSION()").Scan(&version)
if err != nil {
return "", fmt.Errorf("failed to get version: %w", err)
}
return version, nil
}
// GetDatabaseSize returns database size in bytes
func (m *MySQL) GetDatabaseSize(ctx context.Context, database string) (int64, error) {
if m.db == nil {
return 0, fmt.Errorf("not connected to database")
}
query := `SELECT COALESCE(SUM(data_length + index_length), 0) as size_bytes
FROM information_schema.tables
WHERE table_schema = ?`
var size int64
err := m.db.QueryRowContext(ctx, query, database).Scan(&size)
if err != nil {
return 0, fmt.Errorf("failed to get database size: %w", err)
}
return size, nil
}
// GetTableRowCount returns row count for a table
func (m *MySQL) GetTableRowCount(ctx context.Context, database, table string) (int64, error) {
if m.db == nil {
return 0, fmt.Errorf("not connected to database")
}
// First try information_schema for approximate count (faster)
query := `SELECT table_rows FROM information_schema.tables
WHERE table_schema = ? AND table_name = ?`
var count int64
err := m.db.QueryRowContext(ctx, query, database, table).Scan(&count)
if err != nil || count == 0 {
// Fallback to exact count
exactQuery := fmt.Sprintf("SELECT COUNT(*) FROM `%s`.`%s`", database, table)
err = m.db.QueryRowContext(ctx, exactQuery).Scan(&count)
if err != nil {
return 0, fmt.Errorf("failed to get table row count: %w", err)
}
}
return count, nil
}
// BuildBackupCommand builds mysqldump command
func (m *MySQL) BuildBackupCommand(database, outputFile string, options BackupOptions) []string {
cmd := []string{"mysqldump"}
// Connection parameters
cmd = append(cmd, "-h", m.cfg.Host)
cmd = append(cmd, "-P", strconv.Itoa(m.cfg.Port))
cmd = append(cmd, "-u", m.cfg.User)
if m.cfg.Password != "" {
cmd = append(cmd, "-p"+m.cfg.Password)
}
// SSL options
if m.cfg.Insecure {
cmd = append(cmd, "--skip-ssl")
} else if m.cfg.SSLMode != "" {
// MySQL SSL modes: DISABLED, PREFERRED, REQUIRED, VERIFY_CA, VERIFY_IDENTITY
switch strings.ToLower(m.cfg.SSLMode) {
case "disable", "disabled":
cmd = append(cmd, "--skip-ssl")
case "require", "required":
cmd = append(cmd, "--ssl-mode=REQUIRED")
case "verify-ca":
cmd = append(cmd, "--ssl-mode=VERIFY_CA")
case "verify-full", "verify-identity":
cmd = append(cmd, "--ssl-mode=VERIFY_IDENTITY")
default:
cmd = append(cmd, "--ssl-mode=PREFERRED")
}
}
// Backup options
cmd = append(cmd, "--single-transaction") // Consistent backup
cmd = append(cmd, "--routines") // Include stored procedures/functions
cmd = append(cmd, "--triggers") // Include triggers
cmd = append(cmd, "--events") // Include events
if options.SchemaOnly {
cmd = append(cmd, "--no-data")
} else if options.DataOnly {
cmd = append(cmd, "--no-create-info")
}
if options.NoOwner || options.NoPrivileges {
cmd = append(cmd, "--skip-add-drop-table")
}
// Compression (handled externally for MySQL)
// Output redirection will be handled by caller
// Database
cmd = append(cmd, database)
return cmd
}
// BuildRestoreCommand builds mysql restore command
func (m *MySQL) BuildRestoreCommand(database, inputFile string, options RestoreOptions) []string {
cmd := []string{"mysql"}
// Connection parameters
cmd = append(cmd, "-h", m.cfg.Host)
cmd = append(cmd, "-P", strconv.Itoa(m.cfg.Port))
cmd = append(cmd, "-u", m.cfg.User)
if m.cfg.Password != "" {
cmd = append(cmd, "-p"+m.cfg.Password)
}
// SSL options
if m.cfg.Insecure {
cmd = append(cmd, "--skip-ssl")
}
// Options
if options.SingleTransaction {
cmd = append(cmd, "--single-transaction")
}
// Database
cmd = append(cmd, database)
// Input file (will be handled via stdin redirection)
return cmd
}
// BuildSampleQuery builds SQL query for sampling data
func (m *MySQL) BuildSampleQuery(database, table string, strategy SampleStrategy) string {
switch strategy.Type {
case "ratio":
// Every Nth record using row_number (MySQL 8.0+) or modulo
return fmt.Sprintf("SELECT * FROM (SELECT *, (@row_number:=@row_number + 1) AS rn FROM %s.%s CROSS JOIN (SELECT @row_number:=0) AS t) AS numbered WHERE rn %% %d = 1",
database, table, strategy.Value)
case "percent":
// Percentage sampling using RAND()
return fmt.Sprintf("SELECT * FROM %s.%s WHERE RAND() <= %f",
database, table, float64(strategy.Value)/100.0)
case "count":
// First N records
return fmt.Sprintf("SELECT * FROM %s.%s LIMIT %d", database, table, strategy.Value)
default:
return fmt.Sprintf("SELECT * FROM %s.%s LIMIT 1000", database, table)
}
}
// ValidateBackupTools checks if required MySQL tools are available
func (m *MySQL) ValidateBackupTools() error {
tools := []string{"mysqldump", "mysql"}
for _, tool := range tools {
if _, err := exec.LookPath(tool); err != nil {
return fmt.Errorf("required tool not found: %s", tool)
}
}
return nil
}
// buildDSN constructs MySQL connection string
func (m *MySQL) buildDSN() string {
dsn := ""
if m.cfg.User != "" {
dsn += m.cfg.User
}
if m.cfg.Password != "" {
dsn += ":" + m.cfg.Password
}
dsn += "@"
if m.cfg.Host != "" && m.cfg.Host != "localhost" {
dsn += "tcp(" + m.cfg.Host + ":" + strconv.Itoa(m.cfg.Port) + ")"
}
dsn += "/" + m.cfg.Database
// Add connection parameters
params := []string{}
if m.cfg.Insecure {
params = append(params, "tls=skip-verify")
} else if m.cfg.SSLMode != "" {
switch strings.ToLower(m.cfg.SSLMode) {
case "disable", "disabled":
params = append(params, "tls=false")
case "require", "required":
params = append(params, "tls=true")
}
}
// Add charset
params = append(params, "charset=utf8mb4")
params = append(params, "parseTime=true")
if len(params) > 0 {
dsn += "?" + strings.Join(params, "&")
}
return dsn
}
// sanitizeMySQLDSN removes password from DSN for logging
func sanitizeMySQLDSN(dsn string) string {
// Find password part and replace it
if idx := strings.Index(dsn, ":"); idx != -1 {
if endIdx := strings.Index(dsn[idx:], "@"); endIdx != -1 {
return dsn[:idx] + ":***" + dsn[idx+endIdx:]
}
}
return dsn
}

View File

@ -0,0 +1,427 @@
package database
import (
"context"
"database/sql"
"fmt"
"os"
"os/exec"
"strconv"
"strings"
"dbbackup/internal/config"
"dbbackup/internal/logger"
)
// PostgreSQL implements Database interface for PostgreSQL
type PostgreSQL struct {
baseDatabase
}
// NewPostgreSQL creates a new PostgreSQL database instance
func NewPostgreSQL(cfg *config.Config, log logger.Logger) *PostgreSQL {
return &PostgreSQL{
baseDatabase: baseDatabase{
cfg: cfg,
log: log,
},
}
}
// Connect establishes a connection to PostgreSQL
func (p *PostgreSQL) Connect(ctx context.Context) error {
// Build PostgreSQL DSN
dsn := p.buildDSN()
p.dsn = dsn
p.log.Debug("Connecting to PostgreSQL", "dsn", sanitizeDSN(dsn))
db, err := sql.Open("postgres", dsn)
if err != nil {
return fmt.Errorf("failed to open PostgreSQL connection: %w", err)
}
// Configure connection pool
db.SetMaxOpenConns(10)
db.SetMaxIdleConns(5)
db.SetConnMaxLifetime(0)
// Test connection
timeoutCtx, cancel := buildTimeout(ctx, 0)
defer cancel()
if err := db.PingContext(timeoutCtx); err != nil {
db.Close()
return fmt.Errorf("failed to ping PostgreSQL: %w", err)
}
p.db = db
p.log.Info("Connected to PostgreSQL successfully")
return nil
}
// ListDatabases returns list of non-template databases
func (p *PostgreSQL) ListDatabases(ctx context.Context) ([]string, error) {
if p.db == nil {
return nil, fmt.Errorf("not connected to database")
}
query := `SELECT datname FROM pg_database
WHERE datistemplate = false
ORDER BY datname`
rows, err := p.db.QueryContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to query databases: %w", err)
}
defer rows.Close()
var databases []string
for rows.Next() {
var name string
if err := rows.Scan(&name); err != nil {
return nil, fmt.Errorf("failed to scan database name: %w", err)
}
databases = append(databases, name)
}
return databases, rows.Err()
}
// ListTables returns list of tables in a database
func (p *PostgreSQL) ListTables(ctx context.Context, database string) ([]string, error) {
if p.db == nil {
return nil, fmt.Errorf("not connected to database")
}
query := `SELECT schemaname||'.'||tablename as full_name
FROM pg_tables
WHERE schemaname NOT IN ('information_schema', 'pg_catalog', 'pg_toast')
ORDER BY schemaname, tablename`
rows, err := p.db.QueryContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to query tables: %w", err)
}
defer rows.Close()
var tables []string
for rows.Next() {
var name string
if err := rows.Scan(&name); err != nil {
return nil, fmt.Errorf("failed to scan table name: %w", err)
}
tables = append(tables, name)
}
return tables, rows.Err()
}
// CreateDatabase creates a new database
func (p *PostgreSQL) CreateDatabase(ctx context.Context, name string) error {
if p.db == nil {
return fmt.Errorf("not connected to database")
}
// PostgreSQL doesn't support CREATE DATABASE in transactions or prepared statements
query := fmt.Sprintf("CREATE DATABASE %s", name)
_, err := p.db.ExecContext(ctx, query)
if err != nil {
return fmt.Errorf("failed to create database %s: %w", name, err)
}
p.log.Info("Created database", "name", name)
return nil
}
// DropDatabase drops a database
func (p *PostgreSQL) DropDatabase(ctx context.Context, name string) error {
if p.db == nil {
return fmt.Errorf("not connected to database")
}
// Force drop connections and drop database
query := fmt.Sprintf("DROP DATABASE IF EXISTS %s", name)
_, err := p.db.ExecContext(ctx, query)
if err != nil {
return fmt.Errorf("failed to drop database %s: %w", name, err)
}
p.log.Info("Dropped database", "name", name)
return nil
}
// DatabaseExists checks if a database exists
func (p *PostgreSQL) DatabaseExists(ctx context.Context, name string) (bool, error) {
if p.db == nil {
return false, fmt.Errorf("not connected to database")
}
query := `SELECT EXISTS(SELECT 1 FROM pg_database WHERE datname = $1)`
var exists bool
err := p.db.QueryRowContext(ctx, query, name).Scan(&exists)
if err != nil {
return false, fmt.Errorf("failed to check database existence: %w", err)
}
return exists, nil
}
// GetVersion returns PostgreSQL version
func (p *PostgreSQL) GetVersion(ctx context.Context) (string, error) {
if p.db == nil {
return "", fmt.Errorf("not connected to database")
}
var version string
err := p.db.QueryRowContext(ctx, "SELECT version()").Scan(&version)
if err != nil {
return "", fmt.Errorf("failed to get version: %w", err)
}
return version, nil
}
// GetDatabaseSize returns database size in bytes
func (p *PostgreSQL) GetDatabaseSize(ctx context.Context, database string) (int64, error) {
if p.db == nil {
return 0, fmt.Errorf("not connected to database")
}
query := `SELECT pg_database_size($1)`
var size int64
err := p.db.QueryRowContext(ctx, query, database).Scan(&size)
if err != nil {
return 0, fmt.Errorf("failed to get database size: %w", err)
}
return size, nil
}
// GetTableRowCount returns approximate row count for a table
func (p *PostgreSQL) GetTableRowCount(ctx context.Context, database, table string) (int64, error) {
if p.db == nil {
return 0, fmt.Errorf("not connected to database")
}
// Use pg_stat_user_tables for approximate count (faster)
parts := strings.Split(table, ".")
if len(parts) != 2 {
return 0, fmt.Errorf("table name must be in format schema.table")
}
query := `SELECT COALESCE(n_tup_ins, 0) FROM pg_stat_user_tables
WHERE schemaname = $1 AND relname = $2`
var count int64
err := p.db.QueryRowContext(ctx, query, parts[0], parts[1]).Scan(&count)
if err != nil {
// Fallback to exact count if stats not available
exactQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s", table)
err = p.db.QueryRowContext(ctx, exactQuery).Scan(&count)
if err != nil {
return 0, fmt.Errorf("failed to get table row count: %w", err)
}
}
return count, nil
}
// BuildBackupCommand builds pg_dump command
func (p *PostgreSQL) BuildBackupCommand(database, outputFile string, options BackupOptions) []string {
cmd := []string{"pg_dump"}
// Connection parameters
if p.cfg.Host != "localhost" {
cmd = append(cmd, "-h", p.cfg.Host)
cmd = append(cmd, "-p", strconv.Itoa(p.cfg.Port))
cmd = append(cmd, "--no-password")
}
cmd = append(cmd, "-U", p.cfg.User)
// Format and compression
if options.Format != "" {
cmd = append(cmd, "--format="+options.Format)
} else {
cmd = append(cmd, "--format=custom")
}
if options.Compression > 0 {
cmd = append(cmd, "--compress="+strconv.Itoa(options.Compression))
}
// Parallel jobs (only for directory format)
if options.Parallel > 1 && options.Format == "directory" {
cmd = append(cmd, "--jobs="+strconv.Itoa(options.Parallel))
}
// Options
if options.Blobs {
cmd = append(cmd, "--blobs")
}
if options.SchemaOnly {
cmd = append(cmd, "--schema-only")
}
if options.DataOnly {
cmd = append(cmd, "--data-only")
}
if options.NoOwner {
cmd = append(cmd, "--no-owner")
}
if options.NoPrivileges {
cmd = append(cmd, "--no-privileges")
}
if options.Role != "" {
cmd = append(cmd, "--role="+options.Role)
}
// Database and output
cmd = append(cmd, "--dbname="+database)
cmd = append(cmd, "--file="+outputFile)
return cmd
}
// BuildRestoreCommand builds pg_restore command
func (p *PostgreSQL) BuildRestoreCommand(database, inputFile string, options RestoreOptions) []string {
cmd := []string{"pg_restore"}
// Connection parameters
if p.cfg.Host != "localhost" {
cmd = append(cmd, "-h", p.cfg.Host)
cmd = append(cmd, "-p", strconv.Itoa(p.cfg.Port))
cmd = append(cmd, "--no-password")
}
cmd = append(cmd, "-U", p.cfg.User)
// Parallel jobs
if options.Parallel > 1 {
cmd = append(cmd, "--jobs="+strconv.Itoa(options.Parallel))
}
// Options
if options.Clean {
cmd = append(cmd, "--clean")
}
if options.IfExists {
cmd = append(cmd, "--if-exists")
}
if options.NoOwner {
cmd = append(cmd, "--no-owner")
}
if options.NoPrivileges {
cmd = append(cmd, "--no-privileges")
}
if options.SingleTransaction {
cmd = append(cmd, "--single-transaction")
}
// Database and input
cmd = append(cmd, "--dbname="+database)
cmd = append(cmd, inputFile)
return cmd
}
// BuildSampleQuery builds SQL query for sampling data
func (p *PostgreSQL) BuildSampleQuery(database, table string, strategy SampleStrategy) string {
switch strategy.Type {
case "ratio":
// Every Nth record using row_number
return fmt.Sprintf("SELECT * FROM (SELECT *, row_number() OVER () as rn FROM %s) t WHERE rn %% %d = 1",
table, strategy.Value)
case "percent":
// Percentage sampling using TABLESAMPLE (PostgreSQL 9.5+)
return fmt.Sprintf("SELECT * FROM %s TABLESAMPLE BERNOULLI(%d)", table, strategy.Value)
case "count":
// First N records
return fmt.Sprintf("SELECT * FROM %s LIMIT %d", table, strategy.Value)
default:
return fmt.Sprintf("SELECT * FROM %s LIMIT 1000", table)
}
}
// ValidateBackupTools checks if required PostgreSQL tools are available
func (p *PostgreSQL) ValidateBackupTools() error {
tools := []string{"pg_dump", "pg_restore", "pg_dumpall", "psql"}
for _, tool := range tools {
if _, err := exec.LookPath(tool); err != nil {
return fmt.Errorf("required tool not found: %s", tool)
}
}
return nil
}
// buildDSN constructs PostgreSQL connection string
func (p *PostgreSQL) buildDSN() string {
dsn := fmt.Sprintf("user=%s dbname=%s", p.cfg.User, p.cfg.Database)
if p.cfg.Password != "" {
dsn += " password=" + p.cfg.Password
}
// For localhost connections, try socket first for peer auth
if p.cfg.Host == "localhost" && p.cfg.Password == "" {
// Try Unix socket connection for peer authentication
// Common PostgreSQL socket locations
socketDirs := []string{
"/var/run/postgresql",
"/tmp",
"/var/lib/pgsql",
}
for _, dir := range socketDirs {
socketPath := fmt.Sprintf("%s/.s.PGSQL.%d", dir, p.cfg.Port)
if _, err := os.Stat(socketPath); err == nil {
dsn += " host=" + dir
p.log.Debug("Using PostgreSQL socket", "path", socketPath)
break
}
}
} else if p.cfg.Host != "localhost" || p.cfg.Password != "" {
// Use TCP connection
dsn += " host=" + p.cfg.Host
dsn += " port=" + strconv.Itoa(p.cfg.Port)
}
if p.cfg.SSLMode != "" && !p.cfg.Insecure {
// Map SSL modes to supported values for lib/pq
switch strings.ToLower(p.cfg.SSLMode) {
case "prefer", "preferred":
dsn += " sslmode=require" // lib/pq default, closest to prefer
case "require", "required":
dsn += " sslmode=require"
case "verify-ca":
dsn += " sslmode=verify-ca"
case "verify-full", "verify-identity":
dsn += " sslmode=verify-full"
case "disable", "disabled":
dsn += " sslmode=disable"
default:
dsn += " sslmode=require" // Safe default
}
} else if p.cfg.Insecure {
dsn += " sslmode=disable"
}
return dsn
}
// sanitizeDSN removes password from DSN for logging
func sanitizeDSN(dsn string) string {
// Simple password removal for logging
parts := strings.Split(dsn, " ")
var sanitized []string
for _, part := range parts {
if strings.HasPrefix(part, "password=") {
sanitized = append(sanitized, "password=***")
} else {
sanitized = append(sanitized, part)
}
}
return strings.Join(sanitized, " ")
}

185
internal/logger/logger.go Normal file
View File

@ -0,0 +1,185 @@
package logger
import (
"fmt"
"io"
"log/slog"
"os"
"strings"
"time"
)
// Logger defines the interface for logging
type Logger interface {
Debug(msg string, args ...any)
Info(msg string, args ...any)
Warn(msg string, args ...any)
Error(msg string, args ...any)
Time(msg string, args ...any)
// Progress logging for operations
StartOperation(name string) OperationLogger
}
// OperationLogger tracks timing for operations
type OperationLogger interface {
Update(msg string, args ...any)
Complete(msg string, args ...any)
Fail(msg string, args ...any)
}
// logger implements Logger interface using slog
type logger struct {
slog *slog.Logger
level slog.Level
format string
}
// operationLogger tracks a single operation
type operationLogger struct {
name string
startTime time.Time
parent *logger
}
// New creates a new logger
func New(level, format string) Logger {
var slogLevel slog.Level
switch strings.ToLower(level) {
case "debug":
slogLevel = slog.LevelDebug
case "info":
slogLevel = slog.LevelInfo
case "warn", "warning":
slogLevel = slog.LevelWarn
case "error":
slogLevel = slog.LevelError
default:
slogLevel = slog.LevelInfo
}
var handler slog.Handler
opts := &slog.HandlerOptions{
Level: slogLevel,
}
switch strings.ToLower(format) {
case "json":
handler = slog.NewJSONHandler(os.Stdout, opts)
default:
handler = slog.NewTextHandler(os.Stdout, opts)
}
return &logger{
slog: slog.New(handler),
level: slogLevel,
format: format,
}
}
func (l *logger) Debug(msg string, args ...any) {
l.slog.Debug(msg, args...)
}
func (l *logger) Info(msg string, args ...any) {
l.slog.Info(msg, args...)
}
func (l *logger) Warn(msg string, args ...any) {
l.slog.Warn(msg, args...)
}
func (l *logger) Error(msg string, args ...any) {
l.slog.Error(msg, args...)
}
func (l *logger) Time(msg string, args ...any) {
// Time logs are always at info level with special formatting
l.slog.Info("[TIME] "+msg, args...)
}
func (l *logger) StartOperation(name string) OperationLogger {
return &operationLogger{
name: name,
startTime: time.Now(),
parent: l,
}
}
func (ol *operationLogger) Update(msg string, args ...any) {
elapsed := time.Since(ol.startTime)
ol.parent.Info(fmt.Sprintf("[%s] %s", ol.name, msg),
append(args, "elapsed", elapsed.String())...)
}
func (ol *operationLogger) Complete(msg string, args ...any) {
elapsed := time.Since(ol.startTime)
ol.parent.Info(fmt.Sprintf("[%s] COMPLETED: %s", ol.name, msg),
append(args, "duration", formatDuration(elapsed))...)
}
func (ol *operationLogger) Fail(msg string, args ...any) {
elapsed := time.Since(ol.startTime)
ol.parent.Error(fmt.Sprintf("[%s] FAILED: %s", ol.name, msg),
append(args, "duration", formatDuration(elapsed))...)
}
// formatDuration formats duration in human-readable format
func formatDuration(d time.Duration) string {
if d < time.Minute {
return fmt.Sprintf("%.1fs", d.Seconds())
} else if d < time.Hour {
minutes := int(d.Minutes())
seconds := int(d.Seconds()) % 60
return fmt.Sprintf("%dm %ds", minutes, seconds)
} else {
hours := int(d.Hours())
minutes := int(d.Minutes()) % 60
seconds := int(d.Seconds()) % 60
return fmt.Sprintf("%dh %dm %ds", hours, minutes, seconds)
}
}
// FileLogger creates a logger that writes to both stdout and a file
func FileLogger(level, format, filename string) (Logger, error) {
var slogLevel slog.Level
switch strings.ToLower(level) {
case "debug":
slogLevel = slog.LevelDebug
case "info":
slogLevel = slog.LevelInfo
case "warn", "warning":
slogLevel = slog.LevelWarn
case "error":
slogLevel = slog.LevelError
default:
slogLevel = slog.LevelInfo
}
// Open log file
file, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
return nil, fmt.Errorf("failed to open log file: %w", err)
}
// Create multi-writer (stdout + file)
multiWriter := io.MultiWriter(os.Stdout, file)
var handler slog.Handler
opts := &slog.HandlerOptions{
Level: slogLevel,
}
switch strings.ToLower(format) {
case "json":
handler = slog.NewJSONHandler(multiWriter, opts)
default:
handler = slog.NewTextHandler(multiWriter, opts)
}
return &logger{
slog: slog.New(handler),
level: slogLevel,
format: format,
}, nil
}

View File

@ -0,0 +1,427 @@
package progress
import (
"fmt"
"sync"
"time"
)
// DetailedReporter provides comprehensive progress reporting with timestamps and status
type DetailedReporter struct {
mu sync.RWMutex
operations []OperationStatus
startTime time.Time
indicator Indicator
logger Logger
}
// OperationStatus represents the status of a backup/restore operation
type OperationStatus struct {
ID string `json:"id"`
Name string `json:"name"`
Type string `json:"type"` // "backup", "restore", "verify"
Status string `json:"status"` // "running", "completed", "failed"
StartTime time.Time `json:"start_time"`
EndTime *time.Time `json:"end_time,omitempty"`
Duration time.Duration `json:"duration"`
Progress int `json:"progress"` // 0-100
Message string `json:"message"`
Details map[string]string `json:"details"`
Steps []StepStatus `json:"steps"`
BytesTotal int64 `json:"bytes_total"`
BytesDone int64 `json:"bytes_done"`
FilesTotal int `json:"files_total"`
FilesDone int `json:"files_done"`
Errors []string `json:"errors,omitempty"`
}
// StepStatus represents individual steps within an operation
type StepStatus struct {
Name string `json:"name"`
Status string `json:"status"`
StartTime time.Time `json:"start_time"`
EndTime *time.Time `json:"end_time,omitempty"`
Duration time.Duration `json:"duration"`
Message string `json:"message"`
}
// Logger interface for detailed reporting
type Logger interface {
Info(msg string, args ...any)
Warn(msg string, args ...any)
Error(msg string, args ...any)
Debug(msg string, args ...any)
}
// NewDetailedReporter creates a new detailed progress reporter
func NewDetailedReporter(indicator Indicator, logger Logger) *DetailedReporter {
return &DetailedReporter{
operations: make([]OperationStatus, 0),
indicator: indicator,
logger: logger,
}
}
// StartOperation begins tracking a new operation
func (dr *DetailedReporter) StartOperation(id, name, opType string) *OperationTracker {
dr.mu.Lock()
defer dr.mu.Unlock()
operation := OperationStatus{
ID: id,
Name: name,
Type: opType,
Status: "running",
StartTime: time.Now(),
Progress: 0,
Details: make(map[string]string),
Steps: make([]StepStatus, 0),
}
dr.operations = append(dr.operations, operation)
if dr.startTime.IsZero() {
dr.startTime = time.Now()
}
// Start visual indicator
if dr.indicator != nil {
dr.indicator.Start(fmt.Sprintf("Starting %s: %s", opType, name))
}
// Log operation start
dr.logger.Info("Operation started",
"id", id,
"name", name,
"type", opType,
"timestamp", operation.StartTime.Format(time.RFC3339))
return &OperationTracker{
reporter: dr,
operationID: id,
}
}
// OperationTracker provides methods to update operation progress
type OperationTracker struct {
reporter *DetailedReporter
operationID string
}
// UpdateProgress updates the progress of the operation
func (ot *OperationTracker) UpdateProgress(progress int, message string) {
ot.reporter.mu.Lock()
defer ot.reporter.mu.Unlock()
for i := range ot.reporter.operations {
if ot.reporter.operations[i].ID == ot.operationID {
ot.reporter.operations[i].Progress = progress
ot.reporter.operations[i].Message = message
// Update visual indicator
if ot.reporter.indicator != nil {
progressMsg := fmt.Sprintf("[%d%%] %s", progress, message)
ot.reporter.indicator.Update(progressMsg)
}
// Log progress update
ot.reporter.logger.Debug("Progress update",
"operation_id", ot.operationID,
"progress", progress,
"message", message,
"timestamp", time.Now().Format(time.RFC3339))
break
}
}
}
// AddStep adds a new step to the operation
func (ot *OperationTracker) AddStep(name, message string) *StepTracker {
ot.reporter.mu.Lock()
defer ot.reporter.mu.Unlock()
step := StepStatus{
Name: name,
Status: "running",
StartTime: time.Now(),
Message: message,
}
for i := range ot.reporter.operations {
if ot.reporter.operations[i].ID == ot.operationID {
ot.reporter.operations[i].Steps = append(ot.reporter.operations[i].Steps, step)
// Log step start
ot.reporter.logger.Info("Step started",
"operation_id", ot.operationID,
"step", name,
"message", message,
"timestamp", step.StartTime.Format(time.RFC3339))
break
}
}
return &StepTracker{
reporter: ot.reporter,
operationID: ot.operationID,
stepName: name,
}
}
// SetDetails adds metadata to the operation
func (ot *OperationTracker) SetDetails(key, value string) {
ot.reporter.mu.Lock()
defer ot.reporter.mu.Unlock()
for i := range ot.reporter.operations {
if ot.reporter.operations[i].ID == ot.operationID {
ot.reporter.operations[i].Details[key] = value
break
}
}
}
// SetFileProgress updates file-based progress
func (ot *OperationTracker) SetFileProgress(filesDone, filesTotal int) {
ot.reporter.mu.Lock()
defer ot.reporter.mu.Unlock()
for i := range ot.reporter.operations {
if ot.reporter.operations[i].ID == ot.operationID {
ot.reporter.operations[i].FilesDone = filesDone
ot.reporter.operations[i].FilesTotal = filesTotal
if filesTotal > 0 {
progress := (filesDone * 100) / filesTotal
ot.reporter.operations[i].Progress = progress
}
break
}
}
}
// SetByteProgress updates byte-based progress
func (ot *OperationTracker) SetByteProgress(bytesDone, bytesTotal int64) {
ot.reporter.mu.Lock()
defer ot.reporter.mu.Unlock()
for i := range ot.reporter.operations {
if ot.reporter.operations[i].ID == ot.operationID {
ot.reporter.operations[i].BytesDone = bytesDone
ot.reporter.operations[i].BytesTotal = bytesTotal
if bytesTotal > 0 {
progress := int((bytesDone * 100) / bytesTotal)
ot.reporter.operations[i].Progress = progress
}
break
}
}
}
// Complete marks the operation as completed
func (ot *OperationTracker) Complete(message string) {
ot.reporter.mu.Lock()
defer ot.reporter.mu.Unlock()
now := time.Now()
for i := range ot.reporter.operations {
if ot.reporter.operations[i].ID == ot.operationID {
ot.reporter.operations[i].Status = "completed"
ot.reporter.operations[i].Progress = 100
ot.reporter.operations[i].EndTime = &now
ot.reporter.operations[i].Duration = now.Sub(ot.reporter.operations[i].StartTime)
ot.reporter.operations[i].Message = message
// Complete visual indicator
if ot.reporter.indicator != nil {
ot.reporter.indicator.Complete(fmt.Sprintf("✅ %s", message))
}
// Log completion with duration
ot.reporter.logger.Info("Operation completed",
"operation_id", ot.operationID,
"message", message,
"duration", ot.reporter.operations[i].Duration.String(),
"timestamp", now.Format(time.RFC3339))
break
}
}
}
// Fail marks the operation as failed
func (ot *OperationTracker) Fail(err error) {
ot.reporter.mu.Lock()
defer ot.reporter.mu.Unlock()
now := time.Now()
for i := range ot.reporter.operations {
if ot.reporter.operations[i].ID == ot.operationID {
ot.reporter.operations[i].Status = "failed"
ot.reporter.operations[i].EndTime = &now
ot.reporter.operations[i].Duration = now.Sub(ot.reporter.operations[i].StartTime)
ot.reporter.operations[i].Message = err.Error()
ot.reporter.operations[i].Errors = append(ot.reporter.operations[i].Errors, err.Error())
// Fail visual indicator
if ot.reporter.indicator != nil {
ot.reporter.indicator.Fail(fmt.Sprintf("❌ %s", err.Error()))
}
// Log failure
ot.reporter.logger.Error("Operation failed",
"operation_id", ot.operationID,
"error", err.Error(),
"duration", ot.reporter.operations[i].Duration.String(),
"timestamp", now.Format(time.RFC3339))
break
}
}
}
// StepTracker manages individual step progress
type StepTracker struct {
reporter *DetailedReporter
operationID string
stepName string
}
// Complete marks the step as completed
func (st *StepTracker) Complete(message string) {
st.reporter.mu.Lock()
defer st.reporter.mu.Unlock()
now := time.Now()
for i := range st.reporter.operations {
if st.reporter.operations[i].ID == st.operationID {
for j := range st.reporter.operations[i].Steps {
if st.reporter.operations[i].Steps[j].Name == st.stepName {
st.reporter.operations[i].Steps[j].Status = "completed"
st.reporter.operations[i].Steps[j].EndTime = &now
st.reporter.operations[i].Steps[j].Duration = now.Sub(st.reporter.operations[i].Steps[j].StartTime)
st.reporter.operations[i].Steps[j].Message = message
// Log step completion
st.reporter.logger.Info("Step completed",
"operation_id", st.operationID,
"step", st.stepName,
"message", message,
"duration", st.reporter.operations[i].Steps[j].Duration.String(),
"timestamp", now.Format(time.RFC3339))
break
}
}
break
}
}
}
// Fail marks the step as failed
func (st *StepTracker) Fail(err error) {
st.reporter.mu.Lock()
defer st.reporter.mu.Unlock()
now := time.Now()
for i := range st.reporter.operations {
if st.reporter.operations[i].ID == st.operationID {
for j := range st.reporter.operations[i].Steps {
if st.reporter.operations[i].Steps[j].Name == st.stepName {
st.reporter.operations[i].Steps[j].Status = "failed"
st.reporter.operations[i].Steps[j].EndTime = &now
st.reporter.operations[i].Steps[j].Duration = now.Sub(st.reporter.operations[i].Steps[j].StartTime)
st.reporter.operations[i].Steps[j].Message = err.Error()
// Log step failure
st.reporter.logger.Error("Step failed",
"operation_id", st.operationID,
"step", st.stepName,
"error", err.Error(),
"duration", st.reporter.operations[i].Steps[j].Duration.String(),
"timestamp", now.Format(time.RFC3339))
break
}
}
break
}
}
}
// GetOperationStatus returns the current status of an operation
func (dr *DetailedReporter) GetOperationStatus(id string) *OperationStatus {
dr.mu.RLock()
defer dr.mu.RUnlock()
for _, op := range dr.operations {
if op.ID == id {
return &op
}
}
return nil
}
// GetAllOperations returns all tracked operations
func (dr *DetailedReporter) GetAllOperations() []OperationStatus {
dr.mu.RLock()
defer dr.mu.RUnlock()
return append([]OperationStatus(nil), dr.operations...)
}
// GetSummary returns a summary of all operations
func (dr *DetailedReporter) GetSummary() OperationSummary {
dr.mu.RLock()
defer dr.mu.RUnlock()
summary := OperationSummary{
TotalOperations: len(dr.operations),
CompletedOperations: 0,
FailedOperations: 0,
RunningOperations: 0,
TotalDuration: time.Since(dr.startTime),
}
for _, op := range dr.operations {
switch op.Status {
case "completed":
summary.CompletedOperations++
case "failed":
summary.FailedOperations++
case "running":
summary.RunningOperations++
}
}
return summary
}
// OperationSummary provides overall statistics
type OperationSummary struct {
TotalOperations int `json:"total_operations"`
CompletedOperations int `json:"completed_operations"`
FailedOperations int `json:"failed_operations"`
RunningOperations int `json:"running_operations"`
TotalDuration time.Duration `json:"total_duration"`
}
// FormatSummary returns a formatted string representation of the summary
func (os *OperationSummary) FormatSummary() string {
return fmt.Sprintf(
"📊 Operations Summary:\n"+
" Total: %d | Completed: %d | Failed: %d | Running: %d\n"+
" Total Duration: %s",
os.TotalOperations,
os.CompletedOperations,
os.FailedOperations,
os.RunningOperations,
formatDuration(os.TotalDuration))
}
// formatDuration formats a duration in a human-readable way
func formatDuration(d time.Duration) string {
if d < time.Minute {
return fmt.Sprintf("%.1fs", d.Seconds())
} else if d < time.Hour {
return fmt.Sprintf("%.1fm", d.Minutes())
}
return fmt.Sprintf("%.1fh", d.Hours())
}

View File

@ -0,0 +1,398 @@
package progress
import (
"fmt"
"io"
"os"
"strings"
"time"
)
// Indicator represents a progress indicator interface
type Indicator interface {
Start(message string)
Update(message string)
Complete(message string)
Fail(message string)
Stop()
}
// Spinner creates a spinning progress indicator
type Spinner struct {
writer io.Writer
message string
active bool
frames []string
interval time.Duration
stopCh chan bool
}
// NewSpinner creates a new spinner progress indicator
func NewSpinner() *Spinner {
return &Spinner{
writer: os.Stdout,
frames: []string{"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"},
interval: 80 * time.Millisecond,
stopCh: make(chan bool, 1),
}
}
// Start begins the spinner with a message
func (s *Spinner) Start(message string) {
s.message = message
s.active = true
go func() {
i := 0
lastMessage := ""
for {
select {
case <-s.stopCh:
return
default:
if s.active {
currentFrame := fmt.Sprintf("%s %s", s.frames[i%len(s.frames)], s.message)
if s.message != lastMessage {
// Print new line for new messages
fmt.Fprintf(s.writer, "\n%s", currentFrame)
lastMessage = s.message
} else {
// Update in place for same message
fmt.Fprintf(s.writer, "\r%s", currentFrame)
}
i++
time.Sleep(s.interval)
}
}
}
}()
}
// Update changes the spinner message
func (s *Spinner) Update(message string) {
s.message = message
}
// Complete stops the spinner with a success message
func (s *Spinner) Complete(message string) {
s.Stop()
fmt.Fprintf(s.writer, "\n✅ %s\n", message)
}
// Fail stops the spinner with a failure message
func (s *Spinner) Fail(message string) {
s.Stop()
fmt.Fprintf(s.writer, "\n❌ %s\n", message)
}
// Stop stops the spinner
func (s *Spinner) Stop() {
if s.active {
s.active = false
s.stopCh <- true
fmt.Fprint(s.writer, "\n") // New line instead of clearing
}
}
// Dots creates a dots progress indicator
type Dots struct {
writer io.Writer
message string
active bool
stopCh chan bool
}
// NewDots creates a new dots progress indicator
func NewDots() *Dots {
return &Dots{
writer: os.Stdout,
stopCh: make(chan bool, 1),
}
}
// Start begins the dots indicator
func (d *Dots) Start(message string) {
d.message = message
d.active = true
fmt.Fprint(d.writer, message)
go func() {
count := 0
for {
select {
case <-d.stopCh:
return
default:
if d.active {
fmt.Fprint(d.writer, ".")
count++
if count%3 == 0 {
// Reset dots
fmt.Fprint(d.writer, "\r"+d.message)
}
time.Sleep(500 * time.Millisecond)
}
}
}
}()
}
// Update changes the dots message
func (d *Dots) Update(message string) {
d.message = message
if d.active {
fmt.Fprintf(d.writer, "\n%s", message)
}
}
// Complete stops the dots with a success message
func (d *Dots) Complete(message string) {
d.Stop()
fmt.Fprintf(d.writer, " ✅ %s\n", message)
}
// Fail stops the dots with a failure message
func (d *Dots) Fail(message string) {
d.Stop()
fmt.Fprintf(d.writer, " ❌ %s\n", message)
}
// Stop stops the dots indicator
func (d *Dots) Stop() {
if d.active {
d.active = false
d.stopCh <- true
}
}
// ProgressBar creates a visual progress bar
type ProgressBar struct {
writer io.Writer
message string
total int
current int
width int
active bool
stopCh chan bool
}
// NewProgressBar creates a new progress bar
func NewProgressBar(total int) *ProgressBar {
return &ProgressBar{
writer: os.Stdout,
total: total,
width: 40,
stopCh: make(chan bool, 1),
}
}
// Start begins the progress bar
func (p *ProgressBar) Start(message string) {
p.message = message
p.active = true
p.current = 0
p.render()
}
// Update advances the progress bar
func (p *ProgressBar) Update(message string) {
if p.current < p.total {
p.current++
}
p.message = message
p.render()
}
// SetProgress sets specific progress value
func (p *ProgressBar) SetProgress(current int, message string) {
p.current = current
p.message = message
p.render()
}
// Complete finishes the progress bar
func (p *ProgressBar) Complete(message string) {
p.current = p.total
p.message = message
p.render()
fmt.Fprintf(p.writer, " ✅ %s\n", message)
p.Stop()
}
// Fail stops the progress bar with failure
func (p *ProgressBar) Fail(message string) {
p.render()
fmt.Fprintf(p.writer, " ❌ %s\n", message)
p.Stop()
}
// Stop stops the progress bar
func (p *ProgressBar) Stop() {
p.active = false
}
// render draws the progress bar
func (p *ProgressBar) render() {
if !p.active {
return
}
percent := float64(p.current) / float64(p.total)
filled := int(percent * float64(p.width))
bar := strings.Repeat("█", filled) + strings.Repeat("░", p.width-filled)
fmt.Fprintf(p.writer, "\n%s [%s] %d%%", p.message, bar, int(percent*100))
}
// Static creates a simple static progress indicator
type Static struct {
writer io.Writer
}
// NewStatic creates a new static progress indicator
func NewStatic() *Static {
return &Static{
writer: os.Stdout,
}
}
// Start shows the initial message
func (s *Static) Start(message string) {
fmt.Fprintf(s.writer, "→ %s", message)
}
// Update shows an update message
func (s *Static) Update(message string) {
fmt.Fprintf(s.writer, " - %s", message)
}
// Complete shows completion message
func (s *Static) Complete(message string) {
fmt.Fprintf(s.writer, " ✅ %s\n", message)
}
// Fail shows failure message
func (s *Static) Fail(message string) {
fmt.Fprintf(s.writer, " ❌ %s\n", message)
}
// Stop does nothing for static indicator
func (s *Static) Stop() {
// No-op for static indicator
}
// LineByLine creates a line-by-line progress indicator
type LineByLine struct {
writer io.Writer
silent bool
}
// NewLineByLine creates a new line-by-line progress indicator
func NewLineByLine() *LineByLine {
return &LineByLine{
writer: os.Stdout,
silent: false,
}
}
// Light creates a minimal progress indicator with just essential status
type Light struct {
writer io.Writer
silent bool
}
// NewLight creates a new light progress indicator
func NewLight() *Light {
return &Light{
writer: os.Stdout,
silent: false,
}
}
// NewQuietLineByLine creates a quiet line-by-line progress indicator
func NewQuietLineByLine() *LineByLine {
return &LineByLine{
writer: os.Stdout,
silent: true,
}
}
// Start shows the initial message
func (l *LineByLine) Start(message string) {
fmt.Fprintf(l.writer, "\n🔄 %s\n", message)
}
// Update shows an update message
func (l *LineByLine) Update(message string) {
if !l.silent {
fmt.Fprintf(l.writer, " %s\n", message)
}
}
// Complete shows completion message
func (l *LineByLine) Complete(message string) {
fmt.Fprintf(l.writer, "✅ %s\n\n", message)
}
// Fail shows failure message
func (l *LineByLine) Fail(message string) {
fmt.Fprintf(l.writer, "❌ %s\n\n", message)
}
// Stop does nothing for line-by-line (no cleanup needed)
func (l *LineByLine) Stop() {
// No cleanup needed for line-by-line
}
// Light indicator methods - minimal output
func (l *Light) Start(message string) {
if !l.silent {
fmt.Fprintf(l.writer, "▶ %s\n", message)
}
}
func (l *Light) Update(message string) {
if !l.silent {
fmt.Fprintf(l.writer, " %s\n", message)
}
}
func (l *Light) Complete(message string) {
if !l.silent {
fmt.Fprintf(l.writer, "✓ %s\n", message)
}
}
func (l *Light) Fail(message string) {
if !l.silent {
fmt.Fprintf(l.writer, "✗ %s\n", message)
}
}
func (l *Light) Stop() {
// No cleanup needed for light indicator
}
// NewIndicator creates an appropriate progress indicator based on environment
func NewIndicator(interactive bool, indicatorType string) Indicator {
if !interactive {
return NewLineByLine() // Use line-by-line for non-interactive mode
}
switch indicatorType {
case "spinner":
return NewSpinner()
case "dots":
return NewDots()
case "bar":
return NewProgressBar(100) // Default to 100 steps
case "line":
return NewLineByLine()
case "light":
return NewLight()
default:
return NewLineByLine() // Default to line-by-line for better compatibility
}
}

658
internal/tui/menu.go Normal file
View File

@ -0,0 +1,658 @@
package tui
import (
"context"
"fmt"
"strings"
"time"
"github.com/charmbracelet/bubbles/spinner"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"dbbackup/internal/config"
"dbbackup/internal/database"
"dbbackup/internal/logger"
"dbbackup/internal/progress"
)
// Style definitions
var (
titleStyle = lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("#FAFAFA")).
Background(lipgloss.Color("#7D56F4")).
Padding(0, 1)
menuStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("#626262"))
selectedStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("#FF75B7")).
Bold(true)
infoStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("#626262"))
successStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("#04B575")).
Bold(true)
errorStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("#FF6B6B")).
Bold(true)
progressStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("#FFD93D")).
Bold(true)
stepStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("#6BCF7F")).
MarginLeft(2)
detailStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("#A8A8A8")).
MarginLeft(4).
Italic(true)
)
// MenuModel represents the enhanced menu state with progress tracking
type MenuModel struct {
choices []string
cursor int
config *config.Config
logger logger.Logger
quitting bool
message string
// Progress tracking
showProgress bool
showCompletion bool
completionMessage string
completionDismissed bool // Track if user manually dismissed completion
currentOperation *progress.OperationStatus
allOperations []progress.OperationStatus
lastUpdate time.Time
spinner spinner.Model
// Background operations
ctx context.Context
cancel context.CancelFunc
// TUI Progress Reporter
progressReporter *TUIProgressReporter
}
// completionMsg carries completion status
type completionMsg struct {
success bool
message string
}
// operationUpdateMsg carries operation updates
type operationUpdateMsg struct {
operations []progress.OperationStatus
}
// operationCompleteMsg signals operation completion
type operationCompleteMsg struct {
operation *progress.OperationStatus
success bool
}
// Initialize the menu model
func NewMenuModel(cfg *config.Config, log logger.Logger) MenuModel {
ctx, cancel := context.WithCancel(context.Background())
s := spinner.New()
s.Spinner = spinner.Dot
s.Style = lipgloss.NewStyle().Foreground(lipgloss.Color("#FFD93D"))
// Create TUI progress reporter
progressReporter := NewTUIProgressReporter()
model := MenuModel{
choices: []string{
"Single Database Backup",
"Sample Database Backup (with ratio)",
"Cluster Backup (all databases)",
"View Active Operations",
"Show Operation History",
"Database Status & Health Check",
"Configuration Settings",
"Clear Operation History",
"Quit",
},
config: cfg,
logger: log,
ctx: ctx,
cancel: cancel,
spinner: s,
lastUpdate: time.Now(),
progressReporter: progressReporter,
}
// Set up progress callback
progressReporter.AddCallback(func(operations []progress.OperationStatus) {
// This will be called when operations update
// The TUI will pick up these updates in the pollOperations method
})
return model
}
// Init initializes the model
func (m MenuModel) Init() tea.Cmd {
return tea.Batch(
m.spinner.Tick,
m.pollOperations(),
)
}
// pollOperations periodically checks for operation updates
func (m MenuModel) pollOperations() tea.Cmd {
return tea.Tick(time.Millisecond*500, func(t time.Time) tea.Msg {
// Get operations from our TUI progress reporter
operations := m.progressReporter.GetOperations()
return operationUpdateMsg{operations: operations}
})
}
// Update handles messages
func (m MenuModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.KeyMsg:
switch msg.String() {
case "ctrl+c", "q":
if m.cancel != nil {
m.cancel()
}
m.quitting = true
return m, tea.Quit
case "up", "k":
// Clear completion status and allow navigation
if m.showCompletion {
m.showCompletion = false
m.completionMessage = ""
m.message = ""
m.completionDismissed = true // Mark as manually dismissed
}
if m.cursor > 0 {
m.cursor--
}
case "down", "j":
// Clear completion status and allow navigation
if m.showCompletion {
m.showCompletion = false
m.completionMessage = ""
m.message = ""
m.completionDismissed = true // Mark as manually dismissed
}
if m.cursor < len(m.choices)-1 {
m.cursor++
}
case "enter", " ":
// Clear completion status and allow selection
if m.showCompletion {
m.showCompletion = false
m.completionMessage = ""
m.message = ""
m.completionDismissed = true // Mark as manually dismissed
return m, m.pollOperations()
}
switch m.cursor {
case 0: // Single Database Backup
return m.handleSingleBackup()
case 1: // Sample Database Backup
return m.handleSampleBackup()
case 2: // Cluster Backup
return m.handleClusterBackup()
case 3: // View Active Operations
return m.handleViewOperations()
case 4: // Show Operation History
return m.handleOperationHistory()
case 5: // Database Status
return m.handleStatus()
case 6: // Settings
return m.handleSettings()
case 7: // Clear History
return m.handleClearHistory()
case 8: // Quit
if m.cancel != nil {
m.cancel()
}
m.quitting = true
return m, tea.Quit
}
case "esc":
// Clear completion status on escape
if m.showCompletion {
m.showCompletion = false
m.completionMessage = ""
m.message = ""
m.completionDismissed = true // Mark as manually dismissed
}
}
case operationUpdateMsg:
m.allOperations = msg.operations
if len(msg.operations) > 0 {
latest := msg.operations[len(msg.operations)-1]
if latest.Status == "running" {
m.currentOperation = &latest
m.showProgress = true
m.showCompletion = false
m.completionDismissed = false // Reset dismissal flag for new operation
} else if m.currentOperation != nil && latest.ID == m.currentOperation.ID {
m.currentOperation = &latest
m.showProgress = false
// Only show completion status if user hasn't manually dismissed it
if !m.completionDismissed {
if latest.Status == "completed" {
m.showCompletion = true
m.completionMessage = fmt.Sprintf("✅ %s", latest.Message)
} else if latest.Status == "failed" {
m.showCompletion = true
m.completionMessage = fmt.Sprintf("❌ %s", latest.Message)
}
}
}
}
return m, m.pollOperations()
case completionMsg:
m.showProgress = false
m.showCompletion = true
if msg.success {
m.completionMessage = fmt.Sprintf("✅ %s", msg.message)
} else {
m.completionMessage = fmt.Sprintf("❌ %s", msg.message)
}
return m, m.pollOperations()
case operationCompleteMsg:
m.currentOperation = msg.operation
m.showProgress = false
if msg.success {
m.message = fmt.Sprintf("✅ Operation completed: %s", msg.operation.Message)
} else {
m.message = fmt.Sprintf("❌ Operation failed: %s", msg.operation.Message)
}
return m, m.pollOperations()
case spinner.TickMsg:
var cmd tea.Cmd
m.spinner, cmd = m.spinner.Update(msg)
return m, cmd
}
return m, nil
}
// View renders the enhanced menu with progress tracking
func (m MenuModel) View() string {
if m.quitting {
return "Thanks for using DB Backup Tool!\n"
}
var b strings.Builder
// Header
header := titleStyle.Render("🗄️ Database Backup Tool - Interactive Menu")
b.WriteString(fmt.Sprintf("\n%s\n\n", header))
// Database info
dbInfo := infoStyle.Render(fmt.Sprintf("Database: %s@%s:%d (%s)",
m.config.User, m.config.Host, m.config.Port, m.config.DatabaseType))
b.WriteString(fmt.Sprintf("%s\n\n", dbInfo))
// Menu items
for i, choice := range m.choices {
cursor := " "
if m.cursor == i {
cursor = ">"
b.WriteString(selectedStyle.Render(fmt.Sprintf("%s %s", cursor, choice)))
} else {
b.WriteString(menuStyle.Render(fmt.Sprintf("%s %s", cursor, choice)))
}
b.WriteString("\n")
}
// Current operation progress
if m.showProgress && m.currentOperation != nil {
b.WriteString("\n")
b.WriteString(m.renderOperationProgress(m.currentOperation))
b.WriteString("\n")
}
// Completion status (persistent until key press)
if m.showCompletion {
b.WriteString("\n")
b.WriteString(successStyle.Render(m.completionMessage))
b.WriteString("\n")
b.WriteString(infoStyle.Render("💡 Press any key to continue..."))
b.WriteString("\n")
}
// Message area
if m.message != "" && !m.showCompletion {
b.WriteString("\n")
b.WriteString(m.message)
b.WriteString("\n")
}
// Operations summary
if len(m.allOperations) > 0 {
b.WriteString("\n")
b.WriteString(m.renderOperationsSummary())
b.WriteString("\n")
}
// Footer
var footer string
if m.showCompletion {
footer = infoStyle.Render("\n⌨ Press Enter, ↑/↓ arrows, or Esc to continue...")
} else {
footer = infoStyle.Render("\n⌨ Press ↑/↓ to navigate • Enter to select • q to quit")
}
b.WriteString(footer)
return b.String()
}
// renderOperationProgress renders detailed progress for the current operation
func (m MenuModel) renderOperationProgress(op *progress.OperationStatus) string {
var b strings.Builder
// Operation header with spinner
spinnerView := ""
if op.Status == "running" {
spinnerView = m.spinner.View() + " "
}
status := "🔄"
if op.Status == "completed" {
status = "✅"
} else if op.Status == "failed" {
status = "❌"
}
b.WriteString(progressStyle.Render(fmt.Sprintf("%s%s %s [%d%%]",
spinnerView, status, strings.Title(op.Type), op.Progress)))
b.WriteString("\n")
// Progress bar
barWidth := 40
filledWidth := (op.Progress * barWidth) / 100
if filledWidth > barWidth {
filledWidth = barWidth
}
bar := strings.Repeat("█", filledWidth) + strings.Repeat("░", barWidth-filledWidth)
b.WriteString(detailStyle.Render(fmt.Sprintf("[%s] %s", bar, op.Message)))
b.WriteString("\n")
// Time and details
elapsed := time.Since(op.StartTime)
timeInfo := fmt.Sprintf("Elapsed: %s", formatDuration(elapsed))
if op.EndTime != nil {
timeInfo = fmt.Sprintf("Duration: %s", op.Duration.String())
}
b.WriteString(detailStyle.Render(timeInfo))
b.WriteString("\n")
// File/byte progress
if op.FilesTotal > 0 {
b.WriteString(detailStyle.Render(fmt.Sprintf("Files: %d/%d", op.FilesDone, op.FilesTotal)))
b.WriteString("\n")
}
if op.BytesTotal > 0 {
b.WriteString(detailStyle.Render(fmt.Sprintf("Data: %s/%s",
formatBytes(op.BytesDone), formatBytes(op.BytesTotal))))
b.WriteString("\n")
}
// Current steps
if len(op.Steps) > 0 {
b.WriteString(stepStyle.Render("Steps:"))
b.WriteString("\n")
for _, step := range op.Steps {
stepStatus := "⏳"
if step.Status == "completed" {
stepStatus = "✅"
} else if step.Status == "failed" {
stepStatus = "❌"
}
b.WriteString(detailStyle.Render(fmt.Sprintf(" %s %s", stepStatus, step.Name)))
b.WriteString("\n")
}
}
return b.String()
}
// renderOperationsSummary renders a summary of all operations
func (m MenuModel) renderOperationsSummary() string {
if len(m.allOperations) == 0 {
return ""
}
completed := 0
failed := 0
running := 0
for _, op := range m.allOperations {
switch op.Status {
case "completed":
completed++
case "failed":
failed++
case "running":
running++
}
}
summary := fmt.Sprintf("📊 Operations: %d total | %d completed | %d failed | %d running",
len(m.allOperations), completed, failed, running)
return infoStyle.Render(summary)
}
// Enhanced backup handlers with progress tracking
// Handle single database backup with progress
func (m MenuModel) handleSingleBackup() (tea.Model, tea.Cmd) {
if m.config.Database == "" {
m.message = errorStyle.Render("❌ No database specified. Use --database flag or set in config.")
return m, nil
}
m.message = progressStyle.Render(fmt.Sprintf("🔄 Starting single backup for: %s", m.config.Database))
m.showProgress = true
m.showCompletion = false
// Start backup and return polling command
go func() {
err := RunBackupInTUI(m.ctx, m.config, m.logger, "single", m.config.Database, m.progressReporter)
// The completion will be handled by the progress reporter callback system
_ = err // Handle error in the progress reporter
}()
return m, m.pollOperations()
}
// Handle sample backup with progress
func (m MenuModel) handleSampleBackup() (tea.Model, tea.Cmd) {
m.message = progressStyle.Render("🔄 Starting sample backup...")
m.showProgress = true
m.showCompletion = false
m.completionDismissed = false // Reset for new operation
// Start backup and return polling command
go func() {
err := RunBackupInTUI(m.ctx, m.config, m.logger, "sample", "", m.progressReporter)
// The completion will be handled by the progress reporter callback system
_ = err // Handle error in the progress reporter
}()
return m, m.pollOperations()
}
// Handle cluster backup with progress
func (m MenuModel) handleClusterBackup() (tea.Model, tea.Cmd) {
m.message = progressStyle.Render("🔄 Starting cluster backup (all databases)...")
m.showProgress = true
m.showCompletion = false
m.completionDismissed = false // Reset for new operation
// Start backup and return polling command
go func() {
err := RunBackupInTUI(m.ctx, m.config, m.logger, "cluster", "", m.progressReporter)
// The completion will be handled by the progress reporter callback system
_ = err // Handle error in the progress reporter
}()
return m, m.pollOperations()
}
// Handle viewing active operations
func (m MenuModel) handleViewOperations() (tea.Model, tea.Cmd) {
if len(m.allOperations) == 0 {
m.message = infoStyle.Render(" No operations currently running or completed")
return m, nil
}
var activeOps []progress.OperationStatus
for _, op := range m.allOperations {
if op.Status == "running" {
activeOps = append(activeOps, op)
}
}
if len(activeOps) == 0 {
m.message = infoStyle.Render(" No operations currently running")
} else {
m.message = progressStyle.Render(fmt.Sprintf("🔄 %d active operations", len(activeOps)))
}
return m, nil
}
// Handle showing operation history
func (m MenuModel) handleOperationHistory() (tea.Model, tea.Cmd) {
if len(m.allOperations) == 0 {
m.message = infoStyle.Render(" No operation history available")
return m, nil
}
var history strings.Builder
history.WriteString("📋 Operation History:\n")
for i, op := range m.allOperations {
if i >= 5 { // Show last 5 operations
break
}
status := "🔄"
if op.Status == "completed" {
status = "✅"
} else if op.Status == "failed" {
status = "❌"
}
history.WriteString(fmt.Sprintf("%s %s - %s (%s)\n",
status, op.Name, op.Type, op.StartTime.Format("15:04:05")))
}
m.message = history.String()
return m, nil
}
// Handle status check
func (m MenuModel) handleStatus() (tea.Model, tea.Cmd) {
db, err := database.New(m.config, m.logger)
if err != nil {
m.message = errorStyle.Render(fmt.Sprintf("❌ Connection failed: %v", err))
return m, nil
}
defer db.Close()
err = db.Connect(m.ctx)
if err != nil {
m.message = errorStyle.Render(fmt.Sprintf("❌ Connection failed: %v", err))
return m, nil
}
err = db.Ping(m.ctx)
if err != nil {
m.message = errorStyle.Render(fmt.Sprintf("❌ Ping failed: %v", err))
return m, nil
}
version, err := db.GetVersion(m.ctx)
if err != nil {
m.message = errorStyle.Render(fmt.Sprintf("❌ Failed to get version: %v", err))
return m, nil
}
m.message = successStyle.Render(fmt.Sprintf("✅ Connected successfully!\nVersion: %s", version))
return m, nil
}
// Handle settings display
func (m MenuModel) handleSettings() (tea.Model, tea.Cmd) {
// Create and switch to settings model
settingsModel := NewSettingsModel(m.config, m.logger, m)
return settingsModel, settingsModel.Init()
}
// Handle clearing operation history
func (m MenuModel) handleClearHistory() (tea.Model, tea.Cmd) {
m.allOperations = []progress.OperationStatus{}
m.currentOperation = nil
m.showProgress = false
m.message = successStyle.Render("✅ Operation history cleared")
return m, nil
}
// Utility functions
// formatDuration formats a duration in a human-readable way
func formatDuration(d time.Duration) string {
if d < time.Minute {
return fmt.Sprintf("%.1fs", d.Seconds())
} else if d < time.Hour {
return fmt.Sprintf("%.1fm", d.Minutes())
}
return fmt.Sprintf("%.1fh", d.Hours())
}
// formatBytes formats byte count in human-readable format
func formatBytes(bytes int64) string {
const unit = 1024
if bytes < unit {
return fmt.Sprintf("%d B", bytes)
}
div, exp := int64(unit), 0
for n := bytes / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(div), "KMGTPE"[exp])
}
// RunInteractiveMenu starts the enhanced TUI with progress tracking
func RunInteractiveMenu(cfg *config.Config, log logger.Logger) error {
m := NewMenuModel(cfg, log)
p := tea.NewProgram(m, tea.WithAltScreen())
if _, err := p.Run(); err != nil {
return fmt.Errorf("error running interactive menu: %w", err)
}
return nil
}

212
internal/tui/progress.go Normal file
View File

@ -0,0 +1,212 @@
package tui
import (
"context"
"fmt"
"sync"
"time"
"dbbackup/internal/backup"
"dbbackup/internal/config"
"dbbackup/internal/database"
"dbbackup/internal/logger"
"dbbackup/internal/progress"
)
// TUIProgressReporter is a progress reporter that integrates with the TUI
type TUIProgressReporter struct {
mu sync.RWMutex
operations map[string]*progress.OperationStatus
callbacks []func([]progress.OperationStatus)
}
// NewTUIProgressReporter creates a new TUI-compatible progress reporter
func NewTUIProgressReporter() *TUIProgressReporter {
return &TUIProgressReporter{
operations: make(map[string]*progress.OperationStatus),
callbacks: make([]func([]progress.OperationStatus), 0),
}
}
// AddCallback adds a callback function to be called when operations update
func (t *TUIProgressReporter) AddCallback(callback func([]progress.OperationStatus)) {
t.mu.Lock()
defer t.mu.Unlock()
t.callbacks = append(t.callbacks, callback)
}
// notifyCallbacks calls all registered callbacks with current operations
func (t *TUIProgressReporter) notifyCallbacks() {
operations := make([]progress.OperationStatus, 0, len(t.operations))
for _, op := range t.operations {
operations = append(operations, *op)
}
for _, callback := range t.callbacks {
go callback(operations)
}
}
// StartOperation starts tracking a new operation
func (t *TUIProgressReporter) StartOperation(id, name, opType string) *TUIOperationTracker {
t.mu.Lock()
defer t.mu.Unlock()
operation := &progress.OperationStatus{
ID: id,
Name: name,
Type: opType,
Status: "running",
StartTime: time.Now(),
Progress: 0,
Message: fmt.Sprintf("Starting %s: %s", opType, name),
Details: make(map[string]string),
Steps: make([]progress.StepStatus, 0),
}
t.operations[id] = operation
t.notifyCallbacks()
return &TUIOperationTracker{
reporter: t,
operationID: id,
}
}
// TUIOperationTracker tracks progress for TUI display
type TUIOperationTracker struct {
reporter *TUIProgressReporter
operationID string
}
// UpdateProgress updates the operation progress
func (t *TUIOperationTracker) UpdateProgress(progress int, message string) {
t.reporter.mu.Lock()
defer t.reporter.mu.Unlock()
if op, exists := t.reporter.operations[t.operationID]; exists {
op.Progress = progress
op.Message = message
t.reporter.notifyCallbacks()
}
}
// Complete marks the operation as completed
func (t *TUIOperationTracker) Complete(message string) {
t.reporter.mu.Lock()
defer t.reporter.mu.Unlock()
if op, exists := t.reporter.operations[t.operationID]; exists {
now := time.Now()
op.Status = "completed"
op.Progress = 100
op.Message = message
op.EndTime = &now
op.Duration = now.Sub(op.StartTime)
t.reporter.notifyCallbacks()
}
}
// Fail marks the operation as failed
func (t *TUIOperationTracker) Fail(message string) {
t.reporter.mu.Lock()
defer t.reporter.mu.Unlock()
if op, exists := t.reporter.operations[t.operationID]; exists {
now := time.Now()
op.Status = "failed"
op.Message = message
op.EndTime = &now
op.Duration = now.Sub(op.StartTime)
t.reporter.notifyCallbacks()
}
}
// GetOperations returns all current operations
func (t *TUIProgressReporter) GetOperations() []progress.OperationStatus {
t.mu.RLock()
defer t.mu.RUnlock()
operations := make([]progress.OperationStatus, 0, len(t.operations))
for _, op := range t.operations {
operations = append(operations, *op)
}
return operations
}
// SilentLogger implements logger.Logger but doesn't output anything
type SilentLogger struct{}
func (s *SilentLogger) Info(msg string, args ...any) {}
func (s *SilentLogger) Warn(msg string, args ...any) {}
func (s *SilentLogger) Error(msg string, args ...any) {}
func (s *SilentLogger) Debug(msg string, args ...any) {}
func (s *SilentLogger) Time(msg string, args ...any) {}
func (s *SilentLogger) StartOperation(name string) logger.OperationLogger {
return &SilentOperation{}
}
// SilentOperation implements logger.OperationLogger but doesn't output anything
type SilentOperation struct{}
func (s *SilentOperation) Update(message string, args ...any) {}
func (s *SilentOperation) Complete(message string, args ...any) {}
func (s *SilentOperation) Fail(message string, args ...any) {}
// SilentProgressIndicator implements progress.Indicator but doesn't output anything
type SilentProgressIndicator struct{}
func (s *SilentProgressIndicator) Start(message string) {}
func (s *SilentProgressIndicator) Update(message string) {}
func (s *SilentProgressIndicator) Complete(message string) {}
func (s *SilentProgressIndicator) Fail(message string) {}
func (s *SilentProgressIndicator) Stop() {}
// RunBackupInTUI runs a backup operation with TUI-compatible progress reporting
func RunBackupInTUI(ctx context.Context, cfg *config.Config, log logger.Logger,
backupType string, databaseName string, reporter *TUIProgressReporter) error {
// Create database connection
db, err := database.New(cfg, &SilentLogger{}) // Use silent logger
if err != nil {
return fmt.Errorf("failed to create database connection: %w", err)
}
defer db.Close()
err = db.Connect(ctx)
if err != nil {
return fmt.Errorf("failed to connect to database: %w", err)
}
// Create backup engine with silent progress indicator and logger
silentProgress := &SilentProgressIndicator{}
engine := backup.NewSilent(cfg, &SilentLogger{}, db, silentProgress)
// Start operation tracking
operationID := fmt.Sprintf("%s_%d", backupType, time.Now().Unix())
tracker := reporter.StartOperation(operationID, databaseName, backupType)
// Run the appropriate backup type
switch backupType {
case "single":
tracker.UpdateProgress(10, "Preparing single database backup...")
err = engine.BackupSingle(ctx, databaseName)
case "cluster":
tracker.UpdateProgress(10, "Preparing cluster backup...")
err = engine.BackupCluster(ctx)
case "sample":
tracker.UpdateProgress(10, "Preparing sample backup...")
err = engine.BackupSample(ctx, databaseName)
default:
err = fmt.Errorf("unknown backup type: %s", backupType)
}
// Update final status
if err != nil {
tracker.Fail(fmt.Sprintf("Backup failed: %v", err))
return err
} else {
tracker.Complete(fmt.Sprintf("%s backup completed successfully", backupType))
return nil
}
}

465
internal/tui/settings.go Normal file
View File

@ -0,0 +1,465 @@
package tui
import (
"fmt"
"path/filepath"
"strconv"
"strings"
tea "github.com/charmbracelet/bubbletea"
"dbbackup/internal/config"
"dbbackup/internal/logger"
)
// SettingsModel represents the settings configuration state
type SettingsModel struct {
config *config.Config
logger logger.Logger
cursor int
editing bool
editingField string
editingValue string
settings []SettingItem
quitting bool
message string
parent tea.Model
}
// SettingItem represents a configurable setting
type SettingItem struct {
Key string
DisplayName string
Value func(*config.Config) string
Update func(*config.Config, string) error
Type string // "string", "int", "bool", "path"
Description string
}
// Initialize settings model
func NewSettingsModel(cfg *config.Config, log logger.Logger, parent tea.Model) SettingsModel {
settings := []SettingItem{
{
Key: "backup_dir",
DisplayName: "Backup Directory",
Value: func(c *config.Config) string { return c.BackupDir },
Update: func(c *config.Config, v string) error {
if v == "" {
return fmt.Errorf("backup directory cannot be empty")
}
c.BackupDir = filepath.Clean(v)
return nil
},
Type: "path",
Description: "Directory where backup files will be stored",
},
{
Key: "compression_level",
DisplayName: "Compression Level",
Value: func(c *config.Config) string { return fmt.Sprintf("%d", c.CompressionLevel) },
Update: func(c *config.Config, v string) error {
val, err := strconv.Atoi(v)
if err != nil {
return fmt.Errorf("compression level must be a number")
}
if val < 0 || val > 9 {
return fmt.Errorf("compression level must be between 0-9")
}
c.CompressionLevel = val
return nil
},
Type: "int",
Description: "Compression level (0=fastest, 9=smallest)",
},
{
Key: "jobs",
DisplayName: "Parallel Jobs",
Value: func(c *config.Config) string { return fmt.Sprintf("%d", c.Jobs) },
Update: func(c *config.Config, v string) error {
val, err := strconv.Atoi(v)
if err != nil {
return fmt.Errorf("jobs must be a number")
}
if val < 1 {
return fmt.Errorf("jobs must be at least 1")
}
c.Jobs = val
return nil
},
Type: "int",
Description: "Number of parallel jobs for backup operations",
},
{
Key: "dump_jobs",
DisplayName: "Dump Jobs",
Value: func(c *config.Config) string { return fmt.Sprintf("%d", c.DumpJobs) },
Update: func(c *config.Config, v string) error {
val, err := strconv.Atoi(v)
if err != nil {
return fmt.Errorf("dump jobs must be a number")
}
if val < 1 {
return fmt.Errorf("dump jobs must be at least 1")
}
c.DumpJobs = val
return nil
},
Type: "int",
Description: "Number of parallel jobs for database dumps",
},
{
Key: "host",
DisplayName: "Database Host",
Value: func(c *config.Config) string { return c.Host },
Update: func(c *config.Config, v string) error {
if v == "" {
return fmt.Errorf("host cannot be empty")
}
c.Host = v
return nil
},
Type: "string",
Description: "Database server hostname or IP address",
},
{
Key: "port",
DisplayName: "Database Port",
Value: func(c *config.Config) string { return fmt.Sprintf("%d", c.Port) },
Update: func(c *config.Config, v string) error {
val, err := strconv.Atoi(v)
if err != nil {
return fmt.Errorf("port must be a number")
}
if val < 1 || val > 65535 {
return fmt.Errorf("port must be between 1-65535")
}
c.Port = val
return nil
},
Type: "int",
Description: "Database server port number",
},
{
Key: "user",
DisplayName: "Database User",
Value: func(c *config.Config) string { return c.User },
Update: func(c *config.Config, v string) error {
if v == "" {
return fmt.Errorf("user cannot be empty")
}
c.User = v
return nil
},
Type: "string",
Description: "Database username for connections",
},
{
Key: "database",
DisplayName: "Default Database",
Value: func(c *config.Config) string { return c.Database },
Update: func(c *config.Config, v string) error {
c.Database = v // Can be empty for cluster operations
return nil
},
Type: "string",
Description: "Default database name (optional)",
},
{
Key: "ssl_mode",
DisplayName: "SSL Mode",
Value: func(c *config.Config) string { return c.SSLMode },
Update: func(c *config.Config, v string) error {
validModes := []string{"disable", "allow", "prefer", "require", "verify-ca", "verify-full"}
for _, mode := range validModes {
if v == mode {
c.SSLMode = v
return nil
}
}
return fmt.Errorf("invalid SSL mode. Valid options: %s", strings.Join(validModes, ", "))
},
Type: "string",
Description: "SSL connection mode (disable, allow, prefer, require, verify-ca, verify-full)",
},
{
Key: "auto_detect_cores",
DisplayName: "Auto Detect CPU Cores",
Value: func(c *config.Config) string {
if c.AutoDetectCores { return "true" } else { return "false" }
},
Update: func(c *config.Config, v string) error {
val, err := strconv.ParseBool(v)
if err != nil {
return fmt.Errorf("must be true or false")
}
c.AutoDetectCores = val
return nil
},
Type: "bool",
Description: "Automatically detect and optimize for CPU cores",
},
}
return SettingsModel{
config: cfg,
logger: log,
settings: settings,
parent: parent,
}
}
// Init initializes the settings model
func (m SettingsModel) Init() tea.Cmd {
return nil
}
// Update handles messages
func (m SettingsModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.KeyMsg:
if m.editing {
return m.handleEditingInput(msg)
}
switch msg.String() {
case "ctrl+c", "q", "esc":
m.quitting = true
return m.parent, nil
case "up", "k":
if m.cursor > 0 {
m.cursor--
}
case "down", "j":
if m.cursor < len(m.settings)-1 {
m.cursor++
}
case "enter", " ":
return m.startEditing()
case "r":
return m.resetToDefaults()
case "s":
return m.saveSettings()
}
}
return m, nil
}
// handleEditingInput handles input when editing a setting
func (m SettingsModel) handleEditingInput(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
switch msg.String() {
case "ctrl+c":
m.quitting = true
return m.parent, nil
case "esc":
m.editing = false
m.editingField = ""
m.editingValue = ""
m.message = ""
return m, nil
case "enter":
return m.saveEditedValue()
case "backspace":
if len(m.editingValue) > 0 {
m.editingValue = m.editingValue[:len(m.editingValue)-1]
}
default:
// Add character to editing value
if len(msg.String()) == 1 {
m.editingValue += msg.String()
}
}
return m, nil
}
// startEditing begins editing a setting
func (m SettingsModel) startEditing() (tea.Model, tea.Cmd) {
if m.cursor >= len(m.settings) {
return m, nil
}
setting := m.settings[m.cursor]
m.editing = true
m.editingField = setting.Key
m.editingValue = setting.Value(m.config)
m.message = ""
return m, nil
}
// saveEditedValue saves the currently edited value
func (m SettingsModel) saveEditedValue() (tea.Model, tea.Cmd) {
if m.editingField == "" {
return m, nil
}
// Find the setting being edited
var setting *SettingItem
for i := range m.settings {
if m.settings[i].Key == m.editingField {
setting = &m.settings[i]
break
}
}
if setting == nil {
m.message = errorStyle.Render("❌ Setting not found")
m.editing = false
return m, nil
}
// Update the configuration
if err := setting.Update(m.config, m.editingValue); err != nil {
m.message = errorStyle.Render(fmt.Sprintf("❌ %s", err.Error()))
return m, nil
}
m.message = successStyle.Render(fmt.Sprintf("✅ Updated %s", setting.DisplayName))
m.editing = false
m.editingField = ""
m.editingValue = ""
return m, nil
}
// resetToDefaults resets configuration to default values
func (m SettingsModel) resetToDefaults() (tea.Model, tea.Cmd) {
newConfig := config.New()
// Copy important connection details
newConfig.Host = m.config.Host
newConfig.Port = m.config.Port
newConfig.User = m.config.User
newConfig.Database = m.config.Database
newConfig.DatabaseType = m.config.DatabaseType
*m.config = *newConfig
m.message = successStyle.Render("✅ Settings reset to defaults")
return m, nil
}
// saveSettings validates and saves current settings
func (m SettingsModel) saveSettings() (tea.Model, tea.Cmd) {
if err := m.config.Validate(); err != nil {
m.message = errorStyle.Render(fmt.Sprintf("❌ Validation failed: %s", err.Error()))
return m, nil
}
// Optimize CPU settings if auto-detect is enabled
if m.config.AutoDetectCores {
if err := m.config.OptimizeForCPU(); err != nil {
m.message = errorStyle.Render(fmt.Sprintf("❌ CPU optimization failed: %s", err.Error()))
return m, nil
}
}
m.message = successStyle.Render("✅ Settings validated and saved")
return m, nil
}
// View renders the settings interface
func (m SettingsModel) View() string {
if m.quitting {
return "Returning to main menu...\n"
}
var b strings.Builder
// Header
header := titleStyle.Render("⚙️ Configuration Settings")
b.WriteString(fmt.Sprintf("\n%s\n\n", header))
// Settings list
for i, setting := range m.settings {
cursor := " "
value := setting.Value(m.config)
if m.cursor == i {
cursor = ">"
if m.editing && m.editingField == setting.Key {
// Show editing interface
editValue := m.editingValue
if setting.Type == "bool" {
editValue += " (true/false)"
}
line := fmt.Sprintf("%s %s: %s", cursor, setting.DisplayName, editValue)
b.WriteString(selectedStyle.Render(line))
b.WriteString(" ✏️")
} else {
line := fmt.Sprintf("%s %s: %s", cursor, setting.DisplayName, value)
b.WriteString(selectedStyle.Render(line))
}
} else {
line := fmt.Sprintf("%s %s: %s", cursor, setting.DisplayName, value)
b.WriteString(menuStyle.Render(line))
}
b.WriteString("\n")
// Show description for selected item
if m.cursor == i && !m.editing {
desc := detailStyle.Render(fmt.Sprintf(" %s", setting.Description))
b.WriteString(desc)
b.WriteString("\n")
}
}
// Message area
if m.message != "" {
b.WriteString("\n")
b.WriteString(m.message)
b.WriteString("\n")
}
// Current configuration summary
if !m.editing {
b.WriteString("\n")
b.WriteString(infoStyle.Render("📋 Current Configuration:"))
b.WriteString("\n")
summary := []string{
fmt.Sprintf("Database: %s@%s:%d", m.config.User, m.config.Host, m.config.Port),
fmt.Sprintf("Backup Dir: %s", m.config.BackupDir),
fmt.Sprintf("Compression: Level %d", m.config.CompressionLevel),
fmt.Sprintf("Jobs: %d parallel, %d dump", m.config.Jobs, m.config.DumpJobs),
}
for _, line := range summary {
b.WriteString(detailStyle.Render(fmt.Sprintf(" %s", line)))
b.WriteString("\n")
}
}
// Footer with instructions
var footer string
if m.editing {
footer = infoStyle.Render("\n⌨ Type new value • Enter to save • Esc to cancel")
} else {
footer = infoStyle.Render("\n⌨ ↑/↓ navigate • Enter to edit • 's' save • 'r' reset • 'q' back to menu")
}
b.WriteString(footer)
return b.String()
}
// RunSettingsMenu starts the settings configuration interface
func RunSettingsMenu(cfg *config.Config, log logger.Logger, parent tea.Model) error {
m := NewSettingsModel(cfg, log, parent)
p := tea.NewProgram(m, tea.WithAltScreen())
if _, err := p.Run(); err != nil {
return fmt.Errorf("error running settings menu: %w", err)
}
return nil
}