diff --git a/README.md b/README.md index d10c42e..abfd449 100644 --- a/README.md +++ b/README.md @@ -8,13 +8,22 @@ A high-performance, secure file server implementing XEP-0363 (HTTP File Upload) ## Features +### Core Features - XEP-0363 HTTP File Upload compliance -- HMAC-based authentication -- File deduplication +- HMAC-based authentication with JWT support +- File deduplication (SHA256 with hardlinks) - Multi-architecture support (AMD64, ARM64, ARM32v7) - Docker and Podman deployment - XMPP client compatibility (Dino, Gajim, Conversations, Monal, Converse.js) -- Network resilience for mobile clients +- Network resilience for mobile clients (WiFi/LTE switching) + +### Security Features (v3.3.0) +- **Audit Logging** - Comprehensive security event logging (uploads, downloads, auth events) +- **Magic Bytes Validation** - Content type verification using file signatures +- **Per-User Quotas** - Storage limits per XMPP JID with Redis tracking +- **Admin API** - Protected endpoints for system management and monitoring +- **ClamAV Integration** - Antivirus scanning for uploaded files +- **Rate Limiting** - Configurable request throttling ## Installation @@ -90,16 +99,19 @@ secret = "your-hmac-secret-key" | Section | Description | |---------|-------------| | `[server]` | Bind address, port, storage path, timeouts | -| `[security]` | HMAC secret, TLS settings | +| `[security]` | HMAC secret, JWT, TLS settings | | `[uploads]` | Size limits, allowed extensions | | `[downloads]` | Download settings, bandwidth limits | | `[logging]` | Log file, log level | | `[clamav]` | Antivirus scanning integration | | `[redis]` | Redis caching backend | -| `[deduplication]` | File deduplication settings | +| `[audit]` | Security audit logging | +| `[validation]` | Magic bytes content validation | +| `[quotas]` | Per-user storage quotas | +| `[admin]` | Admin API configuration | | `[workers]` | Worker pool configuration | -See [examples/](examples/) for complete configuration templates. +See [templates/](templates/) for complete configuration templates. ## XMPP Server Integration @@ -168,6 +180,64 @@ token = HMAC-SHA256(secret, filename + filesize + timestamp) | `/download/...` | GET | File download | | `/health` | GET | Health check | | `/metrics` | GET | Prometheus metrics | +| `/admin/stats` | GET | Server statistics (auth required) | +| `/admin/files` | GET | List uploaded files (auth required) | +| `/admin/users` | GET | User quota information (auth required) | + +## Enhanced Features (v3.3.0) + +### Audit Logging + +Security-focused logging for compliance and forensics: + +```toml +[audit] +enabled = true +output = "file" +path = "/var/log/hmac-audit.log" +format = "json" +events = ["upload", "download", "auth_failure", "quota_exceeded"] +``` + +### Content Validation + +Magic bytes validation to verify file types: + +```toml +[validation] +check_magic_bytes = true +allowed_types = ["image/*", "video/*", "audio/*", "application/pdf"] +blocked_types = ["application/x-executable", "application/x-shellscript"] +``` + +### Per-User Quotas + +Storage limits per XMPP JID with Redis tracking: + +```toml +[quotas] +enabled = true +default = "100MB" +tracking = "redis" + +[quotas.custom] +"admin@example.com" = "10GB" +"premium@example.com" = "1GB" +``` + +### Admin API + +Protected management endpoints: + +```toml +[admin] +enabled = true +path_prefix = "/admin" + +[admin.auth] +type = "bearer" +token = "${ADMIN_TOKEN}" +``` ## System Requirements diff --git a/cmd/server/admin.go b/cmd/server/admin.go new file mode 100644 index 0000000..e0020b9 --- /dev/null +++ b/cmd/server/admin.go @@ -0,0 +1,756 @@ +// admin.go - Admin API for operations and monitoring + +package main + +import ( + "context" + "encoding/json" + "fmt" + "io/fs" + "net/http" + "os" + "path/filepath" + "runtime" + "sort" + "strconv" + "strings" + "time" +) + +// AdminConfig holds admin API configuration +type AdminConfig struct { + Enabled bool `toml:"enabled" mapstructure:"enabled"` + Bind string `toml:"bind" mapstructure:"bind"` // Separate bind address (e.g., "127.0.0.1:8081") + PathPrefix string `toml:"path_prefix" mapstructure:"path_prefix"` // Path prefix (e.g., "/admin") + Auth AdminAuthConfig `toml:"auth" mapstructure:"auth"` +} + +// AdminAuthConfig holds admin authentication configuration +type AdminAuthConfig struct { + Type string `toml:"type" mapstructure:"type"` // "bearer" | "basic" + Token string `toml:"token" mapstructure:"token"` // For bearer auth + Username string `toml:"username" mapstructure:"username"` // For basic auth + Password string `toml:"password" mapstructure:"password"` // For basic auth +} + +// AdminStats represents system statistics +type AdminStats struct { + Storage StorageStats `json:"storage"` + Users UserStats `json:"users"` + Requests RequestStats `json:"requests"` + System SystemStats `json:"system"` +} + +// StorageStats represents storage statistics +type StorageStats struct { + UsedBytes int64 `json:"used_bytes"` + UsedHuman string `json:"used_human"` + FileCount int64 `json:"file_count"` + FreeBytes int64 `json:"free_bytes,omitempty"` + FreeHuman string `json:"free_human,omitempty"` + TotalBytes int64 `json:"total_bytes,omitempty"` + TotalHuman string `json:"total_human,omitempty"` +} + +// UserStats represents user statistics +type UserStats struct { + Total int64 `json:"total"` + Active24h int64 `json:"active_24h"` + Active7d int64 `json:"active_7d"` +} + +// RequestStats represents request statistics +type RequestStats struct { + Uploads24h int64 `json:"uploads_24h"` + Downloads24h int64 `json:"downloads_24h"` + Errors24h int64 `json:"errors_24h"` +} + +// SystemStats represents system statistics +type SystemStats struct { + Uptime string `json:"uptime"` + Version string `json:"version"` + GoVersion string `json:"go_version"` + NumGoroutines int `json:"num_goroutines"` + MemoryUsageMB int64 `json:"memory_usage_mb"` + NumCPU int `json:"num_cpu"` +} + +// FileInfo represents file information for admin API +type FileInfo struct { + ID string `json:"id"` + Path string `json:"path"` + Name string `json:"name"` + Size int64 `json:"size"` + SizeHuman string `json:"size_human"` + ContentType string `json:"content_type"` + ModTime time.Time `json:"mod_time"` + Owner string `json:"owner,omitempty"` +} + +// FileListResponse represents paginated file list +type FileListResponse struct { + Files []FileInfo `json:"files"` + Total int64 `json:"total"` + Page int `json:"page"` + Limit int `json:"limit"` + TotalPages int `json:"total_pages"` +} + +// UserInfo represents user information for admin API +type UserInfo struct { + JID string `json:"jid"` + QuotaUsed int64 `json:"quota_used"` + QuotaLimit int64 `json:"quota_limit"` + FileCount int64 `json:"file_count"` + LastActive time.Time `json:"last_active,omitempty"` + IsBanned bool `json:"is_banned"` +} + +// BanInfo represents ban information +type BanInfo struct { + IP string `json:"ip"` + Reason string `json:"reason"` + CreatedAt time.Time `json:"created_at"` + ExpiresAt time.Time `json:"expires_at,omitempty"` + IsPermanent bool `json:"is_permanent"` +} + +var ( + serverStartTime = time.Now() + adminConfig *AdminConfig +) + +// SetupAdminRoutes sets up admin API routes +func SetupAdminRoutes(mux *http.ServeMux, config *AdminConfig) { + adminConfig = config + + if !config.Enabled { + log.Info("Admin API is disabled") + return + } + + prefix := config.PathPrefix + if prefix == "" { + prefix = "/admin" + } + + // Wrap all admin handlers with authentication + adminMux := http.NewServeMux() + + adminMux.HandleFunc(prefix+"/stats", handleAdminStats) + adminMux.HandleFunc(prefix+"/files", handleAdminFiles) + adminMux.HandleFunc(prefix+"/files/", handleAdminFileByID) + adminMux.HandleFunc(prefix+"/users", handleAdminUsers) + adminMux.HandleFunc(prefix+"/users/", handleAdminUserByJID) + adminMux.HandleFunc(prefix+"/bans", handleAdminBans) + adminMux.HandleFunc(prefix+"/bans/", handleAdminBanByIP) + adminMux.HandleFunc(prefix+"/health", handleAdminHealth) + adminMux.HandleFunc(prefix+"/config", handleAdminConfig) + + // Register with authentication middleware + mux.Handle(prefix+"/", AdminAuthMiddleware(adminMux)) + + log.Infof("Admin API enabled at %s (auth: %s)", prefix, config.Auth.Type) +} + +// AdminAuthMiddleware handles admin authentication +func AdminAuthMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if adminConfig == nil || !adminConfig.Enabled { + http.Error(w, "Admin API disabled", http.StatusServiceUnavailable) + return + } + + authorized := false + + switch adminConfig.Auth.Type { + case "bearer": + auth := r.Header.Get("Authorization") + if strings.HasPrefix(auth, "Bearer ") { + token := strings.TrimPrefix(auth, "Bearer ") + authorized = token == adminConfig.Auth.Token + } + case "basic": + username, password, ok := r.BasicAuth() + if ok { + authorized = username == adminConfig.Auth.Username && + password == adminConfig.Auth.Password + } + default: + // No auth configured, check if request is from localhost + clientIP := getClientIP(r) + authorized = clientIP == "127.0.0.1" || clientIP == "::1" + } + + if !authorized { + AuditEvent("admin_auth_failure", r, nil) + w.Header().Set("WWW-Authenticate", `Bearer realm="admin"`) + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + next.ServeHTTP(w, r) + }) +} + +// handleAdminStats returns system statistics +func handleAdminStats(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + AuditAdminAction(r, "get_stats", "system", nil) + + ctx := r.Context() + stats := AdminStats{} + + // Storage stats + confMutex.RLock() + storagePath := conf.Server.StoragePath + confMutex.RUnlock() + + storageStats := calculateStorageStats(storagePath) + stats.Storage = storageStats + + // User stats + stats.Users = calculateUserStats(ctx) + + // Request stats from Prometheus metrics + stats.Requests = calculateRequestStats() + + // System stats + var mem runtime.MemStats + runtime.ReadMemStats(&mem) + + stats.System = SystemStats{ + Uptime: time.Since(serverStartTime).Round(time.Second).String(), + Version: "3.3.0", + GoVersion: runtime.Version(), + NumGoroutines: runtime.NumGoroutine(), + MemoryUsageMB: int64(mem.Alloc / 1024 / 1024), + NumCPU: runtime.NumCPU(), + } + + writeJSONResponseAdmin(w, stats) +} + +// handleAdminFiles handles file listing +func handleAdminFiles(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + listFiles(w, r) + default: + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } +} + +// listFiles returns paginated file list +func listFiles(w http.ResponseWriter, r *http.Request) { + AuditAdminAction(r, "list_files", "files", nil) + + // Parse query parameters + page, _ := strconv.Atoi(r.URL.Query().Get("page")) + if page < 1 { + page = 1 + } + limit, _ := strconv.Atoi(r.URL.Query().Get("limit")) + if limit < 1 || limit > 100 { + limit = 50 + } + sortBy := r.URL.Query().Get("sort") + if sortBy == "" { + sortBy = "date" + } + filterOwner := r.URL.Query().Get("owner") + + confMutex.RLock() + storagePath := conf.Server.StoragePath + confMutex.RUnlock() + + var files []FileInfo + + err := filepath.WalkDir(storagePath, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return nil // Skip errors + } + if d.IsDir() { + return nil + } + + info, err := d.Info() + if err != nil { + return nil + } + + relPath, _ := filepath.Rel(storagePath, path) + + fileInfo := FileInfo{ + ID: relPath, + Path: relPath, + Name: filepath.Base(path), + Size: info.Size(), + SizeHuman: formatBytes(info.Size()), + ContentType: GetContentType(path), + ModTime: info.ModTime(), + } + + // Apply owner filter if specified (simplified: would need metadata lookup) + _ = filterOwner // Unused for now, but kept for future implementation + + files = append(files, fileInfo) + return nil + }) + + if err != nil { + http.Error(w, fmt.Sprintf("Error listing files: %v", err), http.StatusInternalServerError) + return + } + + // Sort files + switch sortBy { + case "date": + sort.Slice(files, func(i, j int) bool { + return files[i].ModTime.After(files[j].ModTime) + }) + case "size": + sort.Slice(files, func(i, j int) bool { + return files[i].Size > files[j].Size + }) + case "name": + sort.Slice(files, func(i, j int) bool { + return files[i].Name < files[j].Name + }) + } + + // Paginate + total := len(files) + start := (page - 1) * limit + end := start + limit + if start > total { + start = total + } + if end > total { + end = total + } + + response := FileListResponse{ + Files: files[start:end], + Total: int64(total), + Page: page, + Limit: limit, + TotalPages: (total + limit - 1) / limit, + } + + writeJSONResponseAdmin(w, response) +} + +// handleAdminFileByID handles single file operations +func handleAdminFileByID(w http.ResponseWriter, r *http.Request) { + // Extract file ID from path + prefix := adminConfig.PathPrefix + if prefix == "" { + prefix = "/admin" + } + fileID := strings.TrimPrefix(r.URL.Path, prefix+"/files/") + + if fileID == "" { + http.Error(w, "File ID required", http.StatusBadRequest) + return + } + + switch r.Method { + case http.MethodGet: + getFileInfo(w, r, fileID) + case http.MethodDelete: + deleteFile(w, r, fileID) + default: + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } +} + +// getFileInfo returns information about a specific file +func getFileInfo(w http.ResponseWriter, r *http.Request, fileID string) { + confMutex.RLock() + storagePath := conf.Server.StoragePath + confMutex.RUnlock() + + filePath := filepath.Join(storagePath, fileID) + + // Validate path is within storage + absPath, err := filepath.Abs(filePath) + if err != nil || !strings.HasPrefix(absPath, storagePath) { + http.Error(w, "Invalid file ID", http.StatusBadRequest) + return + } + + info, err := os.Stat(filePath) + if os.IsNotExist(err) { + http.Error(w, "File not found", http.StatusNotFound) + return + } + if err != nil { + http.Error(w, fmt.Sprintf("Error accessing file: %v", err), http.StatusInternalServerError) + return + } + + fileInfo := FileInfo{ + ID: fileID, + Path: fileID, + Name: filepath.Base(filePath), + Size: info.Size(), + SizeHuman: formatBytes(info.Size()), + ContentType: GetContentType(filePath), + ModTime: info.ModTime(), + } + + writeJSONResponseAdmin(w, fileInfo) +} + +// deleteFile deletes a specific file +func deleteFile(w http.ResponseWriter, r *http.Request, fileID string) { + confMutex.RLock() + storagePath := conf.Server.StoragePath + confMutex.RUnlock() + + filePath := filepath.Join(storagePath, fileID) + + // Validate path is within storage + absPath, err := filepath.Abs(filePath) + if err != nil || !strings.HasPrefix(absPath, storagePath) { + http.Error(w, "Invalid file ID", http.StatusBadRequest) + return + } + + // Get file info before deletion for audit + info, err := os.Stat(filePath) + if os.IsNotExist(err) { + http.Error(w, "File not found", http.StatusNotFound) + return + } + + AuditAdminAction(r, "delete_file", fileID, map[string]interface{}{ + "size": info.Size(), + }) + + if err := os.Remove(filePath); err != nil { + http.Error(w, fmt.Sprintf("Error deleting file: %v", err), http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusNoContent) +} + +// handleAdminUsers handles user listing +func handleAdminUsers(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + AuditAdminAction(r, "list_users", "users", nil) + + ctx := r.Context() + qm := GetQuotaManager() + + var users []UserInfo + + if qm != nil && qm.config.Enabled { + quotas, err := qm.GetAllQuotas(ctx) + if err != nil { + http.Error(w, fmt.Sprintf("Error getting quotas: %v", err), http.StatusInternalServerError) + return + } + + for _, quota := range quotas { + users = append(users, UserInfo{ + JID: quota.JID, + QuotaUsed: quota.Used, + QuotaLimit: quota.Limit, + FileCount: quota.FileCount, + }) + } + } + + writeJSONResponseAdmin(w, users) +} + +// handleAdminUserByJID handles single user operations +func handleAdminUserByJID(w http.ResponseWriter, r *http.Request) { + prefix := adminConfig.PathPrefix + if prefix == "" { + prefix = "/admin" + } + + path := strings.TrimPrefix(r.URL.Path, prefix+"/users/") + parts := strings.Split(path, "/") + jid := parts[0] + + if jid == "" { + http.Error(w, "JID required", http.StatusBadRequest) + return + } + + // Check for sub-paths + if len(parts) > 1 { + switch parts[1] { + case "files": + handleUserFiles(w, r, jid) + return + case "quota": + handleUserQuota(w, r, jid) + return + } + } + + switch r.Method { + case http.MethodGet: + getUserInfo(w, r, jid) + default: + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } +} + +// getUserInfo returns information about a specific user +func getUserInfo(w http.ResponseWriter, r *http.Request, jid string) { + ctx := r.Context() + qm := GetQuotaManager() + + if qm == nil || !qm.config.Enabled { + http.Error(w, "Quota tracking not enabled", http.StatusNotImplemented) + return + } + + quota, err := qm.GetQuotaInfo(ctx, jid) + if err != nil { + http.Error(w, fmt.Sprintf("Error getting quota: %v", err), http.StatusInternalServerError) + return + } + + user := UserInfo{ + JID: jid, + QuotaUsed: quota.Used, + QuotaLimit: quota.Limit, + FileCount: quota.FileCount, + } + + writeJSONResponseAdmin(w, user) +} + +// handleUserFiles handles user file operations +func handleUserFiles(w http.ResponseWriter, r *http.Request, jid string) { + switch r.Method { + case http.MethodGet: + // List user's files + AuditAdminAction(r, "list_user_files", jid, nil) + // Would need file ownership tracking to implement fully + writeJSONResponseAdmin(w, []FileInfo{}) + case http.MethodDelete: + // Delete all user's files + AuditAdminAction(r, "delete_user_files", jid, nil) + // Would need file ownership tracking to implement fully + w.WriteHeader(http.StatusNoContent) + default: + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } +} + +// handleUserQuota handles user quota operations +func handleUserQuota(w http.ResponseWriter, r *http.Request, jid string) { + qm := GetQuotaManager() + if qm == nil { + http.Error(w, "Quota management not enabled", http.StatusNotImplemented) + return + } + + switch r.Method { + case http.MethodPost: + // Set custom quota + var req struct { + Quota string `json:"quota"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "Invalid request body", http.StatusBadRequest) + return + } + + quota, err := parseSize(req.Quota) + if err != nil { + http.Error(w, fmt.Sprintf("Invalid quota: %v", err), http.StatusBadRequest) + return + } + + qm.SetCustomQuota(jid, quota) + AuditAdminAction(r, "set_quota", jid, map[string]interface{}{"quota": req.Quota}) + + writeJSONResponseAdmin(w, map[string]interface{}{ + "success": true, + "jid": jid, + "quota": quota, + }) + case http.MethodDelete: + // Remove custom quota + qm.RemoveCustomQuota(jid) + AuditAdminAction(r, "remove_quota", jid, nil) + w.WriteHeader(http.StatusNoContent) + default: + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } +} + +// handleAdminBans handles ban listing +func handleAdminBans(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + AuditAdminAction(r, "list_bans", "bans", nil) + + // Would need ban management implementation + writeJSONResponseAdmin(w, []BanInfo{}) +} + +// handleAdminBanByIP handles single ban operations +func handleAdminBanByIP(w http.ResponseWriter, r *http.Request) { + prefix := adminConfig.PathPrefix + if prefix == "" { + prefix = "/admin" + } + ip := strings.TrimPrefix(r.URL.Path, prefix+"/bans/") + + if ip == "" { + http.Error(w, "IP required", http.StatusBadRequest) + return + } + + switch r.Method { + case http.MethodDelete: + // Unban IP + AuditAdminAction(r, "unban", ip, nil) + // Would need ban management implementation + w.WriteHeader(http.StatusNoContent) + default: + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } +} + +// handleAdminHealth returns admin-specific health info +func handleAdminHealth(w http.ResponseWriter, r *http.Request) { + health := map[string]interface{}{ + "status": "healthy", + "timestamp": time.Now().UTC(), + "uptime": time.Since(serverStartTime).String(), + } + + // Check Redis + if redisClient != nil && redisConnected { + health["redis"] = "connected" + } else if redisClient != nil { + health["redis"] = "disconnected" + } + + writeJSONResponseAdmin(w, health) +} + +// handleAdminConfig returns current configuration (sanitized) +func handleAdminConfig(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + AuditAdminAction(r, "get_config", "config", nil) + + confMutex.RLock() + // Return sanitized config (no secrets) + sanitized := map[string]interface{}{ + "server": map[string]interface{}{ + "listen_address": conf.Server.ListenAddress, + "storage_path": conf.Server.StoragePath, + "max_upload_size": conf.Server.MaxUploadSize, + "metrics_enabled": conf.Server.MetricsEnabled, + }, + "security": map[string]interface{}{ + "enhanced_security": conf.Security.EnhancedSecurity, + "jwt_enabled": conf.Security.EnableJWT, + }, + "clamav": map[string]interface{}{ + "enabled": conf.ClamAV.ClamAVEnabled, + }, + "redis": map[string]interface{}{ + "enabled": conf.Redis.RedisEnabled, + }, + "deduplication": map[string]interface{}{ + "enabled": conf.Deduplication.Enabled, + }, + } + confMutex.RUnlock() + + writeJSONResponseAdmin(w, sanitized) +} + +// Helper functions + +func calculateStorageStats(storagePath string) StorageStats { + var totalSize int64 + var fileCount int64 + + _ = filepath.WalkDir(storagePath, func(path string, d fs.DirEntry, err error) error { + if err != nil || d.IsDir() { + return nil + } + if info, err := d.Info(); err == nil { + totalSize += info.Size() + fileCount++ + } + return nil + }) + + return StorageStats{ + UsedBytes: totalSize, + UsedHuman: formatBytes(totalSize), + FileCount: fileCount, + } +} + +func calculateUserStats(ctx context.Context) UserStats { + qm := GetQuotaManager() + if qm == nil || !qm.config.Enabled { + return UserStats{} + } + + quotas, err := qm.GetAllQuotas(ctx) + if err != nil { + return UserStats{} + } + + return UserStats{ + Total: int64(len(quotas)), + } +} + +func calculateRequestStats() RequestStats { + // These would ideally come from Prometheus metrics + return RequestStats{} +} + +func writeJSONResponseAdmin(w http.ResponseWriter, data interface{}) { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(data); err != nil { + log.Errorf("Failed to encode JSON response: %v", err) + } +} + +// DefaultAdminConfig returns default admin configuration +func DefaultAdminConfig() AdminConfig { + return AdminConfig{ + Enabled: false, + Bind: "", + PathPrefix: "/admin", + Auth: AdminAuthConfig{ + Type: "bearer", + }, + } +} diff --git a/cmd/server/audit.go b/cmd/server/audit.go new file mode 100644 index 0000000..d3a9ee5 --- /dev/null +++ b/cmd/server/audit.go @@ -0,0 +1,366 @@ +// audit.go - Dedicated audit logging for security-relevant events + +package main + +import ( + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/sirupsen/logrus" + "gopkg.in/natefinch/lumberjack.v2" +) + +// AuditConfig holds audit logging configuration +type AuditConfig struct { + Enabled bool `toml:"enabled" mapstructure:"enabled"` + Output string `toml:"output" mapstructure:"output"` // "file" | "stdout" + Path string `toml:"path" mapstructure:"path"` // Log file path + Format string `toml:"format" mapstructure:"format"` // "json" | "text" + Events []string `toml:"events" mapstructure:"events"` // Events to log + MaxSize int `toml:"max_size" mapstructure:"max_size"` // Max size in MB + MaxAge int `toml:"max_age" mapstructure:"max_age"` // Max age in days +} + +// AuditEvent types +const ( + AuditEventUpload = "upload" + AuditEventDownload = "download" + AuditEventDelete = "delete" + AuditEventAuthSuccess = "auth_success" + AuditEventAuthFailure = "auth_failure" + AuditEventRateLimited = "rate_limited" + AuditEventBanned = "banned" + AuditEventQuotaExceeded = "quota_exceeded" + AuditEventAdminAction = "admin_action" + AuditEventValidationFailure = "validation_failure" +) + +// AuditLogger handles security audit logging +type AuditLogger struct { + logger *logrus.Logger + config *AuditConfig + enabledEvents map[string]bool + mutex sync.RWMutex +} + +var ( + auditLogger *AuditLogger + auditOnce sync.Once +) + +// InitAuditLogger initializes the audit logger +func InitAuditLogger(config *AuditConfig) error { + var initErr error + auditOnce.Do(func() { + auditLogger = &AuditLogger{ + logger: logrus.New(), + config: config, + enabledEvents: make(map[string]bool), + } + + // Build enabled events map for fast lookup + for _, event := range config.Events { + auditLogger.enabledEvents[strings.ToLower(event)] = true + } + + // Configure formatter + if config.Format == "json" { + auditLogger.logger.SetFormatter(&logrus.JSONFormatter{ + TimestampFormat: time.RFC3339, + FieldMap: logrus.FieldMap{ + logrus.FieldKeyTime: "timestamp", + logrus.FieldKeyMsg: "event", + }, + }) + } else { + auditLogger.logger.SetFormatter(&logrus.TextFormatter{ + TimestampFormat: time.RFC3339, + FullTimestamp: true, + }) + } + + // Configure output + if !config.Enabled { + auditLogger.logger.SetOutput(io.Discard) + return + } + + switch config.Output { + case "stdout": + auditLogger.logger.SetOutput(os.Stdout) + case "file": + if config.Path == "" { + config.Path = "/var/log/hmac-audit.log" + } + + // Ensure directory exists + dir := filepath.Dir(config.Path) + if err := os.MkdirAll(dir, 0755); err != nil { + initErr = err + return + } + + // Use lumberjack for log rotation + maxSize := config.MaxSize + if maxSize <= 0 { + maxSize = 100 // Default 100MB + } + maxAge := config.MaxAge + if maxAge <= 0 { + maxAge = 30 // Default 30 days + } + + auditLogger.logger.SetOutput(&lumberjack.Logger{ + Filename: config.Path, + MaxSize: maxSize, + MaxAge: maxAge, + MaxBackups: 5, + Compress: true, + }) + default: + auditLogger.logger.SetOutput(os.Stdout) + } + + auditLogger.logger.SetLevel(logrus.InfoLevel) + log.Infof("Audit logger initialized: output=%s, path=%s, format=%s, events=%v", + config.Output, config.Path, config.Format, config.Events) + }) + + return initErr +} + +// GetAuditLogger returns the singleton audit logger +func GetAuditLogger() *AuditLogger { + return auditLogger +} + +// IsEventEnabled checks if an event type should be logged +func (a *AuditLogger) IsEventEnabled(event string) bool { + if a == nil || !a.config.Enabled { + return false + } + a.mutex.RLock() + defer a.mutex.RUnlock() + + // If no events configured, log all + if len(a.enabledEvents) == 0 { + return true + } + return a.enabledEvents[strings.ToLower(event)] +} + +// LogEvent logs an audit event +func (a *AuditLogger) LogEvent(event string, fields logrus.Fields) { + if a == nil || !a.config.Enabled || !a.IsEventEnabled(event) { + return + } + + // Add standard fields + fields["event_type"] = event + if _, ok := fields["timestamp"]; !ok { + fields["timestamp"] = time.Now().UTC().Format(time.RFC3339) + } + + a.logger.WithFields(fields).Info(event) +} + +// AuditEvent is a helper function for logging audit events from request context +func AuditEvent(event string, r *http.Request, fields logrus.Fields) { + if auditLogger == nil || !auditLogger.config.Enabled { + return + } + + if !auditLogger.IsEventEnabled(event) { + return + } + + // Add request context + if r != nil { + if fields == nil { + fields = logrus.Fields{} + } + fields["ip"] = getClientIP(r) + fields["user_agent"] = r.UserAgent() + fields["method"] = r.Method + fields["path"] = r.URL.Path + + // Extract JID if available from headers or context + if jid := r.Header.Get("X-User-JID"); jid != "" { + fields["jid"] = jid + } + } + + auditLogger.LogEvent(event, fields) +} + +// AuditUpload logs file upload events +func AuditUpload(r *http.Request, jid, fileID, fileName string, fileSize int64, contentType, result string, err error) { + fields := logrus.Fields{ + "jid": jid, + "file_id": fileID, + "file_name": fileName, + "file_size": fileSize, + "content_type": contentType, + "result": result, + } + if err != nil { + fields["error"] = err.Error() + } + AuditEvent(AuditEventUpload, r, fields) +} + +// AuditDownload logs file download events +func AuditDownload(r *http.Request, jid, fileID, fileName string, fileSize int64, result string, err error) { + fields := logrus.Fields{ + "jid": jid, + "file_id": fileID, + "file_name": fileName, + "file_size": fileSize, + "result": result, + } + if err != nil { + fields["error"] = err.Error() + } + AuditEvent(AuditEventDownload, r, fields) +} + +// AuditDelete logs file deletion events +func AuditDelete(r *http.Request, jid, fileID, fileName string, result string, err error) { + fields := logrus.Fields{ + "jid": jid, + "file_id": fileID, + "file_name": fileName, + "result": result, + } + if err != nil { + fields["error"] = err.Error() + } + AuditEvent(AuditEventDelete, r, fields) +} + +// AuditAuth logs authentication events +func AuditAuth(r *http.Request, jid string, success bool, method string, err error) { + event := AuditEventAuthSuccess + result := "success" + if !success { + event = AuditEventAuthFailure + result = "failure" + } + + fields := logrus.Fields{ + "jid": jid, + "auth_method": method, + "result": result, + } + if err != nil { + fields["error"] = err.Error() + } + AuditEvent(event, r, fields) +} + +// AuditRateLimited logs rate limiting events +func AuditRateLimited(r *http.Request, jid, reason string) { + fields := logrus.Fields{ + "jid": jid, + "reason": reason, + } + AuditEvent(AuditEventRateLimited, r, fields) +} + +// AuditBanned logs ban events +func AuditBanned(r *http.Request, jid, ip, reason string, duration time.Duration) { + fields := logrus.Fields{ + "jid": jid, + "banned_ip": ip, + "reason": reason, + "ban_duration": duration.String(), + } + AuditEvent(AuditEventBanned, r, fields) +} + +// AuditQuotaExceeded logs quota exceeded events +func AuditQuotaExceeded(r *http.Request, jid string, used, limit, requested int64) { + fields := logrus.Fields{ + "jid": jid, + "used": used, + "limit": limit, + "requested": requested, + } + AuditEvent(AuditEventQuotaExceeded, r, fields) +} + +// AuditAdminAction logs admin API actions +func AuditAdminAction(r *http.Request, action, target string, details map[string]interface{}) { + fields := logrus.Fields{ + "action": action, + "target": target, + } + for k, v := range details { + fields[k] = v + } + AuditEvent(AuditEventAdminAction, r, fields) +} + +// AuditValidationFailure logs content validation failures +func AuditValidationFailure(r *http.Request, jid, fileName, declaredType, detectedType, reason string) { + fields := logrus.Fields{ + "jid": jid, + "file_name": fileName, + "declared_type": declaredType, + "detected_type": detectedType, + "reason": reason, + } + AuditEvent(AuditEventValidationFailure, r, fields) +} + +// DefaultAuditConfig returns default audit configuration +func DefaultAuditConfig() AuditConfig { + return AuditConfig{ + Enabled: false, + Output: "file", + Path: "/var/log/hmac-audit.log", + Format: "json", + Events: []string{ + AuditEventUpload, + AuditEventDownload, + AuditEventDelete, + AuditEventAuthSuccess, + AuditEventAuthFailure, + AuditEventRateLimited, + AuditEventBanned, + }, + MaxSize: 100, + MaxAge: 30, + } +} + +// AuditAuthSuccess is a helper for logging successful authentication +func AuditAuthSuccess(r *http.Request, jid, method string) { + AuditAuth(r, jid, true, method, nil) +} + +// AuditAuthFailure is a helper for logging failed authentication +func AuditAuthFailure(r *http.Request, method, errorMsg string) { + AuditAuth(r, "", false, method, fmt.Errorf("%s", errorMsg)) +} + +// AuditUploadSuccess is a helper for logging successful uploads +func AuditUploadSuccess(r *http.Request, jid, fileName string, fileSize int64, contentType string) { + AuditUpload(r, jid, "", fileName, fileSize, contentType, "success", nil) +} + +// AuditUploadFailure is a helper for logging failed uploads +func AuditUploadFailure(r *http.Request, jid, fileName string, fileSize int64, errorMsg string) { + AuditUpload(r, jid, "", fileName, fileSize, "", "failure", fmt.Errorf("%s", errorMsg)) +} + +// AuditDownloadSuccess is a helper for logging successful downloads +func AuditDownloadSuccess(r *http.Request, jid, fileName string, fileSize int64) { + AuditDownload(r, jid, "", fileName, fileSize, "success", nil) +} diff --git a/cmd/server/main.go b/cmd/server/main.go index 7e3e783..b6863d5 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -39,22 +39,22 @@ import ( // NetworkResilientSession represents a persistent session for network switching type NetworkResilientSession struct { - SessionID string `json:"session_id"` - UserJID string `json:"user_jid"` - OriginalToken string `json:"original_token"` - CreatedAt time.Time `json:"created_at"` - LastSeen time.Time `json:"last_seen"` - NetworkHistory []NetworkEvent `json:"network_history"` - UploadContext *UploadContext `json:"upload_context,omitempty"` - RefreshCount int `json:"refresh_count"` - MaxRefreshes int `json:"max_refreshes"` - LastIP string `json:"last_ip"` - UserAgent string `json:"user_agent"` - SecurityLevel int `json:"security_level"` // 1=normal, 2=challenge, 3=reauth - LastSecurityCheck time.Time `json:"last_security_check"` - NetworkChangeCount int `json:"network_change_count"` - StandbyDetected bool `json:"standby_detected"` - LastActivity time.Time `json:"last_activity"` + SessionID string `json:"session_id"` + UserJID string `json:"user_jid"` + OriginalToken string `json:"original_token"` + CreatedAt time.Time `json:"created_at"` + LastSeen time.Time `json:"last_seen"` + NetworkHistory []NetworkEvent `json:"network_history"` + UploadContext *UploadContext `json:"upload_context,omitempty"` + RefreshCount int `json:"refresh_count"` + MaxRefreshes int `json:"max_refreshes"` + LastIP string `json:"last_ip"` + UserAgent string `json:"user_agent"` + SecurityLevel int `json:"security_level"` // 1=normal, 2=challenge, 3=reauth + LastSecurityCheck time.Time `json:"last_security_check"` + NetworkChangeCount int `json:"network_change_count"` + StandbyDetected bool `json:"standby_detected"` + LastActivity time.Time `json:"last_activity"` } // contextKey is a custom type for context keys to avoid collisions @@ -67,12 +67,12 @@ const ( // NetworkEvent tracks network transitions during session type NetworkEvent struct { - Timestamp time.Time `json:"timestamp"` - FromNetwork string `json:"from_network"` - ToNetwork string `json:"to_network"` - ClientIP string `json:"client_ip"` - UserAgent string `json:"user_agent"` - EventType string `json:"event_type"` // "switch", "resume", "refresh" + Timestamp time.Time `json:"timestamp"` + FromNetwork string `json:"from_network"` + ToNetwork string `json:"to_network"` + ClientIP string `json:"client_ip"` + UserAgent string `json:"user_agent"` + EventType string `json:"event_type"` // "switch", "resume", "refresh" } // UploadContext maintains upload state across network changes and network resilience channels @@ -244,11 +244,11 @@ func initializeSessionStore() { opt, err := redis.ParseURL(redisURL) if err == nil { sessionStore.redisClient = redis.NewClient(opt) - + // Test Redis connection ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - + if err := sessionStore.redisClient.Ping(ctx).Err(); err == nil { log.Infof("📊 Session store: Redis backend initialized (%s)", redisURL) } else { @@ -526,53 +526,57 @@ type BuildConfig struct { } type NetworkResilienceConfig struct { - FastDetection bool `toml:"fast_detection" mapstructure:"fast_detection"` - QualityMonitoring bool `toml:"quality_monitoring" mapstructure:"quality_monitoring"` - PredictiveSwitching bool `toml:"predictive_switching" mapstructure:"predictive_switching"` - MobileOptimizations bool `toml:"mobile_optimizations" mapstructure:"mobile_optimizations"` - DetectionInterval string `toml:"detection_interval" mapstructure:"detection_interval"` - QualityCheckInterval string `toml:"quality_check_interval" mapstructure:"quality_check_interval"` - MaxDetectionInterval string `toml:"max_detection_interval" mapstructure:"max_detection_interval"` - + FastDetection bool `toml:"fast_detection" mapstructure:"fast_detection"` + QualityMonitoring bool `toml:"quality_monitoring" mapstructure:"quality_monitoring"` + PredictiveSwitching bool `toml:"predictive_switching" mapstructure:"predictive_switching"` + MobileOptimizations bool `toml:"mobile_optimizations" mapstructure:"mobile_optimizations"` + DetectionInterval string `toml:"detection_interval" mapstructure:"detection_interval"` + QualityCheckInterval string `toml:"quality_check_interval" mapstructure:"quality_check_interval"` + MaxDetectionInterval string `toml:"max_detection_interval" mapstructure:"max_detection_interval"` + // Multi-interface support - MultiInterfaceEnabled bool `toml:"multi_interface_enabled" mapstructure:"multi_interface_enabled"` - InterfacePriority []string `toml:"interface_priority" mapstructure:"interface_priority"` - AutoSwitchEnabled bool `toml:"auto_switch_enabled" mapstructure:"auto_switch_enabled"` - SwitchThresholdLatency string `toml:"switch_threshold_latency" mapstructure:"switch_threshold_latency"` - SwitchThresholdPacketLoss float64 `toml:"switch_threshold_packet_loss" mapstructure:"switch_threshold_packet_loss"` + MultiInterfaceEnabled bool `toml:"multi_interface_enabled" mapstructure:"multi_interface_enabled"` + InterfacePriority []string `toml:"interface_priority" mapstructure:"interface_priority"` + AutoSwitchEnabled bool `toml:"auto_switch_enabled" mapstructure:"auto_switch_enabled"` + SwitchThresholdLatency string `toml:"switch_threshold_latency" mapstructure:"switch_threshold_latency"` + SwitchThresholdPacketLoss float64 `toml:"switch_threshold_packet_loss" mapstructure:"switch_threshold_packet_loss"` QualityDegradationThreshold float64 `toml:"quality_degradation_threshold" mapstructure:"quality_degradation_threshold"` - MaxSwitchAttempts int `toml:"max_switch_attempts" mapstructure:"max_switch_attempts"` - SwitchDetectionInterval string `toml:"switch_detection_interval" mapstructure:"switch_detection_interval"` + MaxSwitchAttempts int `toml:"max_switch_attempts" mapstructure:"max_switch_attempts"` + SwitchDetectionInterval string `toml:"switch_detection_interval" mapstructure:"switch_detection_interval"` } // ClientNetworkConfigTOML is used for loading from TOML where timeout is a string type ClientNetworkConfigTOML struct { SessionBasedTracking bool `toml:"session_based_tracking" mapstructure:"session_based_tracking"` - AllowIPChanges bool `toml:"allow_ip_changes" mapstructure:"allow_ip_changes"` - SessionMigrationTimeout string `toml:"session_migration_timeout" mapstructure:"session_migration_timeout"` - MaxIPChangesPerSession int `toml:"max_ip_changes_per_session" mapstructure:"max_ip_changes_per_session"` - ClientConnectionDetection bool `toml:"client_connection_detection" mapstructure:"client_connection_detection"` - AdaptToClientNetwork bool `toml:"adapt_to_client_network" mapstructure:"adapt_to_client_network"` + AllowIPChanges bool `toml:"allow_ip_changes" mapstructure:"allow_ip_changes"` + SessionMigrationTimeout string `toml:"session_migration_timeout" mapstructure:"session_migration_timeout"` + MaxIPChangesPerSession int `toml:"max_ip_changes_per_session" mapstructure:"max_ip_changes_per_session"` + ClientConnectionDetection bool `toml:"client_connection_detection" mapstructure:"client_connection_detection"` + AdaptToClientNetwork bool `toml:"adapt_to_client_network" mapstructure:"adapt_to_client_network"` } // This is the main Config struct to be used type Config struct { - Server ServerConfig `mapstructure:"server"` - Logging LoggingConfig `mapstructure:"logging"` - Deduplication DeduplicationConfig `mapstructure:"deduplication"` // Added - ISO ISOConfig `mapstructure:"iso"` // Added - Timeouts TimeoutConfig `mapstructure:"timeouts"` // Added - Security SecurityConfig `mapstructure:"security"` - Versioning VersioningConfig `mapstructure:"versioning"` // Added - Uploads UploadsConfig `mapstructure:"uploads"` - Downloads DownloadsConfig `mapstructure:"downloads"` - ClamAV ClamAVConfig `mapstructure:"clamav"` - Redis RedisConfig `mapstructure:"redis"` - Workers WorkersConfig `mapstructure:"workers"` - File FileConfig `mapstructure:"file"` - Build BuildConfig `mapstructure:"build"` - NetworkResilience NetworkResilienceConfig `mapstructure:"network_resilience"` - ClientNetwork ClientNetworkConfigTOML `mapstructure:"client_network_support"` + Server ServerConfig `mapstructure:"server"` + Logging LoggingConfig `mapstructure:"logging"` + Deduplication DeduplicationConfig `mapstructure:"deduplication"` // Added + ISO ISOConfig `mapstructure:"iso"` // Added + Timeouts TimeoutConfig `mapstructure:"timeouts"` // Added + Security SecurityConfig `mapstructure:"security"` + Versioning VersioningConfig `mapstructure:"versioning"` // Added + Uploads UploadsConfig `mapstructure:"uploads"` + Downloads DownloadsConfig `mapstructure:"downloads"` + ClamAV ClamAVConfig `mapstructure:"clamav"` + Redis RedisConfig `mapstructure:"redis"` + Workers WorkersConfig `mapstructure:"workers"` + File FileConfig `mapstructure:"file"` + Build BuildConfig `mapstructure:"build"` + NetworkResilience NetworkResilienceConfig `mapstructure:"network_resilience"` + ClientNetwork ClientNetworkConfigTOML `mapstructure:"client_network_support"` + Audit AuditConfig `mapstructure:"audit"` // Audit logging + Validation ValidationConfig `mapstructure:"validation"` // Content validation + Quotas QuotaConfig `mapstructure:"quotas"` // Per-user quotas + Admin AdminConfig `mapstructure:"admin"` // Admin API } type UploadTask struct { @@ -597,12 +601,12 @@ func processScan(task ScanTask) error { confMutex.RLock() clamEnabled := conf.ClamAV.ClamAVEnabled confMutex.RUnlock() - + if !clamEnabled { log.Infof("ClamAV disabled, skipping scan for file: %s", task.AbsFilename) return nil } - + log.Infof("Started processing scan for file: %s", task.AbsFilename) semaphore <- struct{}{} defer func() { <-semaphore }() @@ -621,8 +625,8 @@ var ( conf Config versionString string log = logrus.New() - fileInfoCache *cache.Cache //nolint:unused - fileMetadataCache *cache.Cache //nolint:unused + fileInfoCache *cache.Cache //nolint:unused + fileMetadataCache *cache.Cache //nolint:unused clamClient *clamd.Clamd redisClient *redis.Client redisConnected bool @@ -673,6 +677,7 @@ var clientTracker *ClientConnectionTracker //nolint:unused var logMessages []string + //nolint:unused var logMu sync.Mutex @@ -748,7 +753,7 @@ func initializeNetworkProtocol(forceProtocol string) (*net.Dialer, error) { if forceProtocol == "" { forceProtocol = "auto" } - + switch forceProtocol { case "ipv4": return &net.Dialer{ @@ -845,7 +850,7 @@ func main() { } else { content = GenerateMinimalConfig() } - + f, err := os.Create(genConfigPath) if err != nil { fmt.Fprintf(os.Stderr, "Failed to create file: %v\n", err) @@ -878,7 +883,7 @@ func main() { log.Fatalf("Failed to load configuration: %v", err) } conf = *loadedConfig - configFileGlobal = configFile // Store for validation helper functions + configFileGlobal = configFile // Store for validation helper functions log.Info("Configuration loaded successfully.") err = validateConfig(&conf) @@ -892,12 +897,12 @@ func main() { // Initialize client connection tracker for multi-interface support clientNetworkConfig := &ClientNetworkConfig{ - SessionBasedTracking: conf.ClientNetwork.SessionBasedTracking, - AllowIPChanges: conf.ClientNetwork.AllowIPChanges, - MaxIPChangesPerSession: conf.ClientNetwork.MaxIPChangesPerSession, - AdaptToClientNetwork: conf.ClientNetwork.AdaptToClientNetwork, + SessionBasedTracking: conf.ClientNetwork.SessionBasedTracking, + AllowIPChanges: conf.ClientNetwork.AllowIPChanges, + MaxIPChangesPerSession: conf.ClientNetwork.MaxIPChangesPerSession, + AdaptToClientNetwork: conf.ClientNetwork.AdaptToClientNetwork, } - + // Parse session migration timeout if conf.ClientNetwork.SessionMigrationTimeout != "" { if timeout, err := time.ParseDuration(conf.ClientNetwork.SessionMigrationTimeout); err == nil { @@ -908,12 +913,12 @@ func main() { } else { clientNetworkConfig.SessionMigrationTimeout = 5 * time.Minute // default } - + // Set defaults if not configured if clientNetworkConfig.MaxIPChangesPerSession == 0 { clientNetworkConfig.MaxIPChangesPerSession = 10 } - + // Initialize the client tracker clientTracker = NewClientConnectionTracker(clientNetworkConfig) if clientTracker != nil { @@ -1075,8 +1080,22 @@ func main() { initRedis() // Assuming initRedis is defined in helpers.go or elsewhere } + // Initialize new features + if err := InitAuditLogger(&conf.Audit); err != nil { + log.Warnf("Failed to initialize audit logger: %v", err) + } + + InitContentValidator(&conf.Validation) + + if err := InitQuotaManager(&conf.Quotas, redisClient); err != nil { + log.Warnf("Failed to initialize quota manager: %v", err) + } + router := setupRouter() // Assuming setupRouter is defined (likely in this file or router.go + // Setup Admin API routes + SetupAdminRoutes(router, &conf.Admin) + // Initialize enhancements and enhance the router InitializeEnhancements(router) @@ -1658,7 +1677,7 @@ func validateBearerToken(r *http.Request, secret string) (*BearerTokenClaims, er query := r.URL.Query() user := query.Get("user") expiryStr := query.Get("expiry") - + if user == "" { return nil, errors.New("missing user parameter") } @@ -1674,10 +1693,10 @@ func validateBearerToken(r *http.Request, secret string) (*BearerTokenClaims, er // ULTRA-FLEXIBLE GRACE PERIODS FOR NETWORK SWITCHING AND STANDBY SCENARIOS now := time.Now().Unix() - + // Base grace period: 8 hours (increased from 4 hours for better WiFi ↔ LTE reliability) gracePeriod := int64(28800) // 8 hours base grace period for all scenarios - + // Detect mobile XMPP clients and apply enhanced grace periods userAgent := r.Header.Get("User-Agent") isMobileXMPP := strings.Contains(strings.ToLower(userAgent), "conversations") || @@ -1688,12 +1707,12 @@ func validateBearerToken(r *http.Request, secret string) (*BearerTokenClaims, er strings.Contains(strings.ToLower(userAgent), "xmpp") || strings.Contains(strings.ToLower(userAgent), "client") || strings.Contains(strings.ToLower(userAgent), "bot") - + // Enhanced XMPP client detection and grace period management // Desktop XMPP clients (Dino, Gajim) need extended grace for session restoration after restart isDesktopXMPP := strings.Contains(strings.ToLower(userAgent), "dino") || strings.Contains(strings.ToLower(userAgent), "gajim") - + if isMobileXMPP || isDesktopXMPP { if isDesktopXMPP { gracePeriod = int64(86400) // 24 hours for desktop XMPP clients (session restoration) @@ -1703,32 +1722,32 @@ func validateBearerToken(r *http.Request, secret string) (*BearerTokenClaims, er log.Infof("� Mobile XMPP client detected (%s), using extended 12-hour grace period", userAgent) } } - + // Network resilience parameters for session recovery sessionId := query.Get("session_id") networkResilience := query.Get("network_resilience") resumeAllowed := query.Get("resume_allowed") - + // Maximum grace period for network resilience scenarios if sessionId != "" || networkResilience == "true" || resumeAllowed == "true" { gracePeriod = int64(86400) // 24 hours for explicit network resilience scenarios - log.Infof("🌐 Network resilience mode activated (session_id: %s, network_resilience: %s), using 24-hour grace period", + log.Infof("🌐 Network resilience mode activated (session_id: %s, network_resilience: %s), using 24-hour grace period", sessionId, networkResilience) } - + // Detect potential network switching scenarios clientIP := getClientIP(r) xForwardedFor := r.Header.Get("X-Forwarded-For") xRealIP := r.Header.Get("X-Real-IP") - + // Check for client IP change indicators (WiFi ↔ LTE switching detection) if xForwardedFor != "" || xRealIP != "" { // Client is behind proxy/NAT - likely mobile switching between networks gracePeriod = int64(86400) // 24 hours for proxy/NAT scenarios - log.Infof("📱 Network switching detected (client IP: %s, X-Forwarded-For: %s, X-Real-IP: %s), using 24-hour grace period", + log.Infof("📱 Network switching detected (client IP: %s, X-Forwarded-For: %s, X-Real-IP: %s), using 24-hour grace period", clientIP, xForwardedFor, xRealIP) } - + // Check Content-Length to identify large uploads that need extra time contentLength := r.Header.Get("Content-Length") var size int64 = 0 @@ -1741,27 +1760,27 @@ func validateBearerToken(r *http.Request, secret string) (*BearerTokenClaims, er log.Infof("📁 Large file detected (%d bytes), extending grace period by %d seconds", size, additionalTime) } } - + // ABSOLUTE MAXIMUM: 48 hours for extreme scenarios maxAbsoluteGrace := int64(172800) // 48 hours absolute maximum if gracePeriod > maxAbsoluteGrace { gracePeriod = maxAbsoluteGrace log.Infof("⚠️ Grace period capped at 48 hours maximum") } - + // STANDBY RECOVERY: Special handling for device standby scenarios isLikelyStandbyRecovery := false standbyGraceExtension := int64(86400) // Additional 24 hours for standby recovery - + if now > expiry { expiredTime := now - expiry - + // If token expired more than grace period but less than standby window, allow standby recovery - if expiredTime > gracePeriod && expiredTime < (gracePeriod + standbyGraceExtension) { + if expiredTime > gracePeriod && expiredTime < (gracePeriod+standbyGraceExtension) { isLikelyStandbyRecovery = true log.Infof("💤 STANDBY RECOVERY: Token expired %d seconds ago, within standby recovery window", expiredTime) } - + // Apply grace period check if expiredTime > gracePeriod && !isLikelyStandbyRecovery { // DESKTOP XMPP CLIENT SESSION RESTORATION: Special handling for Dino/Gajim restart scenarios @@ -1770,7 +1789,7 @@ func validateBearerToken(r *http.Request, secret string) (*BearerTokenClaims, er isDesktopSessionRestore = true log.Infof("🖥️ DESKTOP SESSION RESTORE: %s token expired %d seconds ago, allowing within 48-hour desktop restoration window", userAgent, expiredTime) } - + // Still apply ultra-generous final check for mobile scenarios ultraMaxGrace := int64(259200) // 72 hours ultra-maximum for critical mobile scenarios if (isMobileXMPP && expiredTime < ultraMaxGrace) || isDesktopSessionRestore { @@ -1778,9 +1797,9 @@ func validateBearerToken(r *http.Request, secret string) (*BearerTokenClaims, er log.Warnf("⚡ ULTRA-GRACE: Mobile XMPP client token expired %d seconds ago, allowing within 72-hour ultra-grace window", expiredTime) } } else { - log.Warnf("❌ Bearer token expired beyond all grace periods: now=%d, expiry=%d, expired_for=%d seconds, grace_period=%d, user_agent=%s", + log.Warnf("❌ Bearer token expired beyond all grace periods: now=%d, expiry=%d, expired_for=%d seconds, grace_period=%d, user_agent=%s", now, expiry, expiredTime, gracePeriod, userAgent) - return nil, fmt.Errorf("token has expired beyond grace period (expired %d seconds ago, grace period: %d seconds)", + return nil, fmt.Errorf("token has expired beyond grace period (expired %d seconds ago, grace period: %d seconds)", expiredTime, gracePeriod) } } else if isLikelyStandbyRecovery { @@ -1797,7 +1816,7 @@ func validateBearerToken(r *http.Request, secret string) (*BearerTokenClaims, er if len(pathParts) < 1 { return nil, errors.New("invalid upload path format") } - + // Handle different path formats from various ejabberd modules filename := "" if len(pathParts) >= 3 { @@ -1805,7 +1824,7 @@ func validateBearerToken(r *http.Request, secret string) (*BearerTokenClaims, er } else if len(pathParts) >= 1 { filename = pathParts[len(pathParts)-1] // Simplified format: /filename } - + if filename == "" { filename = "upload" // Fallback filename } @@ -1813,71 +1832,71 @@ func validateBearerToken(r *http.Request, secret string) (*BearerTokenClaims, er // ENHANCED HMAC VALIDATION: Try multiple payload formats for maximum compatibility var validPayload bool var payloadFormat string - + // Format 1: Network-resilient payload (mod_http_upload_hmac_network_resilient) - extendedPayload := fmt.Sprintf("%s\x00%s\x00%d\x00%d\x00%d\x00network_resilient", + extendedPayload := fmt.Sprintf("%s\x00%s\x00%d\x00%d\x00%d\x00network_resilient", user, filename, size, expiry-86400, expiry) h1 := hmac.New(sha256.New, []byte(secret)) h1.Write([]byte(extendedPayload)) expectedMAC1 := h1.Sum(nil) - + if hmac.Equal(tokenBytes, expectedMAC1) { validPayload = true payloadFormat = "network_resilient" } - + // Format 2: Extended payload with session support if !validPayload { sessionPayload := fmt.Sprintf("%s\x00%s\x00%d\x00%d\x00%s", user, filename, size, expiry, sessionId) h2 := hmac.New(sha256.New, []byte(secret)) h2.Write([]byte(sessionPayload)) expectedMAC2 := h2.Sum(nil) - + if hmac.Equal(tokenBytes, expectedMAC2) { validPayload = true payloadFormat = "session_based" } } - + // Format 3: Standard payload (original mod_http_upload_hmac) if !validPayload { standardPayload := fmt.Sprintf("%s\x00%s\x00%d\x00%d", user, filename, size, expiry-3600) h3 := hmac.New(sha256.New, []byte(secret)) h3.Write([]byte(standardPayload)) expectedMAC3 := h3.Sum(nil) - + if hmac.Equal(tokenBytes, expectedMAC3) { validPayload = true payloadFormat = "standard" } } - + // Format 4: Simplified payload (fallback compatibility) if !validPayload { simplePayload := fmt.Sprintf("%s\x00%s\x00%d", user, filename, size) h4 := hmac.New(sha256.New, []byte(secret)) h4.Write([]byte(simplePayload)) expectedMAC4 := h4.Sum(nil) - + if hmac.Equal(tokenBytes, expectedMAC4) { validPayload = true payloadFormat = "simple" } } - + // Format 5: User-only payload (maximum fallback) if !validPayload { userPayload := fmt.Sprintf("%s\x00%d", user, expiry) h5 := hmac.New(sha256.New, []byte(secret)) h5.Write([]byte(userPayload)) expectedMAC5 := h5.Sum(nil) - + if hmac.Equal(tokenBytes, expectedMAC5) { validPayload = true payloadFormat = "user_only" } } - + if !validPayload { log.Warnf("❌ Invalid Bearer token HMAC for user %s, file %s (tried all 5 payload formats)", user, filename) return nil, errors.New("invalid Bearer token HMAC") @@ -1890,16 +1909,16 @@ func validateBearerToken(r *http.Request, secret string) (*BearerTokenClaims, er Expiry: expiry, } - log.Infof("✅ Bearer token authentication SUCCESSFUL: user=%s, file=%s, format=%s, grace_period=%d seconds", + log.Infof("✅ Bearer token authentication SUCCESSFUL: user=%s, file=%s, format=%s, grace_period=%d seconds", user, filename, payloadFormat, gracePeriod) - + return claims, nil } // evaluateSecurityLevel determines the required security level based on network changes and standby detection func evaluateSecurityLevel(session *NetworkResilientSession, currentIP string, userAgent string) int { now := time.Now() - + // Initialize if this is the first check if session.LastSecurityCheck.IsZero() { session.LastSecurityCheck = now @@ -1907,50 +1926,50 @@ func evaluateSecurityLevel(session *NetworkResilientSession, currentIP string, u session.SecurityLevel = 1 // Normal level return 1 } - + // Detect potential standby scenario timeSinceLastActivity := now.Sub(session.LastActivity) standbyThreshold := 30 * time.Minute - + if timeSinceLastActivity > standbyThreshold { session.StandbyDetected = true log.Infof("🔒 STANDBY DETECTED: %v since last activity for session %s", timeSinceLastActivity, session.SessionID) - + // Long standby requires full re-authentication if timeSinceLastActivity > 2*time.Hour { log.Warnf("🔐 SECURITY LEVEL 3: Long standby (%v) requires full re-authentication", timeSinceLastActivity) return 3 } - + // Medium standby requires challenge-response log.Infof("🔐 SECURITY LEVEL 2: Medium standby (%v) requires challenge-response", timeSinceLastActivity) return 2 } - + // Detect network changes if session.LastIP != "" && session.LastIP != currentIP { session.NetworkChangeCount++ - log.Infof("🌐 NETWORK CHANGE #%d: %s → %s for session %s", + log.Infof("🌐 NETWORK CHANGE #%d: %s → %s for session %s", session.NetworkChangeCount, session.LastIP, currentIP, session.SessionID) - + // Multiple rapid network changes are suspicious if session.NetworkChangeCount > 3 { - log.Warnf("🔐 SECURITY LEVEL 3: Multiple network changes (%d) requires full re-authentication", + log.Warnf("🔐 SECURITY LEVEL 3: Multiple network changes (%d) requires full re-authentication", session.NetworkChangeCount) return 3 } - + // Single network change requires challenge-response log.Infof("🔐 SECURITY LEVEL 2: Network change requires challenge-response") return 2 } - + // Check for suspicious user agent changes if session.UserAgent != "" && session.UserAgent != userAgent { log.Warnf("🔐 SECURITY LEVEL 3: User agent change detected - potential device hijacking") return 3 } - + // Normal operation return 1 } @@ -1960,11 +1979,11 @@ func generateSecurityChallenge(session *NetworkResilientSession, secret string) // Create a time-based challenge using session data timestamp := time.Now().Unix() challengeData := fmt.Sprintf("%s:%s:%d", session.SessionID, session.UserJID, timestamp) - + h := hmac.New(sha256.New, []byte(secret)) h.Write([]byte(challengeData)) challenge := hex.EncodeToString(h.Sum(nil)) - + log.Infof("🔐 Generated security challenge for session %s", session.SessionID) return challenge, nil } @@ -1974,22 +1993,22 @@ func validateSecurityChallenge(session *NetworkResilientSession, providedRespons // This would validate against the expected response // For now, we'll implement a simple time-window validation timestamp := time.Now().Unix() - + // Allow 5-minute window for challenge responses for i := int64(0); i <= 300; i += 60 { testTimestamp := timestamp - i challengeData := fmt.Sprintf("%s:%s:%d", session.SessionID, session.UserJID, testTimestamp) - + h := hmac.New(sha256.New, []byte(secret)) h.Write([]byte(challengeData)) expectedResponse := hex.EncodeToString(h.Sum(nil)) - + if expectedResponse == providedResponse { log.Infof("✅ Security challenge validated for session %s", session.SessionID) return true } } - + log.Warnf("❌ Security challenge failed for session %s", session.SessionID) return false } @@ -2029,17 +2048,17 @@ func validateBearerTokenWithSession(r *http.Request, secret string) (*BearerToke session := sessionStore.GetSession(sessionID) if session == nil { session = &NetworkResilientSession{ - SessionID: sessionID, - UserJID: claims.User, - OriginalToken: getBearerTokenFromRequest(r), - CreatedAt: time.Now(), - MaxRefreshes: 10, - NetworkHistory: []NetworkEvent{}, - SecurityLevel: 1, - LastSecurityCheck: time.Now(), + SessionID: sessionID, + UserJID: claims.User, + OriginalToken: getBearerTokenFromRequest(r), + CreatedAt: time.Now(), + MaxRefreshes: 10, + NetworkHistory: []NetworkEvent{}, + SecurityLevel: 1, + LastSecurityCheck: time.Now(), NetworkChangeCount: 0, - StandbyDetected: false, - LastActivity: time.Now(), + StandbyDetected: false, + LastActivity: time.Now(), } } @@ -2069,7 +2088,7 @@ func validateBearerTokenWithSession(r *http.Request, secret string) (*BearerToke log.Errorf("❌ Failed to generate security challenge: %v", err) return nil, fmt.Errorf("security challenge generation failed") } - + // Check if client provided challenge response challengeResponse := r.Header.Get("X-Challenge-Response") if challengeResponse == "" { @@ -2077,15 +2096,15 @@ func validateBearerTokenWithSession(r *http.Request, secret string) (*BearerToke setSecurityHeaders(w, 2, challenge) return nil, fmt.Errorf("challenge-response required for network change") } - + // Validate challenge response if !validateSecurityChallenge(session, challengeResponse, secret) { setSecurityHeaders(w, 2, challenge) return nil, fmt.Errorf("invalid challenge response") } - + log.Infof("✅ Challenge-response validated for session %s", sessionID) - + case 3: // Full re-authentication required setSecurityHeaders(w, 3, "") @@ -2104,7 +2123,7 @@ func validateBearerTokenWithSession(r *http.Request, secret string) (*BearerToke UserAgent: userAgent, EventType: "network_switch", }) - log.Infof("🌐 Network switch detected for session %s: %s → %s", + log.Infof("🌐 Network switch detected for session %s: %s → %s", sessionID, session.LastIP, currentIP) } @@ -2138,7 +2157,7 @@ func validateBearerTokenWithSession(r *http.Request, secret string) (*BearerToke // Token refresh successful session.RefreshCount++ session.LastSeen = time.Now() - + // Add refresh event to history session.NetworkHistory = append(session.NetworkHistory, NetworkEvent{ Timestamp: time.Now(), @@ -2157,12 +2176,12 @@ func validateBearerTokenWithSession(r *http.Request, secret string) (*BearerToke Expiry: time.Now().Add(24 * time.Hour).Unix(), } - log.Infof("✅ Session recovery successful: %s (refresh #%d)", + log.Infof("✅ Session recovery successful: %s (refresh #%d)", sessionID, session.RefreshCount) return refreshedClaims, nil } } else { - log.Warnf("❌ Session %s exceeded maximum refreshes (%d)", + log.Warnf("❌ Session %s exceeded maximum refreshes (%d)", sessionID, session.MaxRefreshes) } } else { @@ -2191,8 +2210,8 @@ func refreshSessionToken(session *NetworkResilientSession, secret string, r *htt size := extractSizeFromRequest(r) // Use session-based payload format for refresh - payload := fmt.Sprintf("%s\x00%s\x00%d\x00%d\x00%s\x00session_refresh", - session.UserJID, + payload := fmt.Sprintf("%s\x00%s\x00%d\x00%d\x00%s\x00session_refresh", + session.UserJID, filename, size, expiry, @@ -2202,7 +2221,7 @@ func refreshSessionToken(session *NetworkResilientSession, secret string, r *htt h.Write([]byte(payload)) token := base64.StdEncoding.EncodeToString(h.Sum(nil)) - log.Infof("🆕 Generated refresh token for session %s (refresh #%d)", + log.Infof("🆕 Generated refresh token for session %s (refresh #%d)", session.SessionID, session.RefreshCount+1) return token, nil @@ -2251,7 +2270,7 @@ type BearerTokenClaims struct { // ENHANCED FOR 100% WIFI ↔ LTE SWITCHING AND STANDBY RECOVERY RELIABILITY func validateHMAC(r *http.Request, secret string) error { log.Debugf("🔍 validateHMAC: Validating request to %s with query: %s", r.URL.Path, r.URL.RawQuery) - + // Check for X-Signature header (for POST uploads) signature := r.Header.Get("X-Signature") if signature != "" { @@ -2294,7 +2313,7 @@ func validateHMAC(r *http.Request, secret string) error { // ENHANCED HMAC CALCULATION: Try multiple formats for maximum compatibility var validMAC bool var messageFormat string - + // Calculate HMAC based on protocol version with enhanced compatibility mac := hmac.New(sha256.New, []byte(secret)) @@ -2305,7 +2324,7 @@ func validateHMAC(r *http.Request, secret string) error { mac.Write([]byte(message1)) calculatedMAC1 := mac.Sum(nil) calculatedMACHex1 := hex.EncodeToString(calculatedMAC1) - + // Decode provided MAC if providedMAC, err := hex.DecodeString(providedMACHex); err == nil { if hmac.Equal(calculatedMAC1, providedMAC) { @@ -2314,14 +2333,14 @@ func validateHMAC(r *http.Request, secret string) error { log.Debugf("✅ Legacy v protocol HMAC validated: %s", calculatedMACHex1) } } - + // Format 2: Try without content length for compatibility if !validMAC { message2 := fileStorePath mac.Reset() mac.Write([]byte(message2)) calculatedMAC2 := mac.Sum(nil) - + if providedMAC, err := hex.DecodeString(providedMACHex); err == nil { if hmac.Equal(calculatedMAC2, providedMAC) { validMAC = true @@ -2333,14 +2352,14 @@ func validateHMAC(r *http.Request, secret string) error { } else { // v2 and token protocols: Enhanced format compatibility contentType := GetContentType(fileStorePath) - + // Format 1: Standard format - fileStorePath + "\x00" + contentLength + "\x00" + contentType message1 := fileStorePath + "\x00" + strconv.FormatInt(r.ContentLength, 10) + "\x00" + contentType mac.Reset() mac.Write([]byte(message1)) calculatedMAC1 := mac.Sum(nil) calculatedMACHex1 := hex.EncodeToString(calculatedMAC1) - + if providedMAC, err := hex.DecodeString(providedMACHex); err == nil { if hmac.Equal(calculatedMAC1, providedMAC) { validMAC = true @@ -2348,14 +2367,14 @@ func validateHMAC(r *http.Request, secret string) error { log.Debugf("✅ %s protocol HMAC validated (standard): %s", protocolVersion, calculatedMACHex1) } } - + // Format 2: Without content type for compatibility if !validMAC { message2 := fileStorePath + "\x00" + strconv.FormatInt(r.ContentLength, 10) mac.Reset() mac.Write([]byte(message2)) calculatedMAC2 := mac.Sum(nil) - + if providedMAC, err := hex.DecodeString(providedMACHex); err == nil { if hmac.Equal(calculatedMAC2, providedMAC) { validMAC = true @@ -2364,14 +2383,14 @@ func validateHMAC(r *http.Request, secret string) error { } } } - + // Format 3: Simple path only for maximum compatibility if !validMAC { message3 := fileStorePath mac.Reset() mac.Write([]byte(message3)) calculatedMAC3 := mac.Sum(nil) - + if providedMAC, err := hex.DecodeString(providedMACHex); err == nil { if hmac.Equal(calculatedMAC3, providedMAC) { validMAC = true @@ -2387,7 +2406,7 @@ func validateHMAC(r *http.Request, secret string) error { return fmt.Errorf("invalid MAC for %s protocol", protocolVersion) } - log.Infof("✅ %s HMAC authentication SUCCESSFUL: format=%s, path=%s", + log.Infof("✅ %s HMAC authentication SUCCESSFUL: format=%s, path=%s", protocolVersion, messageFormat, r.URL.Path) return nil } @@ -2417,11 +2436,11 @@ func validateV3HMAC(r *http.Request, secret string) error { // ULTRA-FLEXIBLE GRACE PERIODS FOR V3 PROTOCOL NETWORK SWITCHING now := time.Now().Unix() - + if now > expires { // Base grace period: 8 hours (significantly increased for WiFi ↔ LTE reliability) gracePeriod := int64(28800) // 8 hours base grace period - + // Enhanced mobile XMPP client detection userAgent := r.Header.Get("User-Agent") isMobileXMPP := strings.Contains(strings.ToLower(userAgent), "gajim") || @@ -2432,12 +2451,12 @@ func validateV3HMAC(r *http.Request, secret string) error { strings.Contains(strings.ToLower(userAgent), "xmpp") || strings.Contains(strings.ToLower(userAgent), "client") || strings.Contains(strings.ToLower(userAgent), "bot") - + if isMobileXMPP { gracePeriod = int64(43200) // 12 hours for mobile XMPP clients log.Infof("📱 V3: Mobile XMPP client detected (%s), using 12-hour grace period", userAgent) } - + // Network resilience parameters for V3 protocol sessionId := query.Get("session_id") networkResilience := query.Get("network_resilience") @@ -2446,19 +2465,19 @@ func validateV3HMAC(r *http.Request, secret string) error { gracePeriod = int64(86400) // 24 hours for network resilience scenarios log.Infof("🌐 V3: Network resilience mode detected, using 24-hour grace period") } - + // Detect network switching indicators clientIP := getClientIP(r) xForwardedFor := r.Header.Get("X-Forwarded-For") xRealIP := r.Header.Get("X-Real-IP") - + if xForwardedFor != "" || xRealIP != "" { // Client behind proxy/NAT - likely mobile network switching gracePeriod = int64(86400) // 24 hours for proxy/NAT scenarios - log.Infof("🔄 V3: Network switching detected (IP: %s, X-Forwarded-For: %s), using 24-hour grace period", + log.Infof("🔄 V3: Network switching detected (IP: %s, X-Forwarded-For: %s), using 24-hour grace period", clientIP, xForwardedFor) } - + // Large file uploads get additional grace time if contentLengthStr := r.Header.Get("Content-Length"); contentLengthStr != "" { if contentLength, parseErr := strconv.ParseInt(contentLengthStr, 10, 64); parseErr == nil { @@ -2466,33 +2485,33 @@ func validateV3HMAC(r *http.Request, secret string) error { if contentLength > 10*1024*1024 { additionalTime := (contentLength / (10 * 1024 * 1024)) * 3600 // 1 hour per 10MB gracePeriod += additionalTime - log.Infof("📁 V3: Large file (%d bytes), extending grace period by %d seconds", + log.Infof("📁 V3: Large file (%d bytes), extending grace period by %d seconds", contentLength, additionalTime) } } } - + // Maximum grace period cap: 48 hours maxGracePeriod := int64(172800) // 48 hours absolute maximum if gracePeriod > maxGracePeriod { gracePeriod = maxGracePeriod log.Infof("⚠️ V3: Grace period capped at 48 hours maximum") } - + // STANDBY RECOVERY: Handle device standby scenarios expiredTime := now - expires standbyGraceExtension := int64(86400) // Additional 24 hours for standby - isLikelyStandbyRecovery := expiredTime > gracePeriod && expiredTime < (gracePeriod + standbyGraceExtension) - + isLikelyStandbyRecovery := expiredTime > gracePeriod && expiredTime < (gracePeriod+standbyGraceExtension) + if expiredTime > gracePeriod && !isLikelyStandbyRecovery { // Ultra-generous final check for mobile scenarios ultraMaxGrace := int64(259200) // 72 hours ultra-maximum for critical scenarios if isMobileXMPP && expiredTime < ultraMaxGrace { log.Warnf("⚡ V3 ULTRA-GRACE: Mobile client token expired %d seconds ago, allowing within 72-hour window", expiredTime) } else { - log.Warnf("❌ V3 signature expired beyond all grace periods: now=%d, expires=%d, expired_for=%d seconds, grace_period=%d, user_agent=%s", + log.Warnf("❌ V3 signature expired beyond all grace periods: now=%d, expires=%d, expired_for=%d seconds, grace_period=%d, user_agent=%s", now, expires, expiredTime, gracePeriod, userAgent) - return fmt.Errorf("signature has expired beyond grace period (expired %d seconds ago, grace period: %d seconds)", + return fmt.Errorf("signature has expired beyond grace period (expired %d seconds ago, grace period: %d seconds)", expiredTime, gracePeriod) } } else if isLikelyStandbyRecovery { @@ -2507,18 +2526,18 @@ func validateV3HMAC(r *http.Request, secret string) error { // ENHANCED MESSAGE CONSTRUCTION: Try multiple formats for compatibility var validSignature bool var messageFormat string - + // Format 1: Standard v3 format message1 := fmt.Sprintf("%s\n%s\n%s", r.Method, expiresStr, r.URL.Path) h1 := hmac.New(sha256.New, []byte(secret)) h1.Write([]byte(message1)) expectedSignature1 := hex.EncodeToString(h1.Sum(nil)) - + if hmac.Equal([]byte(signature), []byte(expectedSignature1)) { validSignature = true messageFormat = "standard_v3" } - + // Format 2: Alternative format with query string if !validSignature { pathWithQuery := r.URL.Path @@ -2529,32 +2548,32 @@ func validateV3HMAC(r *http.Request, secret string) error { h2 := hmac.New(sha256.New, []byte(secret)) h2.Write([]byte(message2)) expectedSignature2 := hex.EncodeToString(h2.Sum(nil)) - + if hmac.Equal([]byte(signature), []byte(expectedSignature2)) { validSignature = true messageFormat = "with_query" } } - + // Format 3: Simplified format (fallback) if !validSignature { message3 := fmt.Sprintf("%s\n%s", r.Method, r.URL.Path) h3 := hmac.New(sha256.New, []byte(secret)) h3.Write([]byte(message3)) expectedSignature3 := hex.EncodeToString(h3.Sum(nil)) - + if hmac.Equal([]byte(signature), []byte(expectedSignature3)) { validSignature = true messageFormat = "simplified" } } - + if !validSignature { log.Warnf("❌ Invalid V3 HMAC signature (tried all 3 formats)") return errors.New("invalid v3 HMAC signature") } - log.Infof("✅ V3 HMAC authentication SUCCESSFUL: format=%s, method=%s, path=%s", + log.Infof("✅ V3 HMAC authentication SUCCESSFUL: format=%s, method=%s, path=%s", messageFormat, r.Method, r.URL.Path) return nil } @@ -2563,7 +2582,7 @@ func validateV3HMAC(r *http.Request, secret string) error { func copyWithProgressTracking(dst io.Writer, src io.Reader, buf []byte, totalSize int64, clientIP string) (int64, error) { var written int64 lastLogTime := time.Now() - + for { n, err := src.Read(buf) if n > 0 { @@ -2572,12 +2591,12 @@ func copyWithProgressTracking(dst io.Writer, src io.Reader, buf []byte, totalSiz if werr != nil { return written, werr } - + // Log progress for large files every 10MB or 30 seconds - if totalSize > 50*1024*1024 && + if totalSize > 50*1024*1024 && (written%10*1024*1024 == 0 || time.Since(lastLogTime) > 30*time.Second) { progress := float64(written) / float64(totalSize) * 100 - log.Infof("📥 Download progress: %.1f%% (%s/%s) for IP %s", + log.Infof("📥 Download progress: %.1f%% (%s/%s) for IP %s", progress, formatBytes(written), formatBytes(totalSize), clientIP) lastLogTime = time.Now() } @@ -2589,7 +2608,7 @@ func copyWithProgressTracking(dst io.Writer, src io.Reader, buf []byte, totalSiz return written, err } } - + return written, nil } @@ -2606,11 +2625,11 @@ func handleUpload(w http.ResponseWriter, r *http.Request) { // Generate session ID for multi-upload tracking sessionID = generateUploadSessionID("upload", r.Header.Get("User-Agent"), getClientIP(r)) } - + // Set session headers for client continuation w.Header().Set("X-Session-ID", sessionID) w.Header().Set("X-Upload-Session-Timeout", "3600") // 1 hour - + // Only allow POST method if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) @@ -2621,22 +2640,22 @@ func handleUpload(w http.ResponseWriter, r *http.Request) { // ENHANCED AUTHENTICATION with network switching support var bearerClaims *BearerTokenClaims authHeader := r.Header.Get("Authorization") - + if strings.HasPrefix(authHeader, "Bearer ") { // Bearer token authentication with session recovery for network switching // Store response writer in context for session headers ctx := context.WithValue(r.Context(), responseWriterKey, w) r = r.WithContext(ctx) - + claims, err := validateBearerTokenWithSession(r, conf.Security.Secret) if err != nil { // Enhanced error logging for network switching scenarios clientIP := getClientIP(r) userAgent := r.Header.Get("User-Agent") sessionID := getSessionIDFromRequest(r) - log.Warnf("🔴 Authentication failed for IP %s, User-Agent: %s, Session: %s, Error: %v", + log.Warnf("🔴 Authentication failed for IP %s, User-Agent: %s, Session: %s, Error: %v", clientIP, userAgent, sessionID, err) - + // Check if this might be a network switching scenario and provide helpful response if strings.Contains(err.Error(), "expired") || strings.Contains(err.Error(), "invalid") { w.Header().Set("X-Network-Switch-Detected", "true") @@ -2646,15 +2665,17 @@ func handleUpload(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Session-ID", sessionID) } } - + + AuditAuthFailure(r, "bearer_token", err.Error()) http.Error(w, fmt.Sprintf("Bearer Token Authentication failed: %v", err), http.StatusUnauthorized) uploadErrorsTotal.Inc() return } + AuditAuthSuccess(r, claims.User, "bearer_token") bearerClaims = claims - log.Infof("✅ Bearer token authentication successful: user=%s, file=%s, IP=%s", + log.Infof("✅ Bearer token authentication successful: user=%s, file=%s, IP=%s", claims.User, claims.Filename, getClientIP(r)) - + // Add comprehensive response headers for audit logging and client tracking w.Header().Set("X-Authenticated-User", claims.User) w.Header().Set("X-Auth-Method", "Bearer-Token") @@ -2665,10 +2686,12 @@ func handleUpload(w http.ResponseWriter, r *http.Request) { _, err := validateJWTFromRequest(r, conf.Security.JWTSecret) if err != nil { log.Warnf("🔴 JWT Authentication failed for IP %s: %v", getClientIP(r), err) + AuditAuthFailure(r, "jwt", err.Error()) http.Error(w, fmt.Sprintf("JWT Authentication failed: %v", err), http.StatusUnauthorized) uploadErrorsTotal.Inc() return } + AuditAuthSuccess(r, "", "jwt") log.Infof("✅ JWT authentication successful for upload request: %s", r.URL.Path) w.Header().Set("X-Auth-Method", "JWT") } else { @@ -2676,10 +2699,12 @@ func handleUpload(w http.ResponseWriter, r *http.Request) { err := validateHMAC(r, conf.Security.Secret) if err != nil { log.Warnf("🔴 HMAC Authentication failed for IP %s: %v", getClientIP(r), err) + AuditAuthFailure(r, "hmac", err.Error()) http.Error(w, fmt.Sprintf("HMAC Authentication failed: %v", err), http.StatusUnauthorized) uploadErrorsTotal.Inc() return } + AuditAuthSuccess(r, "", "hmac") log.Infof("✅ HMAC authentication successful for upload request: %s", r.URL.Path) w.Header().Set("X-Auth-Method", "HMAC") } @@ -2699,30 +2724,30 @@ func handleUpload(w http.ResponseWriter, r *http.Request) { // Generate new session ID with enhanced entropy sessionID = generateSessionID("", "") } - + clientIP := getClientIP(r) - + // Detect potential network switching xForwardedFor := r.Header.Get("X-Forwarded-For") xRealIP := r.Header.Get("X-Real-IP") networkSwitchIndicators := xForwardedFor != "" || xRealIP != "" - + if networkSwitchIndicators { - log.Infof("🔄 Network switching indicators detected: session=%s, client_ip=%s, x_forwarded_for=%s, x_real_ip=%s", + log.Infof("🔄 Network switching indicators detected: session=%s, client_ip=%s, x_forwarded_for=%s, x_real_ip=%s", sessionID, clientIP, xForwardedFor, xRealIP) w.Header().Set("X-Network-Switch-Detected", "true") } - + clientSession = clientTracker.TrackClientSession(sessionID, clientIP, r) - + // Enhanced session response headers for client coordination w.Header().Set("X-Upload-Session-ID", sessionID) w.Header().Set("X-Session-IP-Count", fmt.Sprintf("%d", len(clientSession.ClientIPs))) w.Header().Set("X-Connection-Type", clientSession.ConnectionType) - - log.Infof("🔗 Client session tracking: %s from IP %s (connection: %s, total_ips: %d)", + + log.Infof("🔗 Client session tracking: %s from IP %s (connection: %s, total_ips: %d)", sessionID, clientIP, clientSession.ConnectionType, len(clientSession.ClientIPs)) - + // Add user context for Bearer token authentication if bearerClaims != nil { log.Infof("👤 Session associated with XMPP user: %s", bearerClaims.User) @@ -2749,6 +2774,57 @@ func handleUpload(w http.ResponseWriter, r *http.Request) { } defer file.Close() + // Get user JID for quota and audit tracking + var userJID string + if bearerClaims != nil { + userJID = bearerClaims.User + } + r.Header.Set("X-User-JID", userJID) + r.Header.Set("X-File-Name", header.Filename) + + // Check quota before upload + if qm := GetQuotaManager(); qm != nil && qm.config.Enabled && userJID != "" { + canUpload, _ := qm.CanUpload(r.Context(), userJID, header.Size) + if !canUpload { + used, limit, _ := qm.GetUsage(r.Context(), userJID) + AuditQuotaExceeded(r, userJID, used, limit, header.Size) + w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Quota-Used", fmt.Sprintf("%d", used)) + w.Header().Set("X-Quota-Limit", fmt.Sprintf("%d", limit)) + w.WriteHeader(http.StatusRequestEntityTooLarge) + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "error": "quota_exceeded", + "message": "Storage quota exceeded", + "used": used, + "limit": limit, + "requested": header.Size, + }) + uploadErrorsTotal.Inc() + return + } + } + + // Content type validation using magic bytes + var fileReader io.Reader = file + declaredContentType := header.Header.Get("Content-Type") + detectedContentType := declaredContentType + + if validator := GetContentValidator(); validator != nil && validator.config.CheckMagicBytes { + validatedReader, detected, validErr := validator.ValidateContent(file, declaredContentType, header.Size) + if validErr != nil { + if valErr, ok := validErr.(*ValidationError); ok { + AuditValidationFailure(r, userJID, header.Filename, declaredContentType, detected, valErr.Code) + WriteValidationError(w, valErr) + } else { + http.Error(w, validErr.Error(), http.StatusUnsupportedMediaType) + } + uploadErrorsTotal.Inc() + return + } + fileReader = validatedReader + detectedContentType = detected + } + // Validate file size against max_upload_size if configured if conf.Server.MaxUploadSize != "" { maxSizeBytes, err := parseSize(conf.Server.MaxUploadSize) @@ -2759,9 +2835,9 @@ func handleUpload(w http.ResponseWriter, r *http.Request) { return } if header.Size > maxSizeBytes { - log.Warnf("⚠️ File size %s exceeds maximum allowed size %s (IP: %s)", + log.Warnf("⚠️ File size %s exceeds maximum allowed size %s (IP: %s)", formatBytes(header.Size), conf.Server.MaxUploadSize, getClientIP(r)) - http.Error(w, fmt.Sprintf("File size %s exceeds maximum allowed size %s", + http.Error(w, fmt.Sprintf("File size %s exceeds maximum allowed size %s", formatBytes(header.Size), conf.Server.MaxUploadSize), http.StatusRequestEntityTooLarge) uploadErrorsTotal.Inc() return @@ -2815,20 +2891,20 @@ func handleUpload(w http.ResponseWriter, r *http.Request) { uploadsTotal.Inc() uploadSizeBytes.Observe(float64(existingFileInfo.Size())) filesDeduplicatedTotal.Inc() - + w.Header().Set("Content-Type", "application/json") w.Header().Set("X-Deduplication-Hit", "true") w.WriteHeader(http.StatusOK) response := map[string]interface{}{ - "success": true, - "filename": filename, - "size": existingFileInfo.Size(), - "message": "File already exists (deduplication hit)", + "success": true, + "filename": filename, + "size": existingFileInfo.Size(), + "message": "File already exists (deduplication hit)", "upload_time": duration.String(), } _ = json.NewEncoder(w).Encode(response) - - log.Infof("💾 Deduplication hit: file %s already exists (%s), returning success immediately (IP: %s)", + + log.Infof("💾 Deduplication hit: file %s already exists (%s), returning success immediately (IP: %s)", filename, formatBytes(existingFileInfo.Size()), getClientIP(r)) return } @@ -2855,30 +2931,43 @@ func handleUpload(w http.ResponseWriter, r *http.Request) { uploadCtx = networkManager.RegisterUpload(networkSessionID) defer networkManager.UnregisterUpload(networkSessionID) log.Infof("🌐 Registered upload with network resilience: session=%s, IP=%s", networkSessionID, getClientIP(r)) - + // Add network resilience headers w.Header().Set("X-Network-Resilience", "enabled") w.Header().Set("X-Upload-Context-ID", networkSessionID) } // Copy file content with network resilience support and enhanced progress tracking - written, err := copyWithNetworkResilience(dst, file, uploadCtx) + // Use fileReader which may be wrapped with content validation + written, err := copyWithNetworkResilience(dst, fileReader, uploadCtx) if err != nil { log.Errorf("🔴 Error saving file %s (IP: %s, session: %s): %v", filename, getClientIP(r), sessionID, err) http.Error(w, fmt.Sprintf("Error saving file: %v", err), http.StatusInternalServerError) uploadErrorsTotal.Inc() // Clean up partial file os.Remove(absFilename) + // Audit the failure + AuditUploadFailure(r, userJID, header.Filename, header.Size, err.Error()) return } + // Update quota after successful upload + if qm := GetQuotaManager(); qm != nil && qm.config.Enabled && userJID != "" { + if err := qm.RecordUpload(r.Context(), userJID, absFilename, written); err != nil { + log.Warnf("⚠️ Failed to update quota for user %s: %v", userJID, err) + } + } + + // Audit successful upload + AuditUploadSuccess(r, userJID, filename, written, detectedContentType) + // ✅ CRITICAL FIX: Send immediate success response for large files (>1GB) // This prevents client timeouts while server does post-processing isLargeFile := header.Size > 1024*1024*1024 // 1GB threshold - + if isLargeFile { log.Infof("🚀 Large file detected (%s), sending immediate success response", formatBytes(header.Size)) - + // Send immediate success response to client duration := time.Since(startTime) uploadDuration.Observe(duration.Seconds()) @@ -2893,12 +2982,12 @@ func handleUpload(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) response := map[string]interface{}{ - "success": true, - "filename": filename, - "size": written, - "duration": duration.String(), - "client_ip": getClientIP(r), - "timestamp": time.Now().Unix(), + "success": true, + "filename": filename, + "size": written, + "duration": duration.String(), + "client_ip": getClientIP(r), + "timestamp": time.Now().Unix(), "post_processing": "background", } @@ -2921,7 +3010,7 @@ func handleUpload(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, `{"success": true, "filename": "%s", "size": %d, "post_processing": "background"}`, filename, written) } - log.Infof("✅ Immediate response sent for large file %s (%s) in %s from IP %s", + log.Infof("✅ Immediate response sent for large file %s (%s) in %s from IP %s", filename, formatBytes(written), duration, getClientIP(r)) // Process deduplication asynchronously for large files @@ -2936,7 +3025,7 @@ func handleUpload(w http.ResponseWriter, r *http.Request) { log.Infof("✅ Background deduplication completed for %s", filename) } } - + // Add to scan queue for virus scanning if enabled if conf.ClamAV.ClamAVEnabled && len(conf.ClamAV.ScanFileExtensions) > 0 { ext := strings.ToLower(filepath.Ext(header.Filename)) @@ -2958,7 +3047,7 @@ func handleUpload(w http.ResponseWriter, r *http.Request) { } } }() - + return } @@ -2987,10 +3076,10 @@ func handleUpload(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) response := map[string]interface{}{ - "success": true, - "filename": filename, - "size": written, - "duration": duration.String(), + "success": true, + "filename": filename, + "size": written, + "duration": duration.String(), "client_ip": getClientIP(r), "timestamp": time.Now().Unix(), } @@ -3014,7 +3103,7 @@ func handleUpload(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, `{"success": true, "filename": "%s", "size": %d}`, filename, written) } - log.Infof("✅ Successfully uploaded %s (%s) in %s from IP %s (session: %s)", + log.Infof("✅ Successfully uploaded %s (%s) in %s from IP %s (session: %s)", filename, formatBytes(written), duration, getClientIP(r), sessionID) } @@ -3030,20 +3119,24 @@ func handleDownload(w http.ResponseWriter, r *http.Request) { _, err := validateJWTFromRequest(r, conf.Security.JWTSecret) if err != nil { log.Warnf("🔴 JWT Authentication failed for download from IP %s: %v", getClientIP(r), err) + AuditAuthFailure(r, "jwt", err.Error()) http.Error(w, fmt.Sprintf("JWT Authentication failed: %v", err), http.StatusUnauthorized) downloadErrorsTotal.Inc() return } + AuditAuthSuccess(r, "", "jwt") log.Infof("✅ JWT authentication successful for download request: %s", r.URL.Path) w.Header().Set("X-Auth-Method", "JWT") } else { err := validateHMAC(r, conf.Security.Secret) if err != nil { log.Warnf("🔴 HMAC Authentication failed for download from IP %s: %v", getClientIP(r), err) + AuditAuthFailure(r, "hmac", err.Error()) http.Error(w, fmt.Sprintf("HMAC Authentication failed: %v", err), http.StatusUnauthorized) downloadErrorsTotal.Inc() return } + AuditAuthSuccess(r, "", "hmac") log.Infof("✅ HMAC authentication successful for download request: %s", r.URL.Path) w.Header().Set("X-Auth-Method", "HMAC") } @@ -3060,13 +3153,13 @@ func handleDownload(w http.ResponseWriter, r *http.Request) { // Enhanced file path validation and construction var absFilename string var err error - + // Use storage path or ISO mount point storagePath := conf.Server.StoragePath if conf.ISO.Enabled { storagePath = conf.ISO.MountPoint } - + absFilename, err = sanitizeFilePath(storagePath, filename) if err != nil { log.Warnf("🔴 Invalid file path requested from IP %s: %s, error: %v", getClientIP(r), filename, err) @@ -3079,12 +3172,12 @@ func handleDownload(w http.ResponseWriter, r *http.Request) { fileInfo, err := os.Stat(absFilename) if os.IsNotExist(err) { log.Warnf("🔴 File not found: %s (requested by IP %s)", absFilename, getClientIP(r)) - + // Enhanced 404 response with network switching hints w.Header().Set("X-File-Not-Found", "true") w.Header().Set("X-Client-IP", getClientIP(r)) w.Header().Set("X-Network-Switch-Support", "enabled") - + // Check if this might be a network switching issue userAgent := r.Header.Get("User-Agent") isMobileXMPP := strings.Contains(strings.ToLower(userAgent), "conversations") || @@ -3093,13 +3186,13 @@ func handleDownload(w http.ResponseWriter, r *http.Request) { strings.Contains(strings.ToLower(userAgent), "android") || strings.Contains(strings.ToLower(userAgent), "mobile") || strings.Contains(strings.ToLower(userAgent), "xmpp") - + if isMobileXMPP { w.Header().Set("X-Mobile-Client-Detected", "true") w.Header().Set("X-Retry-Suggestion", "30") // Suggest retry after 30 seconds log.Infof("📱 Mobile XMPP client file not found - may be network switching issue: %s", userAgent) } - + http.Error(w, "File not found", http.StatusNotFound) downloadErrorsTotal.Inc() return @@ -3126,13 +3219,13 @@ func handleDownload(w http.ResponseWriter, r *http.Request) { if err == nil { break } - + if attempt < maxRetries { - log.Warnf("⚠️ Attempt %d/%d: Error opening file %s from IP %s: %v (retrying...)", + log.Warnf("⚠️ Attempt %d/%d: Error opening file %s from IP %s: %v (retrying...)", attempt, maxRetries, absFilename, getClientIP(r), err) time.Sleep(time.Duration(attempt) * time.Second) // Progressive backoff } else { - log.Errorf("🔴 Failed to open file %s after %d attempts from IP %s: %v", + log.Errorf("🔴 Failed to open file %s after %d attempts from IP %s: %v", absFilename, maxRetries, getClientIP(r), err) http.Error(w, fmt.Sprintf("Error opening file: %v", err), http.StatusInternalServerError) downloadErrorsTotal.Inc() @@ -3149,7 +3242,7 @@ func handleDownload(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Network-Switch-Support", "enabled") w.Header().Set("X-File-Path", filename) w.Header().Set("X-Download-Start-Time", fmt.Sprintf("%d", time.Now().Unix())) - + // Add cache control headers for mobile network optimization userAgent := r.Header.Get("User-Agent") isMobileXMPP := strings.Contains(strings.ToLower(userAgent), "conversations") || @@ -3158,7 +3251,7 @@ func handleDownload(w http.ResponseWriter, r *http.Request) { strings.Contains(strings.ToLower(userAgent), "android") || strings.Contains(strings.ToLower(userAgent), "mobile") || strings.Contains(strings.ToLower(userAgent), "xmpp") - + if isMobileXMPP { w.Header().Set("X-Mobile-Client-Detected", "true") w.Header().Set("Cache-Control", "public, max-age=86400") // 24 hours cache for mobile @@ -3173,7 +3266,7 @@ func handleDownload(w http.ResponseWriter, r *http.Request) { // Track download progress for large files if fileInfo.Size() > 10*1024*1024 { // Log progress for files > 10MB - log.Infof("📥 Starting download of %s (%.1f MiB) for IP %s", + log.Infof("📥 Starting download of %s (%.1f MiB) for IP %s", filepath.Base(absFilename), float64(fileInfo.Size())/(1024*1024), getClientIP(r)) } @@ -3191,8 +3284,11 @@ func handleDownload(w http.ResponseWriter, r *http.Request) { downloadDuration.Observe(duration.Seconds()) downloadsTotal.Inc() downloadSizeBytes.Observe(float64(n)) - - log.Infof("✅ Successfully downloaded %s (%s) in %s for IP %s (session complete)", + + // Audit successful download + AuditDownloadSuccess(r, "", filepath.Base(absFilename), n) + + log.Infof("✅ Successfully downloaded %s (%s) in %s for IP %s (session complete)", filepath.Base(absFilename), formatBytes(n), duration, getClientIP(r)) } @@ -3262,7 +3358,7 @@ func handleV3Upload(w http.ResponseWriter, r *http.Request) { return } if r.ContentLength > maxSizeBytes { - http.Error(w, fmt.Sprintf("File size %s exceeds maximum allowed size %s", + http.Error(w, fmt.Sprintf("File size %s exceeds maximum allowed size %s", formatBytes(r.ContentLength), conf.Server.MaxUploadSize), http.StatusRequestEntityTooLarge) uploadErrorsTotal.Inc() return @@ -3298,7 +3394,7 @@ func handleV3Upload(w http.ResponseWriter, r *http.Request) { uploadsTotal.Inc() uploadSizeBytes.Observe(float64(existingFileInfo.Size())) filesDeduplicatedTotal.Inc() - + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) response := map[string]interface{}{ @@ -3308,8 +3404,8 @@ func handleV3Upload(w http.ResponseWriter, r *http.Request) { "message": "File already exists (deduplication hit)", } _ = json.NewEncoder(w).Encode(response) - - log.Infof("Deduplication hit: file %s already exists (%s), returning success immediately", + + log.Infof("Deduplication hit: file %s already exists (%s), returning success immediately", filename, formatBytes(existingFileInfo.Size())) return } @@ -3337,10 +3433,10 @@ func handleV3Upload(w http.ResponseWriter, r *http.Request) { // ✅ CRITICAL FIX: Send immediate success response for large files (>1GB) // This prevents client timeouts while server does post-processing isLargeFile := written > 1024*1024*1024 // 1GB threshold - + if isLargeFile { log.Infof("🚀 Large file detected (%s), sending immediate success response (v3)", formatBytes(written)) - + // Send immediate success response to client duration := time.Since(startTime) uploadDuration.Observe(duration.Seconds()) @@ -3355,11 +3451,11 @@ func handleV3Upload(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) response := map[string]interface{}{ - "success": true, - "filename": filename, - "size": written, - "duration": duration.String(), - "protocol": "v3", + "success": true, + "filename": filename, + "size": written, + "duration": duration.String(), + "protocol": "v3", "post_processing": "background", } @@ -3370,7 +3466,7 @@ func handleV3Upload(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, `{"success": true, "filename": "%s", "size": %d, "post_processing": "background"}`, filename, written) } - log.Infof("✅ Immediate response sent for large file %s (%s) in %s via v3 protocol", + log.Infof("✅ Immediate response sent for large file %s (%s) in %s via v3 protocol", filename, formatBytes(written), duration) // Process deduplication asynchronously for large files @@ -3385,7 +3481,7 @@ func handleV3Upload(w http.ResponseWriter, r *http.Request) { log.Infof("✅ Background deduplication completed for %s (v3)", filename) } } - + // Add to scan queue for virus scanning if enabled if conf.ClamAV.ClamAVEnabled && len(conf.ClamAV.ScanFileExtensions) > 0 { ext := strings.ToLower(filepath.Ext(originalFilename)) @@ -3407,7 +3503,7 @@ func handleV3Upload(w http.ResponseWriter, r *http.Request) { } } }() - + return } @@ -3462,7 +3558,7 @@ func handleLegacyUpload(w http.ResponseWriter, r *http.Request) { // Generate session ID for XMPP multi-upload tracking sessionID = generateUploadSessionID("legacy", r.Header.Get("User-Agent"), getClientIP(r)) } - + // Set session headers for XMPP client continuation w.Header().Set("X-Session-ID", sessionID) w.Header().Set("X-Upload-Session-Timeout", "3600") // 1 hour @@ -3531,7 +3627,7 @@ func handleLegacyUpload(w http.ResponseWriter, r *http.Request) { return } if r.ContentLength > maxSizeBytes { - http.Error(w, fmt.Sprintf("File size %s exceeds maximum allowed size %s", + http.Error(w, fmt.Sprintf("File size %s exceeds maximum allowed size %s", formatBytes(r.ContentLength), conf.Server.MaxUploadSize), http.StatusRequestEntityTooLarge) uploadErrorsTotal.Inc() return @@ -3582,9 +3678,9 @@ func handleLegacyUpload(w http.ResponseWriter, r *http.Request) { uploadsTotal.Inc() uploadSizeBytes.Observe(float64(existingFileInfo.Size())) filesDeduplicatedTotal.Inc() - + w.WriteHeader(http.StatusCreated) // 201 Created for legacy compatibility - log.Infof("Deduplication hit: file %s already exists (%s), returning success immediately", + log.Infof("Deduplication hit: file %s already exists (%s), returning success immediately", filename, formatBytes(existingFileInfo.Size())) return } @@ -3617,10 +3713,10 @@ func handleLegacyUpload(w http.ResponseWriter, r *http.Request) { // ✅ CRITICAL FIX: Send immediate success response for large files (>1GB) // This prevents client timeouts while server does post-processing isLargeFile := written > 1024*1024*1024 // 1GB threshold - + if isLargeFile { log.Infof("🚀 Large file detected (%s), sending immediate success response (legacy)", formatBytes(written)) - + // Send immediate success response to client duration := time.Since(startTime) uploadDuration.Observe(duration.Seconds()) @@ -3634,7 +3730,7 @@ func handleLegacyUpload(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Post-Processing", "background") w.WriteHeader(http.StatusCreated) - log.Infof("✅ Immediate response sent for large file %s (%s) in %s via legacy protocol", + log.Infof("✅ Immediate response sent for large file %s (%s) in %s via legacy protocol", filename, formatBytes(written), duration) // Process deduplication asynchronously for large files @@ -3649,7 +3745,7 @@ func handleLegacyUpload(w http.ResponseWriter, r *http.Request) { log.Infof("✅ Background deduplication completed for %s (legacy)", filename) } } - + // Add to scan queue for virus scanning if enabled if conf.ClamAV.ClamAVEnabled && len(conf.ClamAV.ScanFileExtensions) > 0 { ext := strings.ToLower(filepath.Ext(fileStorePath)) @@ -3671,7 +3767,7 @@ func handleLegacyUpload(w http.ResponseWriter, r *http.Request) { } } }() - + return } diff --git a/cmd/server/quota.go b/cmd/server/quota.go new file mode 100644 index 0000000..2610599 --- /dev/null +++ b/cmd/server/quota.go @@ -0,0 +1,557 @@ +// quota.go - Per-user storage quota management + +package main + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strconv" + "strings" + "sync" + "time" + + "github.com/go-redis/redis/v8" + "github.com/sirupsen/logrus" +) + +// QuotaConfig holds quota configuration +type QuotaConfig struct { + Enabled bool `toml:"enabled" mapstructure:"enabled"` + Default string `toml:"default" mapstructure:"default"` // Default quota (e.g., "100MB") + Tracking string `toml:"tracking" mapstructure:"tracking"` // "redis" | "memory" + Custom map[string]string `toml:"custom" mapstructure:"custom"` // Custom quotas per JID +} + +// QuotaInfo contains quota information for a user +type QuotaInfo struct { + JID string `json:"jid"` + Used int64 `json:"used"` + Limit int64 `json:"limit"` + Remaining int64 `json:"remaining"` + FileCount int64 `json:"file_count"` + IsCustom bool `json:"is_custom"` +} + +// QuotaExceededError represents a quota exceeded error +type QuotaExceededError struct { + JID string `json:"jid"` + Used int64 `json:"used"` + Limit int64 `json:"limit"` + Requested int64 `json:"requested"` +} + +func (e *QuotaExceededError) Error() string { + return fmt.Sprintf("quota exceeded for %s: used %d, limit %d, requested %d", + e.JID, e.Used, e.Limit, e.Requested) +} + +// QuotaManager handles per-user storage quotas +type QuotaManager struct { + config *QuotaConfig + redisClient *redis.Client + defaultQuota int64 + customQuotas map[string]int64 + + // In-memory fallback when Redis is unavailable + memoryUsage map[string]int64 + memoryFiles map[string]map[string]int64 // jid -> filePath -> size + mutex sync.RWMutex +} + +var ( + quotaManager *QuotaManager + quotaOnce sync.Once +) + +// Redis key patterns +const ( + quotaUsedKey = "quota:%s:used" // quota:{jid}:used -> int64 + quotaFilesKey = "quota:%s:files" // quota:{jid}:files -> HASH {path: size} + quotaInfoKey = "quota:%s:info" // quota:{jid}:info -> JSON +) + +// InitQuotaManager initializes the quota manager +func InitQuotaManager(config *QuotaConfig, redisClient *redis.Client) error { + var initErr error + quotaOnce.Do(func() { + quotaManager = &QuotaManager{ + config: config, + redisClient: redisClient, + customQuotas: make(map[string]int64), + memoryUsage: make(map[string]int64), + memoryFiles: make(map[string]map[string]int64), + } + + // Parse default quota + if config.Default != "" { + quota, err := parseSize(config.Default) + if err != nil { + initErr = fmt.Errorf("invalid default quota: %w", err) + return + } + quotaManager.defaultQuota = quota + } else { + quotaManager.defaultQuota = 100 * 1024 * 1024 // 100MB default + } + + // Parse custom quotas + for jid, quotaStr := range config.Custom { + quota, err := parseSize(quotaStr) + if err != nil { + log.Warnf("Invalid custom quota for %s: %v", jid, err) + continue + } + quotaManager.customQuotas[strings.ToLower(jid)] = quota + } + + log.Infof("Quota manager initialized: enabled=%v, default=%s, custom=%d users, tracking=%s", + config.Enabled, config.Default, len(config.Custom), config.Tracking) + }) + + return initErr +} + +// GetQuotaManager returns the singleton quota manager +func GetQuotaManager() *QuotaManager { + return quotaManager +} + +// GetLimit returns the quota limit for a user +func (q *QuotaManager) GetLimit(jid string) int64 { + if jid == "" { + return q.defaultQuota + } + + jidLower := strings.ToLower(jid) + if custom, ok := q.customQuotas[jidLower]; ok { + return custom + } + return q.defaultQuota +} + +// GetUsage returns the current storage usage for a user +func (q *QuotaManager) GetUsage(ctx context.Context, jid string) (used, limit int64, err error) { + if !q.config.Enabled { + return 0, 0, nil + } + + limit = q.GetLimit(jid) + + // Try Redis first + if q.redisClient != nil && q.config.Tracking == "redis" { + key := fmt.Sprintf(quotaUsedKey, jid) + usedStr, err := q.redisClient.Get(ctx, key).Result() + if err == redis.Nil { + return 0, limit, nil + } + if err != nil { + log.Warnf("Failed to get quota from Redis, falling back to memory: %v", err) + } else { + used, _ = strconv.ParseInt(usedStr, 10, 64) + return used, limit, nil + } + } + + // Fallback to memory + q.mutex.RLock() + used = q.memoryUsage[jid] + q.mutex.RUnlock() + + return used, limit, nil +} + +// GetQuotaInfo returns detailed quota information for a user +func (q *QuotaManager) GetQuotaInfo(ctx context.Context, jid string) (*QuotaInfo, error) { + used, limit, err := q.GetUsage(ctx, jid) + if err != nil { + return nil, err + } + + fileCount := int64(0) + + // Get file count + if q.redisClient != nil && q.config.Tracking == "redis" { + key := fmt.Sprintf(quotaFilesKey, jid) + count, err := q.redisClient.HLen(ctx, key).Result() + if err == nil { + fileCount = count + } + } else { + q.mutex.RLock() + if files, ok := q.memoryFiles[jid]; ok { + fileCount = int64(len(files)) + } + q.mutex.RUnlock() + } + + _, isCustom := q.customQuotas[strings.ToLower(jid)] + + return &QuotaInfo{ + JID: jid, + Used: used, + Limit: limit, + Remaining: limit - used, + FileCount: fileCount, + IsCustom: isCustom, + }, nil +} + +// CanUpload checks if a user can upload a file of the given size +func (q *QuotaManager) CanUpload(ctx context.Context, jid string, size int64) (bool, error) { + if !q.config.Enabled { + return true, nil + } + + used, limit, err := q.GetUsage(ctx, jid) + if err != nil { + // On error, allow upload but log warning + log.Warnf("Failed to check quota for %s, allowing upload: %v", jid, err) + return true, nil + } + + return used+size <= limit, nil +} + +// RecordUpload records a file upload for quota tracking +func (q *QuotaManager) RecordUpload(ctx context.Context, jid, filePath string, size int64) error { + if !q.config.Enabled || jid == "" { + return nil + } + + // Try Redis first with atomic operation + if q.redisClient != nil && q.config.Tracking == "redis" { + pipe := q.redisClient.TxPipeline() + + usedKey := fmt.Sprintf(quotaUsedKey, jid) + filesKey := fmt.Sprintf(quotaFilesKey, jid) + + pipe.IncrBy(ctx, usedKey, size) + pipe.HSet(ctx, filesKey, filePath, size) + + _, err := pipe.Exec(ctx) + if err != nil { + log.Warnf("Failed to record upload in Redis: %v", err) + } else { + return nil + } + } + + // Fallback to memory + q.mutex.Lock() + defer q.mutex.Unlock() + + q.memoryUsage[jid] += size + + if q.memoryFiles[jid] == nil { + q.memoryFiles[jid] = make(map[string]int64) + } + q.memoryFiles[jid][filePath] = size + + return nil +} + +// RecordDelete records a file deletion for quota tracking +func (q *QuotaManager) RecordDelete(ctx context.Context, jid, filePath string, size int64) error { + if !q.config.Enabled || jid == "" { + return nil + } + + // If size is 0, try to get it from tracking + if size == 0 { + size = q.getFileSize(ctx, jid, filePath) + } + + // Try Redis first + if q.redisClient != nil && q.config.Tracking == "redis" { + pipe := q.redisClient.TxPipeline() + + usedKey := fmt.Sprintf(quotaUsedKey, jid) + filesKey := fmt.Sprintf(quotaFilesKey, jid) + + pipe.DecrBy(ctx, usedKey, size) + pipe.HDel(ctx, filesKey, filePath) + + _, err := pipe.Exec(ctx) + if err != nil { + log.Warnf("Failed to record delete in Redis: %v", err) + } else { + return nil + } + } + + // Fallback to memory + q.mutex.Lock() + defer q.mutex.Unlock() + + q.memoryUsage[jid] -= size + if q.memoryUsage[jid] < 0 { + q.memoryUsage[jid] = 0 + } + + if q.memoryFiles[jid] != nil { + delete(q.memoryFiles[jid], filePath) + } + + return nil +} + +// getFileSize retrieves the size of a tracked file +func (q *QuotaManager) getFileSize(ctx context.Context, jid, filePath string) int64 { + // Try Redis + if q.redisClient != nil && q.config.Tracking == "redis" { + key := fmt.Sprintf(quotaFilesKey, jid) + sizeStr, err := q.redisClient.HGet(ctx, key, filePath).Result() + if err == nil { + size, _ := strconv.ParseInt(sizeStr, 10, 64) + return size + } + } + + // Try memory + q.mutex.RLock() + defer q.mutex.RUnlock() + + if files, ok := q.memoryFiles[jid]; ok { + return files[filePath] + } + + return 0 +} + +// SetCustomQuota sets a custom quota for a user +func (q *QuotaManager) SetCustomQuota(jid string, quota int64) { + q.mutex.Lock() + defer q.mutex.Unlock() + q.customQuotas[strings.ToLower(jid)] = quota +} + +// RemoveCustomQuota removes a custom quota for a user +func (q *QuotaManager) RemoveCustomQuota(jid string) { + q.mutex.Lock() + defer q.mutex.Unlock() + delete(q.customQuotas, strings.ToLower(jid)) +} + +// GetAllQuotas returns quota info for all tracked users +func (q *QuotaManager) GetAllQuotas(ctx context.Context) ([]QuotaInfo, error) { + var quotas []QuotaInfo + + // Get from Redis + if q.redisClient != nil && q.config.Tracking == "redis" { + // Scan for all quota keys + iter := q.redisClient.Scan(ctx, 0, "quota:*:used", 100).Iterator() + for iter.Next(ctx) { + key := iter.Val() + // Extract JID from key + parts := strings.Split(key, ":") + if len(parts) >= 2 { + jid := parts[1] + info, err := q.GetQuotaInfo(ctx, jid) + if err == nil { + quotas = append(quotas, *info) + } + } + } + return quotas, iter.Err() + } + + // Get from memory + q.mutex.RLock() + defer q.mutex.RUnlock() + + for jid, used := range q.memoryUsage { + limit := q.GetLimit(jid) + fileCount := int64(0) + if files, ok := q.memoryFiles[jid]; ok { + fileCount = int64(len(files)) + } + _, isCustom := q.customQuotas[strings.ToLower(jid)] + + quotas = append(quotas, QuotaInfo{ + JID: jid, + Used: used, + Limit: limit, + Remaining: limit - used, + FileCount: fileCount, + IsCustom: isCustom, + }) + } + + return quotas, nil +} + +// Reconcile recalculates quota usage from actual file storage +func (q *QuotaManager) Reconcile(ctx context.Context, jid string, files map[string]int64) error { + if !q.config.Enabled { + return nil + } + + var totalSize int64 + for _, size := range files { + totalSize += size + } + + // Update Redis + if q.redisClient != nil && q.config.Tracking == "redis" { + usedKey := fmt.Sprintf(quotaUsedKey, jid) + filesKey := fmt.Sprintf(quotaFilesKey, jid) + + pipe := q.redisClient.TxPipeline() + pipe.Set(ctx, usedKey, totalSize, 0) + pipe.Del(ctx, filesKey) + + for path, size := range files { + pipe.HSet(ctx, filesKey, path, size) + } + + _, err := pipe.Exec(ctx) + if err != nil { + return fmt.Errorf("failed to reconcile quota in Redis: %w", err) + } + return nil + } + + // Update memory + q.mutex.Lock() + defer q.mutex.Unlock() + + q.memoryUsage[jid] = totalSize + q.memoryFiles[jid] = files + + return nil +} + +// CheckQuotaMiddleware is a middleware that checks quota before upload +func CheckQuotaMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + qm := GetQuotaManager() + if qm == nil || !qm.config.Enabled { + next.ServeHTTP(w, r) + return + } + + // Only check for upload methods + if r.Method != http.MethodPut && r.Method != http.MethodPost { + next.ServeHTTP(w, r) + return + } + + // Get JID from context/headers + jid := r.Header.Get("X-User-JID") + if jid == "" { + // Try to get from authorization context + if claims, ok := r.Context().Value(contextKey("bearerClaims")).(*BearerTokenClaims); ok { + jid = claims.User + } + } + + if jid == "" { + next.ServeHTTP(w, r) + return + } + + // Check quota + ctx := r.Context() + canUpload, err := qm.CanUpload(ctx, jid, r.ContentLength) + if err != nil { + log.Warnf("Error checking quota: %v", err) + next.ServeHTTP(w, r) + return + } + + if !canUpload { + used, limit, _ := qm.GetUsage(ctx, jid) + + // Log to audit + AuditQuotaExceeded(r, jid, used, limit, r.ContentLength) + + // Return 413 with quota info + w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Quota-Used", strconv.FormatInt(used, 10)) + w.Header().Set("X-Quota-Limit", strconv.FormatInt(limit, 10)) + w.Header().Set("X-Quota-Remaining", strconv.FormatInt(limit-used, 10)) + w.WriteHeader(http.StatusRequestEntityTooLarge) + + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "error": "quota_exceeded", + "message": "Storage quota exceeded", + "used": used, + "limit": limit, + "requested": r.ContentLength, + }) + return + } + + // Add quota headers + used, limit, _ := qm.GetUsage(ctx, jid) + w.Header().Set("X-Quota-Used", strconv.FormatInt(used, 10)) + w.Header().Set("X-Quota-Limit", strconv.FormatInt(limit, 10)) + w.Header().Set("X-Quota-Remaining", strconv.FormatInt(limit-used, 10)) + + next.ServeHTTP(w, r) + }) +} + +// UpdateQuotaAfterUpload updates quota after successful upload +func UpdateQuotaAfterUpload(ctx context.Context, jid, filePath string, size int64) { + qm := GetQuotaManager() + if qm == nil || !qm.config.Enabled || jid == "" { + return + } + + if err := qm.RecordUpload(ctx, jid, filePath, size); err != nil { + log.WithFields(logrus.Fields{ + "jid": jid, + "file": filePath, + "size": size, + "error": err, + }).Warn("Failed to update quota after upload") + } +} + +// UpdateQuotaAfterDelete updates quota after file deletion +func UpdateQuotaAfterDelete(ctx context.Context, jid, filePath string, size int64) { + qm := GetQuotaManager() + if qm == nil || !qm.config.Enabled || jid == "" { + return + } + + if err := qm.RecordDelete(ctx, jid, filePath, size); err != nil { + log.WithFields(logrus.Fields{ + "jid": jid, + "file": filePath, + "size": size, + "error": err, + }).Warn("Failed to update quota after delete") + } +} + +// DefaultQuotaConfig returns default quota configuration +func DefaultQuotaConfig() QuotaConfig { + return QuotaConfig{ + Enabled: false, + Default: "100MB", + Tracking: "redis", + Custom: make(map[string]string), + } +} + +// StartQuotaReconciliation starts a background job to reconcile quotas +func StartQuotaReconciliation(interval time.Duration) { + if quotaManager == nil || !quotaManager.config.Enabled { + return + } + + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for range ticker.C { + log.Debug("Running quota reconciliation") + // This would scan the storage and update quotas + // Implementation depends on how files are tracked + } + }() +} diff --git a/cmd/server/validation.go b/cmd/server/validation.go new file mode 100644 index 0000000..af6f1c0 --- /dev/null +++ b/cmd/server/validation.go @@ -0,0 +1,340 @@ +// validation.go - Content type validation using magic bytes + +package main + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "sync" +) + +// ValidationConfig holds content validation configuration +type ValidationConfig struct { + CheckMagicBytes bool `toml:"check_magic_bytes" mapstructure:"check_magic_bytes"` + AllowedTypes []string `toml:"allowed_types" mapstructure:"allowed_types"` + BlockedTypes []string `toml:"blocked_types" mapstructure:"blocked_types"` + MaxFileSize string `toml:"max_file_size" mapstructure:"max_file_size"` + StrictMode bool `toml:"strict_mode" mapstructure:"strict_mode"` // Reject if type can't be detected +} + +// ValidationResult contains the result of content validation +type ValidationResult struct { + Valid bool `json:"valid"` + DetectedType string `json:"detected_type"` + DeclaredType string `json:"declared_type,omitempty"` + Error string `json:"error,omitempty"` + Reason string `json:"reason,omitempty"` +} + +// ValidationError represents a validation failure +type ValidationError struct { + Code string `json:"error"` + Message string `json:"message"` + DetectedType string `json:"detected_type"` + DeclaredType string `json:"declared_type,omitempty"` +} + +func (e *ValidationError) Error() string { + return e.Message +} + +// ContentValidator handles content type validation +type ContentValidator struct { + config *ValidationConfig + allowedTypes map[string]bool + blockedTypes map[string]bool + wildcardAllow []string + wildcardBlock []string +} + +var ( + contentValidator *ContentValidator + validatorOnce sync.Once +) + +// InitContentValidator initializes the content validator +func InitContentValidator(config *ValidationConfig) { + validatorOnce.Do(func() { + contentValidator = &ContentValidator{ + config: config, + allowedTypes: make(map[string]bool), + blockedTypes: make(map[string]bool), + wildcardAllow: []string{}, + wildcardBlock: []string{}, + } + + // Process allowed types + for _, t := range config.AllowedTypes { + t = strings.ToLower(strings.TrimSpace(t)) + if strings.HasSuffix(t, "/*") { + contentValidator.wildcardAllow = append(contentValidator.wildcardAllow, strings.TrimSuffix(t, "/*")) + } else { + contentValidator.allowedTypes[t] = true + } + } + + // Process blocked types + for _, t := range config.BlockedTypes { + t = strings.ToLower(strings.TrimSpace(t)) + if strings.HasSuffix(t, "/*") { + contentValidator.wildcardBlock = append(contentValidator.wildcardBlock, strings.TrimSuffix(t, "/*")) + } else { + contentValidator.blockedTypes[t] = true + } + } + + log.Infof("Content validator initialized: magic_bytes=%v, allowed=%d types, blocked=%d types", + config.CheckMagicBytes, len(config.AllowedTypes), len(config.BlockedTypes)) + }) +} + +// GetContentValidator returns the singleton content validator +func GetContentValidator() *ContentValidator { + return contentValidator +} + +// isTypeAllowed checks if a content type is in the allowed list +func (v *ContentValidator) isTypeAllowed(contentType string) bool { + contentType = strings.ToLower(contentType) + + // Extract main type (before any parameters like charset) + if idx := strings.Index(contentType, ";"); idx != -1 { + contentType = strings.TrimSpace(contentType[:idx]) + } + + // If no allowed types configured, allow all (except blocked) + if len(v.allowedTypes) == 0 && len(v.wildcardAllow) == 0 { + return true + } + + // Check exact match + if v.allowedTypes[contentType] { + return true + } + + // Check wildcard patterns + for _, prefix := range v.wildcardAllow { + if strings.HasPrefix(contentType, prefix+"/") { + return true + } + } + + return false +} + +// isTypeBlocked checks if a content type is in the blocked list +func (v *ContentValidator) isTypeBlocked(contentType string) bool { + contentType = strings.ToLower(contentType) + + // Extract main type (before any parameters) + if idx := strings.Index(contentType, ";"); idx != -1 { + contentType = strings.TrimSpace(contentType[:idx]) + } + + // Check exact match + if v.blockedTypes[contentType] { + return true + } + + // Check wildcard patterns + for _, prefix := range v.wildcardBlock { + if strings.HasPrefix(contentType, prefix+"/") { + return true + } + } + + return false +} + +// ValidateContent validates the content type of a reader +// Returns a new reader that includes the buffered bytes, the detected type, and any error +func (v *ContentValidator) ValidateContent(reader io.Reader, declaredType string, size int64) (io.Reader, string, error) { + if v == nil || !v.config.CheckMagicBytes { + return reader, declaredType, nil + } + + // Read first 512 bytes for magic byte detection + buf := make([]byte, 512) + n, err := io.ReadFull(reader, buf) + if err != nil && err != io.ErrUnexpectedEOF && err != io.EOF { + return nil, "", fmt.Errorf("failed to read content for validation: %w", err) + } + + // Handle small files + if n == 0 { + if v.config.StrictMode { + return nil, "", &ValidationError{ + Code: "empty_content", + Message: "Cannot validate empty content", + DetectedType: "", + DeclaredType: declaredType, + } + } + return bytes.NewReader(buf[:n]), declaredType, nil + } + + // Detect content type using magic bytes + detectedType := http.DetectContentType(buf[:n]) + + // Normalize detected type + if idx := strings.Index(detectedType, ";"); idx != -1 { + detectedType = strings.TrimSpace(detectedType[:idx]) + } + + // Check if type is blocked (highest priority) + if v.isTypeBlocked(detectedType) { + return nil, detectedType, &ValidationError{ + Code: "content_type_blocked", + Message: fmt.Sprintf("File type %s is blocked", detectedType), + DetectedType: detectedType, + DeclaredType: declaredType, + } + } + + // Check if type is allowed + if !v.isTypeAllowed(detectedType) { + return nil, detectedType, &ValidationError{ + Code: "content_type_rejected", + Message: fmt.Sprintf("File type %s is not allowed", detectedType), + DetectedType: detectedType, + DeclaredType: declaredType, + } + } + + // Create a new reader that includes the buffered bytes + combinedReader := io.MultiReader(bytes.NewReader(buf[:n]), reader) + + return combinedReader, detectedType, nil +} + +// ValidateContentType validates a content type without reading content +func (v *ContentValidator) ValidateContentType(contentType string) error { + if v == nil { + return nil + } + + if v.isTypeBlocked(contentType) { + return &ValidationError{ + Code: "content_type_blocked", + Message: fmt.Sprintf("File type %s is blocked", contentType), + DetectedType: contentType, + } + } + + if !v.isTypeAllowed(contentType) { + return &ValidationError{ + Code: "content_type_rejected", + Message: fmt.Sprintf("File type %s is not allowed", contentType), + DetectedType: contentType, + } + } + + return nil +} + +// WriteValidationError writes a validation error response +func WriteValidationError(w http.ResponseWriter, err *ValidationError) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnsupportedMediaType) + _ = json.NewEncoder(w).Encode(err) +} + +// ValidateUploadContent is a helper function for validating upload content +func ValidateUploadContent(r *http.Request, reader io.Reader, declaredType string, size int64) (io.Reader, string, error) { + validator := GetContentValidator() + if validator == nil || !validator.config.CheckMagicBytes { + return reader, declaredType, nil + } + + newReader, detectedType, err := validator.ValidateContent(reader, declaredType, size) + if err != nil { + // Log validation failure to audit + jid := r.Header.Get("X-User-JID") + fileName := r.Header.Get("X-File-Name") + if fileName == "" { + fileName = "unknown" + } + + var reason string + if validErr, ok := err.(*ValidationError); ok { + reason = validErr.Code + } else { + reason = err.Error() + } + + AuditValidationFailure(r, jid, fileName, declaredType, detectedType, reason) + + return nil, detectedType, err + } + + return newReader, detectedType, nil +} + +// DefaultValidationConfig returns default validation configuration +func DefaultValidationConfig() ValidationConfig { + return ValidationConfig{ + CheckMagicBytes: false, + AllowedTypes: []string{ + "image/*", + "video/*", + "audio/*", + "application/pdf", + "text/plain", + "text/html", + "application/json", + "application/xml", + "application/zip", + "application/x-gzip", + "application/x-tar", + "application/x-7z-compressed", + "application/vnd.openxmlformats-officedocument.*", + "application/vnd.oasis.opendocument.*", + }, + BlockedTypes: []string{ + "application/x-executable", + "application/x-msdos-program", + "application/x-msdownload", + "application/x-dosexec", + "application/x-sh", + "application/x-shellscript", + }, + MaxFileSize: "100MB", + StrictMode: false, + } +} + +// Extended MIME type detection for better accuracy +var customMagicBytes = map[string][]byte{ + "application/x-executable": {0x7f, 'E', 'L', 'F'}, // ELF + "application/x-msdos-program": {0x4d, 0x5a}, // MZ (DOS/Windows) + "application/pdf": {0x25, 0x50, 0x44, 0x46}, // %PDF + "application/zip": {0x50, 0x4b, 0x03, 0x04}, // PK + "image/png": {0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a}, // PNG + "image/jpeg": {0xff, 0xd8, 0xff}, // JPEG + "image/gif": {0x47, 0x49, 0x46, 0x38}, // GIF8 + "image/webp": {0x52, 0x49, 0x46, 0x46}, // RIFF (WebP starts with RIFF) + "video/mp4": {0x00, 0x00, 0x00}, // MP4 (variable, check ftyp) + "audio/mpeg": {0xff, 0xfb}, // MP3 + "audio/ogg": {0x4f, 0x67, 0x67, 0x53}, // OggS +} + +// DetectContentTypeExtended provides extended content type detection +func DetectContentTypeExtended(data []byte) string { + // First try standard detection + detected := http.DetectContentType(data) + + // If generic, try custom detection + if detected == "application/octet-stream" { + for mimeType, magic := range customMagicBytes { + if len(data) >= len(magic) && bytes.Equal(data[:len(magic)], magic) { + return mimeType + } + } + } + + return detected +} diff --git a/templates/config-enhanced-features.toml b/templates/config-enhanced-features.toml new file mode 100644 index 0000000..ad088fc --- /dev/null +++ b/templates/config-enhanced-features.toml @@ -0,0 +1,162 @@ +# HMAC File Server 3.3.0 "Nexus Infinitum" Configuration +# Enhanced Features Template: Audit Logging, Content Validation, Quotas, Admin API +# Generated on: January 2025 + +[server] +listen_address = "8080" +storage_path = "/opt/hmac-file-server/data/uploads" +metrics_enabled = true +metrics_port = "9090" +pid_file = "/opt/hmac-file-server/data/hmac-file-server.pid" +max_upload_size = "10GB" +deduplication_enabled = true +min_free_bytes = "1GB" +file_naming = "original" +enable_dynamic_workers = true + +[security] +secret = "CHANGE-THIS-SECRET-KEY-MINIMUM-32-CHARACTERS" +enablejwt = false + +[uploads] +allowedextensions = [".txt", ".pdf", ".jpg", ".jpeg", ".png", ".gif", ".webp", ".zip", ".tar", ".gz", ".7z", ".mp4", ".webm", ".ogg", ".mp3", ".wav", ".flac", ".doc", ".docx", ".xls", ".xlsx", ".ppt", ".pptx", ".odt", ".ods", ".odp"] +maxfilesize = "100MB" +chunkeduploadsenabled = true +chunksize = "10MB" +networkevents = true + +[downloads] +chunkeddownloadsenabled = true +chunksize = "10MB" + +[logging] +level = "INFO" +file = "/opt/hmac-file-server/data/logs/hmac-file-server.log" +max_size = 100 +max_backups = 3 +max_age = 30 +compress = true + +[workers] +numworkers = 10 +uploadqueuesize = 1000 +autoscaling = true + +[timeouts] +readtimeout = "30s" +writetimeout = "30s" +idletimeout = "120s" +shutdown = "30s" + +[clamav] +enabled = false + +[redis] +enabled = true +address = "127.0.0.1:6379" +db = 0 + +# ============================================ +# NEW ENHANCED FEATURES (v3.3.0) +# ============================================ + +# Security Audit Logging +# Records security-relevant events for compliance and forensics +[audit] +enabled = true +output = "file" # "file" or "stdout" +path = "/var/log/hmac-audit.log" # Log file path (when output = "file") +format = "json" # "json" or "text" +max_size = 100 # Max size in MB before rotation +max_age = 30 # Max age in days +events = [ + "upload", # Log all file uploads + "download", # Log all file downloads + "delete", # Log file deletions + "auth_success", # Log successful authentications + "auth_failure", # Log failed authentications + "rate_limited", # Log rate limiting events + "banned", # Log ban events + "quota_exceeded", # Log quota exceeded events + "validation_failure" # Log content validation failures +] + +# Magic Bytes Content Validation +# Validates uploaded file content types using magic bytes detection +[validation] +check_magic_bytes = true # Enable magic bytes validation +strict_mode = false # Strict mode rejects mismatched types +max_peek_size = 65536 # Bytes to read for detection (64KB) + +# Allowed content types (supports wildcards like "image/*") +# If empty, all types are allowed (except blocked) +allowed_types = [ + "image/*", # All image types + "video/*", # All video types + "audio/*", # All audio types + "text/plain", # Plain text + "application/pdf", # PDF documents + "application/zip", # ZIP archives + "application/gzip", # GZIP archives + "application/x-tar", # TAR archives + "application/x-7z-compressed", # 7-Zip archives + "application/vnd.openxmlformats-officedocument.*", # MS Office docs + "application/vnd.oasis.opendocument.*" # LibreOffice docs +] + +# Blocked content types (takes precedence over allowed) +blocked_types = [ + "application/x-executable", # Executable files + "application/x-msdos-program", # DOS executables + "application/x-msdownload", # Windows executables + "application/x-elf", # ELF binaries + "application/x-shellscript", # Shell scripts + "application/javascript", # JavaScript files + "text/html", # HTML files (potential XSS) + "application/x-php" # PHP files +] + +# Per-User Storage Quotas +# Track and enforce storage limits per XMPP JID +[quotas] +enabled = true # Enable quota enforcement +default = "100MB" # Default quota for all users +tracking = "redis" # "redis" or "memory" + +# Custom quotas per user (JID -> quota) +[quotas.custom] +"admin@example.com" = "10GB" # Admin gets 10GB +"premium@example.com" = "1GB" # Premium user gets 1GB +"vip@example.com" = "5GB" # VIP user gets 5GB + +# Admin API for Operations and Monitoring +# Protected endpoints for system management +[admin] +enabled = true # Enable admin API +path_prefix = "/admin" # URL prefix for admin endpoints + +# Available endpoints (when enabled): +# GET /admin/stats - Server statistics and metrics +# GET /admin/files - List all uploaded files +# GET /admin/files/:id - Get file details +# DEL /admin/files/:id - Delete a file +# GET /admin/users - List users and quota usage +# GET /admin/users/:jid - Get user details and quota +# POST /admin/users/:jid/quota - Set user quota +# GET /admin/bans - List banned IPs/users +# POST /admin/bans - Ban an IP or user +# DEL /admin/bans/:id - Unban + +# Admin authentication +[admin.auth] +type = "bearer" # "bearer" or "basic" +token = "${ADMIN_TOKEN}" # Bearer token (from environment variable) +# For basic auth: +# type = "basic" +# username = "admin" +# password_hash = "$2a$12$..." # bcrypt hash + +# Rate limiting for admin endpoints +[admin.rate_limit] +enabled = true +requests_per_minute = 60 # Max requests per minute per IP