From c352eb1f1bf818d472de7a27adfe5958f5cd94bf Mon Sep 17 00:00:00 2001 From: Renz Date: Sat, 25 Oct 2025 10:33:29 +0000 Subject: [PATCH] Fix MySQL connection handling: socket detection, timeouts, localhost vs remote --- internal/database/mysql.go | 90 +++++++++++++++++++++++++++++++------- 1 file changed, 74 insertions(+), 16 deletions(-) diff --git a/internal/database/mysql.go b/internal/database/mysql.go index f57abe7..68045e5 100644 --- a/internal/database/mysql.go +++ b/internal/database/mysql.go @@ -4,9 +4,11 @@ import ( "context" "database/sql" "fmt" + "os" "os/exec" "strconv" "strings" + "time" "dbbackup/internal/config" "dbbackup/internal/logger" @@ -43,10 +45,10 @@ func (m *MySQL) Connect(ctx context.Context) error { // Configure connection pool db.SetMaxOpenConns(10) db.SetMaxIdleConns(5) - db.SetConnMaxLifetime(0) + db.SetConnMaxLifetime(time.Hour) // Close connections after 1 hour - // Test connection - timeoutCtx, cancel := buildTimeout(ctx, 0) + // Test connection with proper timeout + timeoutCtx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() if err := db.PingContext(timeoutCtx); err != nil { @@ -237,10 +239,16 @@ func (m *MySQL) GetTableRowCount(ctx context.Context, database, table string) (i func (m *MySQL) BuildBackupCommand(database, outputFile string, options BackupOptions) []string { cmd := []string{"mysqldump"} - // Connection parameters - cmd = append(cmd, "-h", m.cfg.Host) - cmd = append(cmd, "-P", strconv.Itoa(m.cfg.Port)) - cmd = append(cmd, "-u", m.cfg.User) + // Connection parameters - handle localhost vs remote differently + if m.cfg.Host == "" || m.cfg.Host == "localhost" { + // For localhost, use socket connection (don't specify host/port) + cmd = append(cmd, "-u", m.cfg.User) + } else { + // For remote hosts, use TCP/IP + cmd = append(cmd, "-h", m.cfg.Host) + cmd = append(cmd, "-P", strconv.Itoa(m.cfg.Port)) + cmd = append(cmd, "-u", m.cfg.User) + } if m.cfg.Password != "" { cmd = append(cmd, "-p"+m.cfg.Password) @@ -291,10 +299,16 @@ func (m *MySQL) BuildBackupCommand(database, outputFile string, options BackupOp func (m *MySQL) BuildRestoreCommand(database, inputFile string, options RestoreOptions) []string { cmd := []string{"mysql"} - // Connection parameters - cmd = append(cmd, "-h", m.cfg.Host) - cmd = append(cmd, "-P", strconv.Itoa(m.cfg.Port)) - cmd = append(cmd, "-u", m.cfg.User) + // Connection parameters - handle localhost vs remote differently + if m.cfg.Host == "" || m.cfg.Host == "localhost" { + // For localhost, use socket connection (don't specify host/port) + cmd = append(cmd, "-u", m.cfg.User) + } else { + // For remote hosts, use TCP/IP + cmd = append(cmd, "-h", m.cfg.Host) + cmd = append(cmd, "-P", strconv.Itoa(m.cfg.Port)) + cmd = append(cmd, "-u", m.cfg.User) + } if m.cfg.Password != "" { cmd = append(cmd, "-p"+m.cfg.Password) @@ -364,27 +378,71 @@ func (m *MySQL) buildDSN() string { dsn += "@" - if m.cfg.Host != "" && m.cfg.Host != "localhost" { + // Handle localhost with Unix socket vs TCP/IP + if m.cfg.Host == "" || m.cfg.Host == "localhost" { + // Try common socket paths for localhost connections + socketPaths := []string{ + "/run/mysqld/mysqld.sock", + "/var/run/mysqld/mysqld.sock", + "/tmp/mysql.sock", + "/var/lib/mysql/mysql.sock", + } + + // Use the first available socket path, fallback to TCP if none found + socketFound := false + for _, socketPath := range socketPaths { + if _, err := os.Stat(socketPath); err == nil { + dsn += "unix(" + socketPath + ")" + socketFound = true + break + } + } + + // If no socket found, use TCP localhost + if !socketFound { + dsn += "tcp(localhost:" + strconv.Itoa(m.cfg.Port) + ")" + } + } else { + // Use TCP/IP for remote connections dsn += "tcp(" + m.cfg.Host + ":" + strconv.Itoa(m.cfg.Port) + ")" } - dsn += "/" + m.cfg.Database + // Add database if specified + if m.cfg.Database != "" { + dsn += "/" + m.cfg.Database + } else { + dsn += "/" + } // Add connection parameters params := []string{} - if !m.cfg.Insecure { - switch strings.ToLower(m.cfg.SSLMode) { + // Add timeout parameters + params = append(params, "timeout=30s") + params = append(params, "readTimeout=30s") + params = append(params, "writeTimeout=30s") + + // SSL configuration + if m.cfg.Insecure { + params = append(params, "tls=false") + } else if mode := strings.ToLower(m.cfg.SSLMode); mode != "" { + switch mode { case "require", "required": params = append(params, "tls=true") case "verify-ca", "verify-full", "verify-identity": params = append(params, "tls=preferred") + case "disable", "disabled": + params = append(params, "tls=false") + default: + // Default to preferred for unknown SSL modes + params = append(params, "tls=preferred") } } - // Add charset + // Add charset and other connection parameters params = append(params, "charset=utf8mb4") params = append(params, "parseTime=true") + params = append(params, "loc=Local") if len(params) > 0 { dsn += "?" + strings.Join(params, "&")