diff --git a/cmd/placeholder.go b/cmd/placeholder.go index 00223ab..66e12a2 100644 --- a/cmd/placeholder.go +++ b/cmd/placeholder.go @@ -730,12 +730,17 @@ func containsSQLKeywords(content string) bool { } func mysqlRestoreCommand(archivePath string, compressed bool) string { - parts := []string{ - "mysql", - "-h", cfg.Host, + parts := []string{"mysql"} + + // Only add -h flag if host is not localhost (to use Unix socket) + if cfg.Host != "localhost" && cfg.Host != "127.0.0.1" && cfg.Host != "" { + parts = append(parts, "-h", cfg.Host) + } + + parts = append(parts, "-P", fmt.Sprintf("%d", cfg.Port), "-u", cfg.User, - } + ) if cfg.Password != "" { parts = append(parts, fmt.Sprintf("-p'%s'", cfg.Password)) diff --git a/internal/restore/safety.go b/internal/restore/safety.go index 6cc3a8b..9fa1de7 100644 --- a/internal/restore/safety.go +++ b/internal/restore/safety.go @@ -326,13 +326,18 @@ func (s *Safety) checkPostgresDatabaseExists(ctx context.Context, dbName string) // checkMySQLDatabaseExists checks if MySQL database exists func (s *Safety) checkMySQLDatabaseExists(ctx context.Context, dbName string) (bool, error) { - cmd := exec.CommandContext(ctx, - "mysql", - "-h", s.cfg.Host, + args := []string{ "-P", fmt.Sprintf("%d", s.cfg.Port), "-u", s.cfg.User, "-e", fmt.Sprintf("SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME='%s'", dbName), - ) + } + + // Only add -h flag if host is not localhost (to use Unix socket) + if s.cfg.Host != "localhost" && s.cfg.Host != "127.0.0.1" && s.cfg.Host != "" { + args = append([]string{"-h", s.cfg.Host}, args...) + } + + cmd := exec.CommandContext(ctx, "mysql", args...) if s.cfg.Password != "" { cmd.Env = append(os.Environ(), fmt.Sprintf("MYSQL_PWD=%s", s.cfg.Password)) @@ -405,14 +410,19 @@ func (s *Safety) listMySQLUserDatabases(ctx context.Context) ([]string, error) { // Exclude system databases query := "SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME NOT IN ('information_schema', 'mysql', 'performance_schema', 'sys') ORDER BY SCHEMA_NAME" - cmd := exec.CommandContext(ctx, - "mysql", - "-h", s.cfg.Host, + args := []string{ "-P", fmt.Sprintf("%d", s.cfg.Port), "-u", s.cfg.User, "-N", // Skip column names "-e", query, - ) + } + + // Only add -h flag if host is not localhost (to use Unix socket) + if s.cfg.Host != "localhost" && s.cfg.Host != "127.0.0.1" && s.cfg.Host != "" { + args = append([]string{"-h", s.cfg.Host}, args...) + } + + cmd := exec.CommandContext(ctx, "mysql", args...) if s.cfg.Password != "" { cmd.Env = append(os.Environ(), fmt.Sprintf("MYSQL_PWD=%s", s.cfg.Password))