security: P0 fixes - SQL injection prevention + data race fix
- Add identifier validation for database names in PostgreSQL and MySQL - validateIdentifier() rejects names with invalid characters - quoteIdentifier() safely quotes identifiers with proper escaping - Max length: 63 chars (PostgreSQL), 64 chars (MySQL) - Only allows alphanumeric + underscores, must start with letter/underscore - Fix data race in notification manager - Multiple goroutines were appending to shared error slice - Added errMu sync.Mutex to protect concurrent error collection - Security improvements prevent: - SQL injection via malicious database names - CREATE DATABASE `foo`; DROP DATABASE production; --` - Race conditions causing lost or corrupted error data
This commit is contained in:
@@ -163,14 +163,47 @@ func (p *PostgreSQL) ListTables(ctx context.Context, database string) ([]string,
|
||||
return tables, rows.Err()
|
||||
}
|
||||
|
||||
// validateIdentifier checks if a database/table name is safe for use in SQL
|
||||
// Prevents SQL injection by only allowing alphanumeric names with underscores
|
||||
func validateIdentifier(name string) error {
|
||||
if len(name) == 0 {
|
||||
return fmt.Errorf("identifier cannot be empty")
|
||||
}
|
||||
if len(name) > 63 {
|
||||
return fmt.Errorf("identifier too long (max 63 chars): %s", name)
|
||||
}
|
||||
// Only allow alphanumeric, underscores, and must start with letter or underscore
|
||||
for i, c := range name {
|
||||
if i == 0 && !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c == '_') {
|
||||
return fmt.Errorf("identifier must start with letter or underscore: %s", name)
|
||||
}
|
||||
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_') {
|
||||
return fmt.Errorf("identifier contains invalid character %q: %s", c, name)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// quoteIdentifier safely quotes a PostgreSQL identifier
|
||||
func quoteIdentifier(name string) string {
|
||||
// Double any existing double quotes and wrap in double quotes
|
||||
return `"` + strings.ReplaceAll(name, `"`, `""`) + `"`
|
||||
}
|
||||
|
||||
// CreateDatabase creates a new database
|
||||
func (p *PostgreSQL) CreateDatabase(ctx context.Context, name string) error {
|
||||
if p.db == nil {
|
||||
return fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
// Validate identifier to prevent SQL injection
|
||||
if err := validateIdentifier(name); err != nil {
|
||||
return fmt.Errorf("invalid database name: %w", err)
|
||||
}
|
||||
|
||||
// PostgreSQL doesn't support CREATE DATABASE in transactions or prepared statements
|
||||
query := fmt.Sprintf("CREATE DATABASE %s", name)
|
||||
// Use quoted identifier for safety
|
||||
query := fmt.Sprintf("CREATE DATABASE %s", quoteIdentifier(name))
|
||||
_, err := p.db.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create database %s: %w", name, err)
|
||||
@@ -186,8 +219,14 @@ func (p *PostgreSQL) DropDatabase(ctx context.Context, name string) error {
|
||||
return fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
// Validate identifier to prevent SQL injection
|
||||
if err := validateIdentifier(name); err != nil {
|
||||
return fmt.Errorf("invalid database name: %w", err)
|
||||
}
|
||||
|
||||
// Force drop connections and drop database
|
||||
query := fmt.Sprintf("DROP DATABASE IF EXISTS %s", name)
|
||||
// Use quoted identifier for safety
|
||||
query := fmt.Sprintf("DROP DATABASE IF EXISTS %s", quoteIdentifier(name))
|
||||
_, err := p.db.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to drop database %s: %w", name, err)
|
||||
|
||||
Reference in New Issue
Block a user