package database import ( "context" "database/sql" "fmt" "os" "os/exec" "strconv" "strings" "time" "dbbackup/internal/config" "dbbackup/internal/logger" ) // MySQL implements Database interface for MySQL type MySQL struct { baseDatabase } // NewMySQL creates a new MySQL database instance func NewMySQL(cfg *config.Config, log logger.Logger) *MySQL { return &MySQL{ baseDatabase: baseDatabase{ cfg: cfg, log: log, }, } } // Connect establishes a connection to MySQL func (m *MySQL) Connect(ctx context.Context) error { // Build MySQL DSN dsn := m.buildDSN() m.dsn = dsn m.log.Debug("Connecting to MySQL", "dsn", sanitizeMySQLDSN(dsn)) db, err := sql.Open("mysql", dsn) if err != nil { return fmt.Errorf("failed to open MySQL connection: %w", err) } // Configure connection pool db.SetMaxOpenConns(10) db.SetMaxIdleConns(5) db.SetConnMaxLifetime(time.Hour) // Close connections after 1 hour // Test connection with proper timeout timeoutCtx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() if err := db.PingContext(timeoutCtx); err != nil { db.Close() return fmt.Errorf("failed to ping MySQL: %w", err) } m.db = db m.log.Info("Connected to MySQL successfully") return nil } // ListDatabases returns list of databases (excluding system databases) func (m *MySQL) ListDatabases(ctx context.Context) ([]string, error) { if m.db == nil { return nil, fmt.Errorf("not connected to database") } query := `SHOW DATABASES` rows, err := m.db.QueryContext(ctx, query) if err != nil { return nil, fmt.Errorf("failed to query databases: %w", err) } defer rows.Close() var databases []string systemDbs := map[string]bool{ "information_schema": true, "performance_schema": true, "mysql": true, "sys": true, } for rows.Next() { var name string if err := rows.Scan(&name); err != nil { return nil, fmt.Errorf("failed to scan database name: %w", err) } // Skip system databases if !systemDbs[name] { databases = append(databases, name) } } return databases, rows.Err() } // ListTables returns list of tables in a database func (m *MySQL) ListTables(ctx context.Context, database string) ([]string, error) { if m.db == nil { return nil, fmt.Errorf("not connected to database") } query := `SELECT table_name FROM information_schema.tables WHERE table_schema = ? AND table_type = 'BASE TABLE' ORDER BY table_name` rows, err := m.db.QueryContext(ctx, query, database) if err != nil { return nil, fmt.Errorf("failed to query tables: %w", err) } defer rows.Close() var tables []string for rows.Next() { var name string if err := rows.Scan(&name); err != nil { return nil, fmt.Errorf("failed to scan table name: %w", err) } tables = append(tables, name) } return tables, rows.Err() } // CreateDatabase creates a new database func (m *MySQL) CreateDatabase(ctx context.Context, name string) error { if m.db == nil { return fmt.Errorf("not connected to database") } query := fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `%s`", name) _, err := m.db.ExecContext(ctx, query) if err != nil { return fmt.Errorf("failed to create database %s: %w", name, err) } m.log.Info("Created database", "name", name) return nil } // DropDatabase drops a database func (m *MySQL) DropDatabase(ctx context.Context, name string) error { if m.db == nil { return fmt.Errorf("not connected to database") } query := fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", name) _, err := m.db.ExecContext(ctx, query) if err != nil { return fmt.Errorf("failed to drop database %s: %w", name, err) } m.log.Info("Dropped database", "name", name) return nil } // DatabaseExists checks if a database exists func (m *MySQL) DatabaseExists(ctx context.Context, name string) (bool, error) { if m.db == nil { return false, fmt.Errorf("not connected to database") } query := `SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?` var dbName string err := m.db.QueryRowContext(ctx, query, name).Scan(&dbName) if err == sql.ErrNoRows { return false, nil } if err != nil { return false, fmt.Errorf("failed to check database existence: %w", err) } return true, nil } // GetVersion returns MySQL version func (m *MySQL) GetVersion(ctx context.Context) (string, error) { if m.db == nil { return "", fmt.Errorf("not connected to database") } var version string err := m.db.QueryRowContext(ctx, "SELECT VERSION()").Scan(&version) if err != nil { return "", fmt.Errorf("failed to get version: %w", err) } return version, nil } // GetDatabaseSize returns database size in bytes func (m *MySQL) GetDatabaseSize(ctx context.Context, database string) (int64, error) { if m.db == nil { return 0, fmt.Errorf("not connected to database") } query := `SELECT COALESCE(SUM(data_length + index_length), 0) as size_bytes FROM information_schema.tables WHERE table_schema = ?` var size int64 err := m.db.QueryRowContext(ctx, query, database).Scan(&size) if err != nil { return 0, fmt.Errorf("failed to get database size: %w", err) } return size, nil } // GetTableRowCount returns row count for a table func (m *MySQL) GetTableRowCount(ctx context.Context, database, table string) (int64, error) { if m.db == nil { return 0, fmt.Errorf("not connected to database") } // First try information_schema for approximate count (faster) query := `SELECT table_rows FROM information_schema.tables WHERE table_schema = ? AND table_name = ?` var count int64 err := m.db.QueryRowContext(ctx, query, database, table).Scan(&count) if err != nil || count == 0 { // Fallback to exact count exactQuery := fmt.Sprintf("SELECT COUNT(*) FROM `%s`.`%s`", database, table) err = m.db.QueryRowContext(ctx, exactQuery).Scan(&count) if err != nil { return 0, fmt.Errorf("failed to get table row count: %w", err) } } return count, nil } // BuildBackupCommand builds mysqldump command func (m *MySQL) BuildBackupCommand(database, outputFile string, options BackupOptions) []string { cmd := []string{"mysqldump"} // Connection parameters - handle localhost vs remote differently if m.cfg.Host == "" || m.cfg.Host == "localhost" { // For localhost, use socket connection (don't specify host/port) cmd = append(cmd, "-u", m.cfg.User) } else { // For remote hosts, use TCP/IP cmd = append(cmd, "-h", m.cfg.Host) cmd = append(cmd, "-P", strconv.Itoa(m.cfg.Port)) cmd = append(cmd, "-u", m.cfg.User) } if m.cfg.Password != "" { cmd = append(cmd, "-p"+m.cfg.Password) } // SSL options if m.cfg.Insecure { cmd = append(cmd, "--skip-ssl") } else if mode := strings.ToLower(m.cfg.SSLMode); mode != "" { switch mode { case "require", "required": cmd = append(cmd, "--ssl-mode=REQUIRED") case "verify-ca": cmd = append(cmd, "--ssl-mode=VERIFY_CA") case "verify-full", "verify-identity": cmd = append(cmd, "--ssl-mode=VERIFY_IDENTITY") case "disable", "disabled": cmd = append(cmd, "--skip-ssl") } } // Backup options cmd = append(cmd, "--single-transaction") // Consistent backup cmd = append(cmd, "--routines") // Include stored procedures/functions cmd = append(cmd, "--triggers") // Include triggers cmd = append(cmd, "--events") // Include events if options.SchemaOnly { cmd = append(cmd, "--no-data") } else if options.DataOnly { cmd = append(cmd, "--no-create-info") } if options.NoOwner || options.NoPrivileges { cmd = append(cmd, "--skip-add-drop-table") } // Compression (handled externally for MySQL) // Output redirection will be handled by caller // Database cmd = append(cmd, database) return cmd } // BuildRestoreCommand builds mysql restore command func (m *MySQL) BuildRestoreCommand(database, inputFile string, options RestoreOptions) []string { cmd := []string{"mysql"} // Connection parameters - handle localhost vs remote differently if m.cfg.Host == "" || m.cfg.Host == "localhost" { // For localhost, use socket connection (don't specify host/port) cmd = append(cmd, "-u", m.cfg.User) } else { // For remote hosts, use TCP/IP cmd = append(cmd, "-h", m.cfg.Host) cmd = append(cmd, "-P", strconv.Itoa(m.cfg.Port)) cmd = append(cmd, "-u", m.cfg.User) } if m.cfg.Password != "" { cmd = append(cmd, "-p"+m.cfg.Password) } // SSL options if m.cfg.Insecure { cmd = append(cmd, "--skip-ssl") } // Options if options.SingleTransaction { cmd = append(cmd, "--single-transaction") } // Database cmd = append(cmd, database) // Input file (will be handled via stdin redirection) return cmd } // BuildSampleQuery builds SQL query for sampling data func (m *MySQL) BuildSampleQuery(database, table string, strategy SampleStrategy) string { switch strategy.Type { case "ratio": // Every Nth record using row_number (MySQL 8.0+) or modulo return fmt.Sprintf("SELECT * FROM (SELECT *, (@row_number:=@row_number + 1) AS rn FROM %s.%s CROSS JOIN (SELECT @row_number:=0) AS t) AS numbered WHERE rn %% %d = 1", database, table, strategy.Value) case "percent": // Percentage sampling using RAND() return fmt.Sprintf("SELECT * FROM %s.%s WHERE RAND() <= %f", database, table, float64(strategy.Value)/100.0) case "count": // First N records return fmt.Sprintf("SELECT * FROM %s.%s LIMIT %d", database, table, strategy.Value) default: return fmt.Sprintf("SELECT * FROM %s.%s LIMIT 1000", database, table) } } // ValidateBackupTools checks if required MySQL tools are available func (m *MySQL) ValidateBackupTools() error { tools := []string{"mysqldump", "mysql"} for _, tool := range tools { if _, err := exec.LookPath(tool); err != nil { return fmt.Errorf("required tool not found: %s", tool) } } return nil } // buildDSN constructs MySQL connection string func (m *MySQL) buildDSN() string { dsn := "" if m.cfg.User != "" { dsn += m.cfg.User } if m.cfg.Password != "" { dsn += ":" + m.cfg.Password } dsn += "@" // Handle localhost with Unix socket vs TCP/IP if m.cfg.Host == "" || m.cfg.Host == "localhost" { // Try common socket paths for localhost connections socketPaths := []string{ "/run/mysqld/mysqld.sock", "/var/run/mysqld/mysqld.sock", "/tmp/mysql.sock", "/var/lib/mysql/mysql.sock", } // Use the first available socket path, fallback to TCP if none found socketFound := false for _, socketPath := range socketPaths { if _, err := os.Stat(socketPath); err == nil { dsn += "unix(" + socketPath + ")" socketFound = true break } } // If no socket found, use TCP localhost if !socketFound { dsn += "tcp(localhost:" + strconv.Itoa(m.cfg.Port) + ")" } } else { // Use TCP/IP for remote connections dsn += "tcp(" + m.cfg.Host + ":" + strconv.Itoa(m.cfg.Port) + ")" } // Add database if specified if m.cfg.Database != "" { dsn += "/" + m.cfg.Database } else { dsn += "/" } // Add connection parameters params := []string{} // Add timeout parameters params = append(params, "timeout=30s") params = append(params, "readTimeout=30s") params = append(params, "writeTimeout=30s") // SSL configuration if m.cfg.Insecure { params = append(params, "tls=false") } else if mode := strings.ToLower(m.cfg.SSLMode); mode != "" { switch mode { case "require", "required": params = append(params, "tls=true") case "verify-ca", "verify-full", "verify-identity": params = append(params, "tls=preferred") case "disable", "disabled": params = append(params, "tls=false") default: // Default to preferred for unknown SSL modes params = append(params, "tls=preferred") } } // Add charset and other connection parameters params = append(params, "charset=utf8mb4") params = append(params, "parseTime=true") params = append(params, "loc=Local") if len(params) > 0 { dsn += "?" + strings.Join(params, "&") } return dsn } // sanitizeMySQLDSN removes password from DSN for logging func sanitizeMySQLDSN(dsn string) string { // Find password part and replace it if idx := strings.Index(dsn, ":"); idx != -1 { if endIdx := strings.Index(dsn[idx:], "@"); endIdx != -1 { return dsn[:idx] + ":***" + dsn[idx+endIdx:] } } return dsn }