security: P0 fixes - SQL injection prevention + data race fix
- Add identifier validation for database names in PostgreSQL and MySQL - validateIdentifier() rejects names with invalid characters - quoteIdentifier() safely quotes identifiers with proper escaping - Max length: 63 chars (PostgreSQL), 64 chars (MySQL) - Only allows alphanumeric + underscores, must start with letter/underscore - Fix data race in notification manager - Multiple goroutines were appending to shared error slice - Added errMu sync.Mutex to protect concurrent error collection - Security improvements prevent: - SQL injection via malicious database names - CREATE DATABASE `foo`; DROP DATABASE production; --` - Race conditions causing lost or corrupted error data
This commit is contained in:
@@ -126,13 +126,46 @@ func (m *MySQL) ListTables(ctx context.Context, database string) ([]string, erro
|
||||
return tables, rows.Err()
|
||||
}
|
||||
|
||||
// validateMySQLIdentifier checks if a database/table name is safe for use in SQL
|
||||
// Prevents SQL injection by only allowing alphanumeric names with underscores
|
||||
func validateMySQLIdentifier(name string) error {
|
||||
if len(name) == 0 {
|
||||
return fmt.Errorf("identifier cannot be empty")
|
||||
}
|
||||
if len(name) > 64 {
|
||||
return fmt.Errorf("identifier too long (max 64 chars): %s", name)
|
||||
}
|
||||
// Only allow alphanumeric, underscores, and must start with letter or underscore
|
||||
for i, c := range name {
|
||||
if i == 0 && !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c == '_') {
|
||||
return fmt.Errorf("identifier must start with letter or underscore: %s", name)
|
||||
}
|
||||
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_') {
|
||||
return fmt.Errorf("identifier contains invalid character %q: %s", c, name)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// quoteMySQLIdentifier safely quotes a MySQL identifier
|
||||
func quoteMySQLIdentifier(name string) string {
|
||||
// Escape any backticks by doubling them and wrap in backticks
|
||||
return "`" + strings.ReplaceAll(name, "`", "``") + "`"
|
||||
}
|
||||
|
||||
// CreateDatabase creates a new database
|
||||
func (m *MySQL) CreateDatabase(ctx context.Context, name string) error {
|
||||
if m.db == nil {
|
||||
return fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `%s`", name)
|
||||
// Validate identifier to prevent SQL injection
|
||||
if err := validateMySQLIdentifier(name); err != nil {
|
||||
return fmt.Errorf("invalid database name: %w", err)
|
||||
}
|
||||
|
||||
// Use safe quoting for identifier
|
||||
query := fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", quoteMySQLIdentifier(name))
|
||||
_, err := m.db.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create database %s: %w", name, err)
|
||||
@@ -148,7 +181,13 @@ func (m *MySQL) DropDatabase(ctx context.Context, name string) error {
|
||||
return fmt.Errorf("not connected to database")
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", name)
|
||||
// Validate identifier to prevent SQL injection
|
||||
if err := validateMySQLIdentifier(name); err != nil {
|
||||
return fmt.Errorf("invalid database name: %w", err)
|
||||
}
|
||||
|
||||
// Use safe quoting for identifier
|
||||
query := fmt.Sprintf("DROP DATABASE IF EXISTS %s", quoteMySQLIdentifier(name))
|
||||
_, err := m.db.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to drop database %s: %w", name, err)
|
||||
|
||||
Reference in New Issue
Block a user