feat: Add audit logging, magic bytes validation, per-user quotas, and admin API
All checks were successful
CI/CD / Test (push) Successful in 31s
CI/CD / Lint (push) Successful in 42s
CI/CD / Generate SBOM (push) Successful in 17s
CI/CD / Build (darwin-amd64) (push) Successful in 22s
CI/CD / Build (linux-amd64) (push) Successful in 22s
CI/CD / Build (darwin-arm64) (push) Successful in 23s
CI/CD / Build (linux-arm64) (push) Successful in 22s
CI/CD / Build & Push Docker Image (push) Successful in 22s
CI/CD / Mirror to GitHub (push) Successful in 16s
CI/CD / Release (push) Has been skipped
All checks were successful
CI/CD / Test (push) Successful in 31s
CI/CD / Lint (push) Successful in 42s
CI/CD / Generate SBOM (push) Successful in 17s
CI/CD / Build (darwin-amd64) (push) Successful in 22s
CI/CD / Build (linux-amd64) (push) Successful in 22s
CI/CD / Build (darwin-arm64) (push) Successful in 23s
CI/CD / Build (linux-arm64) (push) Successful in 22s
CI/CD / Build & Push Docker Image (push) Successful in 22s
CI/CD / Mirror to GitHub (push) Successful in 16s
CI/CD / Release (push) Has been skipped
New features in v3.3.0: - audit.go: Security audit logging with JSON/text format, log rotation - validation.go: Magic bytes content validation with wildcard patterns - quota.go: Per-user storage quotas with Redis/memory tracking - admin.go: Admin API for stats, file management, user quotas, bans Integration: - Updated main.go with feature initialization and handler integration - Added audit logging for auth success/failure, uploads, downloads - Added quota checking before upload, tracking after successful upload - Added content validation with magic bytes detection Config: - New template: config-enhanced-features.toml with all new options - Updated README.md with feature documentation
This commit is contained in:
82
README.md
82
README.md
@@ -8,13 +8,22 @@ A high-performance, secure file server implementing XEP-0363 (HTTP File Upload)
|
||||
|
||||
## Features
|
||||
|
||||
### Core Features
|
||||
- XEP-0363 HTTP File Upload compliance
|
||||
- HMAC-based authentication
|
||||
- File deduplication
|
||||
- HMAC-based authentication with JWT support
|
||||
- File deduplication (SHA256 with hardlinks)
|
||||
- Multi-architecture support (AMD64, ARM64, ARM32v7)
|
||||
- Docker and Podman deployment
|
||||
- XMPP client compatibility (Dino, Gajim, Conversations, Monal, Converse.js)
|
||||
- Network resilience for mobile clients
|
||||
- Network resilience for mobile clients (WiFi/LTE switching)
|
||||
|
||||
### Security Features (v3.3.0)
|
||||
- **Audit Logging** - Comprehensive security event logging (uploads, downloads, auth events)
|
||||
- **Magic Bytes Validation** - Content type verification using file signatures
|
||||
- **Per-User Quotas** - Storage limits per XMPP JID with Redis tracking
|
||||
- **Admin API** - Protected endpoints for system management and monitoring
|
||||
- **ClamAV Integration** - Antivirus scanning for uploaded files
|
||||
- **Rate Limiting** - Configurable request throttling
|
||||
|
||||
## Installation
|
||||
|
||||
@@ -90,16 +99,19 @@ secret = "your-hmac-secret-key"
|
||||
| Section | Description |
|
||||
|---------|-------------|
|
||||
| `[server]` | Bind address, port, storage path, timeouts |
|
||||
| `[security]` | HMAC secret, TLS settings |
|
||||
| `[security]` | HMAC secret, JWT, TLS settings |
|
||||
| `[uploads]` | Size limits, allowed extensions |
|
||||
| `[downloads]` | Download settings, bandwidth limits |
|
||||
| `[logging]` | Log file, log level |
|
||||
| `[clamav]` | Antivirus scanning integration |
|
||||
| `[redis]` | Redis caching backend |
|
||||
| `[deduplication]` | File deduplication settings |
|
||||
| `[audit]` | Security audit logging |
|
||||
| `[validation]` | Magic bytes content validation |
|
||||
| `[quotas]` | Per-user storage quotas |
|
||||
| `[admin]` | Admin API configuration |
|
||||
| `[workers]` | Worker pool configuration |
|
||||
|
||||
See [examples/](examples/) for complete configuration templates.
|
||||
See [templates/](templates/) for complete configuration templates.
|
||||
|
||||
## XMPP Server Integration
|
||||
|
||||
@@ -168,6 +180,64 @@ token = HMAC-SHA256(secret, filename + filesize + timestamp)
|
||||
| `/download/...` | GET | File download |
|
||||
| `/health` | GET | Health check |
|
||||
| `/metrics` | GET | Prometheus metrics |
|
||||
| `/admin/stats` | GET | Server statistics (auth required) |
|
||||
| `/admin/files` | GET | List uploaded files (auth required) |
|
||||
| `/admin/users` | GET | User quota information (auth required) |
|
||||
|
||||
## Enhanced Features (v3.3.0)
|
||||
|
||||
### Audit Logging
|
||||
|
||||
Security-focused logging for compliance and forensics:
|
||||
|
||||
```toml
|
||||
[audit]
|
||||
enabled = true
|
||||
output = "file"
|
||||
path = "/var/log/hmac-audit.log"
|
||||
format = "json"
|
||||
events = ["upload", "download", "auth_failure", "quota_exceeded"]
|
||||
```
|
||||
|
||||
### Content Validation
|
||||
|
||||
Magic bytes validation to verify file types:
|
||||
|
||||
```toml
|
||||
[validation]
|
||||
check_magic_bytes = true
|
||||
allowed_types = ["image/*", "video/*", "audio/*", "application/pdf"]
|
||||
blocked_types = ["application/x-executable", "application/x-shellscript"]
|
||||
```
|
||||
|
||||
### Per-User Quotas
|
||||
|
||||
Storage limits per XMPP JID with Redis tracking:
|
||||
|
||||
```toml
|
||||
[quotas]
|
||||
enabled = true
|
||||
default = "100MB"
|
||||
tracking = "redis"
|
||||
|
||||
[quotas.custom]
|
||||
"admin@example.com" = "10GB"
|
||||
"premium@example.com" = "1GB"
|
||||
```
|
||||
|
||||
### Admin API
|
||||
|
||||
Protected management endpoints:
|
||||
|
||||
```toml
|
||||
[admin]
|
||||
enabled = true
|
||||
path_prefix = "/admin"
|
||||
|
||||
[admin.auth]
|
||||
type = "bearer"
|
||||
token = "${ADMIN_TOKEN}"
|
||||
```
|
||||
|
||||
## System Requirements
|
||||
|
||||
|
||||
756
cmd/server/admin.go
Normal file
756
cmd/server/admin.go
Normal file
@@ -0,0 +1,756 @@
|
||||
// admin.go - Admin API for operations and monitoring
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AdminConfig holds admin API configuration
|
||||
type AdminConfig struct {
|
||||
Enabled bool `toml:"enabled" mapstructure:"enabled"`
|
||||
Bind string `toml:"bind" mapstructure:"bind"` // Separate bind address (e.g., "127.0.0.1:8081")
|
||||
PathPrefix string `toml:"path_prefix" mapstructure:"path_prefix"` // Path prefix (e.g., "/admin")
|
||||
Auth AdminAuthConfig `toml:"auth" mapstructure:"auth"`
|
||||
}
|
||||
|
||||
// AdminAuthConfig holds admin authentication configuration
|
||||
type AdminAuthConfig struct {
|
||||
Type string `toml:"type" mapstructure:"type"` // "bearer" | "basic"
|
||||
Token string `toml:"token" mapstructure:"token"` // For bearer auth
|
||||
Username string `toml:"username" mapstructure:"username"` // For basic auth
|
||||
Password string `toml:"password" mapstructure:"password"` // For basic auth
|
||||
}
|
||||
|
||||
// AdminStats represents system statistics
|
||||
type AdminStats struct {
|
||||
Storage StorageStats `json:"storage"`
|
||||
Users UserStats `json:"users"`
|
||||
Requests RequestStats `json:"requests"`
|
||||
System SystemStats `json:"system"`
|
||||
}
|
||||
|
||||
// StorageStats represents storage statistics
|
||||
type StorageStats struct {
|
||||
UsedBytes int64 `json:"used_bytes"`
|
||||
UsedHuman string `json:"used_human"`
|
||||
FileCount int64 `json:"file_count"`
|
||||
FreeBytes int64 `json:"free_bytes,omitempty"`
|
||||
FreeHuman string `json:"free_human,omitempty"`
|
||||
TotalBytes int64 `json:"total_bytes,omitempty"`
|
||||
TotalHuman string `json:"total_human,omitempty"`
|
||||
}
|
||||
|
||||
// UserStats represents user statistics
|
||||
type UserStats struct {
|
||||
Total int64 `json:"total"`
|
||||
Active24h int64 `json:"active_24h"`
|
||||
Active7d int64 `json:"active_7d"`
|
||||
}
|
||||
|
||||
// RequestStats represents request statistics
|
||||
type RequestStats struct {
|
||||
Uploads24h int64 `json:"uploads_24h"`
|
||||
Downloads24h int64 `json:"downloads_24h"`
|
||||
Errors24h int64 `json:"errors_24h"`
|
||||
}
|
||||
|
||||
// SystemStats represents system statistics
|
||||
type SystemStats struct {
|
||||
Uptime string `json:"uptime"`
|
||||
Version string `json:"version"`
|
||||
GoVersion string `json:"go_version"`
|
||||
NumGoroutines int `json:"num_goroutines"`
|
||||
MemoryUsageMB int64 `json:"memory_usage_mb"`
|
||||
NumCPU int `json:"num_cpu"`
|
||||
}
|
||||
|
||||
// FileInfo represents file information for admin API
|
||||
type FileInfo struct {
|
||||
ID string `json:"id"`
|
||||
Path string `json:"path"`
|
||||
Name string `json:"name"`
|
||||
Size int64 `json:"size"`
|
||||
SizeHuman string `json:"size_human"`
|
||||
ContentType string `json:"content_type"`
|
||||
ModTime time.Time `json:"mod_time"`
|
||||
Owner string `json:"owner,omitempty"`
|
||||
}
|
||||
|
||||
// FileListResponse represents paginated file list
|
||||
type FileListResponse struct {
|
||||
Files []FileInfo `json:"files"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
Limit int `json:"limit"`
|
||||
TotalPages int `json:"total_pages"`
|
||||
}
|
||||
|
||||
// UserInfo represents user information for admin API
|
||||
type UserInfo struct {
|
||||
JID string `json:"jid"`
|
||||
QuotaUsed int64 `json:"quota_used"`
|
||||
QuotaLimit int64 `json:"quota_limit"`
|
||||
FileCount int64 `json:"file_count"`
|
||||
LastActive time.Time `json:"last_active,omitempty"`
|
||||
IsBanned bool `json:"is_banned"`
|
||||
}
|
||||
|
||||
// BanInfo represents ban information
|
||||
type BanInfo struct {
|
||||
IP string `json:"ip"`
|
||||
Reason string `json:"reason"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ExpiresAt time.Time `json:"expires_at,omitempty"`
|
||||
IsPermanent bool `json:"is_permanent"`
|
||||
}
|
||||
|
||||
var (
|
||||
serverStartTime = time.Now()
|
||||
adminConfig *AdminConfig
|
||||
)
|
||||
|
||||
// SetupAdminRoutes sets up admin API routes
|
||||
func SetupAdminRoutes(mux *http.ServeMux, config *AdminConfig) {
|
||||
adminConfig = config
|
||||
|
||||
if !config.Enabled {
|
||||
log.Info("Admin API is disabled")
|
||||
return
|
||||
}
|
||||
|
||||
prefix := config.PathPrefix
|
||||
if prefix == "" {
|
||||
prefix = "/admin"
|
||||
}
|
||||
|
||||
// Wrap all admin handlers with authentication
|
||||
adminMux := http.NewServeMux()
|
||||
|
||||
adminMux.HandleFunc(prefix+"/stats", handleAdminStats)
|
||||
adminMux.HandleFunc(prefix+"/files", handleAdminFiles)
|
||||
adminMux.HandleFunc(prefix+"/files/", handleAdminFileByID)
|
||||
adminMux.HandleFunc(prefix+"/users", handleAdminUsers)
|
||||
adminMux.HandleFunc(prefix+"/users/", handleAdminUserByJID)
|
||||
adminMux.HandleFunc(prefix+"/bans", handleAdminBans)
|
||||
adminMux.HandleFunc(prefix+"/bans/", handleAdminBanByIP)
|
||||
adminMux.HandleFunc(prefix+"/health", handleAdminHealth)
|
||||
adminMux.HandleFunc(prefix+"/config", handleAdminConfig)
|
||||
|
||||
// Register with authentication middleware
|
||||
mux.Handle(prefix+"/", AdminAuthMiddleware(adminMux))
|
||||
|
||||
log.Infof("Admin API enabled at %s (auth: %s)", prefix, config.Auth.Type)
|
||||
}
|
||||
|
||||
// AdminAuthMiddleware handles admin authentication
|
||||
func AdminAuthMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if adminConfig == nil || !adminConfig.Enabled {
|
||||
http.Error(w, "Admin API disabled", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
authorized := false
|
||||
|
||||
switch adminConfig.Auth.Type {
|
||||
case "bearer":
|
||||
auth := r.Header.Get("Authorization")
|
||||
if strings.HasPrefix(auth, "Bearer ") {
|
||||
token := strings.TrimPrefix(auth, "Bearer ")
|
||||
authorized = token == adminConfig.Auth.Token
|
||||
}
|
||||
case "basic":
|
||||
username, password, ok := r.BasicAuth()
|
||||
if ok {
|
||||
authorized = username == adminConfig.Auth.Username &&
|
||||
password == adminConfig.Auth.Password
|
||||
}
|
||||
default:
|
||||
// No auth configured, check if request is from localhost
|
||||
clientIP := getClientIP(r)
|
||||
authorized = clientIP == "127.0.0.1" || clientIP == "::1"
|
||||
}
|
||||
|
||||
if !authorized {
|
||||
AuditEvent("admin_auth_failure", r, nil)
|
||||
w.Header().Set("WWW-Authenticate", `Bearer realm="admin"`)
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// handleAdminStats returns system statistics
|
||||
func handleAdminStats(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
AuditAdminAction(r, "get_stats", "system", nil)
|
||||
|
||||
ctx := r.Context()
|
||||
stats := AdminStats{}
|
||||
|
||||
// Storage stats
|
||||
confMutex.RLock()
|
||||
storagePath := conf.Server.StoragePath
|
||||
confMutex.RUnlock()
|
||||
|
||||
storageStats := calculateStorageStats(storagePath)
|
||||
stats.Storage = storageStats
|
||||
|
||||
// User stats
|
||||
stats.Users = calculateUserStats(ctx)
|
||||
|
||||
// Request stats from Prometheus metrics
|
||||
stats.Requests = calculateRequestStats()
|
||||
|
||||
// System stats
|
||||
var mem runtime.MemStats
|
||||
runtime.ReadMemStats(&mem)
|
||||
|
||||
stats.System = SystemStats{
|
||||
Uptime: time.Since(serverStartTime).Round(time.Second).String(),
|
||||
Version: "3.3.0",
|
||||
GoVersion: runtime.Version(),
|
||||
NumGoroutines: runtime.NumGoroutine(),
|
||||
MemoryUsageMB: int64(mem.Alloc / 1024 / 1024),
|
||||
NumCPU: runtime.NumCPU(),
|
||||
}
|
||||
|
||||
writeJSONResponseAdmin(w, stats)
|
||||
}
|
||||
|
||||
// handleAdminFiles handles file listing
|
||||
func handleAdminFiles(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
listFiles(w, r)
|
||||
default:
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
// listFiles returns paginated file list
|
||||
func listFiles(w http.ResponseWriter, r *http.Request) {
|
||||
AuditAdminAction(r, "list_files", "files", nil)
|
||||
|
||||
// Parse query parameters
|
||||
page, _ := strconv.Atoi(r.URL.Query().Get("page"))
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
limit, _ := strconv.Atoi(r.URL.Query().Get("limit"))
|
||||
if limit < 1 || limit > 100 {
|
||||
limit = 50
|
||||
}
|
||||
sortBy := r.URL.Query().Get("sort")
|
||||
if sortBy == "" {
|
||||
sortBy = "date"
|
||||
}
|
||||
filterOwner := r.URL.Query().Get("owner")
|
||||
|
||||
confMutex.RLock()
|
||||
storagePath := conf.Server.StoragePath
|
||||
confMutex.RUnlock()
|
||||
|
||||
var files []FileInfo
|
||||
|
||||
err := filepath.WalkDir(storagePath, func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return nil // Skip errors
|
||||
}
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
info, err := d.Info()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
relPath, _ := filepath.Rel(storagePath, path)
|
||||
|
||||
fileInfo := FileInfo{
|
||||
ID: relPath,
|
||||
Path: relPath,
|
||||
Name: filepath.Base(path),
|
||||
Size: info.Size(),
|
||||
SizeHuman: formatBytes(info.Size()),
|
||||
ContentType: GetContentType(path),
|
||||
ModTime: info.ModTime(),
|
||||
}
|
||||
|
||||
// Apply owner filter if specified (simplified: would need metadata lookup)
|
||||
_ = filterOwner // Unused for now, but kept for future implementation
|
||||
|
||||
files = append(files, fileInfo)
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Error listing files: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Sort files
|
||||
switch sortBy {
|
||||
case "date":
|
||||
sort.Slice(files, func(i, j int) bool {
|
||||
return files[i].ModTime.After(files[j].ModTime)
|
||||
})
|
||||
case "size":
|
||||
sort.Slice(files, func(i, j int) bool {
|
||||
return files[i].Size > files[j].Size
|
||||
})
|
||||
case "name":
|
||||
sort.Slice(files, func(i, j int) bool {
|
||||
return files[i].Name < files[j].Name
|
||||
})
|
||||
}
|
||||
|
||||
// Paginate
|
||||
total := len(files)
|
||||
start := (page - 1) * limit
|
||||
end := start + limit
|
||||
if start > total {
|
||||
start = total
|
||||
}
|
||||
if end > total {
|
||||
end = total
|
||||
}
|
||||
|
||||
response := FileListResponse{
|
||||
Files: files[start:end],
|
||||
Total: int64(total),
|
||||
Page: page,
|
||||
Limit: limit,
|
||||
TotalPages: (total + limit - 1) / limit,
|
||||
}
|
||||
|
||||
writeJSONResponseAdmin(w, response)
|
||||
}
|
||||
|
||||
// handleAdminFileByID handles single file operations
|
||||
func handleAdminFileByID(w http.ResponseWriter, r *http.Request) {
|
||||
// Extract file ID from path
|
||||
prefix := adminConfig.PathPrefix
|
||||
if prefix == "" {
|
||||
prefix = "/admin"
|
||||
}
|
||||
fileID := strings.TrimPrefix(r.URL.Path, prefix+"/files/")
|
||||
|
||||
if fileID == "" {
|
||||
http.Error(w, "File ID required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
getFileInfo(w, r, fileID)
|
||||
case http.MethodDelete:
|
||||
deleteFile(w, r, fileID)
|
||||
default:
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
// getFileInfo returns information about a specific file
|
||||
func getFileInfo(w http.ResponseWriter, r *http.Request, fileID string) {
|
||||
confMutex.RLock()
|
||||
storagePath := conf.Server.StoragePath
|
||||
confMutex.RUnlock()
|
||||
|
||||
filePath := filepath.Join(storagePath, fileID)
|
||||
|
||||
// Validate path is within storage
|
||||
absPath, err := filepath.Abs(filePath)
|
||||
if err != nil || !strings.HasPrefix(absPath, storagePath) {
|
||||
http.Error(w, "Invalid file ID", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
info, err := os.Stat(filePath)
|
||||
if os.IsNotExist(err) {
|
||||
http.Error(w, "File not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Error accessing file: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
fileInfo := FileInfo{
|
||||
ID: fileID,
|
||||
Path: fileID,
|
||||
Name: filepath.Base(filePath),
|
||||
Size: info.Size(),
|
||||
SizeHuman: formatBytes(info.Size()),
|
||||
ContentType: GetContentType(filePath),
|
||||
ModTime: info.ModTime(),
|
||||
}
|
||||
|
||||
writeJSONResponseAdmin(w, fileInfo)
|
||||
}
|
||||
|
||||
// deleteFile deletes a specific file
|
||||
func deleteFile(w http.ResponseWriter, r *http.Request, fileID string) {
|
||||
confMutex.RLock()
|
||||
storagePath := conf.Server.StoragePath
|
||||
confMutex.RUnlock()
|
||||
|
||||
filePath := filepath.Join(storagePath, fileID)
|
||||
|
||||
// Validate path is within storage
|
||||
absPath, err := filepath.Abs(filePath)
|
||||
if err != nil || !strings.HasPrefix(absPath, storagePath) {
|
||||
http.Error(w, "Invalid file ID", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Get file info before deletion for audit
|
||||
info, err := os.Stat(filePath)
|
||||
if os.IsNotExist(err) {
|
||||
http.Error(w, "File not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
AuditAdminAction(r, "delete_file", fileID, map[string]interface{}{
|
||||
"size": info.Size(),
|
||||
})
|
||||
|
||||
if err := os.Remove(filePath); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Error deleting file: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// handleAdminUsers handles user listing
|
||||
func handleAdminUsers(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
AuditAdminAction(r, "list_users", "users", nil)
|
||||
|
||||
ctx := r.Context()
|
||||
qm := GetQuotaManager()
|
||||
|
||||
var users []UserInfo
|
||||
|
||||
if qm != nil && qm.config.Enabled {
|
||||
quotas, err := qm.GetAllQuotas(ctx)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Error getting quotas: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
for _, quota := range quotas {
|
||||
users = append(users, UserInfo{
|
||||
JID: quota.JID,
|
||||
QuotaUsed: quota.Used,
|
||||
QuotaLimit: quota.Limit,
|
||||
FileCount: quota.FileCount,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
writeJSONResponseAdmin(w, users)
|
||||
}
|
||||
|
||||
// handleAdminUserByJID handles single user operations
|
||||
func handleAdminUserByJID(w http.ResponseWriter, r *http.Request) {
|
||||
prefix := adminConfig.PathPrefix
|
||||
if prefix == "" {
|
||||
prefix = "/admin"
|
||||
}
|
||||
|
||||
path := strings.TrimPrefix(r.URL.Path, prefix+"/users/")
|
||||
parts := strings.Split(path, "/")
|
||||
jid := parts[0]
|
||||
|
||||
if jid == "" {
|
||||
http.Error(w, "JID required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Check for sub-paths
|
||||
if len(parts) > 1 {
|
||||
switch parts[1] {
|
||||
case "files":
|
||||
handleUserFiles(w, r, jid)
|
||||
return
|
||||
case "quota":
|
||||
handleUserQuota(w, r, jid)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
getUserInfo(w, r, jid)
|
||||
default:
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
// getUserInfo returns information about a specific user
|
||||
func getUserInfo(w http.ResponseWriter, r *http.Request, jid string) {
|
||||
ctx := r.Context()
|
||||
qm := GetQuotaManager()
|
||||
|
||||
if qm == nil || !qm.config.Enabled {
|
||||
http.Error(w, "Quota tracking not enabled", http.StatusNotImplemented)
|
||||
return
|
||||
}
|
||||
|
||||
quota, err := qm.GetQuotaInfo(ctx, jid)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Error getting quota: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
user := UserInfo{
|
||||
JID: jid,
|
||||
QuotaUsed: quota.Used,
|
||||
QuotaLimit: quota.Limit,
|
||||
FileCount: quota.FileCount,
|
||||
}
|
||||
|
||||
writeJSONResponseAdmin(w, user)
|
||||
}
|
||||
|
||||
// handleUserFiles handles user file operations
|
||||
func handleUserFiles(w http.ResponseWriter, r *http.Request, jid string) {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
// List user's files
|
||||
AuditAdminAction(r, "list_user_files", jid, nil)
|
||||
// Would need file ownership tracking to implement fully
|
||||
writeJSONResponseAdmin(w, []FileInfo{})
|
||||
case http.MethodDelete:
|
||||
// Delete all user's files
|
||||
AuditAdminAction(r, "delete_user_files", jid, nil)
|
||||
// Would need file ownership tracking to implement fully
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
default:
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
// handleUserQuota handles user quota operations
|
||||
func handleUserQuota(w http.ResponseWriter, r *http.Request, jid string) {
|
||||
qm := GetQuotaManager()
|
||||
if qm == nil {
|
||||
http.Error(w, "Quota management not enabled", http.StatusNotImplemented)
|
||||
return
|
||||
}
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodPost:
|
||||
// Set custom quota
|
||||
var req struct {
|
||||
Quota string `json:"quota"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
quota, err := parseSize(req.Quota)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Invalid quota: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
qm.SetCustomQuota(jid, quota)
|
||||
AuditAdminAction(r, "set_quota", jid, map[string]interface{}{"quota": req.Quota})
|
||||
|
||||
writeJSONResponseAdmin(w, map[string]interface{}{
|
||||
"success": true,
|
||||
"jid": jid,
|
||||
"quota": quota,
|
||||
})
|
||||
case http.MethodDelete:
|
||||
// Remove custom quota
|
||||
qm.RemoveCustomQuota(jid)
|
||||
AuditAdminAction(r, "remove_quota", jid, nil)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
default:
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
// handleAdminBans handles ban listing
|
||||
func handleAdminBans(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
AuditAdminAction(r, "list_bans", "bans", nil)
|
||||
|
||||
// Would need ban management implementation
|
||||
writeJSONResponseAdmin(w, []BanInfo{})
|
||||
}
|
||||
|
||||
// handleAdminBanByIP handles single ban operations
|
||||
func handleAdminBanByIP(w http.ResponseWriter, r *http.Request) {
|
||||
prefix := adminConfig.PathPrefix
|
||||
if prefix == "" {
|
||||
prefix = "/admin"
|
||||
}
|
||||
ip := strings.TrimPrefix(r.URL.Path, prefix+"/bans/")
|
||||
|
||||
if ip == "" {
|
||||
http.Error(w, "IP required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodDelete:
|
||||
// Unban IP
|
||||
AuditAdminAction(r, "unban", ip, nil)
|
||||
// Would need ban management implementation
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
default:
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
// handleAdminHealth returns admin-specific health info
|
||||
func handleAdminHealth(w http.ResponseWriter, r *http.Request) {
|
||||
health := map[string]interface{}{
|
||||
"status": "healthy",
|
||||
"timestamp": time.Now().UTC(),
|
||||
"uptime": time.Since(serverStartTime).String(),
|
||||
}
|
||||
|
||||
// Check Redis
|
||||
if redisClient != nil && redisConnected {
|
||||
health["redis"] = "connected"
|
||||
} else if redisClient != nil {
|
||||
health["redis"] = "disconnected"
|
||||
}
|
||||
|
||||
writeJSONResponseAdmin(w, health)
|
||||
}
|
||||
|
||||
// handleAdminConfig returns current configuration (sanitized)
|
||||
func handleAdminConfig(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
AuditAdminAction(r, "get_config", "config", nil)
|
||||
|
||||
confMutex.RLock()
|
||||
// Return sanitized config (no secrets)
|
||||
sanitized := map[string]interface{}{
|
||||
"server": map[string]interface{}{
|
||||
"listen_address": conf.Server.ListenAddress,
|
||||
"storage_path": conf.Server.StoragePath,
|
||||
"max_upload_size": conf.Server.MaxUploadSize,
|
||||
"metrics_enabled": conf.Server.MetricsEnabled,
|
||||
},
|
||||
"security": map[string]interface{}{
|
||||
"enhanced_security": conf.Security.EnhancedSecurity,
|
||||
"jwt_enabled": conf.Security.EnableJWT,
|
||||
},
|
||||
"clamav": map[string]interface{}{
|
||||
"enabled": conf.ClamAV.ClamAVEnabled,
|
||||
},
|
||||
"redis": map[string]interface{}{
|
||||
"enabled": conf.Redis.RedisEnabled,
|
||||
},
|
||||
"deduplication": map[string]interface{}{
|
||||
"enabled": conf.Deduplication.Enabled,
|
||||
},
|
||||
}
|
||||
confMutex.RUnlock()
|
||||
|
||||
writeJSONResponseAdmin(w, sanitized)
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func calculateStorageStats(storagePath string) StorageStats {
|
||||
var totalSize int64
|
||||
var fileCount int64
|
||||
|
||||
_ = filepath.WalkDir(storagePath, func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil || d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
if info, err := d.Info(); err == nil {
|
||||
totalSize += info.Size()
|
||||
fileCount++
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return StorageStats{
|
||||
UsedBytes: totalSize,
|
||||
UsedHuman: formatBytes(totalSize),
|
||||
FileCount: fileCount,
|
||||
}
|
||||
}
|
||||
|
||||
func calculateUserStats(ctx context.Context) UserStats {
|
||||
qm := GetQuotaManager()
|
||||
if qm == nil || !qm.config.Enabled {
|
||||
return UserStats{}
|
||||
}
|
||||
|
||||
quotas, err := qm.GetAllQuotas(ctx)
|
||||
if err != nil {
|
||||
return UserStats{}
|
||||
}
|
||||
|
||||
return UserStats{
|
||||
Total: int64(len(quotas)),
|
||||
}
|
||||
}
|
||||
|
||||
func calculateRequestStats() RequestStats {
|
||||
// These would ideally come from Prometheus metrics
|
||||
return RequestStats{}
|
||||
}
|
||||
|
||||
func writeJSONResponseAdmin(w http.ResponseWriter, data interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(data); err != nil {
|
||||
log.Errorf("Failed to encode JSON response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultAdminConfig returns default admin configuration
|
||||
func DefaultAdminConfig() AdminConfig {
|
||||
return AdminConfig{
|
||||
Enabled: false,
|
||||
Bind: "",
|
||||
PathPrefix: "/admin",
|
||||
Auth: AdminAuthConfig{
|
||||
Type: "bearer",
|
||||
},
|
||||
}
|
||||
}
|
||||
366
cmd/server/audit.go
Normal file
366
cmd/server/audit.go
Normal file
@@ -0,0 +1,366 @@
|
||||
// audit.go - Dedicated audit logging for security-relevant events
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
)
|
||||
|
||||
// AuditConfig holds audit logging configuration
|
||||
type AuditConfig struct {
|
||||
Enabled bool `toml:"enabled" mapstructure:"enabled"`
|
||||
Output string `toml:"output" mapstructure:"output"` // "file" | "stdout"
|
||||
Path string `toml:"path" mapstructure:"path"` // Log file path
|
||||
Format string `toml:"format" mapstructure:"format"` // "json" | "text"
|
||||
Events []string `toml:"events" mapstructure:"events"` // Events to log
|
||||
MaxSize int `toml:"max_size" mapstructure:"max_size"` // Max size in MB
|
||||
MaxAge int `toml:"max_age" mapstructure:"max_age"` // Max age in days
|
||||
}
|
||||
|
||||
// AuditEvent types
|
||||
const (
|
||||
AuditEventUpload = "upload"
|
||||
AuditEventDownload = "download"
|
||||
AuditEventDelete = "delete"
|
||||
AuditEventAuthSuccess = "auth_success"
|
||||
AuditEventAuthFailure = "auth_failure"
|
||||
AuditEventRateLimited = "rate_limited"
|
||||
AuditEventBanned = "banned"
|
||||
AuditEventQuotaExceeded = "quota_exceeded"
|
||||
AuditEventAdminAction = "admin_action"
|
||||
AuditEventValidationFailure = "validation_failure"
|
||||
)
|
||||
|
||||
// AuditLogger handles security audit logging
|
||||
type AuditLogger struct {
|
||||
logger *logrus.Logger
|
||||
config *AuditConfig
|
||||
enabledEvents map[string]bool
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
var (
|
||||
auditLogger *AuditLogger
|
||||
auditOnce sync.Once
|
||||
)
|
||||
|
||||
// InitAuditLogger initializes the audit logger
|
||||
func InitAuditLogger(config *AuditConfig) error {
|
||||
var initErr error
|
||||
auditOnce.Do(func() {
|
||||
auditLogger = &AuditLogger{
|
||||
logger: logrus.New(),
|
||||
config: config,
|
||||
enabledEvents: make(map[string]bool),
|
||||
}
|
||||
|
||||
// Build enabled events map for fast lookup
|
||||
for _, event := range config.Events {
|
||||
auditLogger.enabledEvents[strings.ToLower(event)] = true
|
||||
}
|
||||
|
||||
// Configure formatter
|
||||
if config.Format == "json" {
|
||||
auditLogger.logger.SetFormatter(&logrus.JSONFormatter{
|
||||
TimestampFormat: time.RFC3339,
|
||||
FieldMap: logrus.FieldMap{
|
||||
logrus.FieldKeyTime: "timestamp",
|
||||
logrus.FieldKeyMsg: "event",
|
||||
},
|
||||
})
|
||||
} else {
|
||||
auditLogger.logger.SetFormatter(&logrus.TextFormatter{
|
||||
TimestampFormat: time.RFC3339,
|
||||
FullTimestamp: true,
|
||||
})
|
||||
}
|
||||
|
||||
// Configure output
|
||||
if !config.Enabled {
|
||||
auditLogger.logger.SetOutput(io.Discard)
|
||||
return
|
||||
}
|
||||
|
||||
switch config.Output {
|
||||
case "stdout":
|
||||
auditLogger.logger.SetOutput(os.Stdout)
|
||||
case "file":
|
||||
if config.Path == "" {
|
||||
config.Path = "/var/log/hmac-audit.log"
|
||||
}
|
||||
|
||||
// Ensure directory exists
|
||||
dir := filepath.Dir(config.Path)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
initErr = err
|
||||
return
|
||||
}
|
||||
|
||||
// Use lumberjack for log rotation
|
||||
maxSize := config.MaxSize
|
||||
if maxSize <= 0 {
|
||||
maxSize = 100 // Default 100MB
|
||||
}
|
||||
maxAge := config.MaxAge
|
||||
if maxAge <= 0 {
|
||||
maxAge = 30 // Default 30 days
|
||||
}
|
||||
|
||||
auditLogger.logger.SetOutput(&lumberjack.Logger{
|
||||
Filename: config.Path,
|
||||
MaxSize: maxSize,
|
||||
MaxAge: maxAge,
|
||||
MaxBackups: 5,
|
||||
Compress: true,
|
||||
})
|
||||
default:
|
||||
auditLogger.logger.SetOutput(os.Stdout)
|
||||
}
|
||||
|
||||
auditLogger.logger.SetLevel(logrus.InfoLevel)
|
||||
log.Infof("Audit logger initialized: output=%s, path=%s, format=%s, events=%v",
|
||||
config.Output, config.Path, config.Format, config.Events)
|
||||
})
|
||||
|
||||
return initErr
|
||||
}
|
||||
|
||||
// GetAuditLogger returns the singleton audit logger
|
||||
func GetAuditLogger() *AuditLogger {
|
||||
return auditLogger
|
||||
}
|
||||
|
||||
// IsEventEnabled checks if an event type should be logged
|
||||
func (a *AuditLogger) IsEventEnabled(event string) bool {
|
||||
if a == nil || !a.config.Enabled {
|
||||
return false
|
||||
}
|
||||
a.mutex.RLock()
|
||||
defer a.mutex.RUnlock()
|
||||
|
||||
// If no events configured, log all
|
||||
if len(a.enabledEvents) == 0 {
|
||||
return true
|
||||
}
|
||||
return a.enabledEvents[strings.ToLower(event)]
|
||||
}
|
||||
|
||||
// LogEvent logs an audit event
|
||||
func (a *AuditLogger) LogEvent(event string, fields logrus.Fields) {
|
||||
if a == nil || !a.config.Enabled || !a.IsEventEnabled(event) {
|
||||
return
|
||||
}
|
||||
|
||||
// Add standard fields
|
||||
fields["event_type"] = event
|
||||
if _, ok := fields["timestamp"]; !ok {
|
||||
fields["timestamp"] = time.Now().UTC().Format(time.RFC3339)
|
||||
}
|
||||
|
||||
a.logger.WithFields(fields).Info(event)
|
||||
}
|
||||
|
||||
// AuditEvent is a helper function for logging audit events from request context
|
||||
func AuditEvent(event string, r *http.Request, fields logrus.Fields) {
|
||||
if auditLogger == nil || !auditLogger.config.Enabled {
|
||||
return
|
||||
}
|
||||
|
||||
if !auditLogger.IsEventEnabled(event) {
|
||||
return
|
||||
}
|
||||
|
||||
// Add request context
|
||||
if r != nil {
|
||||
if fields == nil {
|
||||
fields = logrus.Fields{}
|
||||
}
|
||||
fields["ip"] = getClientIP(r)
|
||||
fields["user_agent"] = r.UserAgent()
|
||||
fields["method"] = r.Method
|
||||
fields["path"] = r.URL.Path
|
||||
|
||||
// Extract JID if available from headers or context
|
||||
if jid := r.Header.Get("X-User-JID"); jid != "" {
|
||||
fields["jid"] = jid
|
||||
}
|
||||
}
|
||||
|
||||
auditLogger.LogEvent(event, fields)
|
||||
}
|
||||
|
||||
// AuditUpload logs file upload events
|
||||
func AuditUpload(r *http.Request, jid, fileID, fileName string, fileSize int64, contentType, result string, err error) {
|
||||
fields := logrus.Fields{
|
||||
"jid": jid,
|
||||
"file_id": fileID,
|
||||
"file_name": fileName,
|
||||
"file_size": fileSize,
|
||||
"content_type": contentType,
|
||||
"result": result,
|
||||
}
|
||||
if err != nil {
|
||||
fields["error"] = err.Error()
|
||||
}
|
||||
AuditEvent(AuditEventUpload, r, fields)
|
||||
}
|
||||
|
||||
// AuditDownload logs file download events
|
||||
func AuditDownload(r *http.Request, jid, fileID, fileName string, fileSize int64, result string, err error) {
|
||||
fields := logrus.Fields{
|
||||
"jid": jid,
|
||||
"file_id": fileID,
|
||||
"file_name": fileName,
|
||||
"file_size": fileSize,
|
||||
"result": result,
|
||||
}
|
||||
if err != nil {
|
||||
fields["error"] = err.Error()
|
||||
}
|
||||
AuditEvent(AuditEventDownload, r, fields)
|
||||
}
|
||||
|
||||
// AuditDelete logs file deletion events
|
||||
func AuditDelete(r *http.Request, jid, fileID, fileName string, result string, err error) {
|
||||
fields := logrus.Fields{
|
||||
"jid": jid,
|
||||
"file_id": fileID,
|
||||
"file_name": fileName,
|
||||
"result": result,
|
||||
}
|
||||
if err != nil {
|
||||
fields["error"] = err.Error()
|
||||
}
|
||||
AuditEvent(AuditEventDelete, r, fields)
|
||||
}
|
||||
|
||||
// AuditAuth logs authentication events
|
||||
func AuditAuth(r *http.Request, jid string, success bool, method string, err error) {
|
||||
event := AuditEventAuthSuccess
|
||||
result := "success"
|
||||
if !success {
|
||||
event = AuditEventAuthFailure
|
||||
result = "failure"
|
||||
}
|
||||
|
||||
fields := logrus.Fields{
|
||||
"jid": jid,
|
||||
"auth_method": method,
|
||||
"result": result,
|
||||
}
|
||||
if err != nil {
|
||||
fields["error"] = err.Error()
|
||||
}
|
||||
AuditEvent(event, r, fields)
|
||||
}
|
||||
|
||||
// AuditRateLimited logs rate limiting events
|
||||
func AuditRateLimited(r *http.Request, jid, reason string) {
|
||||
fields := logrus.Fields{
|
||||
"jid": jid,
|
||||
"reason": reason,
|
||||
}
|
||||
AuditEvent(AuditEventRateLimited, r, fields)
|
||||
}
|
||||
|
||||
// AuditBanned logs ban events
|
||||
func AuditBanned(r *http.Request, jid, ip, reason string, duration time.Duration) {
|
||||
fields := logrus.Fields{
|
||||
"jid": jid,
|
||||
"banned_ip": ip,
|
||||
"reason": reason,
|
||||
"ban_duration": duration.String(),
|
||||
}
|
||||
AuditEvent(AuditEventBanned, r, fields)
|
||||
}
|
||||
|
||||
// AuditQuotaExceeded logs quota exceeded events
|
||||
func AuditQuotaExceeded(r *http.Request, jid string, used, limit, requested int64) {
|
||||
fields := logrus.Fields{
|
||||
"jid": jid,
|
||||
"used": used,
|
||||
"limit": limit,
|
||||
"requested": requested,
|
||||
}
|
||||
AuditEvent(AuditEventQuotaExceeded, r, fields)
|
||||
}
|
||||
|
||||
// AuditAdminAction logs admin API actions
|
||||
func AuditAdminAction(r *http.Request, action, target string, details map[string]interface{}) {
|
||||
fields := logrus.Fields{
|
||||
"action": action,
|
||||
"target": target,
|
||||
}
|
||||
for k, v := range details {
|
||||
fields[k] = v
|
||||
}
|
||||
AuditEvent(AuditEventAdminAction, r, fields)
|
||||
}
|
||||
|
||||
// AuditValidationFailure logs content validation failures
|
||||
func AuditValidationFailure(r *http.Request, jid, fileName, declaredType, detectedType, reason string) {
|
||||
fields := logrus.Fields{
|
||||
"jid": jid,
|
||||
"file_name": fileName,
|
||||
"declared_type": declaredType,
|
||||
"detected_type": detectedType,
|
||||
"reason": reason,
|
||||
}
|
||||
AuditEvent(AuditEventValidationFailure, r, fields)
|
||||
}
|
||||
|
||||
// DefaultAuditConfig returns default audit configuration
|
||||
func DefaultAuditConfig() AuditConfig {
|
||||
return AuditConfig{
|
||||
Enabled: false,
|
||||
Output: "file",
|
||||
Path: "/var/log/hmac-audit.log",
|
||||
Format: "json",
|
||||
Events: []string{
|
||||
AuditEventUpload,
|
||||
AuditEventDownload,
|
||||
AuditEventDelete,
|
||||
AuditEventAuthSuccess,
|
||||
AuditEventAuthFailure,
|
||||
AuditEventRateLimited,
|
||||
AuditEventBanned,
|
||||
},
|
||||
MaxSize: 100,
|
||||
MaxAge: 30,
|
||||
}
|
||||
}
|
||||
|
||||
// AuditAuthSuccess is a helper for logging successful authentication
|
||||
func AuditAuthSuccess(r *http.Request, jid, method string) {
|
||||
AuditAuth(r, jid, true, method, nil)
|
||||
}
|
||||
|
||||
// AuditAuthFailure is a helper for logging failed authentication
|
||||
func AuditAuthFailure(r *http.Request, method, errorMsg string) {
|
||||
AuditAuth(r, "", false, method, fmt.Errorf("%s", errorMsg))
|
||||
}
|
||||
|
||||
// AuditUploadSuccess is a helper for logging successful uploads
|
||||
func AuditUploadSuccess(r *http.Request, jid, fileName string, fileSize int64, contentType string) {
|
||||
AuditUpload(r, jid, "", fileName, fileSize, contentType, "success", nil)
|
||||
}
|
||||
|
||||
// AuditUploadFailure is a helper for logging failed uploads
|
||||
func AuditUploadFailure(r *http.Request, jid, fileName string, fileSize int64, errorMsg string) {
|
||||
AuditUpload(r, jid, "", fileName, fileSize, "", "failure", fmt.Errorf("%s", errorMsg))
|
||||
}
|
||||
|
||||
// AuditDownloadSuccess is a helper for logging successful downloads
|
||||
func AuditDownloadSuccess(r *http.Request, jid, fileName string, fileSize int64) {
|
||||
AuditDownload(r, jid, "", fileName, fileSize, "success", nil)
|
||||
}
|
||||
@@ -39,22 +39,22 @@ import (
|
||||
|
||||
// NetworkResilientSession represents a persistent session for network switching
|
||||
type NetworkResilientSession struct {
|
||||
SessionID string `json:"session_id"`
|
||||
UserJID string `json:"user_jid"`
|
||||
OriginalToken string `json:"original_token"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
LastSeen time.Time `json:"last_seen"`
|
||||
NetworkHistory []NetworkEvent `json:"network_history"`
|
||||
UploadContext *UploadContext `json:"upload_context,omitempty"`
|
||||
RefreshCount int `json:"refresh_count"`
|
||||
MaxRefreshes int `json:"max_refreshes"`
|
||||
LastIP string `json:"last_ip"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
SecurityLevel int `json:"security_level"` // 1=normal, 2=challenge, 3=reauth
|
||||
LastSecurityCheck time.Time `json:"last_security_check"`
|
||||
NetworkChangeCount int `json:"network_change_count"`
|
||||
StandbyDetected bool `json:"standby_detected"`
|
||||
LastActivity time.Time `json:"last_activity"`
|
||||
SessionID string `json:"session_id"`
|
||||
UserJID string `json:"user_jid"`
|
||||
OriginalToken string `json:"original_token"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
LastSeen time.Time `json:"last_seen"`
|
||||
NetworkHistory []NetworkEvent `json:"network_history"`
|
||||
UploadContext *UploadContext `json:"upload_context,omitempty"`
|
||||
RefreshCount int `json:"refresh_count"`
|
||||
MaxRefreshes int `json:"max_refreshes"`
|
||||
LastIP string `json:"last_ip"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
SecurityLevel int `json:"security_level"` // 1=normal, 2=challenge, 3=reauth
|
||||
LastSecurityCheck time.Time `json:"last_security_check"`
|
||||
NetworkChangeCount int `json:"network_change_count"`
|
||||
StandbyDetected bool `json:"standby_detected"`
|
||||
LastActivity time.Time `json:"last_activity"`
|
||||
}
|
||||
|
||||
// contextKey is a custom type for context keys to avoid collisions
|
||||
@@ -67,12 +67,12 @@ const (
|
||||
|
||||
// NetworkEvent tracks network transitions during session
|
||||
type NetworkEvent struct {
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
FromNetwork string `json:"from_network"`
|
||||
ToNetwork string `json:"to_network"`
|
||||
ClientIP string `json:"client_ip"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
EventType string `json:"event_type"` // "switch", "resume", "refresh"
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
FromNetwork string `json:"from_network"`
|
||||
ToNetwork string `json:"to_network"`
|
||||
ClientIP string `json:"client_ip"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
EventType string `json:"event_type"` // "switch", "resume", "refresh"
|
||||
}
|
||||
|
||||
// UploadContext maintains upload state across network changes and network resilience channels
|
||||
@@ -244,11 +244,11 @@ func initializeSessionStore() {
|
||||
opt, err := redis.ParseURL(redisURL)
|
||||
if err == nil {
|
||||
sessionStore.redisClient = redis.NewClient(opt)
|
||||
|
||||
|
||||
// Test Redis connection
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
|
||||
if err := sessionStore.redisClient.Ping(ctx).Err(); err == nil {
|
||||
log.Infof("📊 Session store: Redis backend initialized (%s)", redisURL)
|
||||
} else {
|
||||
@@ -526,53 +526,57 @@ type BuildConfig struct {
|
||||
}
|
||||
|
||||
type NetworkResilienceConfig struct {
|
||||
FastDetection bool `toml:"fast_detection" mapstructure:"fast_detection"`
|
||||
QualityMonitoring bool `toml:"quality_monitoring" mapstructure:"quality_monitoring"`
|
||||
PredictiveSwitching bool `toml:"predictive_switching" mapstructure:"predictive_switching"`
|
||||
MobileOptimizations bool `toml:"mobile_optimizations" mapstructure:"mobile_optimizations"`
|
||||
DetectionInterval string `toml:"detection_interval" mapstructure:"detection_interval"`
|
||||
QualityCheckInterval string `toml:"quality_check_interval" mapstructure:"quality_check_interval"`
|
||||
MaxDetectionInterval string `toml:"max_detection_interval" mapstructure:"max_detection_interval"`
|
||||
|
||||
FastDetection bool `toml:"fast_detection" mapstructure:"fast_detection"`
|
||||
QualityMonitoring bool `toml:"quality_monitoring" mapstructure:"quality_monitoring"`
|
||||
PredictiveSwitching bool `toml:"predictive_switching" mapstructure:"predictive_switching"`
|
||||
MobileOptimizations bool `toml:"mobile_optimizations" mapstructure:"mobile_optimizations"`
|
||||
DetectionInterval string `toml:"detection_interval" mapstructure:"detection_interval"`
|
||||
QualityCheckInterval string `toml:"quality_check_interval" mapstructure:"quality_check_interval"`
|
||||
MaxDetectionInterval string `toml:"max_detection_interval" mapstructure:"max_detection_interval"`
|
||||
|
||||
// Multi-interface support
|
||||
MultiInterfaceEnabled bool `toml:"multi_interface_enabled" mapstructure:"multi_interface_enabled"`
|
||||
InterfacePriority []string `toml:"interface_priority" mapstructure:"interface_priority"`
|
||||
AutoSwitchEnabled bool `toml:"auto_switch_enabled" mapstructure:"auto_switch_enabled"`
|
||||
SwitchThresholdLatency string `toml:"switch_threshold_latency" mapstructure:"switch_threshold_latency"`
|
||||
SwitchThresholdPacketLoss float64 `toml:"switch_threshold_packet_loss" mapstructure:"switch_threshold_packet_loss"`
|
||||
MultiInterfaceEnabled bool `toml:"multi_interface_enabled" mapstructure:"multi_interface_enabled"`
|
||||
InterfacePriority []string `toml:"interface_priority" mapstructure:"interface_priority"`
|
||||
AutoSwitchEnabled bool `toml:"auto_switch_enabled" mapstructure:"auto_switch_enabled"`
|
||||
SwitchThresholdLatency string `toml:"switch_threshold_latency" mapstructure:"switch_threshold_latency"`
|
||||
SwitchThresholdPacketLoss float64 `toml:"switch_threshold_packet_loss" mapstructure:"switch_threshold_packet_loss"`
|
||||
QualityDegradationThreshold float64 `toml:"quality_degradation_threshold" mapstructure:"quality_degradation_threshold"`
|
||||
MaxSwitchAttempts int `toml:"max_switch_attempts" mapstructure:"max_switch_attempts"`
|
||||
SwitchDetectionInterval string `toml:"switch_detection_interval" mapstructure:"switch_detection_interval"`
|
||||
MaxSwitchAttempts int `toml:"max_switch_attempts" mapstructure:"max_switch_attempts"`
|
||||
SwitchDetectionInterval string `toml:"switch_detection_interval" mapstructure:"switch_detection_interval"`
|
||||
}
|
||||
|
||||
// ClientNetworkConfigTOML is used for loading from TOML where timeout is a string
|
||||
type ClientNetworkConfigTOML struct {
|
||||
SessionBasedTracking bool `toml:"session_based_tracking" mapstructure:"session_based_tracking"`
|
||||
AllowIPChanges bool `toml:"allow_ip_changes" mapstructure:"allow_ip_changes"`
|
||||
SessionMigrationTimeout string `toml:"session_migration_timeout" mapstructure:"session_migration_timeout"`
|
||||
MaxIPChangesPerSession int `toml:"max_ip_changes_per_session" mapstructure:"max_ip_changes_per_session"`
|
||||
ClientConnectionDetection bool `toml:"client_connection_detection" mapstructure:"client_connection_detection"`
|
||||
AdaptToClientNetwork bool `toml:"adapt_to_client_network" mapstructure:"adapt_to_client_network"`
|
||||
AllowIPChanges bool `toml:"allow_ip_changes" mapstructure:"allow_ip_changes"`
|
||||
SessionMigrationTimeout string `toml:"session_migration_timeout" mapstructure:"session_migration_timeout"`
|
||||
MaxIPChangesPerSession int `toml:"max_ip_changes_per_session" mapstructure:"max_ip_changes_per_session"`
|
||||
ClientConnectionDetection bool `toml:"client_connection_detection" mapstructure:"client_connection_detection"`
|
||||
AdaptToClientNetwork bool `toml:"adapt_to_client_network" mapstructure:"adapt_to_client_network"`
|
||||
}
|
||||
|
||||
// This is the main Config struct to be used
|
||||
type Config struct {
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
Logging LoggingConfig `mapstructure:"logging"`
|
||||
Deduplication DeduplicationConfig `mapstructure:"deduplication"` // Added
|
||||
ISO ISOConfig `mapstructure:"iso"` // Added
|
||||
Timeouts TimeoutConfig `mapstructure:"timeouts"` // Added
|
||||
Security SecurityConfig `mapstructure:"security"`
|
||||
Versioning VersioningConfig `mapstructure:"versioning"` // Added
|
||||
Uploads UploadsConfig `mapstructure:"uploads"`
|
||||
Downloads DownloadsConfig `mapstructure:"downloads"`
|
||||
ClamAV ClamAVConfig `mapstructure:"clamav"`
|
||||
Redis RedisConfig `mapstructure:"redis"`
|
||||
Workers WorkersConfig `mapstructure:"workers"`
|
||||
File FileConfig `mapstructure:"file"`
|
||||
Build BuildConfig `mapstructure:"build"`
|
||||
NetworkResilience NetworkResilienceConfig `mapstructure:"network_resilience"`
|
||||
ClientNetwork ClientNetworkConfigTOML `mapstructure:"client_network_support"`
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
Logging LoggingConfig `mapstructure:"logging"`
|
||||
Deduplication DeduplicationConfig `mapstructure:"deduplication"` // Added
|
||||
ISO ISOConfig `mapstructure:"iso"` // Added
|
||||
Timeouts TimeoutConfig `mapstructure:"timeouts"` // Added
|
||||
Security SecurityConfig `mapstructure:"security"`
|
||||
Versioning VersioningConfig `mapstructure:"versioning"` // Added
|
||||
Uploads UploadsConfig `mapstructure:"uploads"`
|
||||
Downloads DownloadsConfig `mapstructure:"downloads"`
|
||||
ClamAV ClamAVConfig `mapstructure:"clamav"`
|
||||
Redis RedisConfig `mapstructure:"redis"`
|
||||
Workers WorkersConfig `mapstructure:"workers"`
|
||||
File FileConfig `mapstructure:"file"`
|
||||
Build BuildConfig `mapstructure:"build"`
|
||||
NetworkResilience NetworkResilienceConfig `mapstructure:"network_resilience"`
|
||||
ClientNetwork ClientNetworkConfigTOML `mapstructure:"client_network_support"`
|
||||
Audit AuditConfig `mapstructure:"audit"` // Audit logging
|
||||
Validation ValidationConfig `mapstructure:"validation"` // Content validation
|
||||
Quotas QuotaConfig `mapstructure:"quotas"` // Per-user quotas
|
||||
Admin AdminConfig `mapstructure:"admin"` // Admin API
|
||||
}
|
||||
|
||||
type UploadTask struct {
|
||||
@@ -597,12 +601,12 @@ func processScan(task ScanTask) error {
|
||||
confMutex.RLock()
|
||||
clamEnabled := conf.ClamAV.ClamAVEnabled
|
||||
confMutex.RUnlock()
|
||||
|
||||
|
||||
if !clamEnabled {
|
||||
log.Infof("ClamAV disabled, skipping scan for file: %s", task.AbsFilename)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
log.Infof("Started processing scan for file: %s", task.AbsFilename)
|
||||
semaphore <- struct{}{}
|
||||
defer func() { <-semaphore }()
|
||||
@@ -621,8 +625,8 @@ var (
|
||||
conf Config
|
||||
versionString string
|
||||
log = logrus.New()
|
||||
fileInfoCache *cache.Cache //nolint:unused
|
||||
fileMetadataCache *cache.Cache //nolint:unused
|
||||
fileInfoCache *cache.Cache //nolint:unused
|
||||
fileMetadataCache *cache.Cache //nolint:unused
|
||||
clamClient *clamd.Clamd
|
||||
redisClient *redis.Client
|
||||
redisConnected bool
|
||||
@@ -673,6 +677,7 @@ var clientTracker *ClientConnectionTracker
|
||||
|
||||
//nolint:unused
|
||||
var logMessages []string
|
||||
|
||||
//nolint:unused
|
||||
var logMu sync.Mutex
|
||||
|
||||
@@ -748,7 +753,7 @@ func initializeNetworkProtocol(forceProtocol string) (*net.Dialer, error) {
|
||||
if forceProtocol == "" {
|
||||
forceProtocol = "auto"
|
||||
}
|
||||
|
||||
|
||||
switch forceProtocol {
|
||||
case "ipv4":
|
||||
return &net.Dialer{
|
||||
@@ -845,7 +850,7 @@ func main() {
|
||||
} else {
|
||||
content = GenerateMinimalConfig()
|
||||
}
|
||||
|
||||
|
||||
f, err := os.Create(genConfigPath)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to create file: %v\n", err)
|
||||
@@ -878,7 +883,7 @@ func main() {
|
||||
log.Fatalf("Failed to load configuration: %v", err)
|
||||
}
|
||||
conf = *loadedConfig
|
||||
configFileGlobal = configFile // Store for validation helper functions
|
||||
configFileGlobal = configFile // Store for validation helper functions
|
||||
log.Info("Configuration loaded successfully.")
|
||||
|
||||
err = validateConfig(&conf)
|
||||
@@ -892,12 +897,12 @@ func main() {
|
||||
|
||||
// Initialize client connection tracker for multi-interface support
|
||||
clientNetworkConfig := &ClientNetworkConfig{
|
||||
SessionBasedTracking: conf.ClientNetwork.SessionBasedTracking,
|
||||
AllowIPChanges: conf.ClientNetwork.AllowIPChanges,
|
||||
MaxIPChangesPerSession: conf.ClientNetwork.MaxIPChangesPerSession,
|
||||
AdaptToClientNetwork: conf.ClientNetwork.AdaptToClientNetwork,
|
||||
SessionBasedTracking: conf.ClientNetwork.SessionBasedTracking,
|
||||
AllowIPChanges: conf.ClientNetwork.AllowIPChanges,
|
||||
MaxIPChangesPerSession: conf.ClientNetwork.MaxIPChangesPerSession,
|
||||
AdaptToClientNetwork: conf.ClientNetwork.AdaptToClientNetwork,
|
||||
}
|
||||
|
||||
|
||||
// Parse session migration timeout
|
||||
if conf.ClientNetwork.SessionMigrationTimeout != "" {
|
||||
if timeout, err := time.ParseDuration(conf.ClientNetwork.SessionMigrationTimeout); err == nil {
|
||||
@@ -908,12 +913,12 @@ func main() {
|
||||
} else {
|
||||
clientNetworkConfig.SessionMigrationTimeout = 5 * time.Minute // default
|
||||
}
|
||||
|
||||
|
||||
// Set defaults if not configured
|
||||
if clientNetworkConfig.MaxIPChangesPerSession == 0 {
|
||||
clientNetworkConfig.MaxIPChangesPerSession = 10
|
||||
}
|
||||
|
||||
|
||||
// Initialize the client tracker
|
||||
clientTracker = NewClientConnectionTracker(clientNetworkConfig)
|
||||
if clientTracker != nil {
|
||||
@@ -1075,8 +1080,22 @@ func main() {
|
||||
initRedis() // Assuming initRedis is defined in helpers.go or elsewhere
|
||||
}
|
||||
|
||||
// Initialize new features
|
||||
if err := InitAuditLogger(&conf.Audit); err != nil {
|
||||
log.Warnf("Failed to initialize audit logger: %v", err)
|
||||
}
|
||||
|
||||
InitContentValidator(&conf.Validation)
|
||||
|
||||
if err := InitQuotaManager(&conf.Quotas, redisClient); err != nil {
|
||||
log.Warnf("Failed to initialize quota manager: %v", err)
|
||||
}
|
||||
|
||||
router := setupRouter() // Assuming setupRouter is defined (likely in this file or router.go
|
||||
|
||||
// Setup Admin API routes
|
||||
SetupAdminRoutes(router, &conf.Admin)
|
||||
|
||||
// Initialize enhancements and enhance the router
|
||||
InitializeEnhancements(router)
|
||||
|
||||
@@ -1658,7 +1677,7 @@ func validateBearerToken(r *http.Request, secret string) (*BearerTokenClaims, er
|
||||
query := r.URL.Query()
|
||||
user := query.Get("user")
|
||||
expiryStr := query.Get("expiry")
|
||||
|
||||
|
||||
if user == "" {
|
||||
return nil, errors.New("missing user parameter")
|
||||
}
|
||||
@@ -1674,10 +1693,10 @@ func validateBearerToken(r *http.Request, secret string) (*BearerTokenClaims, er
|
||||
|
||||
// ULTRA-FLEXIBLE GRACE PERIODS FOR NETWORK SWITCHING AND STANDBY SCENARIOS
|
||||
now := time.Now().Unix()
|
||||
|
||||
|
||||
// Base grace period: 8 hours (increased from 4 hours for better WiFi ↔ LTE reliability)
|
||||
gracePeriod := int64(28800) // 8 hours base grace period for all scenarios
|
||||
|
||||
|
||||
// Detect mobile XMPP clients and apply enhanced grace periods
|
||||
userAgent := r.Header.Get("User-Agent")
|
||||
isMobileXMPP := strings.Contains(strings.ToLower(userAgent), "conversations") ||
|
||||
@@ -1688,12 +1707,12 @@ func validateBearerToken(r *http.Request, secret string) (*BearerTokenClaims, er
|
||||
strings.Contains(strings.ToLower(userAgent), "xmpp") ||
|
||||
strings.Contains(strings.ToLower(userAgent), "client") ||
|
||||
strings.Contains(strings.ToLower(userAgent), "bot")
|
||||
|
||||
|
||||
// Enhanced XMPP client detection and grace period management
|
||||
// Desktop XMPP clients (Dino, Gajim) need extended grace for session restoration after restart
|
||||
isDesktopXMPP := strings.Contains(strings.ToLower(userAgent), "dino") ||
|
||||
strings.Contains(strings.ToLower(userAgent), "gajim")
|
||||
|
||||
|
||||
if isMobileXMPP || isDesktopXMPP {
|
||||
if isDesktopXMPP {
|
||||
gracePeriod = int64(86400) // 24 hours for desktop XMPP clients (session restoration)
|
||||
@@ -1703,32 +1722,32 @@ func validateBearerToken(r *http.Request, secret string) (*BearerTokenClaims, er
|
||||
log.Infof("<22> Mobile XMPP client detected (%s), using extended 12-hour grace period", userAgent)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Network resilience parameters for session recovery
|
||||
sessionId := query.Get("session_id")
|
||||
networkResilience := query.Get("network_resilience")
|
||||
resumeAllowed := query.Get("resume_allowed")
|
||||
|
||||
|
||||
// Maximum grace period for network resilience scenarios
|
||||
if sessionId != "" || networkResilience == "true" || resumeAllowed == "true" {
|
||||
gracePeriod = int64(86400) // 24 hours for explicit network resilience scenarios
|
||||
log.Infof("🌐 Network resilience mode activated (session_id: %s, network_resilience: %s), using 24-hour grace period",
|
||||
log.Infof("🌐 Network resilience mode activated (session_id: %s, network_resilience: %s), using 24-hour grace period",
|
||||
sessionId, networkResilience)
|
||||
}
|
||||
|
||||
|
||||
// Detect potential network switching scenarios
|
||||
clientIP := getClientIP(r)
|
||||
xForwardedFor := r.Header.Get("X-Forwarded-For")
|
||||
xRealIP := r.Header.Get("X-Real-IP")
|
||||
|
||||
|
||||
// Check for client IP change indicators (WiFi ↔ LTE switching detection)
|
||||
if xForwardedFor != "" || xRealIP != "" {
|
||||
// Client is behind proxy/NAT - likely mobile switching between networks
|
||||
gracePeriod = int64(86400) // 24 hours for proxy/NAT scenarios
|
||||
log.Infof("📱 Network switching detected (client IP: %s, X-Forwarded-For: %s, X-Real-IP: %s), using 24-hour grace period",
|
||||
log.Infof("📱 Network switching detected (client IP: %s, X-Forwarded-For: %s, X-Real-IP: %s), using 24-hour grace period",
|
||||
clientIP, xForwardedFor, xRealIP)
|
||||
}
|
||||
|
||||
|
||||
// Check Content-Length to identify large uploads that need extra time
|
||||
contentLength := r.Header.Get("Content-Length")
|
||||
var size int64 = 0
|
||||
@@ -1741,27 +1760,27 @@ func validateBearerToken(r *http.Request, secret string) (*BearerTokenClaims, er
|
||||
log.Infof("📁 Large file detected (%d bytes), extending grace period by %d seconds", size, additionalTime)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// ABSOLUTE MAXIMUM: 48 hours for extreme scenarios
|
||||
maxAbsoluteGrace := int64(172800) // 48 hours absolute maximum
|
||||
if gracePeriod > maxAbsoluteGrace {
|
||||
gracePeriod = maxAbsoluteGrace
|
||||
log.Infof("⚠️ Grace period capped at 48 hours maximum")
|
||||
}
|
||||
|
||||
|
||||
// STANDBY RECOVERY: Special handling for device standby scenarios
|
||||
isLikelyStandbyRecovery := false
|
||||
standbyGraceExtension := int64(86400) // Additional 24 hours for standby recovery
|
||||
|
||||
|
||||
if now > expiry {
|
||||
expiredTime := now - expiry
|
||||
|
||||
|
||||
// If token expired more than grace period but less than standby window, allow standby recovery
|
||||
if expiredTime > gracePeriod && expiredTime < (gracePeriod + standbyGraceExtension) {
|
||||
if expiredTime > gracePeriod && expiredTime < (gracePeriod+standbyGraceExtension) {
|
||||
isLikelyStandbyRecovery = true
|
||||
log.Infof("💤 STANDBY RECOVERY: Token expired %d seconds ago, within standby recovery window", expiredTime)
|
||||
}
|
||||
|
||||
|
||||
// Apply grace period check
|
||||
if expiredTime > gracePeriod && !isLikelyStandbyRecovery {
|
||||
// DESKTOP XMPP CLIENT SESSION RESTORATION: Special handling for Dino/Gajim restart scenarios
|
||||
@@ -1770,7 +1789,7 @@ func validateBearerToken(r *http.Request, secret string) (*BearerTokenClaims, er
|
||||
isDesktopSessionRestore = true
|
||||
log.Infof("🖥️ DESKTOP SESSION RESTORE: %s token expired %d seconds ago, allowing within 48-hour desktop restoration window", userAgent, expiredTime)
|
||||
}
|
||||
|
||||
|
||||
// Still apply ultra-generous final check for mobile scenarios
|
||||
ultraMaxGrace := int64(259200) // 72 hours ultra-maximum for critical mobile scenarios
|
||||
if (isMobileXMPP && expiredTime < ultraMaxGrace) || isDesktopSessionRestore {
|
||||
@@ -1778,9 +1797,9 @@ func validateBearerToken(r *http.Request, secret string) (*BearerTokenClaims, er
|
||||
log.Warnf("⚡ ULTRA-GRACE: Mobile XMPP client token expired %d seconds ago, allowing within 72-hour ultra-grace window", expiredTime)
|
||||
}
|
||||
} else {
|
||||
log.Warnf("❌ Bearer token expired beyond all grace periods: now=%d, expiry=%d, expired_for=%d seconds, grace_period=%d, user_agent=%s",
|
||||
log.Warnf("❌ Bearer token expired beyond all grace periods: now=%d, expiry=%d, expired_for=%d seconds, grace_period=%d, user_agent=%s",
|
||||
now, expiry, expiredTime, gracePeriod, userAgent)
|
||||
return nil, fmt.Errorf("token has expired beyond grace period (expired %d seconds ago, grace period: %d seconds)",
|
||||
return nil, fmt.Errorf("token has expired beyond grace period (expired %d seconds ago, grace period: %d seconds)",
|
||||
expiredTime, gracePeriod)
|
||||
}
|
||||
} else if isLikelyStandbyRecovery {
|
||||
@@ -1797,7 +1816,7 @@ func validateBearerToken(r *http.Request, secret string) (*BearerTokenClaims, er
|
||||
if len(pathParts) < 1 {
|
||||
return nil, errors.New("invalid upload path format")
|
||||
}
|
||||
|
||||
|
||||
// Handle different path formats from various ejabberd modules
|
||||
filename := ""
|
||||
if len(pathParts) >= 3 {
|
||||
@@ -1805,7 +1824,7 @@ func validateBearerToken(r *http.Request, secret string) (*BearerTokenClaims, er
|
||||
} else if len(pathParts) >= 1 {
|
||||
filename = pathParts[len(pathParts)-1] // Simplified format: /filename
|
||||
}
|
||||
|
||||
|
||||
if filename == "" {
|
||||
filename = "upload" // Fallback filename
|
||||
}
|
||||
@@ -1813,71 +1832,71 @@ func validateBearerToken(r *http.Request, secret string) (*BearerTokenClaims, er
|
||||
// ENHANCED HMAC VALIDATION: Try multiple payload formats for maximum compatibility
|
||||
var validPayload bool
|
||||
var payloadFormat string
|
||||
|
||||
|
||||
// Format 1: Network-resilient payload (mod_http_upload_hmac_network_resilient)
|
||||
extendedPayload := fmt.Sprintf("%s\x00%s\x00%d\x00%d\x00%d\x00network_resilient",
|
||||
extendedPayload := fmt.Sprintf("%s\x00%s\x00%d\x00%d\x00%d\x00network_resilient",
|
||||
user, filename, size, expiry-86400, expiry)
|
||||
h1 := hmac.New(sha256.New, []byte(secret))
|
||||
h1.Write([]byte(extendedPayload))
|
||||
expectedMAC1 := h1.Sum(nil)
|
||||
|
||||
|
||||
if hmac.Equal(tokenBytes, expectedMAC1) {
|
||||
validPayload = true
|
||||
payloadFormat = "network_resilient"
|
||||
}
|
||||
|
||||
|
||||
// Format 2: Extended payload with session support
|
||||
if !validPayload {
|
||||
sessionPayload := fmt.Sprintf("%s\x00%s\x00%d\x00%d\x00%s", user, filename, size, expiry, sessionId)
|
||||
h2 := hmac.New(sha256.New, []byte(secret))
|
||||
h2.Write([]byte(sessionPayload))
|
||||
expectedMAC2 := h2.Sum(nil)
|
||||
|
||||
|
||||
if hmac.Equal(tokenBytes, expectedMAC2) {
|
||||
validPayload = true
|
||||
payloadFormat = "session_based"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Format 3: Standard payload (original mod_http_upload_hmac)
|
||||
if !validPayload {
|
||||
standardPayload := fmt.Sprintf("%s\x00%s\x00%d\x00%d", user, filename, size, expiry-3600)
|
||||
h3 := hmac.New(sha256.New, []byte(secret))
|
||||
h3.Write([]byte(standardPayload))
|
||||
expectedMAC3 := h3.Sum(nil)
|
||||
|
||||
|
||||
if hmac.Equal(tokenBytes, expectedMAC3) {
|
||||
validPayload = true
|
||||
payloadFormat = "standard"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Format 4: Simplified payload (fallback compatibility)
|
||||
if !validPayload {
|
||||
simplePayload := fmt.Sprintf("%s\x00%s\x00%d", user, filename, size)
|
||||
h4 := hmac.New(sha256.New, []byte(secret))
|
||||
h4.Write([]byte(simplePayload))
|
||||
expectedMAC4 := h4.Sum(nil)
|
||||
|
||||
|
||||
if hmac.Equal(tokenBytes, expectedMAC4) {
|
||||
validPayload = true
|
||||
payloadFormat = "simple"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Format 5: User-only payload (maximum fallback)
|
||||
if !validPayload {
|
||||
userPayload := fmt.Sprintf("%s\x00%d", user, expiry)
|
||||
h5 := hmac.New(sha256.New, []byte(secret))
|
||||
h5.Write([]byte(userPayload))
|
||||
expectedMAC5 := h5.Sum(nil)
|
||||
|
||||
|
||||
if hmac.Equal(tokenBytes, expectedMAC5) {
|
||||
validPayload = true
|
||||
payloadFormat = "user_only"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if !validPayload {
|
||||
log.Warnf("❌ Invalid Bearer token HMAC for user %s, file %s (tried all 5 payload formats)", user, filename)
|
||||
return nil, errors.New("invalid Bearer token HMAC")
|
||||
@@ -1890,16 +1909,16 @@ func validateBearerToken(r *http.Request, secret string) (*BearerTokenClaims, er
|
||||
Expiry: expiry,
|
||||
}
|
||||
|
||||
log.Infof("✅ Bearer token authentication SUCCESSFUL: user=%s, file=%s, format=%s, grace_period=%d seconds",
|
||||
log.Infof("✅ Bearer token authentication SUCCESSFUL: user=%s, file=%s, format=%s, grace_period=%d seconds",
|
||||
user, filename, payloadFormat, gracePeriod)
|
||||
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// evaluateSecurityLevel determines the required security level based on network changes and standby detection
|
||||
func evaluateSecurityLevel(session *NetworkResilientSession, currentIP string, userAgent string) int {
|
||||
now := time.Now()
|
||||
|
||||
|
||||
// Initialize if this is the first check
|
||||
if session.LastSecurityCheck.IsZero() {
|
||||
session.LastSecurityCheck = now
|
||||
@@ -1907,50 +1926,50 @@ func evaluateSecurityLevel(session *NetworkResilientSession, currentIP string, u
|
||||
session.SecurityLevel = 1 // Normal level
|
||||
return 1
|
||||
}
|
||||
|
||||
|
||||
// Detect potential standby scenario
|
||||
timeSinceLastActivity := now.Sub(session.LastActivity)
|
||||
standbyThreshold := 30 * time.Minute
|
||||
|
||||
|
||||
if timeSinceLastActivity > standbyThreshold {
|
||||
session.StandbyDetected = true
|
||||
log.Infof("🔒 STANDBY DETECTED: %v since last activity for session %s", timeSinceLastActivity, session.SessionID)
|
||||
|
||||
|
||||
// Long standby requires full re-authentication
|
||||
if timeSinceLastActivity > 2*time.Hour {
|
||||
log.Warnf("🔐 SECURITY LEVEL 3: Long standby (%v) requires full re-authentication", timeSinceLastActivity)
|
||||
return 3
|
||||
}
|
||||
|
||||
|
||||
// Medium standby requires challenge-response
|
||||
log.Infof("🔐 SECURITY LEVEL 2: Medium standby (%v) requires challenge-response", timeSinceLastActivity)
|
||||
return 2
|
||||
}
|
||||
|
||||
|
||||
// Detect network changes
|
||||
if session.LastIP != "" && session.LastIP != currentIP {
|
||||
session.NetworkChangeCount++
|
||||
log.Infof("🌐 NETWORK CHANGE #%d: %s → %s for session %s",
|
||||
log.Infof("🌐 NETWORK CHANGE #%d: %s → %s for session %s",
|
||||
session.NetworkChangeCount, session.LastIP, currentIP, session.SessionID)
|
||||
|
||||
|
||||
// Multiple rapid network changes are suspicious
|
||||
if session.NetworkChangeCount > 3 {
|
||||
log.Warnf("🔐 SECURITY LEVEL 3: Multiple network changes (%d) requires full re-authentication",
|
||||
log.Warnf("🔐 SECURITY LEVEL 3: Multiple network changes (%d) requires full re-authentication",
|
||||
session.NetworkChangeCount)
|
||||
return 3
|
||||
}
|
||||
|
||||
|
||||
// Single network change requires challenge-response
|
||||
log.Infof("🔐 SECURITY LEVEL 2: Network change requires challenge-response")
|
||||
return 2
|
||||
}
|
||||
|
||||
|
||||
// Check for suspicious user agent changes
|
||||
if session.UserAgent != "" && session.UserAgent != userAgent {
|
||||
log.Warnf("🔐 SECURITY LEVEL 3: User agent change detected - potential device hijacking")
|
||||
return 3
|
||||
}
|
||||
|
||||
|
||||
// Normal operation
|
||||
return 1
|
||||
}
|
||||
@@ -1960,11 +1979,11 @@ func generateSecurityChallenge(session *NetworkResilientSession, secret string)
|
||||
// Create a time-based challenge using session data
|
||||
timestamp := time.Now().Unix()
|
||||
challengeData := fmt.Sprintf("%s:%s:%d", session.SessionID, session.UserJID, timestamp)
|
||||
|
||||
|
||||
h := hmac.New(sha256.New, []byte(secret))
|
||||
h.Write([]byte(challengeData))
|
||||
challenge := hex.EncodeToString(h.Sum(nil))
|
||||
|
||||
|
||||
log.Infof("🔐 Generated security challenge for session %s", session.SessionID)
|
||||
return challenge, nil
|
||||
}
|
||||
@@ -1974,22 +1993,22 @@ func validateSecurityChallenge(session *NetworkResilientSession, providedRespons
|
||||
// This would validate against the expected response
|
||||
// For now, we'll implement a simple time-window validation
|
||||
timestamp := time.Now().Unix()
|
||||
|
||||
|
||||
// Allow 5-minute window for challenge responses
|
||||
for i := int64(0); i <= 300; i += 60 {
|
||||
testTimestamp := timestamp - i
|
||||
challengeData := fmt.Sprintf("%s:%s:%d", session.SessionID, session.UserJID, testTimestamp)
|
||||
|
||||
|
||||
h := hmac.New(sha256.New, []byte(secret))
|
||||
h.Write([]byte(challengeData))
|
||||
expectedResponse := hex.EncodeToString(h.Sum(nil))
|
||||
|
||||
|
||||
if expectedResponse == providedResponse {
|
||||
log.Infof("✅ Security challenge validated for session %s", session.SessionID)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
log.Warnf("❌ Security challenge failed for session %s", session.SessionID)
|
||||
return false
|
||||
}
|
||||
@@ -2029,17 +2048,17 @@ func validateBearerTokenWithSession(r *http.Request, secret string) (*BearerToke
|
||||
session := sessionStore.GetSession(sessionID)
|
||||
if session == nil {
|
||||
session = &NetworkResilientSession{
|
||||
SessionID: sessionID,
|
||||
UserJID: claims.User,
|
||||
OriginalToken: getBearerTokenFromRequest(r),
|
||||
CreatedAt: time.Now(),
|
||||
MaxRefreshes: 10,
|
||||
NetworkHistory: []NetworkEvent{},
|
||||
SecurityLevel: 1,
|
||||
LastSecurityCheck: time.Now(),
|
||||
SessionID: sessionID,
|
||||
UserJID: claims.User,
|
||||
OriginalToken: getBearerTokenFromRequest(r),
|
||||
CreatedAt: time.Now(),
|
||||
MaxRefreshes: 10,
|
||||
NetworkHistory: []NetworkEvent{},
|
||||
SecurityLevel: 1,
|
||||
LastSecurityCheck: time.Now(),
|
||||
NetworkChangeCount: 0,
|
||||
StandbyDetected: false,
|
||||
LastActivity: time.Now(),
|
||||
StandbyDetected: false,
|
||||
LastActivity: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2069,7 +2088,7 @@ func validateBearerTokenWithSession(r *http.Request, secret string) (*BearerToke
|
||||
log.Errorf("❌ Failed to generate security challenge: %v", err)
|
||||
return nil, fmt.Errorf("security challenge generation failed")
|
||||
}
|
||||
|
||||
|
||||
// Check if client provided challenge response
|
||||
challengeResponse := r.Header.Get("X-Challenge-Response")
|
||||
if challengeResponse == "" {
|
||||
@@ -2077,15 +2096,15 @@ func validateBearerTokenWithSession(r *http.Request, secret string) (*BearerToke
|
||||
setSecurityHeaders(w, 2, challenge)
|
||||
return nil, fmt.Errorf("challenge-response required for network change")
|
||||
}
|
||||
|
||||
|
||||
// Validate challenge response
|
||||
if !validateSecurityChallenge(session, challengeResponse, secret) {
|
||||
setSecurityHeaders(w, 2, challenge)
|
||||
return nil, fmt.Errorf("invalid challenge response")
|
||||
}
|
||||
|
||||
|
||||
log.Infof("✅ Challenge-response validated for session %s", sessionID)
|
||||
|
||||
|
||||
case 3:
|
||||
// Full re-authentication required
|
||||
setSecurityHeaders(w, 3, "")
|
||||
@@ -2104,7 +2123,7 @@ func validateBearerTokenWithSession(r *http.Request, secret string) (*BearerToke
|
||||
UserAgent: userAgent,
|
||||
EventType: "network_switch",
|
||||
})
|
||||
log.Infof("🌐 Network switch detected for session %s: %s → %s",
|
||||
log.Infof("🌐 Network switch detected for session %s: %s → %s",
|
||||
sessionID, session.LastIP, currentIP)
|
||||
}
|
||||
|
||||
@@ -2138,7 +2157,7 @@ func validateBearerTokenWithSession(r *http.Request, secret string) (*BearerToke
|
||||
// Token refresh successful
|
||||
session.RefreshCount++
|
||||
session.LastSeen = time.Now()
|
||||
|
||||
|
||||
// Add refresh event to history
|
||||
session.NetworkHistory = append(session.NetworkHistory, NetworkEvent{
|
||||
Timestamp: time.Now(),
|
||||
@@ -2157,12 +2176,12 @@ func validateBearerTokenWithSession(r *http.Request, secret string) (*BearerToke
|
||||
Expiry: time.Now().Add(24 * time.Hour).Unix(),
|
||||
}
|
||||
|
||||
log.Infof("✅ Session recovery successful: %s (refresh #%d)",
|
||||
log.Infof("✅ Session recovery successful: %s (refresh #%d)",
|
||||
sessionID, session.RefreshCount)
|
||||
return refreshedClaims, nil
|
||||
}
|
||||
} else {
|
||||
log.Warnf("❌ Session %s exceeded maximum refreshes (%d)",
|
||||
log.Warnf("❌ Session %s exceeded maximum refreshes (%d)",
|
||||
sessionID, session.MaxRefreshes)
|
||||
}
|
||||
} else {
|
||||
@@ -2191,8 +2210,8 @@ func refreshSessionToken(session *NetworkResilientSession, secret string, r *htt
|
||||
size := extractSizeFromRequest(r)
|
||||
|
||||
// Use session-based payload format for refresh
|
||||
payload := fmt.Sprintf("%s\x00%s\x00%d\x00%d\x00%s\x00session_refresh",
|
||||
session.UserJID,
|
||||
payload := fmt.Sprintf("%s\x00%s\x00%d\x00%d\x00%s\x00session_refresh",
|
||||
session.UserJID,
|
||||
filename,
|
||||
size,
|
||||
expiry,
|
||||
@@ -2202,7 +2221,7 @@ func refreshSessionToken(session *NetworkResilientSession, secret string, r *htt
|
||||
h.Write([]byte(payload))
|
||||
token := base64.StdEncoding.EncodeToString(h.Sum(nil))
|
||||
|
||||
log.Infof("🆕 Generated refresh token for session %s (refresh #%d)",
|
||||
log.Infof("🆕 Generated refresh token for session %s (refresh #%d)",
|
||||
session.SessionID, session.RefreshCount+1)
|
||||
|
||||
return token, nil
|
||||
@@ -2251,7 +2270,7 @@ type BearerTokenClaims struct {
|
||||
// ENHANCED FOR 100% WIFI ↔ LTE SWITCHING AND STANDBY RECOVERY RELIABILITY
|
||||
func validateHMAC(r *http.Request, secret string) error {
|
||||
log.Debugf("🔍 validateHMAC: Validating request to %s with query: %s", r.URL.Path, r.URL.RawQuery)
|
||||
|
||||
|
||||
// Check for X-Signature header (for POST uploads)
|
||||
signature := r.Header.Get("X-Signature")
|
||||
if signature != "" {
|
||||
@@ -2294,7 +2313,7 @@ func validateHMAC(r *http.Request, secret string) error {
|
||||
// ENHANCED HMAC CALCULATION: Try multiple formats for maximum compatibility
|
||||
var validMAC bool
|
||||
var messageFormat string
|
||||
|
||||
|
||||
// Calculate HMAC based on protocol version with enhanced compatibility
|
||||
mac := hmac.New(sha256.New, []byte(secret))
|
||||
|
||||
@@ -2305,7 +2324,7 @@ func validateHMAC(r *http.Request, secret string) error {
|
||||
mac.Write([]byte(message1))
|
||||
calculatedMAC1 := mac.Sum(nil)
|
||||
calculatedMACHex1 := hex.EncodeToString(calculatedMAC1)
|
||||
|
||||
|
||||
// Decode provided MAC
|
||||
if providedMAC, err := hex.DecodeString(providedMACHex); err == nil {
|
||||
if hmac.Equal(calculatedMAC1, providedMAC) {
|
||||
@@ -2314,14 +2333,14 @@ func validateHMAC(r *http.Request, secret string) error {
|
||||
log.Debugf("✅ Legacy v protocol HMAC validated: %s", calculatedMACHex1)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Format 2: Try without content length for compatibility
|
||||
if !validMAC {
|
||||
message2 := fileStorePath
|
||||
mac.Reset()
|
||||
mac.Write([]byte(message2))
|
||||
calculatedMAC2 := mac.Sum(nil)
|
||||
|
||||
|
||||
if providedMAC, err := hex.DecodeString(providedMACHex); err == nil {
|
||||
if hmac.Equal(calculatedMAC2, providedMAC) {
|
||||
validMAC = true
|
||||
@@ -2333,14 +2352,14 @@ func validateHMAC(r *http.Request, secret string) error {
|
||||
} else {
|
||||
// v2 and token protocols: Enhanced format compatibility
|
||||
contentType := GetContentType(fileStorePath)
|
||||
|
||||
|
||||
// Format 1: Standard format - fileStorePath + "\x00" + contentLength + "\x00" + contentType
|
||||
message1 := fileStorePath + "\x00" + strconv.FormatInt(r.ContentLength, 10) + "\x00" + contentType
|
||||
mac.Reset()
|
||||
mac.Write([]byte(message1))
|
||||
calculatedMAC1 := mac.Sum(nil)
|
||||
calculatedMACHex1 := hex.EncodeToString(calculatedMAC1)
|
||||
|
||||
|
||||
if providedMAC, err := hex.DecodeString(providedMACHex); err == nil {
|
||||
if hmac.Equal(calculatedMAC1, providedMAC) {
|
||||
validMAC = true
|
||||
@@ -2348,14 +2367,14 @@ func validateHMAC(r *http.Request, secret string) error {
|
||||
log.Debugf("✅ %s protocol HMAC validated (standard): %s", protocolVersion, calculatedMACHex1)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Format 2: Without content type for compatibility
|
||||
if !validMAC {
|
||||
message2 := fileStorePath + "\x00" + strconv.FormatInt(r.ContentLength, 10)
|
||||
mac.Reset()
|
||||
mac.Write([]byte(message2))
|
||||
calculatedMAC2 := mac.Sum(nil)
|
||||
|
||||
|
||||
if providedMAC, err := hex.DecodeString(providedMACHex); err == nil {
|
||||
if hmac.Equal(calculatedMAC2, providedMAC) {
|
||||
validMAC = true
|
||||
@@ -2364,14 +2383,14 @@ func validateHMAC(r *http.Request, secret string) error {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Format 3: Simple path only for maximum compatibility
|
||||
if !validMAC {
|
||||
message3 := fileStorePath
|
||||
mac.Reset()
|
||||
mac.Write([]byte(message3))
|
||||
calculatedMAC3 := mac.Sum(nil)
|
||||
|
||||
|
||||
if providedMAC, err := hex.DecodeString(providedMACHex); err == nil {
|
||||
if hmac.Equal(calculatedMAC3, providedMAC) {
|
||||
validMAC = true
|
||||
@@ -2387,7 +2406,7 @@ func validateHMAC(r *http.Request, secret string) error {
|
||||
return fmt.Errorf("invalid MAC for %s protocol", protocolVersion)
|
||||
}
|
||||
|
||||
log.Infof("✅ %s HMAC authentication SUCCESSFUL: format=%s, path=%s",
|
||||
log.Infof("✅ %s HMAC authentication SUCCESSFUL: format=%s, path=%s",
|
||||
protocolVersion, messageFormat, r.URL.Path)
|
||||
return nil
|
||||
}
|
||||
@@ -2417,11 +2436,11 @@ func validateV3HMAC(r *http.Request, secret string) error {
|
||||
|
||||
// ULTRA-FLEXIBLE GRACE PERIODS FOR V3 PROTOCOL NETWORK SWITCHING
|
||||
now := time.Now().Unix()
|
||||
|
||||
|
||||
if now > expires {
|
||||
// Base grace period: 8 hours (significantly increased for WiFi ↔ LTE reliability)
|
||||
gracePeriod := int64(28800) // 8 hours base grace period
|
||||
|
||||
|
||||
// Enhanced mobile XMPP client detection
|
||||
userAgent := r.Header.Get("User-Agent")
|
||||
isMobileXMPP := strings.Contains(strings.ToLower(userAgent), "gajim") ||
|
||||
@@ -2432,12 +2451,12 @@ func validateV3HMAC(r *http.Request, secret string) error {
|
||||
strings.Contains(strings.ToLower(userAgent), "xmpp") ||
|
||||
strings.Contains(strings.ToLower(userAgent), "client") ||
|
||||
strings.Contains(strings.ToLower(userAgent), "bot")
|
||||
|
||||
|
||||
if isMobileXMPP {
|
||||
gracePeriod = int64(43200) // 12 hours for mobile XMPP clients
|
||||
log.Infof("📱 V3: Mobile XMPP client detected (%s), using 12-hour grace period", userAgent)
|
||||
}
|
||||
|
||||
|
||||
// Network resilience parameters for V3 protocol
|
||||
sessionId := query.Get("session_id")
|
||||
networkResilience := query.Get("network_resilience")
|
||||
@@ -2446,19 +2465,19 @@ func validateV3HMAC(r *http.Request, secret string) error {
|
||||
gracePeriod = int64(86400) // 24 hours for network resilience scenarios
|
||||
log.Infof("🌐 V3: Network resilience mode detected, using 24-hour grace period")
|
||||
}
|
||||
|
||||
|
||||
// Detect network switching indicators
|
||||
clientIP := getClientIP(r)
|
||||
xForwardedFor := r.Header.Get("X-Forwarded-For")
|
||||
xRealIP := r.Header.Get("X-Real-IP")
|
||||
|
||||
|
||||
if xForwardedFor != "" || xRealIP != "" {
|
||||
// Client behind proxy/NAT - likely mobile network switching
|
||||
gracePeriod = int64(86400) // 24 hours for proxy/NAT scenarios
|
||||
log.Infof("🔄 V3: Network switching detected (IP: %s, X-Forwarded-For: %s), using 24-hour grace period",
|
||||
log.Infof("🔄 V3: Network switching detected (IP: %s, X-Forwarded-For: %s), using 24-hour grace period",
|
||||
clientIP, xForwardedFor)
|
||||
}
|
||||
|
||||
|
||||
// Large file uploads get additional grace time
|
||||
if contentLengthStr := r.Header.Get("Content-Length"); contentLengthStr != "" {
|
||||
if contentLength, parseErr := strconv.ParseInt(contentLengthStr, 10, 64); parseErr == nil {
|
||||
@@ -2466,33 +2485,33 @@ func validateV3HMAC(r *http.Request, secret string) error {
|
||||
if contentLength > 10*1024*1024 {
|
||||
additionalTime := (contentLength / (10 * 1024 * 1024)) * 3600 // 1 hour per 10MB
|
||||
gracePeriod += additionalTime
|
||||
log.Infof("📁 V3: Large file (%d bytes), extending grace period by %d seconds",
|
||||
log.Infof("📁 V3: Large file (%d bytes), extending grace period by %d seconds",
|
||||
contentLength, additionalTime)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Maximum grace period cap: 48 hours
|
||||
maxGracePeriod := int64(172800) // 48 hours absolute maximum
|
||||
if gracePeriod > maxGracePeriod {
|
||||
gracePeriod = maxGracePeriod
|
||||
log.Infof("⚠️ V3: Grace period capped at 48 hours maximum")
|
||||
}
|
||||
|
||||
|
||||
// STANDBY RECOVERY: Handle device standby scenarios
|
||||
expiredTime := now - expires
|
||||
standbyGraceExtension := int64(86400) // Additional 24 hours for standby
|
||||
isLikelyStandbyRecovery := expiredTime > gracePeriod && expiredTime < (gracePeriod + standbyGraceExtension)
|
||||
|
||||
isLikelyStandbyRecovery := expiredTime > gracePeriod && expiredTime < (gracePeriod+standbyGraceExtension)
|
||||
|
||||
if expiredTime > gracePeriod && !isLikelyStandbyRecovery {
|
||||
// Ultra-generous final check for mobile scenarios
|
||||
ultraMaxGrace := int64(259200) // 72 hours ultra-maximum for critical scenarios
|
||||
if isMobileXMPP && expiredTime < ultraMaxGrace {
|
||||
log.Warnf("⚡ V3 ULTRA-GRACE: Mobile client token expired %d seconds ago, allowing within 72-hour window", expiredTime)
|
||||
} else {
|
||||
log.Warnf("❌ V3 signature expired beyond all grace periods: now=%d, expires=%d, expired_for=%d seconds, grace_period=%d, user_agent=%s",
|
||||
log.Warnf("❌ V3 signature expired beyond all grace periods: now=%d, expires=%d, expired_for=%d seconds, grace_period=%d, user_agent=%s",
|
||||
now, expires, expiredTime, gracePeriod, userAgent)
|
||||
return fmt.Errorf("signature has expired beyond grace period (expired %d seconds ago, grace period: %d seconds)",
|
||||
return fmt.Errorf("signature has expired beyond grace period (expired %d seconds ago, grace period: %d seconds)",
|
||||
expiredTime, gracePeriod)
|
||||
}
|
||||
} else if isLikelyStandbyRecovery {
|
||||
@@ -2507,18 +2526,18 @@ func validateV3HMAC(r *http.Request, secret string) error {
|
||||
// ENHANCED MESSAGE CONSTRUCTION: Try multiple formats for compatibility
|
||||
var validSignature bool
|
||||
var messageFormat string
|
||||
|
||||
|
||||
// Format 1: Standard v3 format
|
||||
message1 := fmt.Sprintf("%s\n%s\n%s", r.Method, expiresStr, r.URL.Path)
|
||||
h1 := hmac.New(sha256.New, []byte(secret))
|
||||
h1.Write([]byte(message1))
|
||||
expectedSignature1 := hex.EncodeToString(h1.Sum(nil))
|
||||
|
||||
|
||||
if hmac.Equal([]byte(signature), []byte(expectedSignature1)) {
|
||||
validSignature = true
|
||||
messageFormat = "standard_v3"
|
||||
}
|
||||
|
||||
|
||||
// Format 2: Alternative format with query string
|
||||
if !validSignature {
|
||||
pathWithQuery := r.URL.Path
|
||||
@@ -2529,32 +2548,32 @@ func validateV3HMAC(r *http.Request, secret string) error {
|
||||
h2 := hmac.New(sha256.New, []byte(secret))
|
||||
h2.Write([]byte(message2))
|
||||
expectedSignature2 := hex.EncodeToString(h2.Sum(nil))
|
||||
|
||||
|
||||
if hmac.Equal([]byte(signature), []byte(expectedSignature2)) {
|
||||
validSignature = true
|
||||
messageFormat = "with_query"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Format 3: Simplified format (fallback)
|
||||
if !validSignature {
|
||||
message3 := fmt.Sprintf("%s\n%s", r.Method, r.URL.Path)
|
||||
h3 := hmac.New(sha256.New, []byte(secret))
|
||||
h3.Write([]byte(message3))
|
||||
expectedSignature3 := hex.EncodeToString(h3.Sum(nil))
|
||||
|
||||
|
||||
if hmac.Equal([]byte(signature), []byte(expectedSignature3)) {
|
||||
validSignature = true
|
||||
messageFormat = "simplified"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if !validSignature {
|
||||
log.Warnf("❌ Invalid V3 HMAC signature (tried all 3 formats)")
|
||||
return errors.New("invalid v3 HMAC signature")
|
||||
}
|
||||
|
||||
log.Infof("✅ V3 HMAC authentication SUCCESSFUL: format=%s, method=%s, path=%s",
|
||||
log.Infof("✅ V3 HMAC authentication SUCCESSFUL: format=%s, method=%s, path=%s",
|
||||
messageFormat, r.Method, r.URL.Path)
|
||||
return nil
|
||||
}
|
||||
@@ -2563,7 +2582,7 @@ func validateV3HMAC(r *http.Request, secret string) error {
|
||||
func copyWithProgressTracking(dst io.Writer, src io.Reader, buf []byte, totalSize int64, clientIP string) (int64, error) {
|
||||
var written int64
|
||||
lastLogTime := time.Now()
|
||||
|
||||
|
||||
for {
|
||||
n, err := src.Read(buf)
|
||||
if n > 0 {
|
||||
@@ -2572,12 +2591,12 @@ func copyWithProgressTracking(dst io.Writer, src io.Reader, buf []byte, totalSiz
|
||||
if werr != nil {
|
||||
return written, werr
|
||||
}
|
||||
|
||||
|
||||
// Log progress for large files every 10MB or 30 seconds
|
||||
if totalSize > 50*1024*1024 &&
|
||||
if totalSize > 50*1024*1024 &&
|
||||
(written%10*1024*1024 == 0 || time.Since(lastLogTime) > 30*time.Second) {
|
||||
progress := float64(written) / float64(totalSize) * 100
|
||||
log.Infof("📥 Download progress: %.1f%% (%s/%s) for IP %s",
|
||||
log.Infof("📥 Download progress: %.1f%% (%s/%s) for IP %s",
|
||||
progress, formatBytes(written), formatBytes(totalSize), clientIP)
|
||||
lastLogTime = time.Now()
|
||||
}
|
||||
@@ -2589,7 +2608,7 @@ func copyWithProgressTracking(dst io.Writer, src io.Reader, buf []byte, totalSiz
|
||||
return written, err
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return written, nil
|
||||
}
|
||||
|
||||
@@ -2606,11 +2625,11 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
|
||||
// Generate session ID for multi-upload tracking
|
||||
sessionID = generateUploadSessionID("upload", r.Header.Get("User-Agent"), getClientIP(r))
|
||||
}
|
||||
|
||||
|
||||
// Set session headers for client continuation
|
||||
w.Header().Set("X-Session-ID", sessionID)
|
||||
w.Header().Set("X-Upload-Session-Timeout", "3600") // 1 hour
|
||||
|
||||
|
||||
// Only allow POST method
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
@@ -2621,22 +2640,22 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
|
||||
// ENHANCED AUTHENTICATION with network switching support
|
||||
var bearerClaims *BearerTokenClaims
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
|
||||
|
||||
if strings.HasPrefix(authHeader, "Bearer ") {
|
||||
// Bearer token authentication with session recovery for network switching
|
||||
// Store response writer in context for session headers
|
||||
ctx := context.WithValue(r.Context(), responseWriterKey, w)
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
|
||||
claims, err := validateBearerTokenWithSession(r, conf.Security.Secret)
|
||||
if err != nil {
|
||||
// Enhanced error logging for network switching scenarios
|
||||
clientIP := getClientIP(r)
|
||||
userAgent := r.Header.Get("User-Agent")
|
||||
sessionID := getSessionIDFromRequest(r)
|
||||
log.Warnf("🔴 Authentication failed for IP %s, User-Agent: %s, Session: %s, Error: %v",
|
||||
log.Warnf("🔴 Authentication failed for IP %s, User-Agent: %s, Session: %s, Error: %v",
|
||||
clientIP, userAgent, sessionID, err)
|
||||
|
||||
|
||||
// Check if this might be a network switching scenario and provide helpful response
|
||||
if strings.Contains(err.Error(), "expired") || strings.Contains(err.Error(), "invalid") {
|
||||
w.Header().Set("X-Network-Switch-Detected", "true")
|
||||
@@ -2646,15 +2665,17 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Session-ID", sessionID)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
AuditAuthFailure(r, "bearer_token", err.Error())
|
||||
http.Error(w, fmt.Sprintf("Bearer Token Authentication failed: %v", err), http.StatusUnauthorized)
|
||||
uploadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
AuditAuthSuccess(r, claims.User, "bearer_token")
|
||||
bearerClaims = claims
|
||||
log.Infof("✅ Bearer token authentication successful: user=%s, file=%s, IP=%s",
|
||||
log.Infof("✅ Bearer token authentication successful: user=%s, file=%s, IP=%s",
|
||||
claims.User, claims.Filename, getClientIP(r))
|
||||
|
||||
|
||||
// Add comprehensive response headers for audit logging and client tracking
|
||||
w.Header().Set("X-Authenticated-User", claims.User)
|
||||
w.Header().Set("X-Auth-Method", "Bearer-Token")
|
||||
@@ -2665,10 +2686,12 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := validateJWTFromRequest(r, conf.Security.JWTSecret)
|
||||
if err != nil {
|
||||
log.Warnf("🔴 JWT Authentication failed for IP %s: %v", getClientIP(r), err)
|
||||
AuditAuthFailure(r, "jwt", err.Error())
|
||||
http.Error(w, fmt.Sprintf("JWT Authentication failed: %v", err), http.StatusUnauthorized)
|
||||
uploadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
AuditAuthSuccess(r, "", "jwt")
|
||||
log.Infof("✅ JWT authentication successful for upload request: %s", r.URL.Path)
|
||||
w.Header().Set("X-Auth-Method", "JWT")
|
||||
} else {
|
||||
@@ -2676,10 +2699,12 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
|
||||
err := validateHMAC(r, conf.Security.Secret)
|
||||
if err != nil {
|
||||
log.Warnf("🔴 HMAC Authentication failed for IP %s: %v", getClientIP(r), err)
|
||||
AuditAuthFailure(r, "hmac", err.Error())
|
||||
http.Error(w, fmt.Sprintf("HMAC Authentication failed: %v", err), http.StatusUnauthorized)
|
||||
uploadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
AuditAuthSuccess(r, "", "hmac")
|
||||
log.Infof("✅ HMAC authentication successful for upload request: %s", r.URL.Path)
|
||||
w.Header().Set("X-Auth-Method", "HMAC")
|
||||
}
|
||||
@@ -2699,30 +2724,30 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
|
||||
// Generate new session ID with enhanced entropy
|
||||
sessionID = generateSessionID("", "")
|
||||
}
|
||||
|
||||
|
||||
clientIP := getClientIP(r)
|
||||
|
||||
|
||||
// Detect potential network switching
|
||||
xForwardedFor := r.Header.Get("X-Forwarded-For")
|
||||
xRealIP := r.Header.Get("X-Real-IP")
|
||||
networkSwitchIndicators := xForwardedFor != "" || xRealIP != ""
|
||||
|
||||
|
||||
if networkSwitchIndicators {
|
||||
log.Infof("🔄 Network switching indicators detected: session=%s, client_ip=%s, x_forwarded_for=%s, x_real_ip=%s",
|
||||
log.Infof("🔄 Network switching indicators detected: session=%s, client_ip=%s, x_forwarded_for=%s, x_real_ip=%s",
|
||||
sessionID, clientIP, xForwardedFor, xRealIP)
|
||||
w.Header().Set("X-Network-Switch-Detected", "true")
|
||||
}
|
||||
|
||||
|
||||
clientSession = clientTracker.TrackClientSession(sessionID, clientIP, r)
|
||||
|
||||
|
||||
// Enhanced session response headers for client coordination
|
||||
w.Header().Set("X-Upload-Session-ID", sessionID)
|
||||
w.Header().Set("X-Session-IP-Count", fmt.Sprintf("%d", len(clientSession.ClientIPs)))
|
||||
w.Header().Set("X-Connection-Type", clientSession.ConnectionType)
|
||||
|
||||
log.Infof("🔗 Client session tracking: %s from IP %s (connection: %s, total_ips: %d)",
|
||||
|
||||
log.Infof("🔗 Client session tracking: %s from IP %s (connection: %s, total_ips: %d)",
|
||||
sessionID, clientIP, clientSession.ConnectionType, len(clientSession.ClientIPs))
|
||||
|
||||
|
||||
// Add user context for Bearer token authentication
|
||||
if bearerClaims != nil {
|
||||
log.Infof("👤 Session associated with XMPP user: %s", bearerClaims.User)
|
||||
@@ -2749,6 +2774,57 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Get user JID for quota and audit tracking
|
||||
var userJID string
|
||||
if bearerClaims != nil {
|
||||
userJID = bearerClaims.User
|
||||
}
|
||||
r.Header.Set("X-User-JID", userJID)
|
||||
r.Header.Set("X-File-Name", header.Filename)
|
||||
|
||||
// Check quota before upload
|
||||
if qm := GetQuotaManager(); qm != nil && qm.config.Enabled && userJID != "" {
|
||||
canUpload, _ := qm.CanUpload(r.Context(), userJID, header.Size)
|
||||
if !canUpload {
|
||||
used, limit, _ := qm.GetUsage(r.Context(), userJID)
|
||||
AuditQuotaExceeded(r, userJID, used, limit, header.Size)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("X-Quota-Used", fmt.Sprintf("%d", used))
|
||||
w.Header().Set("X-Quota-Limit", fmt.Sprintf("%d", limit))
|
||||
w.WriteHeader(http.StatusRequestEntityTooLarge)
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"error": "quota_exceeded",
|
||||
"message": "Storage quota exceeded",
|
||||
"used": used,
|
||||
"limit": limit,
|
||||
"requested": header.Size,
|
||||
})
|
||||
uploadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Content type validation using magic bytes
|
||||
var fileReader io.Reader = file
|
||||
declaredContentType := header.Header.Get("Content-Type")
|
||||
detectedContentType := declaredContentType
|
||||
|
||||
if validator := GetContentValidator(); validator != nil && validator.config.CheckMagicBytes {
|
||||
validatedReader, detected, validErr := validator.ValidateContent(file, declaredContentType, header.Size)
|
||||
if validErr != nil {
|
||||
if valErr, ok := validErr.(*ValidationError); ok {
|
||||
AuditValidationFailure(r, userJID, header.Filename, declaredContentType, detected, valErr.Code)
|
||||
WriteValidationError(w, valErr)
|
||||
} else {
|
||||
http.Error(w, validErr.Error(), http.StatusUnsupportedMediaType)
|
||||
}
|
||||
uploadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
fileReader = validatedReader
|
||||
detectedContentType = detected
|
||||
}
|
||||
|
||||
// Validate file size against max_upload_size if configured
|
||||
if conf.Server.MaxUploadSize != "" {
|
||||
maxSizeBytes, err := parseSize(conf.Server.MaxUploadSize)
|
||||
@@ -2759,9 +2835,9 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
if header.Size > maxSizeBytes {
|
||||
log.Warnf("⚠️ File size %s exceeds maximum allowed size %s (IP: %s)",
|
||||
log.Warnf("⚠️ File size %s exceeds maximum allowed size %s (IP: %s)",
|
||||
formatBytes(header.Size), conf.Server.MaxUploadSize, getClientIP(r))
|
||||
http.Error(w, fmt.Sprintf("File size %s exceeds maximum allowed size %s",
|
||||
http.Error(w, fmt.Sprintf("File size %s exceeds maximum allowed size %s",
|
||||
formatBytes(header.Size), conf.Server.MaxUploadSize), http.StatusRequestEntityTooLarge)
|
||||
uploadErrorsTotal.Inc()
|
||||
return
|
||||
@@ -2815,20 +2891,20 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
|
||||
uploadsTotal.Inc()
|
||||
uploadSizeBytes.Observe(float64(existingFileInfo.Size()))
|
||||
filesDeduplicatedTotal.Inc()
|
||||
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("X-Deduplication-Hit", "true")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
response := map[string]interface{}{
|
||||
"success": true,
|
||||
"filename": filename,
|
||||
"size": existingFileInfo.Size(),
|
||||
"message": "File already exists (deduplication hit)",
|
||||
"success": true,
|
||||
"filename": filename,
|
||||
"size": existingFileInfo.Size(),
|
||||
"message": "File already exists (deduplication hit)",
|
||||
"upload_time": duration.String(),
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(response)
|
||||
|
||||
log.Infof("💾 Deduplication hit: file %s already exists (%s), returning success immediately (IP: %s)",
|
||||
|
||||
log.Infof("💾 Deduplication hit: file %s already exists (%s), returning success immediately (IP: %s)",
|
||||
filename, formatBytes(existingFileInfo.Size()), getClientIP(r))
|
||||
return
|
||||
}
|
||||
@@ -2855,30 +2931,43 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
|
||||
uploadCtx = networkManager.RegisterUpload(networkSessionID)
|
||||
defer networkManager.UnregisterUpload(networkSessionID)
|
||||
log.Infof("🌐 Registered upload with network resilience: session=%s, IP=%s", networkSessionID, getClientIP(r))
|
||||
|
||||
|
||||
// Add network resilience headers
|
||||
w.Header().Set("X-Network-Resilience", "enabled")
|
||||
w.Header().Set("X-Upload-Context-ID", networkSessionID)
|
||||
}
|
||||
|
||||
// Copy file content with network resilience support and enhanced progress tracking
|
||||
written, err := copyWithNetworkResilience(dst, file, uploadCtx)
|
||||
// Use fileReader which may be wrapped with content validation
|
||||
written, err := copyWithNetworkResilience(dst, fileReader, uploadCtx)
|
||||
if err != nil {
|
||||
log.Errorf("🔴 Error saving file %s (IP: %s, session: %s): %v", filename, getClientIP(r), sessionID, err)
|
||||
http.Error(w, fmt.Sprintf("Error saving file: %v", err), http.StatusInternalServerError)
|
||||
uploadErrorsTotal.Inc()
|
||||
// Clean up partial file
|
||||
os.Remove(absFilename)
|
||||
// Audit the failure
|
||||
AuditUploadFailure(r, userJID, header.Filename, header.Size, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Update quota after successful upload
|
||||
if qm := GetQuotaManager(); qm != nil && qm.config.Enabled && userJID != "" {
|
||||
if err := qm.RecordUpload(r.Context(), userJID, absFilename, written); err != nil {
|
||||
log.Warnf("⚠️ Failed to update quota for user %s: %v", userJID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Audit successful upload
|
||||
AuditUploadSuccess(r, userJID, filename, written, detectedContentType)
|
||||
|
||||
// ✅ CRITICAL FIX: Send immediate success response for large files (>1GB)
|
||||
// This prevents client timeouts while server does post-processing
|
||||
isLargeFile := header.Size > 1024*1024*1024 // 1GB threshold
|
||||
|
||||
|
||||
if isLargeFile {
|
||||
log.Infof("🚀 Large file detected (%s), sending immediate success response", formatBytes(header.Size))
|
||||
|
||||
|
||||
// Send immediate success response to client
|
||||
duration := time.Since(startTime)
|
||||
uploadDuration.Observe(duration.Seconds())
|
||||
@@ -2893,12 +2982,12 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
response := map[string]interface{}{
|
||||
"success": true,
|
||||
"filename": filename,
|
||||
"size": written,
|
||||
"duration": duration.String(),
|
||||
"client_ip": getClientIP(r),
|
||||
"timestamp": time.Now().Unix(),
|
||||
"success": true,
|
||||
"filename": filename,
|
||||
"size": written,
|
||||
"duration": duration.String(),
|
||||
"client_ip": getClientIP(r),
|
||||
"timestamp": time.Now().Unix(),
|
||||
"post_processing": "background",
|
||||
}
|
||||
|
||||
@@ -2921,7 +3010,7 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintf(w, `{"success": true, "filename": "%s", "size": %d, "post_processing": "background"}`, filename, written)
|
||||
}
|
||||
|
||||
log.Infof("✅ Immediate response sent for large file %s (%s) in %s from IP %s",
|
||||
log.Infof("✅ Immediate response sent for large file %s (%s) in %s from IP %s",
|
||||
filename, formatBytes(written), duration, getClientIP(r))
|
||||
|
||||
// Process deduplication asynchronously for large files
|
||||
@@ -2936,7 +3025,7 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
|
||||
log.Infof("✅ Background deduplication completed for %s", filename)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Add to scan queue for virus scanning if enabled
|
||||
if conf.ClamAV.ClamAVEnabled && len(conf.ClamAV.ScanFileExtensions) > 0 {
|
||||
ext := strings.ToLower(filepath.Ext(header.Filename))
|
||||
@@ -2958,7 +3047,7 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2987,10 +3076,10 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
response := map[string]interface{}{
|
||||
"success": true,
|
||||
"filename": filename,
|
||||
"size": written,
|
||||
"duration": duration.String(),
|
||||
"success": true,
|
||||
"filename": filename,
|
||||
"size": written,
|
||||
"duration": duration.String(),
|
||||
"client_ip": getClientIP(r),
|
||||
"timestamp": time.Now().Unix(),
|
||||
}
|
||||
@@ -3014,7 +3103,7 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintf(w, `{"success": true, "filename": "%s", "size": %d}`, filename, written)
|
||||
}
|
||||
|
||||
log.Infof("✅ Successfully uploaded %s (%s) in %s from IP %s (session: %s)",
|
||||
log.Infof("✅ Successfully uploaded %s (%s) in %s from IP %s (session: %s)",
|
||||
filename, formatBytes(written), duration, getClientIP(r), sessionID)
|
||||
}
|
||||
|
||||
@@ -3030,20 +3119,24 @@ func handleDownload(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := validateJWTFromRequest(r, conf.Security.JWTSecret)
|
||||
if err != nil {
|
||||
log.Warnf("🔴 JWT Authentication failed for download from IP %s: %v", getClientIP(r), err)
|
||||
AuditAuthFailure(r, "jwt", err.Error())
|
||||
http.Error(w, fmt.Sprintf("JWT Authentication failed: %v", err), http.StatusUnauthorized)
|
||||
downloadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
AuditAuthSuccess(r, "", "jwt")
|
||||
log.Infof("✅ JWT authentication successful for download request: %s", r.URL.Path)
|
||||
w.Header().Set("X-Auth-Method", "JWT")
|
||||
} else {
|
||||
err := validateHMAC(r, conf.Security.Secret)
|
||||
if err != nil {
|
||||
log.Warnf("🔴 HMAC Authentication failed for download from IP %s: %v", getClientIP(r), err)
|
||||
AuditAuthFailure(r, "hmac", err.Error())
|
||||
http.Error(w, fmt.Sprintf("HMAC Authentication failed: %v", err), http.StatusUnauthorized)
|
||||
downloadErrorsTotal.Inc()
|
||||
return
|
||||
}
|
||||
AuditAuthSuccess(r, "", "hmac")
|
||||
log.Infof("✅ HMAC authentication successful for download request: %s", r.URL.Path)
|
||||
w.Header().Set("X-Auth-Method", "HMAC")
|
||||
}
|
||||
@@ -3060,13 +3153,13 @@ func handleDownload(w http.ResponseWriter, r *http.Request) {
|
||||
// Enhanced file path validation and construction
|
||||
var absFilename string
|
||||
var err error
|
||||
|
||||
|
||||
// Use storage path or ISO mount point
|
||||
storagePath := conf.Server.StoragePath
|
||||
if conf.ISO.Enabled {
|
||||
storagePath = conf.ISO.MountPoint
|
||||
}
|
||||
|
||||
|
||||
absFilename, err = sanitizeFilePath(storagePath, filename)
|
||||
if err != nil {
|
||||
log.Warnf("🔴 Invalid file path requested from IP %s: %s, error: %v", getClientIP(r), filename, err)
|
||||
@@ -3079,12 +3172,12 @@ func handleDownload(w http.ResponseWriter, r *http.Request) {
|
||||
fileInfo, err := os.Stat(absFilename)
|
||||
if os.IsNotExist(err) {
|
||||
log.Warnf("🔴 File not found: %s (requested by IP %s)", absFilename, getClientIP(r))
|
||||
|
||||
|
||||
// Enhanced 404 response with network switching hints
|
||||
w.Header().Set("X-File-Not-Found", "true")
|
||||
w.Header().Set("X-Client-IP", getClientIP(r))
|
||||
w.Header().Set("X-Network-Switch-Support", "enabled")
|
||||
|
||||
|
||||
// Check if this might be a network switching issue
|
||||
userAgent := r.Header.Get("User-Agent")
|
||||
isMobileXMPP := strings.Contains(strings.ToLower(userAgent), "conversations") ||
|
||||
@@ -3093,13 +3186,13 @@ func handleDownload(w http.ResponseWriter, r *http.Request) {
|
||||
strings.Contains(strings.ToLower(userAgent), "android") ||
|
||||
strings.Contains(strings.ToLower(userAgent), "mobile") ||
|
||||
strings.Contains(strings.ToLower(userAgent), "xmpp")
|
||||
|
||||
|
||||
if isMobileXMPP {
|
||||
w.Header().Set("X-Mobile-Client-Detected", "true")
|
||||
w.Header().Set("X-Retry-Suggestion", "30") // Suggest retry after 30 seconds
|
||||
log.Infof("📱 Mobile XMPP client file not found - may be network switching issue: %s", userAgent)
|
||||
}
|
||||
|
||||
|
||||
http.Error(w, "File not found", http.StatusNotFound)
|
||||
downloadErrorsTotal.Inc()
|
||||
return
|
||||
@@ -3126,13 +3219,13 @@ func handleDownload(w http.ResponseWriter, r *http.Request) {
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
|
||||
|
||||
if attempt < maxRetries {
|
||||
log.Warnf("⚠️ Attempt %d/%d: Error opening file %s from IP %s: %v (retrying...)",
|
||||
log.Warnf("⚠️ Attempt %d/%d: Error opening file %s from IP %s: %v (retrying...)",
|
||||
attempt, maxRetries, absFilename, getClientIP(r), err)
|
||||
time.Sleep(time.Duration(attempt) * time.Second) // Progressive backoff
|
||||
} else {
|
||||
log.Errorf("🔴 Failed to open file %s after %d attempts from IP %s: %v",
|
||||
log.Errorf("🔴 Failed to open file %s after %d attempts from IP %s: %v",
|
||||
absFilename, maxRetries, getClientIP(r), err)
|
||||
http.Error(w, fmt.Sprintf("Error opening file: %v", err), http.StatusInternalServerError)
|
||||
downloadErrorsTotal.Inc()
|
||||
@@ -3149,7 +3242,7 @@ func handleDownload(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Network-Switch-Support", "enabled")
|
||||
w.Header().Set("X-File-Path", filename)
|
||||
w.Header().Set("X-Download-Start-Time", fmt.Sprintf("%d", time.Now().Unix()))
|
||||
|
||||
|
||||
// Add cache control headers for mobile network optimization
|
||||
userAgent := r.Header.Get("User-Agent")
|
||||
isMobileXMPP := strings.Contains(strings.ToLower(userAgent), "conversations") ||
|
||||
@@ -3158,7 +3251,7 @@ func handleDownload(w http.ResponseWriter, r *http.Request) {
|
||||
strings.Contains(strings.ToLower(userAgent), "android") ||
|
||||
strings.Contains(strings.ToLower(userAgent), "mobile") ||
|
||||
strings.Contains(strings.ToLower(userAgent), "xmpp")
|
||||
|
||||
|
||||
if isMobileXMPP {
|
||||
w.Header().Set("X-Mobile-Client-Detected", "true")
|
||||
w.Header().Set("Cache-Control", "public, max-age=86400") // 24 hours cache for mobile
|
||||
@@ -3173,7 +3266,7 @@ func handleDownload(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Track download progress for large files
|
||||
if fileInfo.Size() > 10*1024*1024 { // Log progress for files > 10MB
|
||||
log.Infof("📥 Starting download of %s (%.1f MiB) for IP %s",
|
||||
log.Infof("📥 Starting download of %s (%.1f MiB) for IP %s",
|
||||
filepath.Base(absFilename), float64(fileInfo.Size())/(1024*1024), getClientIP(r))
|
||||
}
|
||||
|
||||
@@ -3191,8 +3284,11 @@ func handleDownload(w http.ResponseWriter, r *http.Request) {
|
||||
downloadDuration.Observe(duration.Seconds())
|
||||
downloadsTotal.Inc()
|
||||
downloadSizeBytes.Observe(float64(n))
|
||||
|
||||
log.Infof("✅ Successfully downloaded %s (%s) in %s for IP %s (session complete)",
|
||||
|
||||
// Audit successful download
|
||||
AuditDownloadSuccess(r, "", filepath.Base(absFilename), n)
|
||||
|
||||
log.Infof("✅ Successfully downloaded %s (%s) in %s for IP %s (session complete)",
|
||||
filepath.Base(absFilename), formatBytes(n), duration, getClientIP(r))
|
||||
}
|
||||
|
||||
@@ -3262,7 +3358,7 @@ func handleV3Upload(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
if r.ContentLength > maxSizeBytes {
|
||||
http.Error(w, fmt.Sprintf("File size %s exceeds maximum allowed size %s",
|
||||
http.Error(w, fmt.Sprintf("File size %s exceeds maximum allowed size %s",
|
||||
formatBytes(r.ContentLength), conf.Server.MaxUploadSize), http.StatusRequestEntityTooLarge)
|
||||
uploadErrorsTotal.Inc()
|
||||
return
|
||||
@@ -3298,7 +3394,7 @@ func handleV3Upload(w http.ResponseWriter, r *http.Request) {
|
||||
uploadsTotal.Inc()
|
||||
uploadSizeBytes.Observe(float64(existingFileInfo.Size()))
|
||||
filesDeduplicatedTotal.Inc()
|
||||
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
response := map[string]interface{}{
|
||||
@@ -3308,8 +3404,8 @@ func handleV3Upload(w http.ResponseWriter, r *http.Request) {
|
||||
"message": "File already exists (deduplication hit)",
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(response)
|
||||
|
||||
log.Infof("Deduplication hit: file %s already exists (%s), returning success immediately",
|
||||
|
||||
log.Infof("Deduplication hit: file %s already exists (%s), returning success immediately",
|
||||
filename, formatBytes(existingFileInfo.Size()))
|
||||
return
|
||||
}
|
||||
@@ -3337,10 +3433,10 @@ func handleV3Upload(w http.ResponseWriter, r *http.Request) {
|
||||
// ✅ CRITICAL FIX: Send immediate success response for large files (>1GB)
|
||||
// This prevents client timeouts while server does post-processing
|
||||
isLargeFile := written > 1024*1024*1024 // 1GB threshold
|
||||
|
||||
|
||||
if isLargeFile {
|
||||
log.Infof("🚀 Large file detected (%s), sending immediate success response (v3)", formatBytes(written))
|
||||
|
||||
|
||||
// Send immediate success response to client
|
||||
duration := time.Since(startTime)
|
||||
uploadDuration.Observe(duration.Seconds())
|
||||
@@ -3355,11 +3451,11 @@ func handleV3Upload(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
response := map[string]interface{}{
|
||||
"success": true,
|
||||
"filename": filename,
|
||||
"size": written,
|
||||
"duration": duration.String(),
|
||||
"protocol": "v3",
|
||||
"success": true,
|
||||
"filename": filename,
|
||||
"size": written,
|
||||
"duration": duration.String(),
|
||||
"protocol": "v3",
|
||||
"post_processing": "background",
|
||||
}
|
||||
|
||||
@@ -3370,7 +3466,7 @@ func handleV3Upload(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintf(w, `{"success": true, "filename": "%s", "size": %d, "post_processing": "background"}`, filename, written)
|
||||
}
|
||||
|
||||
log.Infof("✅ Immediate response sent for large file %s (%s) in %s via v3 protocol",
|
||||
log.Infof("✅ Immediate response sent for large file %s (%s) in %s via v3 protocol",
|
||||
filename, formatBytes(written), duration)
|
||||
|
||||
// Process deduplication asynchronously for large files
|
||||
@@ -3385,7 +3481,7 @@ func handleV3Upload(w http.ResponseWriter, r *http.Request) {
|
||||
log.Infof("✅ Background deduplication completed for %s (v3)", filename)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Add to scan queue for virus scanning if enabled
|
||||
if conf.ClamAV.ClamAVEnabled && len(conf.ClamAV.ScanFileExtensions) > 0 {
|
||||
ext := strings.ToLower(filepath.Ext(originalFilename))
|
||||
@@ -3407,7 +3503,7 @@ func handleV3Upload(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -3462,7 +3558,7 @@ func handleLegacyUpload(w http.ResponseWriter, r *http.Request) {
|
||||
// Generate session ID for XMPP multi-upload tracking
|
||||
sessionID = generateUploadSessionID("legacy", r.Header.Get("User-Agent"), getClientIP(r))
|
||||
}
|
||||
|
||||
|
||||
// Set session headers for XMPP client continuation
|
||||
w.Header().Set("X-Session-ID", sessionID)
|
||||
w.Header().Set("X-Upload-Session-Timeout", "3600") // 1 hour
|
||||
@@ -3531,7 +3627,7 @@ func handleLegacyUpload(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
if r.ContentLength > maxSizeBytes {
|
||||
http.Error(w, fmt.Sprintf("File size %s exceeds maximum allowed size %s",
|
||||
http.Error(w, fmt.Sprintf("File size %s exceeds maximum allowed size %s",
|
||||
formatBytes(r.ContentLength), conf.Server.MaxUploadSize), http.StatusRequestEntityTooLarge)
|
||||
uploadErrorsTotal.Inc()
|
||||
return
|
||||
@@ -3582,9 +3678,9 @@ func handleLegacyUpload(w http.ResponseWriter, r *http.Request) {
|
||||
uploadsTotal.Inc()
|
||||
uploadSizeBytes.Observe(float64(existingFileInfo.Size()))
|
||||
filesDeduplicatedTotal.Inc()
|
||||
|
||||
|
||||
w.WriteHeader(http.StatusCreated) // 201 Created for legacy compatibility
|
||||
log.Infof("Deduplication hit: file %s already exists (%s), returning success immediately",
|
||||
log.Infof("Deduplication hit: file %s already exists (%s), returning success immediately",
|
||||
filename, formatBytes(existingFileInfo.Size()))
|
||||
return
|
||||
}
|
||||
@@ -3617,10 +3713,10 @@ func handleLegacyUpload(w http.ResponseWriter, r *http.Request) {
|
||||
// ✅ CRITICAL FIX: Send immediate success response for large files (>1GB)
|
||||
// This prevents client timeouts while server does post-processing
|
||||
isLargeFile := written > 1024*1024*1024 // 1GB threshold
|
||||
|
||||
|
||||
if isLargeFile {
|
||||
log.Infof("🚀 Large file detected (%s), sending immediate success response (legacy)", formatBytes(written))
|
||||
|
||||
|
||||
// Send immediate success response to client
|
||||
duration := time.Since(startTime)
|
||||
uploadDuration.Observe(duration.Seconds())
|
||||
@@ -3634,7 +3730,7 @@ func handleLegacyUpload(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Post-Processing", "background")
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
|
||||
log.Infof("✅ Immediate response sent for large file %s (%s) in %s via legacy protocol",
|
||||
log.Infof("✅ Immediate response sent for large file %s (%s) in %s via legacy protocol",
|
||||
filename, formatBytes(written), duration)
|
||||
|
||||
// Process deduplication asynchronously for large files
|
||||
@@ -3649,7 +3745,7 @@ func handleLegacyUpload(w http.ResponseWriter, r *http.Request) {
|
||||
log.Infof("✅ Background deduplication completed for %s (legacy)", filename)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Add to scan queue for virus scanning if enabled
|
||||
if conf.ClamAV.ClamAVEnabled && len(conf.ClamAV.ScanFileExtensions) > 0 {
|
||||
ext := strings.ToLower(filepath.Ext(fileStorePath))
|
||||
@@ -3671,7 +3767,7 @@ func handleLegacyUpload(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
557
cmd/server/quota.go
Normal file
557
cmd/server/quota.go
Normal file
@@ -0,0 +1,557 @@
|
||||
// quota.go - Per-user storage quota management
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// QuotaConfig holds quota configuration
|
||||
type QuotaConfig struct {
|
||||
Enabled bool `toml:"enabled" mapstructure:"enabled"`
|
||||
Default string `toml:"default" mapstructure:"default"` // Default quota (e.g., "100MB")
|
||||
Tracking string `toml:"tracking" mapstructure:"tracking"` // "redis" | "memory"
|
||||
Custom map[string]string `toml:"custom" mapstructure:"custom"` // Custom quotas per JID
|
||||
}
|
||||
|
||||
// QuotaInfo contains quota information for a user
|
||||
type QuotaInfo struct {
|
||||
JID string `json:"jid"`
|
||||
Used int64 `json:"used"`
|
||||
Limit int64 `json:"limit"`
|
||||
Remaining int64 `json:"remaining"`
|
||||
FileCount int64 `json:"file_count"`
|
||||
IsCustom bool `json:"is_custom"`
|
||||
}
|
||||
|
||||
// QuotaExceededError represents a quota exceeded error
|
||||
type QuotaExceededError struct {
|
||||
JID string `json:"jid"`
|
||||
Used int64 `json:"used"`
|
||||
Limit int64 `json:"limit"`
|
||||
Requested int64 `json:"requested"`
|
||||
}
|
||||
|
||||
func (e *QuotaExceededError) Error() string {
|
||||
return fmt.Sprintf("quota exceeded for %s: used %d, limit %d, requested %d",
|
||||
e.JID, e.Used, e.Limit, e.Requested)
|
||||
}
|
||||
|
||||
// QuotaManager handles per-user storage quotas
|
||||
type QuotaManager struct {
|
||||
config *QuotaConfig
|
||||
redisClient *redis.Client
|
||||
defaultQuota int64
|
||||
customQuotas map[string]int64
|
||||
|
||||
// In-memory fallback when Redis is unavailable
|
||||
memoryUsage map[string]int64
|
||||
memoryFiles map[string]map[string]int64 // jid -> filePath -> size
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
var (
|
||||
quotaManager *QuotaManager
|
||||
quotaOnce sync.Once
|
||||
)
|
||||
|
||||
// Redis key patterns
|
||||
const (
|
||||
quotaUsedKey = "quota:%s:used" // quota:{jid}:used -> int64
|
||||
quotaFilesKey = "quota:%s:files" // quota:{jid}:files -> HASH {path: size}
|
||||
quotaInfoKey = "quota:%s:info" // quota:{jid}:info -> JSON
|
||||
)
|
||||
|
||||
// InitQuotaManager initializes the quota manager
|
||||
func InitQuotaManager(config *QuotaConfig, redisClient *redis.Client) error {
|
||||
var initErr error
|
||||
quotaOnce.Do(func() {
|
||||
quotaManager = &QuotaManager{
|
||||
config: config,
|
||||
redisClient: redisClient,
|
||||
customQuotas: make(map[string]int64),
|
||||
memoryUsage: make(map[string]int64),
|
||||
memoryFiles: make(map[string]map[string]int64),
|
||||
}
|
||||
|
||||
// Parse default quota
|
||||
if config.Default != "" {
|
||||
quota, err := parseSize(config.Default)
|
||||
if err != nil {
|
||||
initErr = fmt.Errorf("invalid default quota: %w", err)
|
||||
return
|
||||
}
|
||||
quotaManager.defaultQuota = quota
|
||||
} else {
|
||||
quotaManager.defaultQuota = 100 * 1024 * 1024 // 100MB default
|
||||
}
|
||||
|
||||
// Parse custom quotas
|
||||
for jid, quotaStr := range config.Custom {
|
||||
quota, err := parseSize(quotaStr)
|
||||
if err != nil {
|
||||
log.Warnf("Invalid custom quota for %s: %v", jid, err)
|
||||
continue
|
||||
}
|
||||
quotaManager.customQuotas[strings.ToLower(jid)] = quota
|
||||
}
|
||||
|
||||
log.Infof("Quota manager initialized: enabled=%v, default=%s, custom=%d users, tracking=%s",
|
||||
config.Enabled, config.Default, len(config.Custom), config.Tracking)
|
||||
})
|
||||
|
||||
return initErr
|
||||
}
|
||||
|
||||
// GetQuotaManager returns the singleton quota manager
|
||||
func GetQuotaManager() *QuotaManager {
|
||||
return quotaManager
|
||||
}
|
||||
|
||||
// GetLimit returns the quota limit for a user
|
||||
func (q *QuotaManager) GetLimit(jid string) int64 {
|
||||
if jid == "" {
|
||||
return q.defaultQuota
|
||||
}
|
||||
|
||||
jidLower := strings.ToLower(jid)
|
||||
if custom, ok := q.customQuotas[jidLower]; ok {
|
||||
return custom
|
||||
}
|
||||
return q.defaultQuota
|
||||
}
|
||||
|
||||
// GetUsage returns the current storage usage for a user
|
||||
func (q *QuotaManager) GetUsage(ctx context.Context, jid string) (used, limit int64, err error) {
|
||||
if !q.config.Enabled {
|
||||
return 0, 0, nil
|
||||
}
|
||||
|
||||
limit = q.GetLimit(jid)
|
||||
|
||||
// Try Redis first
|
||||
if q.redisClient != nil && q.config.Tracking == "redis" {
|
||||
key := fmt.Sprintf(quotaUsedKey, jid)
|
||||
usedStr, err := q.redisClient.Get(ctx, key).Result()
|
||||
if err == redis.Nil {
|
||||
return 0, limit, nil
|
||||
}
|
||||
if err != nil {
|
||||
log.Warnf("Failed to get quota from Redis, falling back to memory: %v", err)
|
||||
} else {
|
||||
used, _ = strconv.ParseInt(usedStr, 10, 64)
|
||||
return used, limit, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to memory
|
||||
q.mutex.RLock()
|
||||
used = q.memoryUsage[jid]
|
||||
q.mutex.RUnlock()
|
||||
|
||||
return used, limit, nil
|
||||
}
|
||||
|
||||
// GetQuotaInfo returns detailed quota information for a user
|
||||
func (q *QuotaManager) GetQuotaInfo(ctx context.Context, jid string) (*QuotaInfo, error) {
|
||||
used, limit, err := q.GetUsage(ctx, jid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fileCount := int64(0)
|
||||
|
||||
// Get file count
|
||||
if q.redisClient != nil && q.config.Tracking == "redis" {
|
||||
key := fmt.Sprintf(quotaFilesKey, jid)
|
||||
count, err := q.redisClient.HLen(ctx, key).Result()
|
||||
if err == nil {
|
||||
fileCount = count
|
||||
}
|
||||
} else {
|
||||
q.mutex.RLock()
|
||||
if files, ok := q.memoryFiles[jid]; ok {
|
||||
fileCount = int64(len(files))
|
||||
}
|
||||
q.mutex.RUnlock()
|
||||
}
|
||||
|
||||
_, isCustom := q.customQuotas[strings.ToLower(jid)]
|
||||
|
||||
return &QuotaInfo{
|
||||
JID: jid,
|
||||
Used: used,
|
||||
Limit: limit,
|
||||
Remaining: limit - used,
|
||||
FileCount: fileCount,
|
||||
IsCustom: isCustom,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CanUpload checks if a user can upload a file of the given size
|
||||
func (q *QuotaManager) CanUpload(ctx context.Context, jid string, size int64) (bool, error) {
|
||||
if !q.config.Enabled {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
used, limit, err := q.GetUsage(ctx, jid)
|
||||
if err != nil {
|
||||
// On error, allow upload but log warning
|
||||
log.Warnf("Failed to check quota for %s, allowing upload: %v", jid, err)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return used+size <= limit, nil
|
||||
}
|
||||
|
||||
// RecordUpload records a file upload for quota tracking
|
||||
func (q *QuotaManager) RecordUpload(ctx context.Context, jid, filePath string, size int64) error {
|
||||
if !q.config.Enabled || jid == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try Redis first with atomic operation
|
||||
if q.redisClient != nil && q.config.Tracking == "redis" {
|
||||
pipe := q.redisClient.TxPipeline()
|
||||
|
||||
usedKey := fmt.Sprintf(quotaUsedKey, jid)
|
||||
filesKey := fmt.Sprintf(quotaFilesKey, jid)
|
||||
|
||||
pipe.IncrBy(ctx, usedKey, size)
|
||||
pipe.HSet(ctx, filesKey, filePath, size)
|
||||
|
||||
_, err := pipe.Exec(ctx)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to record upload in Redis: %v", err)
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to memory
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
|
||||
q.memoryUsage[jid] += size
|
||||
|
||||
if q.memoryFiles[jid] == nil {
|
||||
q.memoryFiles[jid] = make(map[string]int64)
|
||||
}
|
||||
q.memoryFiles[jid][filePath] = size
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RecordDelete records a file deletion for quota tracking
|
||||
func (q *QuotaManager) RecordDelete(ctx context.Context, jid, filePath string, size int64) error {
|
||||
if !q.config.Enabled || jid == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If size is 0, try to get it from tracking
|
||||
if size == 0 {
|
||||
size = q.getFileSize(ctx, jid, filePath)
|
||||
}
|
||||
|
||||
// Try Redis first
|
||||
if q.redisClient != nil && q.config.Tracking == "redis" {
|
||||
pipe := q.redisClient.TxPipeline()
|
||||
|
||||
usedKey := fmt.Sprintf(quotaUsedKey, jid)
|
||||
filesKey := fmt.Sprintf(quotaFilesKey, jid)
|
||||
|
||||
pipe.DecrBy(ctx, usedKey, size)
|
||||
pipe.HDel(ctx, filesKey, filePath)
|
||||
|
||||
_, err := pipe.Exec(ctx)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to record delete in Redis: %v", err)
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to memory
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
|
||||
q.memoryUsage[jid] -= size
|
||||
if q.memoryUsage[jid] < 0 {
|
||||
q.memoryUsage[jid] = 0
|
||||
}
|
||||
|
||||
if q.memoryFiles[jid] != nil {
|
||||
delete(q.memoryFiles[jid], filePath)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getFileSize retrieves the size of a tracked file
|
||||
func (q *QuotaManager) getFileSize(ctx context.Context, jid, filePath string) int64 {
|
||||
// Try Redis
|
||||
if q.redisClient != nil && q.config.Tracking == "redis" {
|
||||
key := fmt.Sprintf(quotaFilesKey, jid)
|
||||
sizeStr, err := q.redisClient.HGet(ctx, key, filePath).Result()
|
||||
if err == nil {
|
||||
size, _ := strconv.ParseInt(sizeStr, 10, 64)
|
||||
return size
|
||||
}
|
||||
}
|
||||
|
||||
// Try memory
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
if files, ok := q.memoryFiles[jid]; ok {
|
||||
return files[filePath]
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
// SetCustomQuota sets a custom quota for a user
|
||||
func (q *QuotaManager) SetCustomQuota(jid string, quota int64) {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
q.customQuotas[strings.ToLower(jid)] = quota
|
||||
}
|
||||
|
||||
// RemoveCustomQuota removes a custom quota for a user
|
||||
func (q *QuotaManager) RemoveCustomQuota(jid string) {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
delete(q.customQuotas, strings.ToLower(jid))
|
||||
}
|
||||
|
||||
// GetAllQuotas returns quota info for all tracked users
|
||||
func (q *QuotaManager) GetAllQuotas(ctx context.Context) ([]QuotaInfo, error) {
|
||||
var quotas []QuotaInfo
|
||||
|
||||
// Get from Redis
|
||||
if q.redisClient != nil && q.config.Tracking == "redis" {
|
||||
// Scan for all quota keys
|
||||
iter := q.redisClient.Scan(ctx, 0, "quota:*:used", 100).Iterator()
|
||||
for iter.Next(ctx) {
|
||||
key := iter.Val()
|
||||
// Extract JID from key
|
||||
parts := strings.Split(key, ":")
|
||||
if len(parts) >= 2 {
|
||||
jid := parts[1]
|
||||
info, err := q.GetQuotaInfo(ctx, jid)
|
||||
if err == nil {
|
||||
quotas = append(quotas, *info)
|
||||
}
|
||||
}
|
||||
}
|
||||
return quotas, iter.Err()
|
||||
}
|
||||
|
||||
// Get from memory
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
for jid, used := range q.memoryUsage {
|
||||
limit := q.GetLimit(jid)
|
||||
fileCount := int64(0)
|
||||
if files, ok := q.memoryFiles[jid]; ok {
|
||||
fileCount = int64(len(files))
|
||||
}
|
||||
_, isCustom := q.customQuotas[strings.ToLower(jid)]
|
||||
|
||||
quotas = append(quotas, QuotaInfo{
|
||||
JID: jid,
|
||||
Used: used,
|
||||
Limit: limit,
|
||||
Remaining: limit - used,
|
||||
FileCount: fileCount,
|
||||
IsCustom: isCustom,
|
||||
})
|
||||
}
|
||||
|
||||
return quotas, nil
|
||||
}
|
||||
|
||||
// Reconcile recalculates quota usage from actual file storage
|
||||
func (q *QuotaManager) Reconcile(ctx context.Context, jid string, files map[string]int64) error {
|
||||
if !q.config.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
var totalSize int64
|
||||
for _, size := range files {
|
||||
totalSize += size
|
||||
}
|
||||
|
||||
// Update Redis
|
||||
if q.redisClient != nil && q.config.Tracking == "redis" {
|
||||
usedKey := fmt.Sprintf(quotaUsedKey, jid)
|
||||
filesKey := fmt.Sprintf(quotaFilesKey, jid)
|
||||
|
||||
pipe := q.redisClient.TxPipeline()
|
||||
pipe.Set(ctx, usedKey, totalSize, 0)
|
||||
pipe.Del(ctx, filesKey)
|
||||
|
||||
for path, size := range files {
|
||||
pipe.HSet(ctx, filesKey, path, size)
|
||||
}
|
||||
|
||||
_, err := pipe.Exec(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to reconcile quota in Redis: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update memory
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
|
||||
q.memoryUsage[jid] = totalSize
|
||||
q.memoryFiles[jid] = files
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckQuotaMiddleware is a middleware that checks quota before upload
|
||||
func CheckQuotaMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
qm := GetQuotaManager()
|
||||
if qm == nil || !qm.config.Enabled {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Only check for upload methods
|
||||
if r.Method != http.MethodPut && r.Method != http.MethodPost {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Get JID from context/headers
|
||||
jid := r.Header.Get("X-User-JID")
|
||||
if jid == "" {
|
||||
// Try to get from authorization context
|
||||
if claims, ok := r.Context().Value(contextKey("bearerClaims")).(*BearerTokenClaims); ok {
|
||||
jid = claims.User
|
||||
}
|
||||
}
|
||||
|
||||
if jid == "" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Check quota
|
||||
ctx := r.Context()
|
||||
canUpload, err := qm.CanUpload(ctx, jid, r.ContentLength)
|
||||
if err != nil {
|
||||
log.Warnf("Error checking quota: %v", err)
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if !canUpload {
|
||||
used, limit, _ := qm.GetUsage(ctx, jid)
|
||||
|
||||
// Log to audit
|
||||
AuditQuotaExceeded(r, jid, used, limit, r.ContentLength)
|
||||
|
||||
// Return 413 with quota info
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("X-Quota-Used", strconv.FormatInt(used, 10))
|
||||
w.Header().Set("X-Quota-Limit", strconv.FormatInt(limit, 10))
|
||||
w.Header().Set("X-Quota-Remaining", strconv.FormatInt(limit-used, 10))
|
||||
w.WriteHeader(http.StatusRequestEntityTooLarge)
|
||||
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"error": "quota_exceeded",
|
||||
"message": "Storage quota exceeded",
|
||||
"used": used,
|
||||
"limit": limit,
|
||||
"requested": r.ContentLength,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Add quota headers
|
||||
used, limit, _ := qm.GetUsage(ctx, jid)
|
||||
w.Header().Set("X-Quota-Used", strconv.FormatInt(used, 10))
|
||||
w.Header().Set("X-Quota-Limit", strconv.FormatInt(limit, 10))
|
||||
w.Header().Set("X-Quota-Remaining", strconv.FormatInt(limit-used, 10))
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateQuotaAfterUpload updates quota after successful upload
|
||||
func UpdateQuotaAfterUpload(ctx context.Context, jid, filePath string, size int64) {
|
||||
qm := GetQuotaManager()
|
||||
if qm == nil || !qm.config.Enabled || jid == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if err := qm.RecordUpload(ctx, jid, filePath, size); err != nil {
|
||||
log.WithFields(logrus.Fields{
|
||||
"jid": jid,
|
||||
"file": filePath,
|
||||
"size": size,
|
||||
"error": err,
|
||||
}).Warn("Failed to update quota after upload")
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateQuotaAfterDelete updates quota after file deletion
|
||||
func UpdateQuotaAfterDelete(ctx context.Context, jid, filePath string, size int64) {
|
||||
qm := GetQuotaManager()
|
||||
if qm == nil || !qm.config.Enabled || jid == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if err := qm.RecordDelete(ctx, jid, filePath, size); err != nil {
|
||||
log.WithFields(logrus.Fields{
|
||||
"jid": jid,
|
||||
"file": filePath,
|
||||
"size": size,
|
||||
"error": err,
|
||||
}).Warn("Failed to update quota after delete")
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultQuotaConfig returns default quota configuration
|
||||
func DefaultQuotaConfig() QuotaConfig {
|
||||
return QuotaConfig{
|
||||
Enabled: false,
|
||||
Default: "100MB",
|
||||
Tracking: "redis",
|
||||
Custom: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// StartQuotaReconciliation starts a background job to reconcile quotas
|
||||
func StartQuotaReconciliation(interval time.Duration) {
|
||||
if quotaManager == nil || !quotaManager.config.Enabled {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
log.Debug("Running quota reconciliation")
|
||||
// This would scan the storage and update quotas
|
||||
// Implementation depends on how files are tracked
|
||||
}
|
||||
}()
|
||||
}
|
||||
340
cmd/server/validation.go
Normal file
340
cmd/server/validation.go
Normal file
@@ -0,0 +1,340 @@
|
||||
// validation.go - Content type validation using magic bytes
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ValidationConfig holds content validation configuration
|
||||
type ValidationConfig struct {
|
||||
CheckMagicBytes bool `toml:"check_magic_bytes" mapstructure:"check_magic_bytes"`
|
||||
AllowedTypes []string `toml:"allowed_types" mapstructure:"allowed_types"`
|
||||
BlockedTypes []string `toml:"blocked_types" mapstructure:"blocked_types"`
|
||||
MaxFileSize string `toml:"max_file_size" mapstructure:"max_file_size"`
|
||||
StrictMode bool `toml:"strict_mode" mapstructure:"strict_mode"` // Reject if type can't be detected
|
||||
}
|
||||
|
||||
// ValidationResult contains the result of content validation
|
||||
type ValidationResult struct {
|
||||
Valid bool `json:"valid"`
|
||||
DetectedType string `json:"detected_type"`
|
||||
DeclaredType string `json:"declared_type,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
}
|
||||
|
||||
// ValidationError represents a validation failure
|
||||
type ValidationError struct {
|
||||
Code string `json:"error"`
|
||||
Message string `json:"message"`
|
||||
DetectedType string `json:"detected_type"`
|
||||
DeclaredType string `json:"declared_type,omitempty"`
|
||||
}
|
||||
|
||||
func (e *ValidationError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
// ContentValidator handles content type validation
|
||||
type ContentValidator struct {
|
||||
config *ValidationConfig
|
||||
allowedTypes map[string]bool
|
||||
blockedTypes map[string]bool
|
||||
wildcardAllow []string
|
||||
wildcardBlock []string
|
||||
}
|
||||
|
||||
var (
|
||||
contentValidator *ContentValidator
|
||||
validatorOnce sync.Once
|
||||
)
|
||||
|
||||
// InitContentValidator initializes the content validator
|
||||
func InitContentValidator(config *ValidationConfig) {
|
||||
validatorOnce.Do(func() {
|
||||
contentValidator = &ContentValidator{
|
||||
config: config,
|
||||
allowedTypes: make(map[string]bool),
|
||||
blockedTypes: make(map[string]bool),
|
||||
wildcardAllow: []string{},
|
||||
wildcardBlock: []string{},
|
||||
}
|
||||
|
||||
// Process allowed types
|
||||
for _, t := range config.AllowedTypes {
|
||||
t = strings.ToLower(strings.TrimSpace(t))
|
||||
if strings.HasSuffix(t, "/*") {
|
||||
contentValidator.wildcardAllow = append(contentValidator.wildcardAllow, strings.TrimSuffix(t, "/*"))
|
||||
} else {
|
||||
contentValidator.allowedTypes[t] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Process blocked types
|
||||
for _, t := range config.BlockedTypes {
|
||||
t = strings.ToLower(strings.TrimSpace(t))
|
||||
if strings.HasSuffix(t, "/*") {
|
||||
contentValidator.wildcardBlock = append(contentValidator.wildcardBlock, strings.TrimSuffix(t, "/*"))
|
||||
} else {
|
||||
contentValidator.blockedTypes[t] = true
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("Content validator initialized: magic_bytes=%v, allowed=%d types, blocked=%d types",
|
||||
config.CheckMagicBytes, len(config.AllowedTypes), len(config.BlockedTypes))
|
||||
})
|
||||
}
|
||||
|
||||
// GetContentValidator returns the singleton content validator
|
||||
func GetContentValidator() *ContentValidator {
|
||||
return contentValidator
|
||||
}
|
||||
|
||||
// isTypeAllowed checks if a content type is in the allowed list
|
||||
func (v *ContentValidator) isTypeAllowed(contentType string) bool {
|
||||
contentType = strings.ToLower(contentType)
|
||||
|
||||
// Extract main type (before any parameters like charset)
|
||||
if idx := strings.Index(contentType, ";"); idx != -1 {
|
||||
contentType = strings.TrimSpace(contentType[:idx])
|
||||
}
|
||||
|
||||
// If no allowed types configured, allow all (except blocked)
|
||||
if len(v.allowedTypes) == 0 && len(v.wildcardAllow) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check exact match
|
||||
if v.allowedTypes[contentType] {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check wildcard patterns
|
||||
for _, prefix := range v.wildcardAllow {
|
||||
if strings.HasPrefix(contentType, prefix+"/") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isTypeBlocked checks if a content type is in the blocked list
|
||||
func (v *ContentValidator) isTypeBlocked(contentType string) bool {
|
||||
contentType = strings.ToLower(contentType)
|
||||
|
||||
// Extract main type (before any parameters)
|
||||
if idx := strings.Index(contentType, ";"); idx != -1 {
|
||||
contentType = strings.TrimSpace(contentType[:idx])
|
||||
}
|
||||
|
||||
// Check exact match
|
||||
if v.blockedTypes[contentType] {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check wildcard patterns
|
||||
for _, prefix := range v.wildcardBlock {
|
||||
if strings.HasPrefix(contentType, prefix+"/") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ValidateContent validates the content type of a reader
|
||||
// Returns a new reader that includes the buffered bytes, the detected type, and any error
|
||||
func (v *ContentValidator) ValidateContent(reader io.Reader, declaredType string, size int64) (io.Reader, string, error) {
|
||||
if v == nil || !v.config.CheckMagicBytes {
|
||||
return reader, declaredType, nil
|
||||
}
|
||||
|
||||
// Read first 512 bytes for magic byte detection
|
||||
buf := make([]byte, 512)
|
||||
n, err := io.ReadFull(reader, buf)
|
||||
if err != nil && err != io.ErrUnexpectedEOF && err != io.EOF {
|
||||
return nil, "", fmt.Errorf("failed to read content for validation: %w", err)
|
||||
}
|
||||
|
||||
// Handle small files
|
||||
if n == 0 {
|
||||
if v.config.StrictMode {
|
||||
return nil, "", &ValidationError{
|
||||
Code: "empty_content",
|
||||
Message: "Cannot validate empty content",
|
||||
DetectedType: "",
|
||||
DeclaredType: declaredType,
|
||||
}
|
||||
}
|
||||
return bytes.NewReader(buf[:n]), declaredType, nil
|
||||
}
|
||||
|
||||
// Detect content type using magic bytes
|
||||
detectedType := http.DetectContentType(buf[:n])
|
||||
|
||||
// Normalize detected type
|
||||
if idx := strings.Index(detectedType, ";"); idx != -1 {
|
||||
detectedType = strings.TrimSpace(detectedType[:idx])
|
||||
}
|
||||
|
||||
// Check if type is blocked (highest priority)
|
||||
if v.isTypeBlocked(detectedType) {
|
||||
return nil, detectedType, &ValidationError{
|
||||
Code: "content_type_blocked",
|
||||
Message: fmt.Sprintf("File type %s is blocked", detectedType),
|
||||
DetectedType: detectedType,
|
||||
DeclaredType: declaredType,
|
||||
}
|
||||
}
|
||||
|
||||
// Check if type is allowed
|
||||
if !v.isTypeAllowed(detectedType) {
|
||||
return nil, detectedType, &ValidationError{
|
||||
Code: "content_type_rejected",
|
||||
Message: fmt.Sprintf("File type %s is not allowed", detectedType),
|
||||
DetectedType: detectedType,
|
||||
DeclaredType: declaredType,
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new reader that includes the buffered bytes
|
||||
combinedReader := io.MultiReader(bytes.NewReader(buf[:n]), reader)
|
||||
|
||||
return combinedReader, detectedType, nil
|
||||
}
|
||||
|
||||
// ValidateContentType validates a content type without reading content
|
||||
func (v *ContentValidator) ValidateContentType(contentType string) error {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if v.isTypeBlocked(contentType) {
|
||||
return &ValidationError{
|
||||
Code: "content_type_blocked",
|
||||
Message: fmt.Sprintf("File type %s is blocked", contentType),
|
||||
DetectedType: contentType,
|
||||
}
|
||||
}
|
||||
|
||||
if !v.isTypeAllowed(contentType) {
|
||||
return &ValidationError{
|
||||
Code: "content_type_rejected",
|
||||
Message: fmt.Sprintf("File type %s is not allowed", contentType),
|
||||
DetectedType: contentType,
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteValidationError writes a validation error response
|
||||
func WriteValidationError(w http.ResponseWriter, err *ValidationError) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnsupportedMediaType)
|
||||
_ = json.NewEncoder(w).Encode(err)
|
||||
}
|
||||
|
||||
// ValidateUploadContent is a helper function for validating upload content
|
||||
func ValidateUploadContent(r *http.Request, reader io.Reader, declaredType string, size int64) (io.Reader, string, error) {
|
||||
validator := GetContentValidator()
|
||||
if validator == nil || !validator.config.CheckMagicBytes {
|
||||
return reader, declaredType, nil
|
||||
}
|
||||
|
||||
newReader, detectedType, err := validator.ValidateContent(reader, declaredType, size)
|
||||
if err != nil {
|
||||
// Log validation failure to audit
|
||||
jid := r.Header.Get("X-User-JID")
|
||||
fileName := r.Header.Get("X-File-Name")
|
||||
if fileName == "" {
|
||||
fileName = "unknown"
|
||||
}
|
||||
|
||||
var reason string
|
||||
if validErr, ok := err.(*ValidationError); ok {
|
||||
reason = validErr.Code
|
||||
} else {
|
||||
reason = err.Error()
|
||||
}
|
||||
|
||||
AuditValidationFailure(r, jid, fileName, declaredType, detectedType, reason)
|
||||
|
||||
return nil, detectedType, err
|
||||
}
|
||||
|
||||
return newReader, detectedType, nil
|
||||
}
|
||||
|
||||
// DefaultValidationConfig returns default validation configuration
|
||||
func DefaultValidationConfig() ValidationConfig {
|
||||
return ValidationConfig{
|
||||
CheckMagicBytes: false,
|
||||
AllowedTypes: []string{
|
||||
"image/*",
|
||||
"video/*",
|
||||
"audio/*",
|
||||
"application/pdf",
|
||||
"text/plain",
|
||||
"text/html",
|
||||
"application/json",
|
||||
"application/xml",
|
||||
"application/zip",
|
||||
"application/x-gzip",
|
||||
"application/x-tar",
|
||||
"application/x-7z-compressed",
|
||||
"application/vnd.openxmlformats-officedocument.*",
|
||||
"application/vnd.oasis.opendocument.*",
|
||||
},
|
||||
BlockedTypes: []string{
|
||||
"application/x-executable",
|
||||
"application/x-msdos-program",
|
||||
"application/x-msdownload",
|
||||
"application/x-dosexec",
|
||||
"application/x-sh",
|
||||
"application/x-shellscript",
|
||||
},
|
||||
MaxFileSize: "100MB",
|
||||
StrictMode: false,
|
||||
}
|
||||
}
|
||||
|
||||
// Extended MIME type detection for better accuracy
|
||||
var customMagicBytes = map[string][]byte{
|
||||
"application/x-executable": {0x7f, 'E', 'L', 'F'}, // ELF
|
||||
"application/x-msdos-program": {0x4d, 0x5a}, // MZ (DOS/Windows)
|
||||
"application/pdf": {0x25, 0x50, 0x44, 0x46}, // %PDF
|
||||
"application/zip": {0x50, 0x4b, 0x03, 0x04}, // PK
|
||||
"image/png": {0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a}, // PNG
|
||||
"image/jpeg": {0xff, 0xd8, 0xff}, // JPEG
|
||||
"image/gif": {0x47, 0x49, 0x46, 0x38}, // GIF8
|
||||
"image/webp": {0x52, 0x49, 0x46, 0x46}, // RIFF (WebP starts with RIFF)
|
||||
"video/mp4": {0x00, 0x00, 0x00}, // MP4 (variable, check ftyp)
|
||||
"audio/mpeg": {0xff, 0xfb}, // MP3
|
||||
"audio/ogg": {0x4f, 0x67, 0x67, 0x53}, // OggS
|
||||
}
|
||||
|
||||
// DetectContentTypeExtended provides extended content type detection
|
||||
func DetectContentTypeExtended(data []byte) string {
|
||||
// First try standard detection
|
||||
detected := http.DetectContentType(data)
|
||||
|
||||
// If generic, try custom detection
|
||||
if detected == "application/octet-stream" {
|
||||
for mimeType, magic := range customMagicBytes {
|
||||
if len(data) >= len(magic) && bytes.Equal(data[:len(magic)], magic) {
|
||||
return mimeType
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return detected
|
||||
}
|
||||
162
templates/config-enhanced-features.toml
Normal file
162
templates/config-enhanced-features.toml
Normal file
@@ -0,0 +1,162 @@
|
||||
# HMAC File Server 3.3.0 "Nexus Infinitum" Configuration
|
||||
# Enhanced Features Template: Audit Logging, Content Validation, Quotas, Admin API
|
||||
# Generated on: January 2025
|
||||
|
||||
[server]
|
||||
listen_address = "8080"
|
||||
storage_path = "/opt/hmac-file-server/data/uploads"
|
||||
metrics_enabled = true
|
||||
metrics_port = "9090"
|
||||
pid_file = "/opt/hmac-file-server/data/hmac-file-server.pid"
|
||||
max_upload_size = "10GB"
|
||||
deduplication_enabled = true
|
||||
min_free_bytes = "1GB"
|
||||
file_naming = "original"
|
||||
enable_dynamic_workers = true
|
||||
|
||||
[security]
|
||||
secret = "CHANGE-THIS-SECRET-KEY-MINIMUM-32-CHARACTERS"
|
||||
enablejwt = false
|
||||
|
||||
[uploads]
|
||||
allowedextensions = [".txt", ".pdf", ".jpg", ".jpeg", ".png", ".gif", ".webp", ".zip", ".tar", ".gz", ".7z", ".mp4", ".webm", ".ogg", ".mp3", ".wav", ".flac", ".doc", ".docx", ".xls", ".xlsx", ".ppt", ".pptx", ".odt", ".ods", ".odp"]
|
||||
maxfilesize = "100MB"
|
||||
chunkeduploadsenabled = true
|
||||
chunksize = "10MB"
|
||||
networkevents = true
|
||||
|
||||
[downloads]
|
||||
chunkeddownloadsenabled = true
|
||||
chunksize = "10MB"
|
||||
|
||||
[logging]
|
||||
level = "INFO"
|
||||
file = "/opt/hmac-file-server/data/logs/hmac-file-server.log"
|
||||
max_size = 100
|
||||
max_backups = 3
|
||||
max_age = 30
|
||||
compress = true
|
||||
|
||||
[workers]
|
||||
numworkers = 10
|
||||
uploadqueuesize = 1000
|
||||
autoscaling = true
|
||||
|
||||
[timeouts]
|
||||
readtimeout = "30s"
|
||||
writetimeout = "30s"
|
||||
idletimeout = "120s"
|
||||
shutdown = "30s"
|
||||
|
||||
[clamav]
|
||||
enabled = false
|
||||
|
||||
[redis]
|
||||
enabled = true
|
||||
address = "127.0.0.1:6379"
|
||||
db = 0
|
||||
|
||||
# ============================================
|
||||
# NEW ENHANCED FEATURES (v3.3.0)
|
||||
# ============================================
|
||||
|
||||
# Security Audit Logging
|
||||
# Records security-relevant events for compliance and forensics
|
||||
[audit]
|
||||
enabled = true
|
||||
output = "file" # "file" or "stdout"
|
||||
path = "/var/log/hmac-audit.log" # Log file path (when output = "file")
|
||||
format = "json" # "json" or "text"
|
||||
max_size = 100 # Max size in MB before rotation
|
||||
max_age = 30 # Max age in days
|
||||
events = [
|
||||
"upload", # Log all file uploads
|
||||
"download", # Log all file downloads
|
||||
"delete", # Log file deletions
|
||||
"auth_success", # Log successful authentications
|
||||
"auth_failure", # Log failed authentications
|
||||
"rate_limited", # Log rate limiting events
|
||||
"banned", # Log ban events
|
||||
"quota_exceeded", # Log quota exceeded events
|
||||
"validation_failure" # Log content validation failures
|
||||
]
|
||||
|
||||
# Magic Bytes Content Validation
|
||||
# Validates uploaded file content types using magic bytes detection
|
||||
[validation]
|
||||
check_magic_bytes = true # Enable magic bytes validation
|
||||
strict_mode = false # Strict mode rejects mismatched types
|
||||
max_peek_size = 65536 # Bytes to read for detection (64KB)
|
||||
|
||||
# Allowed content types (supports wildcards like "image/*")
|
||||
# If empty, all types are allowed (except blocked)
|
||||
allowed_types = [
|
||||
"image/*", # All image types
|
||||
"video/*", # All video types
|
||||
"audio/*", # All audio types
|
||||
"text/plain", # Plain text
|
||||
"application/pdf", # PDF documents
|
||||
"application/zip", # ZIP archives
|
||||
"application/gzip", # GZIP archives
|
||||
"application/x-tar", # TAR archives
|
||||
"application/x-7z-compressed", # 7-Zip archives
|
||||
"application/vnd.openxmlformats-officedocument.*", # MS Office docs
|
||||
"application/vnd.oasis.opendocument.*" # LibreOffice docs
|
||||
]
|
||||
|
||||
# Blocked content types (takes precedence over allowed)
|
||||
blocked_types = [
|
||||
"application/x-executable", # Executable files
|
||||
"application/x-msdos-program", # DOS executables
|
||||
"application/x-msdownload", # Windows executables
|
||||
"application/x-elf", # ELF binaries
|
||||
"application/x-shellscript", # Shell scripts
|
||||
"application/javascript", # JavaScript files
|
||||
"text/html", # HTML files (potential XSS)
|
||||
"application/x-php" # PHP files
|
||||
]
|
||||
|
||||
# Per-User Storage Quotas
|
||||
# Track and enforce storage limits per XMPP JID
|
||||
[quotas]
|
||||
enabled = true # Enable quota enforcement
|
||||
default = "100MB" # Default quota for all users
|
||||
tracking = "redis" # "redis" or "memory"
|
||||
|
||||
# Custom quotas per user (JID -> quota)
|
||||
[quotas.custom]
|
||||
"admin@example.com" = "10GB" # Admin gets 10GB
|
||||
"premium@example.com" = "1GB" # Premium user gets 1GB
|
||||
"vip@example.com" = "5GB" # VIP user gets 5GB
|
||||
|
||||
# Admin API for Operations and Monitoring
|
||||
# Protected endpoints for system management
|
||||
[admin]
|
||||
enabled = true # Enable admin API
|
||||
path_prefix = "/admin" # URL prefix for admin endpoints
|
||||
|
||||
# Available endpoints (when enabled):
|
||||
# GET /admin/stats - Server statistics and metrics
|
||||
# GET /admin/files - List all uploaded files
|
||||
# GET /admin/files/:id - Get file details
|
||||
# DEL /admin/files/:id - Delete a file
|
||||
# GET /admin/users - List users and quota usage
|
||||
# GET /admin/users/:jid - Get user details and quota
|
||||
# POST /admin/users/:jid/quota - Set user quota
|
||||
# GET /admin/bans - List banned IPs/users
|
||||
# POST /admin/bans - Ban an IP or user
|
||||
# DEL /admin/bans/:id - Unban
|
||||
|
||||
# Admin authentication
|
||||
[admin.auth]
|
||||
type = "bearer" # "bearer" or "basic"
|
||||
token = "${ADMIN_TOKEN}" # Bearer token (from environment variable)
|
||||
# For basic auth:
|
||||
# type = "basic"
|
||||
# username = "admin"
|
||||
# password_hash = "$2a$12$..." # bcrypt hash
|
||||
|
||||
# Rate limiting for admin endpoints
|
||||
[admin.rate_limit]
|
||||
enabled = true
|
||||
requests_per_minute = 60 # Max requests per minute per IP
|
||||
Reference in New Issue
Block a user