diff --git a/internal/restore/safety.go b/internal/restore/safety.go index d478aa5..6cc3a8b 100644 --- a/internal/restore/safety.go +++ b/internal/restore/safety.go @@ -297,16 +297,24 @@ func (s *Safety) CheckDatabaseExists(ctx context.Context, dbName string) (bool, // checkPostgresDatabaseExists checks if PostgreSQL database exists func (s *Safety) checkPostgresDatabaseExists(ctx context.Context, dbName string) (bool, error) { - cmd := exec.CommandContext(ctx, - "psql", - "-h", s.cfg.Host, + args := []string{ "-p", fmt.Sprintf("%d", s.cfg.Port), "-U", s.cfg.User, "-d", "postgres", "-tAc", fmt.Sprintf("SELECT 1 FROM pg_database WHERE datname='%s'", dbName), - ) + } + + // Only add -h flag if host is not localhost (to use Unix socket for peer auth) + if s.cfg.Host != "localhost" && s.cfg.Host != "127.0.0.1" && s.cfg.Host != "" { + args = append([]string{"-h", s.cfg.Host}, args...) + } + + cmd := exec.CommandContext(ctx, "psql", args...) - cmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", s.cfg.Password)) + // Set password if provided + if s.cfg.Password != "" { + cmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", s.cfg.Password)) + } output, err := cmd.Output() if err != nil { @@ -354,17 +362,25 @@ func (s *Safety) listPostgresUserDatabases(ctx context.Context) ([]string, error // Query to get non-template databases excluding 'postgres' system DB query := "SELECT datname FROM pg_database WHERE datistemplate = false AND datname != 'postgres' ORDER BY datname" - cmd := exec.CommandContext(ctx, - "psql", - "-h", s.cfg.Host, + args := []string{ "-p", fmt.Sprintf("%d", s.cfg.Port), "-U", s.cfg.User, "-d", "postgres", "-tA", // Tuples only, unaligned "-c", query, - ) + } + + // Only add -h flag if host is not localhost (to use Unix socket for peer auth) + if s.cfg.Host != "localhost" && s.cfg.Host != "127.0.0.1" && s.cfg.Host != "" { + args = append([]string{"-h", s.cfg.Host}, args...) + } + + cmd := exec.CommandContext(ctx, "psql", args...) - cmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", s.cfg.Password)) + // Set password if provided + if s.cfg.Password != "" { + cmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", s.cfg.Password)) + } output, err := cmd.Output() if err != nil {