Fix: MySQL/MariaDB socket authentication - remove hardcoded -h flag for localhost

Issue: MySQL/MariaDB functions always used '-h hostname' flag, which can cause
issues with Unix socket authentication when connecting to localhost.

Similar to PostgreSQL peer authentication, MySQL prefers Unix socket connections
for localhost rather than TCP connections. Using '-h localhost' forces TCP which
may fail with socket-based authentication configurations.

Fixed locations:
1. internal/restore/safety.go:
   - checkMySQLDatabaseExists() - now conditionally adds -h flag
   - listMySQLUserDatabases() - now conditionally adds -h flag

2. cmd/placeholder.go:
   - mysqlRestoreCommand() - now conditionally adds -h flag

Pattern applied (consistent with PostgreSQL fixes):
- Skip -h flag when host is localhost, 127.0.0.1, or empty
- Only add -h flag for actual remote hosts
- Allows mysql client to use Unix socket connection for local access

This ensures MySQL/MariaDB operations work correctly with both:
- Socket authentication (localhost via Unix socket)
- Password authentication (remote hosts via TCP)
This commit is contained in:
2025-11-12 08:55:06 +00:00
parent 98f483ae11
commit eb3e5c0135
2 changed files with 27 additions and 12 deletions

View File

@@ -730,12 +730,17 @@ func containsSQLKeywords(content string) bool {
} }
func mysqlRestoreCommand(archivePath string, compressed bool) string { func mysqlRestoreCommand(archivePath string, compressed bool) string {
parts := []string{ parts := []string{"mysql"}
"mysql",
"-h", cfg.Host, // Only add -h flag if host is not localhost (to use Unix socket)
if cfg.Host != "localhost" && cfg.Host != "127.0.0.1" && cfg.Host != "" {
parts = append(parts, "-h", cfg.Host)
}
parts = append(parts,
"-P", fmt.Sprintf("%d", cfg.Port), "-P", fmt.Sprintf("%d", cfg.Port),
"-u", cfg.User, "-u", cfg.User,
} )
if cfg.Password != "" { if cfg.Password != "" {
parts = append(parts, fmt.Sprintf("-p'%s'", cfg.Password)) parts = append(parts, fmt.Sprintf("-p'%s'", cfg.Password))

View File

@@ -326,13 +326,18 @@ func (s *Safety) checkPostgresDatabaseExists(ctx context.Context, dbName string)
// checkMySQLDatabaseExists checks if MySQL database exists // checkMySQLDatabaseExists checks if MySQL database exists
func (s *Safety) checkMySQLDatabaseExists(ctx context.Context, dbName string) (bool, error) { func (s *Safety) checkMySQLDatabaseExists(ctx context.Context, dbName string) (bool, error) {
cmd := exec.CommandContext(ctx, args := []string{
"mysql",
"-h", s.cfg.Host,
"-P", fmt.Sprintf("%d", s.cfg.Port), "-P", fmt.Sprintf("%d", s.cfg.Port),
"-u", s.cfg.User, "-u", s.cfg.User,
"-e", fmt.Sprintf("SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME='%s'", dbName), "-e", fmt.Sprintf("SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME='%s'", dbName),
) }
// Only add -h flag if host is not localhost (to use Unix socket)
if s.cfg.Host != "localhost" && s.cfg.Host != "127.0.0.1" && s.cfg.Host != "" {
args = append([]string{"-h", s.cfg.Host}, args...)
}
cmd := exec.CommandContext(ctx, "mysql", args...)
if s.cfg.Password != "" { if s.cfg.Password != "" {
cmd.Env = append(os.Environ(), fmt.Sprintf("MYSQL_PWD=%s", s.cfg.Password)) cmd.Env = append(os.Environ(), fmt.Sprintf("MYSQL_PWD=%s", s.cfg.Password))
@@ -405,14 +410,19 @@ func (s *Safety) listMySQLUserDatabases(ctx context.Context) ([]string, error) {
// Exclude system databases // Exclude system databases
query := "SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME NOT IN ('information_schema', 'mysql', 'performance_schema', 'sys') ORDER BY SCHEMA_NAME" query := "SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME NOT IN ('information_schema', 'mysql', 'performance_schema', 'sys') ORDER BY SCHEMA_NAME"
cmd := exec.CommandContext(ctx, args := []string{
"mysql",
"-h", s.cfg.Host,
"-P", fmt.Sprintf("%d", s.cfg.Port), "-P", fmt.Sprintf("%d", s.cfg.Port),
"-u", s.cfg.User, "-u", s.cfg.User,
"-N", // Skip column names "-N", // Skip column names
"-e", query, "-e", query,
) }
// Only add -h flag if host is not localhost (to use Unix socket)
if s.cfg.Host != "localhost" && s.cfg.Host != "127.0.0.1" && s.cfg.Host != "" {
args = append([]string{"-h", s.cfg.Host}, args...)
}
cmd := exec.CommandContext(ctx, "mysql", args...)
if s.cfg.Password != "" { if s.cfg.Password != "" {
cmd.Env = append(os.Environ(), fmt.Sprintf("MYSQL_PWD=%s", s.cfg.Password)) cmd.Env = append(os.Environ(), fmt.Sprintf("MYSQL_PWD=%s", s.cfg.Password))