diff --git a/CHANGELOG.md b/CHANGELOG.md index 1fa6f3b..8f891e7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Provides actionable error messages with root cause ### Fixed +- **P0: SQL Injection** - Added identifier validation for database names in CREATE/DROP DATABASE to prevent SQL injection attacks; uses safe quoting and regex validation (alphanumeric + underscore only) +- **P0: Data Race** - Fixed concurrent goroutines appending to shared error slice in notification manager; now uses mutex synchronization - **P0: psql ON_ERROR_STOP** - Added `-v ON_ERROR_STOP=1` to psql commands to fail fast on first error instead of accumulating millions of errors - **P1: Pipe deadlock** - Fixed streaming compression deadlock when pg_dump blocks on full pipe buffer; now uses goroutine with proper context timeout handling - **P1: SIGPIPE handling** - Detect exit code 141 (broken pipe) and report compressor failure as root cause diff --git a/internal/database/mysql.go b/internal/database/mysql.go index c3d771b..c67813a 100755 --- a/internal/database/mysql.go +++ b/internal/database/mysql.go @@ -126,13 +126,46 @@ func (m *MySQL) ListTables(ctx context.Context, database string) ([]string, erro return tables, rows.Err() } +// validateMySQLIdentifier checks if a database/table name is safe for use in SQL +// Prevents SQL injection by only allowing alphanumeric names with underscores +func validateMySQLIdentifier(name string) error { + if len(name) == 0 { + return fmt.Errorf("identifier cannot be empty") + } + if len(name) > 64 { + return fmt.Errorf("identifier too long (max 64 chars): %s", name) + } + // Only allow alphanumeric, underscores, and must start with letter or underscore + for i, c := range name { + if i == 0 && !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c == '_') { + return fmt.Errorf("identifier must start with letter or underscore: %s", name) + } + if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_') { + return fmt.Errorf("identifier contains invalid character %q: %s", c, name) + } + } + return nil +} + +// quoteMySQLIdentifier safely quotes a MySQL identifier +func quoteMySQLIdentifier(name string) string { + // Escape any backticks by doubling them and wrap in backticks + return "`" + strings.ReplaceAll(name, "`", "``") + "`" +} + // CreateDatabase creates a new database func (m *MySQL) CreateDatabase(ctx context.Context, name string) error { if m.db == nil { return fmt.Errorf("not connected to database") } - query := fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `%s`", name) + // Validate identifier to prevent SQL injection + if err := validateMySQLIdentifier(name); err != nil { + return fmt.Errorf("invalid database name: %w", err) + } + + // Use safe quoting for identifier + query := fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", quoteMySQLIdentifier(name)) _, err := m.db.ExecContext(ctx, query) if err != nil { return fmt.Errorf("failed to create database %s: %w", name, err) @@ -148,7 +181,13 @@ func (m *MySQL) DropDatabase(ctx context.Context, name string) error { return fmt.Errorf("not connected to database") } - query := fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", name) + // Validate identifier to prevent SQL injection + if err := validateMySQLIdentifier(name); err != nil { + return fmt.Errorf("invalid database name: %w", err) + } + + // Use safe quoting for identifier + query := fmt.Sprintf("DROP DATABASE IF EXISTS %s", quoteMySQLIdentifier(name)) _, err := m.db.ExecContext(ctx, query) if err != nil { return fmt.Errorf("failed to drop database %s: %w", name, err) diff --git a/internal/database/postgresql.go b/internal/database/postgresql.go index f60bafd..921b5c1 100755 --- a/internal/database/postgresql.go +++ b/internal/database/postgresql.go @@ -163,14 +163,47 @@ func (p *PostgreSQL) ListTables(ctx context.Context, database string) ([]string, return tables, rows.Err() } +// validateIdentifier checks if a database/table name is safe for use in SQL +// Prevents SQL injection by only allowing alphanumeric names with underscores +func validateIdentifier(name string) error { + if len(name) == 0 { + return fmt.Errorf("identifier cannot be empty") + } + if len(name) > 63 { + return fmt.Errorf("identifier too long (max 63 chars): %s", name) + } + // Only allow alphanumeric, underscores, and must start with letter or underscore + for i, c := range name { + if i == 0 && !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c == '_') { + return fmt.Errorf("identifier must start with letter or underscore: %s", name) + } + if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_') { + return fmt.Errorf("identifier contains invalid character %q: %s", c, name) + } + } + return nil +} + +// quoteIdentifier safely quotes a PostgreSQL identifier +func quoteIdentifier(name string) string { + // Double any existing double quotes and wrap in double quotes + return `"` + strings.ReplaceAll(name, `"`, `""`) + `"` +} + // CreateDatabase creates a new database func (p *PostgreSQL) CreateDatabase(ctx context.Context, name string) error { if p.db == nil { return fmt.Errorf("not connected to database") } + // Validate identifier to prevent SQL injection + if err := validateIdentifier(name); err != nil { + return fmt.Errorf("invalid database name: %w", err) + } + // PostgreSQL doesn't support CREATE DATABASE in transactions or prepared statements - query := fmt.Sprintf("CREATE DATABASE %s", name) + // Use quoted identifier for safety + query := fmt.Sprintf("CREATE DATABASE %s", quoteIdentifier(name)) _, err := p.db.ExecContext(ctx, query) if err != nil { return fmt.Errorf("failed to create database %s: %w", name, err) @@ -186,8 +219,14 @@ func (p *PostgreSQL) DropDatabase(ctx context.Context, name string) error { return fmt.Errorf("not connected to database") } + // Validate identifier to prevent SQL injection + if err := validateIdentifier(name); err != nil { + return fmt.Errorf("invalid database name: %w", err) + } + // Force drop connections and drop database - query := fmt.Sprintf("DROP DATABASE IF EXISTS %s", name) + // Use quoted identifier for safety + query := fmt.Sprintf("DROP DATABASE IF EXISTS %s", quoteIdentifier(name)) _, err := p.db.ExecContext(ctx, query) if err != nil { return fmt.Errorf("failed to drop database %s: %w", name, err) diff --git a/internal/notify/manager.go b/internal/notify/manager.go index ec74951..703d349 100644 --- a/internal/notify/manager.go +++ b/internal/notify/manager.go @@ -69,6 +69,7 @@ func (m *Manager) NotifySync(ctx context.Context, event *Event) error { m.mu.RUnlock() var errors []error + var errMu sync.Mutex var wg sync.WaitGroup for _, n := range notifiers { @@ -80,7 +81,9 @@ func (m *Manager) NotifySync(ctx context.Context, event *Event) error { go func(notifier Notifier) { defer wg.Done() if err := notifier.Send(ctx, event); err != nil { + errMu.Lock() errors = append(errors, fmt.Errorf("%s: %w", notifier.Name(), err)) + errMu.Unlock() } }(n) }