Fix MySQL connection handling: socket detection, timeouts, localhost vs remote
This commit is contained in:
@ -4,9 +4,11 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"dbbackup/internal/config"
|
"dbbackup/internal/config"
|
||||||
"dbbackup/internal/logger"
|
"dbbackup/internal/logger"
|
||||||
@ -43,10 +45,10 @@ func (m *MySQL) Connect(ctx context.Context) error {
|
|||||||
// Configure connection pool
|
// Configure connection pool
|
||||||
db.SetMaxOpenConns(10)
|
db.SetMaxOpenConns(10)
|
||||||
db.SetMaxIdleConns(5)
|
db.SetMaxIdleConns(5)
|
||||||
db.SetConnMaxLifetime(0)
|
db.SetConnMaxLifetime(time.Hour) // Close connections after 1 hour
|
||||||
|
|
||||||
// Test connection
|
// Test connection with proper timeout
|
||||||
timeoutCtx, cancel := buildTimeout(ctx, 0)
|
timeoutCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
if err := db.PingContext(timeoutCtx); err != nil {
|
if err := db.PingContext(timeoutCtx); err != nil {
|
||||||
@ -237,10 +239,16 @@ func (m *MySQL) GetTableRowCount(ctx context.Context, database, table string) (i
|
|||||||
func (m *MySQL) BuildBackupCommand(database, outputFile string, options BackupOptions) []string {
|
func (m *MySQL) BuildBackupCommand(database, outputFile string, options BackupOptions) []string {
|
||||||
cmd := []string{"mysqldump"}
|
cmd := []string{"mysqldump"}
|
||||||
|
|
||||||
// Connection parameters
|
// Connection parameters - handle localhost vs remote differently
|
||||||
|
if m.cfg.Host == "" || m.cfg.Host == "localhost" {
|
||||||
|
// For localhost, use socket connection (don't specify host/port)
|
||||||
|
cmd = append(cmd, "-u", m.cfg.User)
|
||||||
|
} else {
|
||||||
|
// For remote hosts, use TCP/IP
|
||||||
cmd = append(cmd, "-h", m.cfg.Host)
|
cmd = append(cmd, "-h", m.cfg.Host)
|
||||||
cmd = append(cmd, "-P", strconv.Itoa(m.cfg.Port))
|
cmd = append(cmd, "-P", strconv.Itoa(m.cfg.Port))
|
||||||
cmd = append(cmd, "-u", m.cfg.User)
|
cmd = append(cmd, "-u", m.cfg.User)
|
||||||
|
}
|
||||||
|
|
||||||
if m.cfg.Password != "" {
|
if m.cfg.Password != "" {
|
||||||
cmd = append(cmd, "-p"+m.cfg.Password)
|
cmd = append(cmd, "-p"+m.cfg.Password)
|
||||||
@ -291,10 +299,16 @@ func (m *MySQL) BuildBackupCommand(database, outputFile string, options BackupOp
|
|||||||
func (m *MySQL) BuildRestoreCommand(database, inputFile string, options RestoreOptions) []string {
|
func (m *MySQL) BuildRestoreCommand(database, inputFile string, options RestoreOptions) []string {
|
||||||
cmd := []string{"mysql"}
|
cmd := []string{"mysql"}
|
||||||
|
|
||||||
// Connection parameters
|
// Connection parameters - handle localhost vs remote differently
|
||||||
|
if m.cfg.Host == "" || m.cfg.Host == "localhost" {
|
||||||
|
// For localhost, use socket connection (don't specify host/port)
|
||||||
|
cmd = append(cmd, "-u", m.cfg.User)
|
||||||
|
} else {
|
||||||
|
// For remote hosts, use TCP/IP
|
||||||
cmd = append(cmd, "-h", m.cfg.Host)
|
cmd = append(cmd, "-h", m.cfg.Host)
|
||||||
cmd = append(cmd, "-P", strconv.Itoa(m.cfg.Port))
|
cmd = append(cmd, "-P", strconv.Itoa(m.cfg.Port))
|
||||||
cmd = append(cmd, "-u", m.cfg.User)
|
cmd = append(cmd, "-u", m.cfg.User)
|
||||||
|
}
|
||||||
|
|
||||||
if m.cfg.Password != "" {
|
if m.cfg.Password != "" {
|
||||||
cmd = append(cmd, "-p"+m.cfg.Password)
|
cmd = append(cmd, "-p"+m.cfg.Password)
|
||||||
@ -364,27 +378,71 @@ func (m *MySQL) buildDSN() string {
|
|||||||
|
|
||||||
dsn += "@"
|
dsn += "@"
|
||||||
|
|
||||||
if m.cfg.Host != "" && m.cfg.Host != "localhost" {
|
// Handle localhost with Unix socket vs TCP/IP
|
||||||
|
if m.cfg.Host == "" || m.cfg.Host == "localhost" {
|
||||||
|
// Try common socket paths for localhost connections
|
||||||
|
socketPaths := []string{
|
||||||
|
"/run/mysqld/mysqld.sock",
|
||||||
|
"/var/run/mysqld/mysqld.sock",
|
||||||
|
"/tmp/mysql.sock",
|
||||||
|
"/var/lib/mysql/mysql.sock",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the first available socket path, fallback to TCP if none found
|
||||||
|
socketFound := false
|
||||||
|
for _, socketPath := range socketPaths {
|
||||||
|
if _, err := os.Stat(socketPath); err == nil {
|
||||||
|
dsn += "unix(" + socketPath + ")"
|
||||||
|
socketFound = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no socket found, use TCP localhost
|
||||||
|
if !socketFound {
|
||||||
|
dsn += "tcp(localhost:" + strconv.Itoa(m.cfg.Port) + ")"
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Use TCP/IP for remote connections
|
||||||
dsn += "tcp(" + m.cfg.Host + ":" + strconv.Itoa(m.cfg.Port) + ")"
|
dsn += "tcp(" + m.cfg.Host + ":" + strconv.Itoa(m.cfg.Port) + ")"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add database if specified
|
||||||
|
if m.cfg.Database != "" {
|
||||||
dsn += "/" + m.cfg.Database
|
dsn += "/" + m.cfg.Database
|
||||||
|
} else {
|
||||||
|
dsn += "/"
|
||||||
|
}
|
||||||
|
|
||||||
// Add connection parameters
|
// Add connection parameters
|
||||||
params := []string{}
|
params := []string{}
|
||||||
|
|
||||||
if !m.cfg.Insecure {
|
// Add timeout parameters
|
||||||
switch strings.ToLower(m.cfg.SSLMode) {
|
params = append(params, "timeout=30s")
|
||||||
|
params = append(params, "readTimeout=30s")
|
||||||
|
params = append(params, "writeTimeout=30s")
|
||||||
|
|
||||||
|
// SSL configuration
|
||||||
|
if m.cfg.Insecure {
|
||||||
|
params = append(params, "tls=false")
|
||||||
|
} else if mode := strings.ToLower(m.cfg.SSLMode); mode != "" {
|
||||||
|
switch mode {
|
||||||
case "require", "required":
|
case "require", "required":
|
||||||
params = append(params, "tls=true")
|
params = append(params, "tls=true")
|
||||||
case "verify-ca", "verify-full", "verify-identity":
|
case "verify-ca", "verify-full", "verify-identity":
|
||||||
params = append(params, "tls=preferred")
|
params = append(params, "tls=preferred")
|
||||||
|
case "disable", "disabled":
|
||||||
|
params = append(params, "tls=false")
|
||||||
|
default:
|
||||||
|
// Default to preferred for unknown SSL modes
|
||||||
|
params = append(params, "tls=preferred")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add charset
|
// Add charset and other connection parameters
|
||||||
params = append(params, "charset=utf8mb4")
|
params = append(params, "charset=utf8mb4")
|
||||||
params = append(params, "parseTime=true")
|
params = append(params, "parseTime=true")
|
||||||
|
params = append(params, "loc=Local")
|
||||||
|
|
||||||
if len(params) > 0 {
|
if len(params) > 0 {
|
||||||
dsn += "?" + strings.Join(params, "&")
|
dsn += "?" + strings.Join(params, "&")
|
||||||
|
|||||||
Reference in New Issue
Block a user