feat: Add audit logging, magic bytes validation, per-user quotas, and admin API
All checks were successful
CI/CD / Test (push) Successful in 31s
CI/CD / Lint (push) Successful in 42s
CI/CD / Generate SBOM (push) Successful in 17s
CI/CD / Build (darwin-amd64) (push) Successful in 22s
CI/CD / Build (linux-amd64) (push) Successful in 22s
CI/CD / Build (darwin-arm64) (push) Successful in 23s
CI/CD / Build (linux-arm64) (push) Successful in 22s
CI/CD / Build & Push Docker Image (push) Successful in 22s
CI/CD / Mirror to GitHub (push) Successful in 16s
CI/CD / Release (push) Has been skipped
All checks were successful
CI/CD / Test (push) Successful in 31s
CI/CD / Lint (push) Successful in 42s
CI/CD / Generate SBOM (push) Successful in 17s
CI/CD / Build (darwin-amd64) (push) Successful in 22s
CI/CD / Build (linux-amd64) (push) Successful in 22s
CI/CD / Build (darwin-arm64) (push) Successful in 23s
CI/CD / Build (linux-arm64) (push) Successful in 22s
CI/CD / Build & Push Docker Image (push) Successful in 22s
CI/CD / Mirror to GitHub (push) Successful in 16s
CI/CD / Release (push) Has been skipped
New features in v3.3.0: - audit.go: Security audit logging with JSON/text format, log rotation - validation.go: Magic bytes content validation with wildcard patterns - quota.go: Per-user storage quotas with Redis/memory tracking - admin.go: Admin API for stats, file management, user quotas, bans Integration: - Updated main.go with feature initialization and handler integration - Added audit logging for auth success/failure, uploads, downloads - Added quota checking before upload, tracking after successful upload - Added content validation with magic bytes detection Config: - New template: config-enhanced-features.toml with all new options - Updated README.md with feature documentation
This commit is contained in:
82
README.md
82
README.md
@@ -8,13 +8,22 @@ A high-performance, secure file server implementing XEP-0363 (HTTP File Upload)
|
|||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
|
### Core Features
|
||||||
- XEP-0363 HTTP File Upload compliance
|
- XEP-0363 HTTP File Upload compliance
|
||||||
- HMAC-based authentication
|
- HMAC-based authentication with JWT support
|
||||||
- File deduplication
|
- File deduplication (SHA256 with hardlinks)
|
||||||
- Multi-architecture support (AMD64, ARM64, ARM32v7)
|
- Multi-architecture support (AMD64, ARM64, ARM32v7)
|
||||||
- Docker and Podman deployment
|
- Docker and Podman deployment
|
||||||
- XMPP client compatibility (Dino, Gajim, Conversations, Monal, Converse.js)
|
- XMPP client compatibility (Dino, Gajim, Conversations, Monal, Converse.js)
|
||||||
- Network resilience for mobile clients
|
- Network resilience for mobile clients (WiFi/LTE switching)
|
||||||
|
|
||||||
|
### Security Features (v3.3.0)
|
||||||
|
- **Audit Logging** - Comprehensive security event logging (uploads, downloads, auth events)
|
||||||
|
- **Magic Bytes Validation** - Content type verification using file signatures
|
||||||
|
- **Per-User Quotas** - Storage limits per XMPP JID with Redis tracking
|
||||||
|
- **Admin API** - Protected endpoints for system management and monitoring
|
||||||
|
- **ClamAV Integration** - Antivirus scanning for uploaded files
|
||||||
|
- **Rate Limiting** - Configurable request throttling
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
@@ -90,16 +99,19 @@ secret = "your-hmac-secret-key"
|
|||||||
| Section | Description |
|
| Section | Description |
|
||||||
|---------|-------------|
|
|---------|-------------|
|
||||||
| `[server]` | Bind address, port, storage path, timeouts |
|
| `[server]` | Bind address, port, storage path, timeouts |
|
||||||
| `[security]` | HMAC secret, TLS settings |
|
| `[security]` | HMAC secret, JWT, TLS settings |
|
||||||
| `[uploads]` | Size limits, allowed extensions |
|
| `[uploads]` | Size limits, allowed extensions |
|
||||||
| `[downloads]` | Download settings, bandwidth limits |
|
| `[downloads]` | Download settings, bandwidth limits |
|
||||||
| `[logging]` | Log file, log level |
|
| `[logging]` | Log file, log level |
|
||||||
| `[clamav]` | Antivirus scanning integration |
|
| `[clamav]` | Antivirus scanning integration |
|
||||||
| `[redis]` | Redis caching backend |
|
| `[redis]` | Redis caching backend |
|
||||||
| `[deduplication]` | File deduplication settings |
|
| `[audit]` | Security audit logging |
|
||||||
|
| `[validation]` | Magic bytes content validation |
|
||||||
|
| `[quotas]` | Per-user storage quotas |
|
||||||
|
| `[admin]` | Admin API configuration |
|
||||||
| `[workers]` | Worker pool configuration |
|
| `[workers]` | Worker pool configuration |
|
||||||
|
|
||||||
See [examples/](examples/) for complete configuration templates.
|
See [templates/](templates/) for complete configuration templates.
|
||||||
|
|
||||||
## XMPP Server Integration
|
## XMPP Server Integration
|
||||||
|
|
||||||
@@ -168,6 +180,64 @@ token = HMAC-SHA256(secret, filename + filesize + timestamp)
|
|||||||
| `/download/...` | GET | File download |
|
| `/download/...` | GET | File download |
|
||||||
| `/health` | GET | Health check |
|
| `/health` | GET | Health check |
|
||||||
| `/metrics` | GET | Prometheus metrics |
|
| `/metrics` | GET | Prometheus metrics |
|
||||||
|
| `/admin/stats` | GET | Server statistics (auth required) |
|
||||||
|
| `/admin/files` | GET | List uploaded files (auth required) |
|
||||||
|
| `/admin/users` | GET | User quota information (auth required) |
|
||||||
|
|
||||||
|
## Enhanced Features (v3.3.0)
|
||||||
|
|
||||||
|
### Audit Logging
|
||||||
|
|
||||||
|
Security-focused logging for compliance and forensics:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[audit]
|
||||||
|
enabled = true
|
||||||
|
output = "file"
|
||||||
|
path = "/var/log/hmac-audit.log"
|
||||||
|
format = "json"
|
||||||
|
events = ["upload", "download", "auth_failure", "quota_exceeded"]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Content Validation
|
||||||
|
|
||||||
|
Magic bytes validation to verify file types:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[validation]
|
||||||
|
check_magic_bytes = true
|
||||||
|
allowed_types = ["image/*", "video/*", "audio/*", "application/pdf"]
|
||||||
|
blocked_types = ["application/x-executable", "application/x-shellscript"]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Per-User Quotas
|
||||||
|
|
||||||
|
Storage limits per XMPP JID with Redis tracking:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[quotas]
|
||||||
|
enabled = true
|
||||||
|
default = "100MB"
|
||||||
|
tracking = "redis"
|
||||||
|
|
||||||
|
[quotas.custom]
|
||||||
|
"admin@example.com" = "10GB"
|
||||||
|
"premium@example.com" = "1GB"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Admin API
|
||||||
|
|
||||||
|
Protected management endpoints:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[admin]
|
||||||
|
enabled = true
|
||||||
|
path_prefix = "/admin"
|
||||||
|
|
||||||
|
[admin.auth]
|
||||||
|
type = "bearer"
|
||||||
|
token = "${ADMIN_TOKEN}"
|
||||||
|
```
|
||||||
|
|
||||||
## System Requirements
|
## System Requirements
|
||||||
|
|
||||||
|
|||||||
756
cmd/server/admin.go
Normal file
756
cmd/server/admin.go
Normal file
@@ -0,0 +1,756 @@
|
|||||||
|
// admin.go - Admin API for operations and monitoring
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io/fs"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AdminConfig holds admin API configuration
|
||||||
|
type AdminConfig struct {
|
||||||
|
Enabled bool `toml:"enabled" mapstructure:"enabled"`
|
||||||
|
Bind string `toml:"bind" mapstructure:"bind"` // Separate bind address (e.g., "127.0.0.1:8081")
|
||||||
|
PathPrefix string `toml:"path_prefix" mapstructure:"path_prefix"` // Path prefix (e.g., "/admin")
|
||||||
|
Auth AdminAuthConfig `toml:"auth" mapstructure:"auth"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdminAuthConfig holds admin authentication configuration
|
||||||
|
type AdminAuthConfig struct {
|
||||||
|
Type string `toml:"type" mapstructure:"type"` // "bearer" | "basic"
|
||||||
|
Token string `toml:"token" mapstructure:"token"` // For bearer auth
|
||||||
|
Username string `toml:"username" mapstructure:"username"` // For basic auth
|
||||||
|
Password string `toml:"password" mapstructure:"password"` // For basic auth
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdminStats represents system statistics
|
||||||
|
type AdminStats struct {
|
||||||
|
Storage StorageStats `json:"storage"`
|
||||||
|
Users UserStats `json:"users"`
|
||||||
|
Requests RequestStats `json:"requests"`
|
||||||
|
System SystemStats `json:"system"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// StorageStats represents storage statistics
|
||||||
|
type StorageStats struct {
|
||||||
|
UsedBytes int64 `json:"used_bytes"`
|
||||||
|
UsedHuman string `json:"used_human"`
|
||||||
|
FileCount int64 `json:"file_count"`
|
||||||
|
FreeBytes int64 `json:"free_bytes,omitempty"`
|
||||||
|
FreeHuman string `json:"free_human,omitempty"`
|
||||||
|
TotalBytes int64 `json:"total_bytes,omitempty"`
|
||||||
|
TotalHuman string `json:"total_human,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserStats represents user statistics
|
||||||
|
type UserStats struct {
|
||||||
|
Total int64 `json:"total"`
|
||||||
|
Active24h int64 `json:"active_24h"`
|
||||||
|
Active7d int64 `json:"active_7d"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestStats represents request statistics
|
||||||
|
type RequestStats struct {
|
||||||
|
Uploads24h int64 `json:"uploads_24h"`
|
||||||
|
Downloads24h int64 `json:"downloads_24h"`
|
||||||
|
Errors24h int64 `json:"errors_24h"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SystemStats represents system statistics
|
||||||
|
type SystemStats struct {
|
||||||
|
Uptime string `json:"uptime"`
|
||||||
|
Version string `json:"version"`
|
||||||
|
GoVersion string `json:"go_version"`
|
||||||
|
NumGoroutines int `json:"num_goroutines"`
|
||||||
|
MemoryUsageMB int64 `json:"memory_usage_mb"`
|
||||||
|
NumCPU int `json:"num_cpu"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FileInfo represents file information for admin API
|
||||||
|
type FileInfo struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Path string `json:"path"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Size int64 `json:"size"`
|
||||||
|
SizeHuman string `json:"size_human"`
|
||||||
|
ContentType string `json:"content_type"`
|
||||||
|
ModTime time.Time `json:"mod_time"`
|
||||||
|
Owner string `json:"owner,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FileListResponse represents paginated file list
|
||||||
|
type FileListResponse struct {
|
||||||
|
Files []FileInfo `json:"files"`
|
||||||
|
Total int64 `json:"total"`
|
||||||
|
Page int `json:"page"`
|
||||||
|
Limit int `json:"limit"`
|
||||||
|
TotalPages int `json:"total_pages"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserInfo represents user information for admin API
|
||||||
|
type UserInfo struct {
|
||||||
|
JID string `json:"jid"`
|
||||||
|
QuotaUsed int64 `json:"quota_used"`
|
||||||
|
QuotaLimit int64 `json:"quota_limit"`
|
||||||
|
FileCount int64 `json:"file_count"`
|
||||||
|
LastActive time.Time `json:"last_active,omitempty"`
|
||||||
|
IsBanned bool `json:"is_banned"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BanInfo represents ban information
|
||||||
|
type BanInfo struct {
|
||||||
|
IP string `json:"ip"`
|
||||||
|
Reason string `json:"reason"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
ExpiresAt time.Time `json:"expires_at,omitempty"`
|
||||||
|
IsPermanent bool `json:"is_permanent"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
serverStartTime = time.Now()
|
||||||
|
adminConfig *AdminConfig
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetupAdminRoutes sets up admin API routes
|
||||||
|
func SetupAdminRoutes(mux *http.ServeMux, config *AdminConfig) {
|
||||||
|
adminConfig = config
|
||||||
|
|
||||||
|
if !config.Enabled {
|
||||||
|
log.Info("Admin API is disabled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix := config.PathPrefix
|
||||||
|
if prefix == "" {
|
||||||
|
prefix = "/admin"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrap all admin handlers with authentication
|
||||||
|
adminMux := http.NewServeMux()
|
||||||
|
|
||||||
|
adminMux.HandleFunc(prefix+"/stats", handleAdminStats)
|
||||||
|
adminMux.HandleFunc(prefix+"/files", handleAdminFiles)
|
||||||
|
adminMux.HandleFunc(prefix+"/files/", handleAdminFileByID)
|
||||||
|
adminMux.HandleFunc(prefix+"/users", handleAdminUsers)
|
||||||
|
adminMux.HandleFunc(prefix+"/users/", handleAdminUserByJID)
|
||||||
|
adminMux.HandleFunc(prefix+"/bans", handleAdminBans)
|
||||||
|
adminMux.HandleFunc(prefix+"/bans/", handleAdminBanByIP)
|
||||||
|
adminMux.HandleFunc(prefix+"/health", handleAdminHealth)
|
||||||
|
adminMux.HandleFunc(prefix+"/config", handleAdminConfig)
|
||||||
|
|
||||||
|
// Register with authentication middleware
|
||||||
|
mux.Handle(prefix+"/", AdminAuthMiddleware(adminMux))
|
||||||
|
|
||||||
|
log.Infof("Admin API enabled at %s (auth: %s)", prefix, config.Auth.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdminAuthMiddleware handles admin authentication
|
||||||
|
func AdminAuthMiddleware(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if adminConfig == nil || !adminConfig.Enabled {
|
||||||
|
http.Error(w, "Admin API disabled", http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
authorized := false
|
||||||
|
|
||||||
|
switch adminConfig.Auth.Type {
|
||||||
|
case "bearer":
|
||||||
|
auth := r.Header.Get("Authorization")
|
||||||
|
if strings.HasPrefix(auth, "Bearer ") {
|
||||||
|
token := strings.TrimPrefix(auth, "Bearer ")
|
||||||
|
authorized = token == adminConfig.Auth.Token
|
||||||
|
}
|
||||||
|
case "basic":
|
||||||
|
username, password, ok := r.BasicAuth()
|
||||||
|
if ok {
|
||||||
|
authorized = username == adminConfig.Auth.Username &&
|
||||||
|
password == adminConfig.Auth.Password
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// No auth configured, check if request is from localhost
|
||||||
|
clientIP := getClientIP(r)
|
||||||
|
authorized = clientIP == "127.0.0.1" || clientIP == "::1"
|
||||||
|
}
|
||||||
|
|
||||||
|
if !authorized {
|
||||||
|
AuditEvent("admin_auth_failure", r, nil)
|
||||||
|
w.Header().Set("WWW-Authenticate", `Bearer realm="admin"`)
|
||||||
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleAdminStats returns system statistics
|
||||||
|
func handleAdminStats(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
AuditAdminAction(r, "get_stats", "system", nil)
|
||||||
|
|
||||||
|
ctx := r.Context()
|
||||||
|
stats := AdminStats{}
|
||||||
|
|
||||||
|
// Storage stats
|
||||||
|
confMutex.RLock()
|
||||||
|
storagePath := conf.Server.StoragePath
|
||||||
|
confMutex.RUnlock()
|
||||||
|
|
||||||
|
storageStats := calculateStorageStats(storagePath)
|
||||||
|
stats.Storage = storageStats
|
||||||
|
|
||||||
|
// User stats
|
||||||
|
stats.Users = calculateUserStats(ctx)
|
||||||
|
|
||||||
|
// Request stats from Prometheus metrics
|
||||||
|
stats.Requests = calculateRequestStats()
|
||||||
|
|
||||||
|
// System stats
|
||||||
|
var mem runtime.MemStats
|
||||||
|
runtime.ReadMemStats(&mem)
|
||||||
|
|
||||||
|
stats.System = SystemStats{
|
||||||
|
Uptime: time.Since(serverStartTime).Round(time.Second).String(),
|
||||||
|
Version: "3.3.0",
|
||||||
|
GoVersion: runtime.Version(),
|
||||||
|
NumGoroutines: runtime.NumGoroutine(),
|
||||||
|
MemoryUsageMB: int64(mem.Alloc / 1024 / 1024),
|
||||||
|
NumCPU: runtime.NumCPU(),
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSONResponseAdmin(w, stats)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleAdminFiles handles file listing
|
||||||
|
func handleAdminFiles(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodGet:
|
||||||
|
listFiles(w, r)
|
||||||
|
default:
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// listFiles returns paginated file list
|
||||||
|
func listFiles(w http.ResponseWriter, r *http.Request) {
|
||||||
|
AuditAdminAction(r, "list_files", "files", nil)
|
||||||
|
|
||||||
|
// Parse query parameters
|
||||||
|
page, _ := strconv.Atoi(r.URL.Query().Get("page"))
|
||||||
|
if page < 1 {
|
||||||
|
page = 1
|
||||||
|
}
|
||||||
|
limit, _ := strconv.Atoi(r.URL.Query().Get("limit"))
|
||||||
|
if limit < 1 || limit > 100 {
|
||||||
|
limit = 50
|
||||||
|
}
|
||||||
|
sortBy := r.URL.Query().Get("sort")
|
||||||
|
if sortBy == "" {
|
||||||
|
sortBy = "date"
|
||||||
|
}
|
||||||
|
filterOwner := r.URL.Query().Get("owner")
|
||||||
|
|
||||||
|
confMutex.RLock()
|
||||||
|
storagePath := conf.Server.StoragePath
|
||||||
|
confMutex.RUnlock()
|
||||||
|
|
||||||
|
var files []FileInfo
|
||||||
|
|
||||||
|
err := filepath.WalkDir(storagePath, func(path string, d fs.DirEntry, err error) error {
|
||||||
|
if err != nil {
|
||||||
|
return nil // Skip errors
|
||||||
|
}
|
||||||
|
if d.IsDir() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := d.Info()
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
relPath, _ := filepath.Rel(storagePath, path)
|
||||||
|
|
||||||
|
fileInfo := FileInfo{
|
||||||
|
ID: relPath,
|
||||||
|
Path: relPath,
|
||||||
|
Name: filepath.Base(path),
|
||||||
|
Size: info.Size(),
|
||||||
|
SizeHuman: formatBytes(info.Size()),
|
||||||
|
ContentType: GetContentType(path),
|
||||||
|
ModTime: info.ModTime(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply owner filter if specified (simplified: would need metadata lookup)
|
||||||
|
_ = filterOwner // Unused for now, but kept for future implementation
|
||||||
|
|
||||||
|
files = append(files, fileInfo)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Error listing files: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort files
|
||||||
|
switch sortBy {
|
||||||
|
case "date":
|
||||||
|
sort.Slice(files, func(i, j int) bool {
|
||||||
|
return files[i].ModTime.After(files[j].ModTime)
|
||||||
|
})
|
||||||
|
case "size":
|
||||||
|
sort.Slice(files, func(i, j int) bool {
|
||||||
|
return files[i].Size > files[j].Size
|
||||||
|
})
|
||||||
|
case "name":
|
||||||
|
sort.Slice(files, func(i, j int) bool {
|
||||||
|
return files[i].Name < files[j].Name
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Paginate
|
||||||
|
total := len(files)
|
||||||
|
start := (page - 1) * limit
|
||||||
|
end := start + limit
|
||||||
|
if start > total {
|
||||||
|
start = total
|
||||||
|
}
|
||||||
|
if end > total {
|
||||||
|
end = total
|
||||||
|
}
|
||||||
|
|
||||||
|
response := FileListResponse{
|
||||||
|
Files: files[start:end],
|
||||||
|
Total: int64(total),
|
||||||
|
Page: page,
|
||||||
|
Limit: limit,
|
||||||
|
TotalPages: (total + limit - 1) / limit,
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSONResponseAdmin(w, response)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleAdminFileByID handles single file operations
|
||||||
|
func handleAdminFileByID(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Extract file ID from path
|
||||||
|
prefix := adminConfig.PathPrefix
|
||||||
|
if prefix == "" {
|
||||||
|
prefix = "/admin"
|
||||||
|
}
|
||||||
|
fileID := strings.TrimPrefix(r.URL.Path, prefix+"/files/")
|
||||||
|
|
||||||
|
if fileID == "" {
|
||||||
|
http.Error(w, "File ID required", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodGet:
|
||||||
|
getFileInfo(w, r, fileID)
|
||||||
|
case http.MethodDelete:
|
||||||
|
deleteFile(w, r, fileID)
|
||||||
|
default:
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getFileInfo returns information about a specific file
|
||||||
|
func getFileInfo(w http.ResponseWriter, r *http.Request, fileID string) {
|
||||||
|
confMutex.RLock()
|
||||||
|
storagePath := conf.Server.StoragePath
|
||||||
|
confMutex.RUnlock()
|
||||||
|
|
||||||
|
filePath := filepath.Join(storagePath, fileID)
|
||||||
|
|
||||||
|
// Validate path is within storage
|
||||||
|
absPath, err := filepath.Abs(filePath)
|
||||||
|
if err != nil || !strings.HasPrefix(absPath, storagePath) {
|
||||||
|
http.Error(w, "Invalid file ID", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := os.Stat(filePath)
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
http.Error(w, "File not found", http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Error accessing file: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fileInfo := FileInfo{
|
||||||
|
ID: fileID,
|
||||||
|
Path: fileID,
|
||||||
|
Name: filepath.Base(filePath),
|
||||||
|
Size: info.Size(),
|
||||||
|
SizeHuman: formatBytes(info.Size()),
|
||||||
|
ContentType: GetContentType(filePath),
|
||||||
|
ModTime: info.ModTime(),
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSONResponseAdmin(w, fileInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
// deleteFile deletes a specific file
|
||||||
|
func deleteFile(w http.ResponseWriter, r *http.Request, fileID string) {
|
||||||
|
confMutex.RLock()
|
||||||
|
storagePath := conf.Server.StoragePath
|
||||||
|
confMutex.RUnlock()
|
||||||
|
|
||||||
|
filePath := filepath.Join(storagePath, fileID)
|
||||||
|
|
||||||
|
// Validate path is within storage
|
||||||
|
absPath, err := filepath.Abs(filePath)
|
||||||
|
if err != nil || !strings.HasPrefix(absPath, storagePath) {
|
||||||
|
http.Error(w, "Invalid file ID", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get file info before deletion for audit
|
||||||
|
info, err := os.Stat(filePath)
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
http.Error(w, "File not found", http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
AuditAdminAction(r, "delete_file", fileID, map[string]interface{}{
|
||||||
|
"size": info.Size(),
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := os.Remove(filePath); err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Error deleting file: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleAdminUsers handles user listing
|
||||||
|
func handleAdminUsers(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
AuditAdminAction(r, "list_users", "users", nil)
|
||||||
|
|
||||||
|
ctx := r.Context()
|
||||||
|
qm := GetQuotaManager()
|
||||||
|
|
||||||
|
var users []UserInfo
|
||||||
|
|
||||||
|
if qm != nil && qm.config.Enabled {
|
||||||
|
quotas, err := qm.GetAllQuotas(ctx)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Error getting quotas: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, quota := range quotas {
|
||||||
|
users = append(users, UserInfo{
|
||||||
|
JID: quota.JID,
|
||||||
|
QuotaUsed: quota.Used,
|
||||||
|
QuotaLimit: quota.Limit,
|
||||||
|
FileCount: quota.FileCount,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSONResponseAdmin(w, users)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleAdminUserByJID handles single user operations
|
||||||
|
func handleAdminUserByJID(w http.ResponseWriter, r *http.Request) {
|
||||||
|
prefix := adminConfig.PathPrefix
|
||||||
|
if prefix == "" {
|
||||||
|
prefix = "/admin"
|
||||||
|
}
|
||||||
|
|
||||||
|
path := strings.TrimPrefix(r.URL.Path, prefix+"/users/")
|
||||||
|
parts := strings.Split(path, "/")
|
||||||
|
jid := parts[0]
|
||||||
|
|
||||||
|
if jid == "" {
|
||||||
|
http.Error(w, "JID required", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for sub-paths
|
||||||
|
if len(parts) > 1 {
|
||||||
|
switch parts[1] {
|
||||||
|
case "files":
|
||||||
|
handleUserFiles(w, r, jid)
|
||||||
|
return
|
||||||
|
case "quota":
|
||||||
|
handleUserQuota(w, r, jid)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodGet:
|
||||||
|
getUserInfo(w, r, jid)
|
||||||
|
default:
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getUserInfo returns information about a specific user
|
||||||
|
func getUserInfo(w http.ResponseWriter, r *http.Request, jid string) {
|
||||||
|
ctx := r.Context()
|
||||||
|
qm := GetQuotaManager()
|
||||||
|
|
||||||
|
if qm == nil || !qm.config.Enabled {
|
||||||
|
http.Error(w, "Quota tracking not enabled", http.StatusNotImplemented)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
quota, err := qm.GetQuotaInfo(ctx, jid)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Error getting quota: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user := UserInfo{
|
||||||
|
JID: jid,
|
||||||
|
QuotaUsed: quota.Used,
|
||||||
|
QuotaLimit: quota.Limit,
|
||||||
|
FileCount: quota.FileCount,
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSONResponseAdmin(w, user)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleUserFiles handles user file operations
|
||||||
|
func handleUserFiles(w http.ResponseWriter, r *http.Request, jid string) {
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodGet:
|
||||||
|
// List user's files
|
||||||
|
AuditAdminAction(r, "list_user_files", jid, nil)
|
||||||
|
// Would need file ownership tracking to implement fully
|
||||||
|
writeJSONResponseAdmin(w, []FileInfo{})
|
||||||
|
case http.MethodDelete:
|
||||||
|
// Delete all user's files
|
||||||
|
AuditAdminAction(r, "delete_user_files", jid, nil)
|
||||||
|
// Would need file ownership tracking to implement fully
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
default:
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleUserQuota handles user quota operations
|
||||||
|
func handleUserQuota(w http.ResponseWriter, r *http.Request, jid string) {
|
||||||
|
qm := GetQuotaManager()
|
||||||
|
if qm == nil {
|
||||||
|
http.Error(w, "Quota management not enabled", http.StatusNotImplemented)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodPost:
|
||||||
|
// Set custom quota
|
||||||
|
var req struct {
|
||||||
|
Quota string `json:"quota"`
|
||||||
|
}
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
quota, err := parseSize(req.Quota)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Invalid quota: %v", err), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
qm.SetCustomQuota(jid, quota)
|
||||||
|
AuditAdminAction(r, "set_quota", jid, map[string]interface{}{"quota": req.Quota})
|
||||||
|
|
||||||
|
writeJSONResponseAdmin(w, map[string]interface{}{
|
||||||
|
"success": true,
|
||||||
|
"jid": jid,
|
||||||
|
"quota": quota,
|
||||||
|
})
|
||||||
|
case http.MethodDelete:
|
||||||
|
// Remove custom quota
|
||||||
|
qm.RemoveCustomQuota(jid)
|
||||||
|
AuditAdminAction(r, "remove_quota", jid, nil)
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
default:
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleAdminBans handles ban listing
|
||||||
|
func handleAdminBans(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
AuditAdminAction(r, "list_bans", "bans", nil)
|
||||||
|
|
||||||
|
// Would need ban management implementation
|
||||||
|
writeJSONResponseAdmin(w, []BanInfo{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleAdminBanByIP handles single ban operations
|
||||||
|
func handleAdminBanByIP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
prefix := adminConfig.PathPrefix
|
||||||
|
if prefix == "" {
|
||||||
|
prefix = "/admin"
|
||||||
|
}
|
||||||
|
ip := strings.TrimPrefix(r.URL.Path, prefix+"/bans/")
|
||||||
|
|
||||||
|
if ip == "" {
|
||||||
|
http.Error(w, "IP required", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodDelete:
|
||||||
|
// Unban IP
|
||||||
|
AuditAdminAction(r, "unban", ip, nil)
|
||||||
|
// Would need ban management implementation
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
default:
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleAdminHealth returns admin-specific health info
|
||||||
|
func handleAdminHealth(w http.ResponseWriter, r *http.Request) {
|
||||||
|
health := map[string]interface{}{
|
||||||
|
"status": "healthy",
|
||||||
|
"timestamp": time.Now().UTC(),
|
||||||
|
"uptime": time.Since(serverStartTime).String(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check Redis
|
||||||
|
if redisClient != nil && redisConnected {
|
||||||
|
health["redis"] = "connected"
|
||||||
|
} else if redisClient != nil {
|
||||||
|
health["redis"] = "disconnected"
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSONResponseAdmin(w, health)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleAdminConfig returns current configuration (sanitized)
|
||||||
|
func handleAdminConfig(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
AuditAdminAction(r, "get_config", "config", nil)
|
||||||
|
|
||||||
|
confMutex.RLock()
|
||||||
|
// Return sanitized config (no secrets)
|
||||||
|
sanitized := map[string]interface{}{
|
||||||
|
"server": map[string]interface{}{
|
||||||
|
"listen_address": conf.Server.ListenAddress,
|
||||||
|
"storage_path": conf.Server.StoragePath,
|
||||||
|
"max_upload_size": conf.Server.MaxUploadSize,
|
||||||
|
"metrics_enabled": conf.Server.MetricsEnabled,
|
||||||
|
},
|
||||||
|
"security": map[string]interface{}{
|
||||||
|
"enhanced_security": conf.Security.EnhancedSecurity,
|
||||||
|
"jwt_enabled": conf.Security.EnableJWT,
|
||||||
|
},
|
||||||
|
"clamav": map[string]interface{}{
|
||||||
|
"enabled": conf.ClamAV.ClamAVEnabled,
|
||||||
|
},
|
||||||
|
"redis": map[string]interface{}{
|
||||||
|
"enabled": conf.Redis.RedisEnabled,
|
||||||
|
},
|
||||||
|
"deduplication": map[string]interface{}{
|
||||||
|
"enabled": conf.Deduplication.Enabled,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
confMutex.RUnlock()
|
||||||
|
|
||||||
|
writeJSONResponseAdmin(w, sanitized)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper functions
|
||||||
|
|
||||||
|
func calculateStorageStats(storagePath string) StorageStats {
|
||||||
|
var totalSize int64
|
||||||
|
var fileCount int64
|
||||||
|
|
||||||
|
_ = filepath.WalkDir(storagePath, func(path string, d fs.DirEntry, err error) error {
|
||||||
|
if err != nil || d.IsDir() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if info, err := d.Info(); err == nil {
|
||||||
|
totalSize += info.Size()
|
||||||
|
fileCount++
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
return StorageStats{
|
||||||
|
UsedBytes: totalSize,
|
||||||
|
UsedHuman: formatBytes(totalSize),
|
||||||
|
FileCount: fileCount,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func calculateUserStats(ctx context.Context) UserStats {
|
||||||
|
qm := GetQuotaManager()
|
||||||
|
if qm == nil || !qm.config.Enabled {
|
||||||
|
return UserStats{}
|
||||||
|
}
|
||||||
|
|
||||||
|
quotas, err := qm.GetAllQuotas(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return UserStats{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return UserStats{
|
||||||
|
Total: int64(len(quotas)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func calculateRequestStats() RequestStats {
|
||||||
|
// These would ideally come from Prometheus metrics
|
||||||
|
return RequestStats{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeJSONResponseAdmin(w http.ResponseWriter, data interface{}) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
if err := json.NewEncoder(w).Encode(data); err != nil {
|
||||||
|
log.Errorf("Failed to encode JSON response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultAdminConfig returns default admin configuration
|
||||||
|
func DefaultAdminConfig() AdminConfig {
|
||||||
|
return AdminConfig{
|
||||||
|
Enabled: false,
|
||||||
|
Bind: "",
|
||||||
|
PathPrefix: "/admin",
|
||||||
|
Auth: AdminAuthConfig{
|
||||||
|
Type: "bearer",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
366
cmd/server/audit.go
Normal file
366
cmd/server/audit.go
Normal file
@@ -0,0 +1,366 @@
|
|||||||
|
// audit.go - Dedicated audit logging for security-relevant events
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"gopkg.in/natefinch/lumberjack.v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuditConfig holds audit logging configuration
|
||||||
|
type AuditConfig struct {
|
||||||
|
Enabled bool `toml:"enabled" mapstructure:"enabled"`
|
||||||
|
Output string `toml:"output" mapstructure:"output"` // "file" | "stdout"
|
||||||
|
Path string `toml:"path" mapstructure:"path"` // Log file path
|
||||||
|
Format string `toml:"format" mapstructure:"format"` // "json" | "text"
|
||||||
|
Events []string `toml:"events" mapstructure:"events"` // Events to log
|
||||||
|
MaxSize int `toml:"max_size" mapstructure:"max_size"` // Max size in MB
|
||||||
|
MaxAge int `toml:"max_age" mapstructure:"max_age"` // Max age in days
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuditEvent types
|
||||||
|
const (
|
||||||
|
AuditEventUpload = "upload"
|
||||||
|
AuditEventDownload = "download"
|
||||||
|
AuditEventDelete = "delete"
|
||||||
|
AuditEventAuthSuccess = "auth_success"
|
||||||
|
AuditEventAuthFailure = "auth_failure"
|
||||||
|
AuditEventRateLimited = "rate_limited"
|
||||||
|
AuditEventBanned = "banned"
|
||||||
|
AuditEventQuotaExceeded = "quota_exceeded"
|
||||||
|
AuditEventAdminAction = "admin_action"
|
||||||
|
AuditEventValidationFailure = "validation_failure"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuditLogger handles security audit logging
|
||||||
|
type AuditLogger struct {
|
||||||
|
logger *logrus.Logger
|
||||||
|
config *AuditConfig
|
||||||
|
enabledEvents map[string]bool
|
||||||
|
mutex sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
auditLogger *AuditLogger
|
||||||
|
auditOnce sync.Once
|
||||||
|
)
|
||||||
|
|
||||||
|
// InitAuditLogger initializes the audit logger
|
||||||
|
func InitAuditLogger(config *AuditConfig) error {
|
||||||
|
var initErr error
|
||||||
|
auditOnce.Do(func() {
|
||||||
|
auditLogger = &AuditLogger{
|
||||||
|
logger: logrus.New(),
|
||||||
|
config: config,
|
||||||
|
enabledEvents: make(map[string]bool),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build enabled events map for fast lookup
|
||||||
|
for _, event := range config.Events {
|
||||||
|
auditLogger.enabledEvents[strings.ToLower(event)] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure formatter
|
||||||
|
if config.Format == "json" {
|
||||||
|
auditLogger.logger.SetFormatter(&logrus.JSONFormatter{
|
||||||
|
TimestampFormat: time.RFC3339,
|
||||||
|
FieldMap: logrus.FieldMap{
|
||||||
|
logrus.FieldKeyTime: "timestamp",
|
||||||
|
logrus.FieldKeyMsg: "event",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
auditLogger.logger.SetFormatter(&logrus.TextFormatter{
|
||||||
|
TimestampFormat: time.RFC3339,
|
||||||
|
FullTimestamp: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure output
|
||||||
|
if !config.Enabled {
|
||||||
|
auditLogger.logger.SetOutput(io.Discard)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch config.Output {
|
||||||
|
case "stdout":
|
||||||
|
auditLogger.logger.SetOutput(os.Stdout)
|
||||||
|
case "file":
|
||||||
|
if config.Path == "" {
|
||||||
|
config.Path = "/var/log/hmac-audit.log"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure directory exists
|
||||||
|
dir := filepath.Dir(config.Path)
|
||||||
|
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||||
|
initErr = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use lumberjack for log rotation
|
||||||
|
maxSize := config.MaxSize
|
||||||
|
if maxSize <= 0 {
|
||||||
|
maxSize = 100 // Default 100MB
|
||||||
|
}
|
||||||
|
maxAge := config.MaxAge
|
||||||
|
if maxAge <= 0 {
|
||||||
|
maxAge = 30 // Default 30 days
|
||||||
|
}
|
||||||
|
|
||||||
|
auditLogger.logger.SetOutput(&lumberjack.Logger{
|
||||||
|
Filename: config.Path,
|
||||||
|
MaxSize: maxSize,
|
||||||
|
MaxAge: maxAge,
|
||||||
|
MaxBackups: 5,
|
||||||
|
Compress: true,
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
auditLogger.logger.SetOutput(os.Stdout)
|
||||||
|
}
|
||||||
|
|
||||||
|
auditLogger.logger.SetLevel(logrus.InfoLevel)
|
||||||
|
log.Infof("Audit logger initialized: output=%s, path=%s, format=%s, events=%v",
|
||||||
|
config.Output, config.Path, config.Format, config.Events)
|
||||||
|
})
|
||||||
|
|
||||||
|
return initErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAuditLogger returns the singleton audit logger
|
||||||
|
func GetAuditLogger() *AuditLogger {
|
||||||
|
return auditLogger
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsEventEnabled checks if an event type should be logged
|
||||||
|
func (a *AuditLogger) IsEventEnabled(event string) bool {
|
||||||
|
if a == nil || !a.config.Enabled {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
a.mutex.RLock()
|
||||||
|
defer a.mutex.RUnlock()
|
||||||
|
|
||||||
|
// If no events configured, log all
|
||||||
|
if len(a.enabledEvents) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return a.enabledEvents[strings.ToLower(event)]
|
||||||
|
}
|
||||||
|
|
||||||
|
// LogEvent logs an audit event
|
||||||
|
func (a *AuditLogger) LogEvent(event string, fields logrus.Fields) {
|
||||||
|
if a == nil || !a.config.Enabled || !a.IsEventEnabled(event) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add standard fields
|
||||||
|
fields["event_type"] = event
|
||||||
|
if _, ok := fields["timestamp"]; !ok {
|
||||||
|
fields["timestamp"] = time.Now().UTC().Format(time.RFC3339)
|
||||||
|
}
|
||||||
|
|
||||||
|
a.logger.WithFields(fields).Info(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuditEvent is a helper function for logging audit events from request context
|
||||||
|
func AuditEvent(event string, r *http.Request, fields logrus.Fields) {
|
||||||
|
if auditLogger == nil || !auditLogger.config.Enabled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !auditLogger.IsEventEnabled(event) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add request context
|
||||||
|
if r != nil {
|
||||||
|
if fields == nil {
|
||||||
|
fields = logrus.Fields{}
|
||||||
|
}
|
||||||
|
fields["ip"] = getClientIP(r)
|
||||||
|
fields["user_agent"] = r.UserAgent()
|
||||||
|
fields["method"] = r.Method
|
||||||
|
fields["path"] = r.URL.Path
|
||||||
|
|
||||||
|
// Extract JID if available from headers or context
|
||||||
|
if jid := r.Header.Get("X-User-JID"); jid != "" {
|
||||||
|
fields["jid"] = jid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auditLogger.LogEvent(event, fields)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuditUpload logs file upload events
|
||||||
|
func AuditUpload(r *http.Request, jid, fileID, fileName string, fileSize int64, contentType, result string, err error) {
|
||||||
|
fields := logrus.Fields{
|
||||||
|
"jid": jid,
|
||||||
|
"file_id": fileID,
|
||||||
|
"file_name": fileName,
|
||||||
|
"file_size": fileSize,
|
||||||
|
"content_type": contentType,
|
||||||
|
"result": result,
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
fields["error"] = err.Error()
|
||||||
|
}
|
||||||
|
AuditEvent(AuditEventUpload, r, fields)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuditDownload logs file download events
|
||||||
|
func AuditDownload(r *http.Request, jid, fileID, fileName string, fileSize int64, result string, err error) {
|
||||||
|
fields := logrus.Fields{
|
||||||
|
"jid": jid,
|
||||||
|
"file_id": fileID,
|
||||||
|
"file_name": fileName,
|
||||||
|
"file_size": fileSize,
|
||||||
|
"result": result,
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
fields["error"] = err.Error()
|
||||||
|
}
|
||||||
|
AuditEvent(AuditEventDownload, r, fields)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuditDelete logs file deletion events
|
||||||
|
func AuditDelete(r *http.Request, jid, fileID, fileName string, result string, err error) {
|
||||||
|
fields := logrus.Fields{
|
||||||
|
"jid": jid,
|
||||||
|
"file_id": fileID,
|
||||||
|
"file_name": fileName,
|
||||||
|
"result": result,
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
fields["error"] = err.Error()
|
||||||
|
}
|
||||||
|
AuditEvent(AuditEventDelete, r, fields)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuditAuth logs authentication events
|
||||||
|
func AuditAuth(r *http.Request, jid string, success bool, method string, err error) {
|
||||||
|
event := AuditEventAuthSuccess
|
||||||
|
result := "success"
|
||||||
|
if !success {
|
||||||
|
event = AuditEventAuthFailure
|
||||||
|
result = "failure"
|
||||||
|
}
|
||||||
|
|
||||||
|
fields := logrus.Fields{
|
||||||
|
"jid": jid,
|
||||||
|
"auth_method": method,
|
||||||
|
"result": result,
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
fields["error"] = err.Error()
|
||||||
|
}
|
||||||
|
AuditEvent(event, r, fields)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuditRateLimited logs rate limiting events
|
||||||
|
func AuditRateLimited(r *http.Request, jid, reason string) {
|
||||||
|
fields := logrus.Fields{
|
||||||
|
"jid": jid,
|
||||||
|
"reason": reason,
|
||||||
|
}
|
||||||
|
AuditEvent(AuditEventRateLimited, r, fields)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuditBanned logs ban events
|
||||||
|
func AuditBanned(r *http.Request, jid, ip, reason string, duration time.Duration) {
|
||||||
|
fields := logrus.Fields{
|
||||||
|
"jid": jid,
|
||||||
|
"banned_ip": ip,
|
||||||
|
"reason": reason,
|
||||||
|
"ban_duration": duration.String(),
|
||||||
|
}
|
||||||
|
AuditEvent(AuditEventBanned, r, fields)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuditQuotaExceeded logs quota exceeded events
|
||||||
|
func AuditQuotaExceeded(r *http.Request, jid string, used, limit, requested int64) {
|
||||||
|
fields := logrus.Fields{
|
||||||
|
"jid": jid,
|
||||||
|
"used": used,
|
||||||
|
"limit": limit,
|
||||||
|
"requested": requested,
|
||||||
|
}
|
||||||
|
AuditEvent(AuditEventQuotaExceeded, r, fields)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuditAdminAction logs admin API actions
|
||||||
|
func AuditAdminAction(r *http.Request, action, target string, details map[string]interface{}) {
|
||||||
|
fields := logrus.Fields{
|
||||||
|
"action": action,
|
||||||
|
"target": target,
|
||||||
|
}
|
||||||
|
for k, v := range details {
|
||||||
|
fields[k] = v
|
||||||
|
}
|
||||||
|
AuditEvent(AuditEventAdminAction, r, fields)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuditValidationFailure logs content validation failures
|
||||||
|
func AuditValidationFailure(r *http.Request, jid, fileName, declaredType, detectedType, reason string) {
|
||||||
|
fields := logrus.Fields{
|
||||||
|
"jid": jid,
|
||||||
|
"file_name": fileName,
|
||||||
|
"declared_type": declaredType,
|
||||||
|
"detected_type": detectedType,
|
||||||
|
"reason": reason,
|
||||||
|
}
|
||||||
|
AuditEvent(AuditEventValidationFailure, r, fields)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultAuditConfig returns default audit configuration
|
||||||
|
func DefaultAuditConfig() AuditConfig {
|
||||||
|
return AuditConfig{
|
||||||
|
Enabled: false,
|
||||||
|
Output: "file",
|
||||||
|
Path: "/var/log/hmac-audit.log",
|
||||||
|
Format: "json",
|
||||||
|
Events: []string{
|
||||||
|
AuditEventUpload,
|
||||||
|
AuditEventDownload,
|
||||||
|
AuditEventDelete,
|
||||||
|
AuditEventAuthSuccess,
|
||||||
|
AuditEventAuthFailure,
|
||||||
|
AuditEventRateLimited,
|
||||||
|
AuditEventBanned,
|
||||||
|
},
|
||||||
|
MaxSize: 100,
|
||||||
|
MaxAge: 30,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuditAuthSuccess is a helper for logging successful authentication
|
||||||
|
func AuditAuthSuccess(r *http.Request, jid, method string) {
|
||||||
|
AuditAuth(r, jid, true, method, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuditAuthFailure is a helper for logging failed authentication
|
||||||
|
func AuditAuthFailure(r *http.Request, method, errorMsg string) {
|
||||||
|
AuditAuth(r, "", false, method, fmt.Errorf("%s", errorMsg))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuditUploadSuccess is a helper for logging successful uploads
|
||||||
|
func AuditUploadSuccess(r *http.Request, jid, fileName string, fileSize int64, contentType string) {
|
||||||
|
AuditUpload(r, jid, "", fileName, fileSize, contentType, "success", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuditUploadFailure is a helper for logging failed uploads
|
||||||
|
func AuditUploadFailure(r *http.Request, jid, fileName string, fileSize int64, errorMsg string) {
|
||||||
|
AuditUpload(r, jid, "", fileName, fileSize, "", "failure", fmt.Errorf("%s", errorMsg))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuditDownloadSuccess is a helper for logging successful downloads
|
||||||
|
func AuditDownloadSuccess(r *http.Request, jid, fileName string, fileSize int64) {
|
||||||
|
AuditDownload(r, jid, "", fileName, fileSize, "success", nil)
|
||||||
|
}
|
||||||
@@ -39,22 +39,22 @@ import (
|
|||||||
|
|
||||||
// NetworkResilientSession represents a persistent session for network switching
|
// NetworkResilientSession represents a persistent session for network switching
|
||||||
type NetworkResilientSession struct {
|
type NetworkResilientSession struct {
|
||||||
SessionID string `json:"session_id"`
|
SessionID string `json:"session_id"`
|
||||||
UserJID string `json:"user_jid"`
|
UserJID string `json:"user_jid"`
|
||||||
OriginalToken string `json:"original_token"`
|
OriginalToken string `json:"original_token"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
LastSeen time.Time `json:"last_seen"`
|
LastSeen time.Time `json:"last_seen"`
|
||||||
NetworkHistory []NetworkEvent `json:"network_history"`
|
NetworkHistory []NetworkEvent `json:"network_history"`
|
||||||
UploadContext *UploadContext `json:"upload_context,omitempty"`
|
UploadContext *UploadContext `json:"upload_context,omitempty"`
|
||||||
RefreshCount int `json:"refresh_count"`
|
RefreshCount int `json:"refresh_count"`
|
||||||
MaxRefreshes int `json:"max_refreshes"`
|
MaxRefreshes int `json:"max_refreshes"`
|
||||||
LastIP string `json:"last_ip"`
|
LastIP string `json:"last_ip"`
|
||||||
UserAgent string `json:"user_agent"`
|
UserAgent string `json:"user_agent"`
|
||||||
SecurityLevel int `json:"security_level"` // 1=normal, 2=challenge, 3=reauth
|
SecurityLevel int `json:"security_level"` // 1=normal, 2=challenge, 3=reauth
|
||||||
LastSecurityCheck time.Time `json:"last_security_check"`
|
LastSecurityCheck time.Time `json:"last_security_check"`
|
||||||
NetworkChangeCount int `json:"network_change_count"`
|
NetworkChangeCount int `json:"network_change_count"`
|
||||||
StandbyDetected bool `json:"standby_detected"`
|
StandbyDetected bool `json:"standby_detected"`
|
||||||
LastActivity time.Time `json:"last_activity"`
|
LastActivity time.Time `json:"last_activity"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// contextKey is a custom type for context keys to avoid collisions
|
// contextKey is a custom type for context keys to avoid collisions
|
||||||
@@ -67,12 +67,12 @@ const (
|
|||||||
|
|
||||||
// NetworkEvent tracks network transitions during session
|
// NetworkEvent tracks network transitions during session
|
||||||
type NetworkEvent struct {
|
type NetworkEvent struct {
|
||||||
Timestamp time.Time `json:"timestamp"`
|
Timestamp time.Time `json:"timestamp"`
|
||||||
FromNetwork string `json:"from_network"`
|
FromNetwork string `json:"from_network"`
|
||||||
ToNetwork string `json:"to_network"`
|
ToNetwork string `json:"to_network"`
|
||||||
ClientIP string `json:"client_ip"`
|
ClientIP string `json:"client_ip"`
|
||||||
UserAgent string `json:"user_agent"`
|
UserAgent string `json:"user_agent"`
|
||||||
EventType string `json:"event_type"` // "switch", "resume", "refresh"
|
EventType string `json:"event_type"` // "switch", "resume", "refresh"
|
||||||
}
|
}
|
||||||
|
|
||||||
// UploadContext maintains upload state across network changes and network resilience channels
|
// UploadContext maintains upload state across network changes and network resilience channels
|
||||||
@@ -244,11 +244,11 @@ func initializeSessionStore() {
|
|||||||
opt, err := redis.ParseURL(redisURL)
|
opt, err := redis.ParseURL(redisURL)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
sessionStore.redisClient = redis.NewClient(opt)
|
sessionStore.redisClient = redis.NewClient(opt)
|
||||||
|
|
||||||
// Test Redis connection
|
// Test Redis connection
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
if err := sessionStore.redisClient.Ping(ctx).Err(); err == nil {
|
if err := sessionStore.redisClient.Ping(ctx).Err(); err == nil {
|
||||||
log.Infof("📊 Session store: Redis backend initialized (%s)", redisURL)
|
log.Infof("📊 Session store: Redis backend initialized (%s)", redisURL)
|
||||||
} else {
|
} else {
|
||||||
@@ -526,53 +526,57 @@ type BuildConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type NetworkResilienceConfig struct {
|
type NetworkResilienceConfig struct {
|
||||||
FastDetection bool `toml:"fast_detection" mapstructure:"fast_detection"`
|
FastDetection bool `toml:"fast_detection" mapstructure:"fast_detection"`
|
||||||
QualityMonitoring bool `toml:"quality_monitoring" mapstructure:"quality_monitoring"`
|
QualityMonitoring bool `toml:"quality_monitoring" mapstructure:"quality_monitoring"`
|
||||||
PredictiveSwitching bool `toml:"predictive_switching" mapstructure:"predictive_switching"`
|
PredictiveSwitching bool `toml:"predictive_switching" mapstructure:"predictive_switching"`
|
||||||
MobileOptimizations bool `toml:"mobile_optimizations" mapstructure:"mobile_optimizations"`
|
MobileOptimizations bool `toml:"mobile_optimizations" mapstructure:"mobile_optimizations"`
|
||||||
DetectionInterval string `toml:"detection_interval" mapstructure:"detection_interval"`
|
DetectionInterval string `toml:"detection_interval" mapstructure:"detection_interval"`
|
||||||
QualityCheckInterval string `toml:"quality_check_interval" mapstructure:"quality_check_interval"`
|
QualityCheckInterval string `toml:"quality_check_interval" mapstructure:"quality_check_interval"`
|
||||||
MaxDetectionInterval string `toml:"max_detection_interval" mapstructure:"max_detection_interval"`
|
MaxDetectionInterval string `toml:"max_detection_interval" mapstructure:"max_detection_interval"`
|
||||||
|
|
||||||
// Multi-interface support
|
// Multi-interface support
|
||||||
MultiInterfaceEnabled bool `toml:"multi_interface_enabled" mapstructure:"multi_interface_enabled"`
|
MultiInterfaceEnabled bool `toml:"multi_interface_enabled" mapstructure:"multi_interface_enabled"`
|
||||||
InterfacePriority []string `toml:"interface_priority" mapstructure:"interface_priority"`
|
InterfacePriority []string `toml:"interface_priority" mapstructure:"interface_priority"`
|
||||||
AutoSwitchEnabled bool `toml:"auto_switch_enabled" mapstructure:"auto_switch_enabled"`
|
AutoSwitchEnabled bool `toml:"auto_switch_enabled" mapstructure:"auto_switch_enabled"`
|
||||||
SwitchThresholdLatency string `toml:"switch_threshold_latency" mapstructure:"switch_threshold_latency"`
|
SwitchThresholdLatency string `toml:"switch_threshold_latency" mapstructure:"switch_threshold_latency"`
|
||||||
SwitchThresholdPacketLoss float64 `toml:"switch_threshold_packet_loss" mapstructure:"switch_threshold_packet_loss"`
|
SwitchThresholdPacketLoss float64 `toml:"switch_threshold_packet_loss" mapstructure:"switch_threshold_packet_loss"`
|
||||||
QualityDegradationThreshold float64 `toml:"quality_degradation_threshold" mapstructure:"quality_degradation_threshold"`
|
QualityDegradationThreshold float64 `toml:"quality_degradation_threshold" mapstructure:"quality_degradation_threshold"`
|
||||||
MaxSwitchAttempts int `toml:"max_switch_attempts" mapstructure:"max_switch_attempts"`
|
MaxSwitchAttempts int `toml:"max_switch_attempts" mapstructure:"max_switch_attempts"`
|
||||||
SwitchDetectionInterval string `toml:"switch_detection_interval" mapstructure:"switch_detection_interval"`
|
SwitchDetectionInterval string `toml:"switch_detection_interval" mapstructure:"switch_detection_interval"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClientNetworkConfigTOML is used for loading from TOML where timeout is a string
|
// ClientNetworkConfigTOML is used for loading from TOML where timeout is a string
|
||||||
type ClientNetworkConfigTOML struct {
|
type ClientNetworkConfigTOML struct {
|
||||||
SessionBasedTracking bool `toml:"session_based_tracking" mapstructure:"session_based_tracking"`
|
SessionBasedTracking bool `toml:"session_based_tracking" mapstructure:"session_based_tracking"`
|
||||||
AllowIPChanges bool `toml:"allow_ip_changes" mapstructure:"allow_ip_changes"`
|
AllowIPChanges bool `toml:"allow_ip_changes" mapstructure:"allow_ip_changes"`
|
||||||
SessionMigrationTimeout string `toml:"session_migration_timeout" mapstructure:"session_migration_timeout"`
|
SessionMigrationTimeout string `toml:"session_migration_timeout" mapstructure:"session_migration_timeout"`
|
||||||
MaxIPChangesPerSession int `toml:"max_ip_changes_per_session" mapstructure:"max_ip_changes_per_session"`
|
MaxIPChangesPerSession int `toml:"max_ip_changes_per_session" mapstructure:"max_ip_changes_per_session"`
|
||||||
ClientConnectionDetection bool `toml:"client_connection_detection" mapstructure:"client_connection_detection"`
|
ClientConnectionDetection bool `toml:"client_connection_detection" mapstructure:"client_connection_detection"`
|
||||||
AdaptToClientNetwork bool `toml:"adapt_to_client_network" mapstructure:"adapt_to_client_network"`
|
AdaptToClientNetwork bool `toml:"adapt_to_client_network" mapstructure:"adapt_to_client_network"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// This is the main Config struct to be used
|
// This is the main Config struct to be used
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Server ServerConfig `mapstructure:"server"`
|
Server ServerConfig `mapstructure:"server"`
|
||||||
Logging LoggingConfig `mapstructure:"logging"`
|
Logging LoggingConfig `mapstructure:"logging"`
|
||||||
Deduplication DeduplicationConfig `mapstructure:"deduplication"` // Added
|
Deduplication DeduplicationConfig `mapstructure:"deduplication"` // Added
|
||||||
ISO ISOConfig `mapstructure:"iso"` // Added
|
ISO ISOConfig `mapstructure:"iso"` // Added
|
||||||
Timeouts TimeoutConfig `mapstructure:"timeouts"` // Added
|
Timeouts TimeoutConfig `mapstructure:"timeouts"` // Added
|
||||||
Security SecurityConfig `mapstructure:"security"`
|
Security SecurityConfig `mapstructure:"security"`
|
||||||
Versioning VersioningConfig `mapstructure:"versioning"` // Added
|
Versioning VersioningConfig `mapstructure:"versioning"` // Added
|
||||||
Uploads UploadsConfig `mapstructure:"uploads"`
|
Uploads UploadsConfig `mapstructure:"uploads"`
|
||||||
Downloads DownloadsConfig `mapstructure:"downloads"`
|
Downloads DownloadsConfig `mapstructure:"downloads"`
|
||||||
ClamAV ClamAVConfig `mapstructure:"clamav"`
|
ClamAV ClamAVConfig `mapstructure:"clamav"`
|
||||||
Redis RedisConfig `mapstructure:"redis"`
|
Redis RedisConfig `mapstructure:"redis"`
|
||||||
Workers WorkersConfig `mapstructure:"workers"`
|
Workers WorkersConfig `mapstructure:"workers"`
|
||||||
File FileConfig `mapstructure:"file"`
|
File FileConfig `mapstructure:"file"`
|
||||||
Build BuildConfig `mapstructure:"build"`
|
Build BuildConfig `mapstructure:"build"`
|
||||||
NetworkResilience NetworkResilienceConfig `mapstructure:"network_resilience"`
|
NetworkResilience NetworkResilienceConfig `mapstructure:"network_resilience"`
|
||||||
ClientNetwork ClientNetworkConfigTOML `mapstructure:"client_network_support"`
|
ClientNetwork ClientNetworkConfigTOML `mapstructure:"client_network_support"`
|
||||||
|
Audit AuditConfig `mapstructure:"audit"` // Audit logging
|
||||||
|
Validation ValidationConfig `mapstructure:"validation"` // Content validation
|
||||||
|
Quotas QuotaConfig `mapstructure:"quotas"` // Per-user quotas
|
||||||
|
Admin AdminConfig `mapstructure:"admin"` // Admin API
|
||||||
}
|
}
|
||||||
|
|
||||||
type UploadTask struct {
|
type UploadTask struct {
|
||||||
@@ -597,12 +601,12 @@ func processScan(task ScanTask) error {
|
|||||||
confMutex.RLock()
|
confMutex.RLock()
|
||||||
clamEnabled := conf.ClamAV.ClamAVEnabled
|
clamEnabled := conf.ClamAV.ClamAVEnabled
|
||||||
confMutex.RUnlock()
|
confMutex.RUnlock()
|
||||||
|
|
||||||
if !clamEnabled {
|
if !clamEnabled {
|
||||||
log.Infof("ClamAV disabled, skipping scan for file: %s", task.AbsFilename)
|
log.Infof("ClamAV disabled, skipping scan for file: %s", task.AbsFilename)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("Started processing scan for file: %s", task.AbsFilename)
|
log.Infof("Started processing scan for file: %s", task.AbsFilename)
|
||||||
semaphore <- struct{}{}
|
semaphore <- struct{}{}
|
||||||
defer func() { <-semaphore }()
|
defer func() { <-semaphore }()
|
||||||
@@ -621,8 +625,8 @@ var (
|
|||||||
conf Config
|
conf Config
|
||||||
versionString string
|
versionString string
|
||||||
log = logrus.New()
|
log = logrus.New()
|
||||||
fileInfoCache *cache.Cache //nolint:unused
|
fileInfoCache *cache.Cache //nolint:unused
|
||||||
fileMetadataCache *cache.Cache //nolint:unused
|
fileMetadataCache *cache.Cache //nolint:unused
|
||||||
clamClient *clamd.Clamd
|
clamClient *clamd.Clamd
|
||||||
redisClient *redis.Client
|
redisClient *redis.Client
|
||||||
redisConnected bool
|
redisConnected bool
|
||||||
@@ -673,6 +677,7 @@ var clientTracker *ClientConnectionTracker
|
|||||||
|
|
||||||
//nolint:unused
|
//nolint:unused
|
||||||
var logMessages []string
|
var logMessages []string
|
||||||
|
|
||||||
//nolint:unused
|
//nolint:unused
|
||||||
var logMu sync.Mutex
|
var logMu sync.Mutex
|
||||||
|
|
||||||
@@ -748,7 +753,7 @@ func initializeNetworkProtocol(forceProtocol string) (*net.Dialer, error) {
|
|||||||
if forceProtocol == "" {
|
if forceProtocol == "" {
|
||||||
forceProtocol = "auto"
|
forceProtocol = "auto"
|
||||||
}
|
}
|
||||||
|
|
||||||
switch forceProtocol {
|
switch forceProtocol {
|
||||||
case "ipv4":
|
case "ipv4":
|
||||||
return &net.Dialer{
|
return &net.Dialer{
|
||||||
@@ -845,7 +850,7 @@ func main() {
|
|||||||
} else {
|
} else {
|
||||||
content = GenerateMinimalConfig()
|
content = GenerateMinimalConfig()
|
||||||
}
|
}
|
||||||
|
|
||||||
f, err := os.Create(genConfigPath)
|
f, err := os.Create(genConfigPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Fprintf(os.Stderr, "Failed to create file: %v\n", err)
|
fmt.Fprintf(os.Stderr, "Failed to create file: %v\n", err)
|
||||||
@@ -878,7 +883,7 @@ func main() {
|
|||||||
log.Fatalf("Failed to load configuration: %v", err)
|
log.Fatalf("Failed to load configuration: %v", err)
|
||||||
}
|
}
|
||||||
conf = *loadedConfig
|
conf = *loadedConfig
|
||||||
configFileGlobal = configFile // Store for validation helper functions
|
configFileGlobal = configFile // Store for validation helper functions
|
||||||
log.Info("Configuration loaded successfully.")
|
log.Info("Configuration loaded successfully.")
|
||||||
|
|
||||||
err = validateConfig(&conf)
|
err = validateConfig(&conf)
|
||||||
@@ -892,12 +897,12 @@ func main() {
|
|||||||
|
|
||||||
// Initialize client connection tracker for multi-interface support
|
// Initialize client connection tracker for multi-interface support
|
||||||
clientNetworkConfig := &ClientNetworkConfig{
|
clientNetworkConfig := &ClientNetworkConfig{
|
||||||
SessionBasedTracking: conf.ClientNetwork.SessionBasedTracking,
|
SessionBasedTracking: conf.ClientNetwork.SessionBasedTracking,
|
||||||
AllowIPChanges: conf.ClientNetwork.AllowIPChanges,
|
AllowIPChanges: conf.ClientNetwork.AllowIPChanges,
|
||||||
MaxIPChangesPerSession: conf.ClientNetwork.MaxIPChangesPerSession,
|
MaxIPChangesPerSession: conf.ClientNetwork.MaxIPChangesPerSession,
|
||||||
AdaptToClientNetwork: conf.ClientNetwork.AdaptToClientNetwork,
|
AdaptToClientNetwork: conf.ClientNetwork.AdaptToClientNetwork,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse session migration timeout
|
// Parse session migration timeout
|
||||||
if conf.ClientNetwork.SessionMigrationTimeout != "" {
|
if conf.ClientNetwork.SessionMigrationTimeout != "" {
|
||||||
if timeout, err := time.ParseDuration(conf.ClientNetwork.SessionMigrationTimeout); err == nil {
|
if timeout, err := time.ParseDuration(conf.ClientNetwork.SessionMigrationTimeout); err == nil {
|
||||||
@@ -908,12 +913,12 @@ func main() {
|
|||||||
} else {
|
} else {
|
||||||
clientNetworkConfig.SessionMigrationTimeout = 5 * time.Minute // default
|
clientNetworkConfig.SessionMigrationTimeout = 5 * time.Minute // default
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set defaults if not configured
|
// Set defaults if not configured
|
||||||
if clientNetworkConfig.MaxIPChangesPerSession == 0 {
|
if clientNetworkConfig.MaxIPChangesPerSession == 0 {
|
||||||
clientNetworkConfig.MaxIPChangesPerSession = 10
|
clientNetworkConfig.MaxIPChangesPerSession = 10
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize the client tracker
|
// Initialize the client tracker
|
||||||
clientTracker = NewClientConnectionTracker(clientNetworkConfig)
|
clientTracker = NewClientConnectionTracker(clientNetworkConfig)
|
||||||
if clientTracker != nil {
|
if clientTracker != nil {
|
||||||
@@ -1075,8 +1080,22 @@ func main() {
|
|||||||
initRedis() // Assuming initRedis is defined in helpers.go or elsewhere
|
initRedis() // Assuming initRedis is defined in helpers.go or elsewhere
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize new features
|
||||||
|
if err := InitAuditLogger(&conf.Audit); err != nil {
|
||||||
|
log.Warnf("Failed to initialize audit logger: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
InitContentValidator(&conf.Validation)
|
||||||
|
|
||||||
|
if err := InitQuotaManager(&conf.Quotas, redisClient); err != nil {
|
||||||
|
log.Warnf("Failed to initialize quota manager: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
router := setupRouter() // Assuming setupRouter is defined (likely in this file or router.go
|
router := setupRouter() // Assuming setupRouter is defined (likely in this file or router.go
|
||||||
|
|
||||||
|
// Setup Admin API routes
|
||||||
|
SetupAdminRoutes(router, &conf.Admin)
|
||||||
|
|
||||||
// Initialize enhancements and enhance the router
|
// Initialize enhancements and enhance the router
|
||||||
InitializeEnhancements(router)
|
InitializeEnhancements(router)
|
||||||
|
|
||||||
@@ -1658,7 +1677,7 @@ func validateBearerToken(r *http.Request, secret string) (*BearerTokenClaims, er
|
|||||||
query := r.URL.Query()
|
query := r.URL.Query()
|
||||||
user := query.Get("user")
|
user := query.Get("user")
|
||||||
expiryStr := query.Get("expiry")
|
expiryStr := query.Get("expiry")
|
||||||
|
|
||||||
if user == "" {
|
if user == "" {
|
||||||
return nil, errors.New("missing user parameter")
|
return nil, errors.New("missing user parameter")
|
||||||
}
|
}
|
||||||
@@ -1674,10 +1693,10 @@ func validateBearerToken(r *http.Request, secret string) (*BearerTokenClaims, er
|
|||||||
|
|
||||||
// ULTRA-FLEXIBLE GRACE PERIODS FOR NETWORK SWITCHING AND STANDBY SCENARIOS
|
// ULTRA-FLEXIBLE GRACE PERIODS FOR NETWORK SWITCHING AND STANDBY SCENARIOS
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
|
|
||||||
// Base grace period: 8 hours (increased from 4 hours for better WiFi ↔ LTE reliability)
|
// Base grace period: 8 hours (increased from 4 hours for better WiFi ↔ LTE reliability)
|
||||||
gracePeriod := int64(28800) // 8 hours base grace period for all scenarios
|
gracePeriod := int64(28800) // 8 hours base grace period for all scenarios
|
||||||
|
|
||||||
// Detect mobile XMPP clients and apply enhanced grace periods
|
// Detect mobile XMPP clients and apply enhanced grace periods
|
||||||
userAgent := r.Header.Get("User-Agent")
|
userAgent := r.Header.Get("User-Agent")
|
||||||
isMobileXMPP := strings.Contains(strings.ToLower(userAgent), "conversations") ||
|
isMobileXMPP := strings.Contains(strings.ToLower(userAgent), "conversations") ||
|
||||||
@@ -1688,12 +1707,12 @@ func validateBearerToken(r *http.Request, secret string) (*BearerTokenClaims, er
|
|||||||
strings.Contains(strings.ToLower(userAgent), "xmpp") ||
|
strings.Contains(strings.ToLower(userAgent), "xmpp") ||
|
||||||
strings.Contains(strings.ToLower(userAgent), "client") ||
|
strings.Contains(strings.ToLower(userAgent), "client") ||
|
||||||
strings.Contains(strings.ToLower(userAgent), "bot")
|
strings.Contains(strings.ToLower(userAgent), "bot")
|
||||||
|
|
||||||
// Enhanced XMPP client detection and grace period management
|
// Enhanced XMPP client detection and grace period management
|
||||||
// Desktop XMPP clients (Dino, Gajim) need extended grace for session restoration after restart
|
// Desktop XMPP clients (Dino, Gajim) need extended grace for session restoration after restart
|
||||||
isDesktopXMPP := strings.Contains(strings.ToLower(userAgent), "dino") ||
|
isDesktopXMPP := strings.Contains(strings.ToLower(userAgent), "dino") ||
|
||||||
strings.Contains(strings.ToLower(userAgent), "gajim")
|
strings.Contains(strings.ToLower(userAgent), "gajim")
|
||||||
|
|
||||||
if isMobileXMPP || isDesktopXMPP {
|
if isMobileXMPP || isDesktopXMPP {
|
||||||
if isDesktopXMPP {
|
if isDesktopXMPP {
|
||||||
gracePeriod = int64(86400) // 24 hours for desktop XMPP clients (session restoration)
|
gracePeriod = int64(86400) // 24 hours for desktop XMPP clients (session restoration)
|
||||||
@@ -1703,32 +1722,32 @@ func validateBearerToken(r *http.Request, secret string) (*BearerTokenClaims, er
|
|||||||
log.Infof("<22> Mobile XMPP client detected (%s), using extended 12-hour grace period", userAgent)
|
log.Infof("<22> Mobile XMPP client detected (%s), using extended 12-hour grace period", userAgent)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Network resilience parameters for session recovery
|
// Network resilience parameters for session recovery
|
||||||
sessionId := query.Get("session_id")
|
sessionId := query.Get("session_id")
|
||||||
networkResilience := query.Get("network_resilience")
|
networkResilience := query.Get("network_resilience")
|
||||||
resumeAllowed := query.Get("resume_allowed")
|
resumeAllowed := query.Get("resume_allowed")
|
||||||
|
|
||||||
// Maximum grace period for network resilience scenarios
|
// Maximum grace period for network resilience scenarios
|
||||||
if sessionId != "" || networkResilience == "true" || resumeAllowed == "true" {
|
if sessionId != "" || networkResilience == "true" || resumeAllowed == "true" {
|
||||||
gracePeriod = int64(86400) // 24 hours for explicit network resilience scenarios
|
gracePeriod = int64(86400) // 24 hours for explicit network resilience scenarios
|
||||||
log.Infof("🌐 Network resilience mode activated (session_id: %s, network_resilience: %s), using 24-hour grace period",
|
log.Infof("🌐 Network resilience mode activated (session_id: %s, network_resilience: %s), using 24-hour grace period",
|
||||||
sessionId, networkResilience)
|
sessionId, networkResilience)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Detect potential network switching scenarios
|
// Detect potential network switching scenarios
|
||||||
clientIP := getClientIP(r)
|
clientIP := getClientIP(r)
|
||||||
xForwardedFor := r.Header.Get("X-Forwarded-For")
|
xForwardedFor := r.Header.Get("X-Forwarded-For")
|
||||||
xRealIP := r.Header.Get("X-Real-IP")
|
xRealIP := r.Header.Get("X-Real-IP")
|
||||||
|
|
||||||
// Check for client IP change indicators (WiFi ↔ LTE switching detection)
|
// Check for client IP change indicators (WiFi ↔ LTE switching detection)
|
||||||
if xForwardedFor != "" || xRealIP != "" {
|
if xForwardedFor != "" || xRealIP != "" {
|
||||||
// Client is behind proxy/NAT - likely mobile switching between networks
|
// Client is behind proxy/NAT - likely mobile switching between networks
|
||||||
gracePeriod = int64(86400) // 24 hours for proxy/NAT scenarios
|
gracePeriod = int64(86400) // 24 hours for proxy/NAT scenarios
|
||||||
log.Infof("📱 Network switching detected (client IP: %s, X-Forwarded-For: %s, X-Real-IP: %s), using 24-hour grace period",
|
log.Infof("📱 Network switching detected (client IP: %s, X-Forwarded-For: %s, X-Real-IP: %s), using 24-hour grace period",
|
||||||
clientIP, xForwardedFor, xRealIP)
|
clientIP, xForwardedFor, xRealIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check Content-Length to identify large uploads that need extra time
|
// Check Content-Length to identify large uploads that need extra time
|
||||||
contentLength := r.Header.Get("Content-Length")
|
contentLength := r.Header.Get("Content-Length")
|
||||||
var size int64 = 0
|
var size int64 = 0
|
||||||
@@ -1741,27 +1760,27 @@ func validateBearerToken(r *http.Request, secret string) (*BearerTokenClaims, er
|
|||||||
log.Infof("📁 Large file detected (%d bytes), extending grace period by %d seconds", size, additionalTime)
|
log.Infof("📁 Large file detected (%d bytes), extending grace period by %d seconds", size, additionalTime)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ABSOLUTE MAXIMUM: 48 hours for extreme scenarios
|
// ABSOLUTE MAXIMUM: 48 hours for extreme scenarios
|
||||||
maxAbsoluteGrace := int64(172800) // 48 hours absolute maximum
|
maxAbsoluteGrace := int64(172800) // 48 hours absolute maximum
|
||||||
if gracePeriod > maxAbsoluteGrace {
|
if gracePeriod > maxAbsoluteGrace {
|
||||||
gracePeriod = maxAbsoluteGrace
|
gracePeriod = maxAbsoluteGrace
|
||||||
log.Infof("⚠️ Grace period capped at 48 hours maximum")
|
log.Infof("⚠️ Grace period capped at 48 hours maximum")
|
||||||
}
|
}
|
||||||
|
|
||||||
// STANDBY RECOVERY: Special handling for device standby scenarios
|
// STANDBY RECOVERY: Special handling for device standby scenarios
|
||||||
isLikelyStandbyRecovery := false
|
isLikelyStandbyRecovery := false
|
||||||
standbyGraceExtension := int64(86400) // Additional 24 hours for standby recovery
|
standbyGraceExtension := int64(86400) // Additional 24 hours for standby recovery
|
||||||
|
|
||||||
if now > expiry {
|
if now > expiry {
|
||||||
expiredTime := now - expiry
|
expiredTime := now - expiry
|
||||||
|
|
||||||
// If token expired more than grace period but less than standby window, allow standby recovery
|
// If token expired more than grace period but less than standby window, allow standby recovery
|
||||||
if expiredTime > gracePeriod && expiredTime < (gracePeriod + standbyGraceExtension) {
|
if expiredTime > gracePeriod && expiredTime < (gracePeriod+standbyGraceExtension) {
|
||||||
isLikelyStandbyRecovery = true
|
isLikelyStandbyRecovery = true
|
||||||
log.Infof("💤 STANDBY RECOVERY: Token expired %d seconds ago, within standby recovery window", expiredTime)
|
log.Infof("💤 STANDBY RECOVERY: Token expired %d seconds ago, within standby recovery window", expiredTime)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply grace period check
|
// Apply grace period check
|
||||||
if expiredTime > gracePeriod && !isLikelyStandbyRecovery {
|
if expiredTime > gracePeriod && !isLikelyStandbyRecovery {
|
||||||
// DESKTOP XMPP CLIENT SESSION RESTORATION: Special handling for Dino/Gajim restart scenarios
|
// DESKTOP XMPP CLIENT SESSION RESTORATION: Special handling for Dino/Gajim restart scenarios
|
||||||
@@ -1770,7 +1789,7 @@ func validateBearerToken(r *http.Request, secret string) (*BearerTokenClaims, er
|
|||||||
isDesktopSessionRestore = true
|
isDesktopSessionRestore = true
|
||||||
log.Infof("🖥️ DESKTOP SESSION RESTORE: %s token expired %d seconds ago, allowing within 48-hour desktop restoration window", userAgent, expiredTime)
|
log.Infof("🖥️ DESKTOP SESSION RESTORE: %s token expired %d seconds ago, allowing within 48-hour desktop restoration window", userAgent, expiredTime)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Still apply ultra-generous final check for mobile scenarios
|
// Still apply ultra-generous final check for mobile scenarios
|
||||||
ultraMaxGrace := int64(259200) // 72 hours ultra-maximum for critical mobile scenarios
|
ultraMaxGrace := int64(259200) // 72 hours ultra-maximum for critical mobile scenarios
|
||||||
if (isMobileXMPP && expiredTime < ultraMaxGrace) || isDesktopSessionRestore {
|
if (isMobileXMPP && expiredTime < ultraMaxGrace) || isDesktopSessionRestore {
|
||||||
@@ -1778,9 +1797,9 @@ func validateBearerToken(r *http.Request, secret string) (*BearerTokenClaims, er
|
|||||||
log.Warnf("⚡ ULTRA-GRACE: Mobile XMPP client token expired %d seconds ago, allowing within 72-hour ultra-grace window", expiredTime)
|
log.Warnf("⚡ ULTRA-GRACE: Mobile XMPP client token expired %d seconds ago, allowing within 72-hour ultra-grace window", expiredTime)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
log.Warnf("❌ Bearer token expired beyond all grace periods: now=%d, expiry=%d, expired_for=%d seconds, grace_period=%d, user_agent=%s",
|
log.Warnf("❌ Bearer token expired beyond all grace periods: now=%d, expiry=%d, expired_for=%d seconds, grace_period=%d, user_agent=%s",
|
||||||
now, expiry, expiredTime, gracePeriod, userAgent)
|
now, expiry, expiredTime, gracePeriod, userAgent)
|
||||||
return nil, fmt.Errorf("token has expired beyond grace period (expired %d seconds ago, grace period: %d seconds)",
|
return nil, fmt.Errorf("token has expired beyond grace period (expired %d seconds ago, grace period: %d seconds)",
|
||||||
expiredTime, gracePeriod)
|
expiredTime, gracePeriod)
|
||||||
}
|
}
|
||||||
} else if isLikelyStandbyRecovery {
|
} else if isLikelyStandbyRecovery {
|
||||||
@@ -1797,7 +1816,7 @@ func validateBearerToken(r *http.Request, secret string) (*BearerTokenClaims, er
|
|||||||
if len(pathParts) < 1 {
|
if len(pathParts) < 1 {
|
||||||
return nil, errors.New("invalid upload path format")
|
return nil, errors.New("invalid upload path format")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle different path formats from various ejabberd modules
|
// Handle different path formats from various ejabberd modules
|
||||||
filename := ""
|
filename := ""
|
||||||
if len(pathParts) >= 3 {
|
if len(pathParts) >= 3 {
|
||||||
@@ -1805,7 +1824,7 @@ func validateBearerToken(r *http.Request, secret string) (*BearerTokenClaims, er
|
|||||||
} else if len(pathParts) >= 1 {
|
} else if len(pathParts) >= 1 {
|
||||||
filename = pathParts[len(pathParts)-1] // Simplified format: /filename
|
filename = pathParts[len(pathParts)-1] // Simplified format: /filename
|
||||||
}
|
}
|
||||||
|
|
||||||
if filename == "" {
|
if filename == "" {
|
||||||
filename = "upload" // Fallback filename
|
filename = "upload" // Fallback filename
|
||||||
}
|
}
|
||||||
@@ -1813,71 +1832,71 @@ func validateBearerToken(r *http.Request, secret string) (*BearerTokenClaims, er
|
|||||||
// ENHANCED HMAC VALIDATION: Try multiple payload formats for maximum compatibility
|
// ENHANCED HMAC VALIDATION: Try multiple payload formats for maximum compatibility
|
||||||
var validPayload bool
|
var validPayload bool
|
||||||
var payloadFormat string
|
var payloadFormat string
|
||||||
|
|
||||||
// Format 1: Network-resilient payload (mod_http_upload_hmac_network_resilient)
|
// Format 1: Network-resilient payload (mod_http_upload_hmac_network_resilient)
|
||||||
extendedPayload := fmt.Sprintf("%s\x00%s\x00%d\x00%d\x00%d\x00network_resilient",
|
extendedPayload := fmt.Sprintf("%s\x00%s\x00%d\x00%d\x00%d\x00network_resilient",
|
||||||
user, filename, size, expiry-86400, expiry)
|
user, filename, size, expiry-86400, expiry)
|
||||||
h1 := hmac.New(sha256.New, []byte(secret))
|
h1 := hmac.New(sha256.New, []byte(secret))
|
||||||
h1.Write([]byte(extendedPayload))
|
h1.Write([]byte(extendedPayload))
|
||||||
expectedMAC1 := h1.Sum(nil)
|
expectedMAC1 := h1.Sum(nil)
|
||||||
|
|
||||||
if hmac.Equal(tokenBytes, expectedMAC1) {
|
if hmac.Equal(tokenBytes, expectedMAC1) {
|
||||||
validPayload = true
|
validPayload = true
|
||||||
payloadFormat = "network_resilient"
|
payloadFormat = "network_resilient"
|
||||||
}
|
}
|
||||||
|
|
||||||
// Format 2: Extended payload with session support
|
// Format 2: Extended payload with session support
|
||||||
if !validPayload {
|
if !validPayload {
|
||||||
sessionPayload := fmt.Sprintf("%s\x00%s\x00%d\x00%d\x00%s", user, filename, size, expiry, sessionId)
|
sessionPayload := fmt.Sprintf("%s\x00%s\x00%d\x00%d\x00%s", user, filename, size, expiry, sessionId)
|
||||||
h2 := hmac.New(sha256.New, []byte(secret))
|
h2 := hmac.New(sha256.New, []byte(secret))
|
||||||
h2.Write([]byte(sessionPayload))
|
h2.Write([]byte(sessionPayload))
|
||||||
expectedMAC2 := h2.Sum(nil)
|
expectedMAC2 := h2.Sum(nil)
|
||||||
|
|
||||||
if hmac.Equal(tokenBytes, expectedMAC2) {
|
if hmac.Equal(tokenBytes, expectedMAC2) {
|
||||||
validPayload = true
|
validPayload = true
|
||||||
payloadFormat = "session_based"
|
payloadFormat = "session_based"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Format 3: Standard payload (original mod_http_upload_hmac)
|
// Format 3: Standard payload (original mod_http_upload_hmac)
|
||||||
if !validPayload {
|
if !validPayload {
|
||||||
standardPayload := fmt.Sprintf("%s\x00%s\x00%d\x00%d", user, filename, size, expiry-3600)
|
standardPayload := fmt.Sprintf("%s\x00%s\x00%d\x00%d", user, filename, size, expiry-3600)
|
||||||
h3 := hmac.New(sha256.New, []byte(secret))
|
h3 := hmac.New(sha256.New, []byte(secret))
|
||||||
h3.Write([]byte(standardPayload))
|
h3.Write([]byte(standardPayload))
|
||||||
expectedMAC3 := h3.Sum(nil)
|
expectedMAC3 := h3.Sum(nil)
|
||||||
|
|
||||||
if hmac.Equal(tokenBytes, expectedMAC3) {
|
if hmac.Equal(tokenBytes, expectedMAC3) {
|
||||||
validPayload = true
|
validPayload = true
|
||||||
payloadFormat = "standard"
|
payloadFormat = "standard"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Format 4: Simplified payload (fallback compatibility)
|
// Format 4: Simplified payload (fallback compatibility)
|
||||||
if !validPayload {
|
if !validPayload {
|
||||||
simplePayload := fmt.Sprintf("%s\x00%s\x00%d", user, filename, size)
|
simplePayload := fmt.Sprintf("%s\x00%s\x00%d", user, filename, size)
|
||||||
h4 := hmac.New(sha256.New, []byte(secret))
|
h4 := hmac.New(sha256.New, []byte(secret))
|
||||||
h4.Write([]byte(simplePayload))
|
h4.Write([]byte(simplePayload))
|
||||||
expectedMAC4 := h4.Sum(nil)
|
expectedMAC4 := h4.Sum(nil)
|
||||||
|
|
||||||
if hmac.Equal(tokenBytes, expectedMAC4) {
|
if hmac.Equal(tokenBytes, expectedMAC4) {
|
||||||
validPayload = true
|
validPayload = true
|
||||||
payloadFormat = "simple"
|
payloadFormat = "simple"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Format 5: User-only payload (maximum fallback)
|
// Format 5: User-only payload (maximum fallback)
|
||||||
if !validPayload {
|
if !validPayload {
|
||||||
userPayload := fmt.Sprintf("%s\x00%d", user, expiry)
|
userPayload := fmt.Sprintf("%s\x00%d", user, expiry)
|
||||||
h5 := hmac.New(sha256.New, []byte(secret))
|
h5 := hmac.New(sha256.New, []byte(secret))
|
||||||
h5.Write([]byte(userPayload))
|
h5.Write([]byte(userPayload))
|
||||||
expectedMAC5 := h5.Sum(nil)
|
expectedMAC5 := h5.Sum(nil)
|
||||||
|
|
||||||
if hmac.Equal(tokenBytes, expectedMAC5) {
|
if hmac.Equal(tokenBytes, expectedMAC5) {
|
||||||
validPayload = true
|
validPayload = true
|
||||||
payloadFormat = "user_only"
|
payloadFormat = "user_only"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !validPayload {
|
if !validPayload {
|
||||||
log.Warnf("❌ Invalid Bearer token HMAC for user %s, file %s (tried all 5 payload formats)", user, filename)
|
log.Warnf("❌ Invalid Bearer token HMAC for user %s, file %s (tried all 5 payload formats)", user, filename)
|
||||||
return nil, errors.New("invalid Bearer token HMAC")
|
return nil, errors.New("invalid Bearer token HMAC")
|
||||||
@@ -1890,16 +1909,16 @@ func validateBearerToken(r *http.Request, secret string) (*BearerTokenClaims, er
|
|||||||
Expiry: expiry,
|
Expiry: expiry,
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("✅ Bearer token authentication SUCCESSFUL: user=%s, file=%s, format=%s, grace_period=%d seconds",
|
log.Infof("✅ Bearer token authentication SUCCESSFUL: user=%s, file=%s, format=%s, grace_period=%d seconds",
|
||||||
user, filename, payloadFormat, gracePeriod)
|
user, filename, payloadFormat, gracePeriod)
|
||||||
|
|
||||||
return claims, nil
|
return claims, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// evaluateSecurityLevel determines the required security level based on network changes and standby detection
|
// evaluateSecurityLevel determines the required security level based on network changes and standby detection
|
||||||
func evaluateSecurityLevel(session *NetworkResilientSession, currentIP string, userAgent string) int {
|
func evaluateSecurityLevel(session *NetworkResilientSession, currentIP string, userAgent string) int {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
// Initialize if this is the first check
|
// Initialize if this is the first check
|
||||||
if session.LastSecurityCheck.IsZero() {
|
if session.LastSecurityCheck.IsZero() {
|
||||||
session.LastSecurityCheck = now
|
session.LastSecurityCheck = now
|
||||||
@@ -1907,50 +1926,50 @@ func evaluateSecurityLevel(session *NetworkResilientSession, currentIP string, u
|
|||||||
session.SecurityLevel = 1 // Normal level
|
session.SecurityLevel = 1 // Normal level
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
// Detect potential standby scenario
|
// Detect potential standby scenario
|
||||||
timeSinceLastActivity := now.Sub(session.LastActivity)
|
timeSinceLastActivity := now.Sub(session.LastActivity)
|
||||||
standbyThreshold := 30 * time.Minute
|
standbyThreshold := 30 * time.Minute
|
||||||
|
|
||||||
if timeSinceLastActivity > standbyThreshold {
|
if timeSinceLastActivity > standbyThreshold {
|
||||||
session.StandbyDetected = true
|
session.StandbyDetected = true
|
||||||
log.Infof("🔒 STANDBY DETECTED: %v since last activity for session %s", timeSinceLastActivity, session.SessionID)
|
log.Infof("🔒 STANDBY DETECTED: %v since last activity for session %s", timeSinceLastActivity, session.SessionID)
|
||||||
|
|
||||||
// Long standby requires full re-authentication
|
// Long standby requires full re-authentication
|
||||||
if timeSinceLastActivity > 2*time.Hour {
|
if timeSinceLastActivity > 2*time.Hour {
|
||||||
log.Warnf("🔐 SECURITY LEVEL 3: Long standby (%v) requires full re-authentication", timeSinceLastActivity)
|
log.Warnf("🔐 SECURITY LEVEL 3: Long standby (%v) requires full re-authentication", timeSinceLastActivity)
|
||||||
return 3
|
return 3
|
||||||
}
|
}
|
||||||
|
|
||||||
// Medium standby requires challenge-response
|
// Medium standby requires challenge-response
|
||||||
log.Infof("🔐 SECURITY LEVEL 2: Medium standby (%v) requires challenge-response", timeSinceLastActivity)
|
log.Infof("🔐 SECURITY LEVEL 2: Medium standby (%v) requires challenge-response", timeSinceLastActivity)
|
||||||
return 2
|
return 2
|
||||||
}
|
}
|
||||||
|
|
||||||
// Detect network changes
|
// Detect network changes
|
||||||
if session.LastIP != "" && session.LastIP != currentIP {
|
if session.LastIP != "" && session.LastIP != currentIP {
|
||||||
session.NetworkChangeCount++
|
session.NetworkChangeCount++
|
||||||
log.Infof("🌐 NETWORK CHANGE #%d: %s → %s for session %s",
|
log.Infof("🌐 NETWORK CHANGE #%d: %s → %s for session %s",
|
||||||
session.NetworkChangeCount, session.LastIP, currentIP, session.SessionID)
|
session.NetworkChangeCount, session.LastIP, currentIP, session.SessionID)
|
||||||
|
|
||||||
// Multiple rapid network changes are suspicious
|
// Multiple rapid network changes are suspicious
|
||||||
if session.NetworkChangeCount > 3 {
|
if session.NetworkChangeCount > 3 {
|
||||||
log.Warnf("🔐 SECURITY LEVEL 3: Multiple network changes (%d) requires full re-authentication",
|
log.Warnf("🔐 SECURITY LEVEL 3: Multiple network changes (%d) requires full re-authentication",
|
||||||
session.NetworkChangeCount)
|
session.NetworkChangeCount)
|
||||||
return 3
|
return 3
|
||||||
}
|
}
|
||||||
|
|
||||||
// Single network change requires challenge-response
|
// Single network change requires challenge-response
|
||||||
log.Infof("🔐 SECURITY LEVEL 2: Network change requires challenge-response")
|
log.Infof("🔐 SECURITY LEVEL 2: Network change requires challenge-response")
|
||||||
return 2
|
return 2
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for suspicious user agent changes
|
// Check for suspicious user agent changes
|
||||||
if session.UserAgent != "" && session.UserAgent != userAgent {
|
if session.UserAgent != "" && session.UserAgent != userAgent {
|
||||||
log.Warnf("🔐 SECURITY LEVEL 3: User agent change detected - potential device hijacking")
|
log.Warnf("🔐 SECURITY LEVEL 3: User agent change detected - potential device hijacking")
|
||||||
return 3
|
return 3
|
||||||
}
|
}
|
||||||
|
|
||||||
// Normal operation
|
// Normal operation
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
@@ -1960,11 +1979,11 @@ func generateSecurityChallenge(session *NetworkResilientSession, secret string)
|
|||||||
// Create a time-based challenge using session data
|
// Create a time-based challenge using session data
|
||||||
timestamp := time.Now().Unix()
|
timestamp := time.Now().Unix()
|
||||||
challengeData := fmt.Sprintf("%s:%s:%d", session.SessionID, session.UserJID, timestamp)
|
challengeData := fmt.Sprintf("%s:%s:%d", session.SessionID, session.UserJID, timestamp)
|
||||||
|
|
||||||
h := hmac.New(sha256.New, []byte(secret))
|
h := hmac.New(sha256.New, []byte(secret))
|
||||||
h.Write([]byte(challengeData))
|
h.Write([]byte(challengeData))
|
||||||
challenge := hex.EncodeToString(h.Sum(nil))
|
challenge := hex.EncodeToString(h.Sum(nil))
|
||||||
|
|
||||||
log.Infof("🔐 Generated security challenge for session %s", session.SessionID)
|
log.Infof("🔐 Generated security challenge for session %s", session.SessionID)
|
||||||
return challenge, nil
|
return challenge, nil
|
||||||
}
|
}
|
||||||
@@ -1974,22 +1993,22 @@ func validateSecurityChallenge(session *NetworkResilientSession, providedRespons
|
|||||||
// This would validate against the expected response
|
// This would validate against the expected response
|
||||||
// For now, we'll implement a simple time-window validation
|
// For now, we'll implement a simple time-window validation
|
||||||
timestamp := time.Now().Unix()
|
timestamp := time.Now().Unix()
|
||||||
|
|
||||||
// Allow 5-minute window for challenge responses
|
// Allow 5-minute window for challenge responses
|
||||||
for i := int64(0); i <= 300; i += 60 {
|
for i := int64(0); i <= 300; i += 60 {
|
||||||
testTimestamp := timestamp - i
|
testTimestamp := timestamp - i
|
||||||
challengeData := fmt.Sprintf("%s:%s:%d", session.SessionID, session.UserJID, testTimestamp)
|
challengeData := fmt.Sprintf("%s:%s:%d", session.SessionID, session.UserJID, testTimestamp)
|
||||||
|
|
||||||
h := hmac.New(sha256.New, []byte(secret))
|
h := hmac.New(sha256.New, []byte(secret))
|
||||||
h.Write([]byte(challengeData))
|
h.Write([]byte(challengeData))
|
||||||
expectedResponse := hex.EncodeToString(h.Sum(nil))
|
expectedResponse := hex.EncodeToString(h.Sum(nil))
|
||||||
|
|
||||||
if expectedResponse == providedResponse {
|
if expectedResponse == providedResponse {
|
||||||
log.Infof("✅ Security challenge validated for session %s", session.SessionID)
|
log.Infof("✅ Security challenge validated for session %s", session.SessionID)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Warnf("❌ Security challenge failed for session %s", session.SessionID)
|
log.Warnf("❌ Security challenge failed for session %s", session.SessionID)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -2029,17 +2048,17 @@ func validateBearerTokenWithSession(r *http.Request, secret string) (*BearerToke
|
|||||||
session := sessionStore.GetSession(sessionID)
|
session := sessionStore.GetSession(sessionID)
|
||||||
if session == nil {
|
if session == nil {
|
||||||
session = &NetworkResilientSession{
|
session = &NetworkResilientSession{
|
||||||
SessionID: sessionID,
|
SessionID: sessionID,
|
||||||
UserJID: claims.User,
|
UserJID: claims.User,
|
||||||
OriginalToken: getBearerTokenFromRequest(r),
|
OriginalToken: getBearerTokenFromRequest(r),
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
MaxRefreshes: 10,
|
MaxRefreshes: 10,
|
||||||
NetworkHistory: []NetworkEvent{},
|
NetworkHistory: []NetworkEvent{},
|
||||||
SecurityLevel: 1,
|
SecurityLevel: 1,
|
||||||
LastSecurityCheck: time.Now(),
|
LastSecurityCheck: time.Now(),
|
||||||
NetworkChangeCount: 0,
|
NetworkChangeCount: 0,
|
||||||
StandbyDetected: false,
|
StandbyDetected: false,
|
||||||
LastActivity: time.Now(),
|
LastActivity: time.Now(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2069,7 +2088,7 @@ func validateBearerTokenWithSession(r *http.Request, secret string) (*BearerToke
|
|||||||
log.Errorf("❌ Failed to generate security challenge: %v", err)
|
log.Errorf("❌ Failed to generate security challenge: %v", err)
|
||||||
return nil, fmt.Errorf("security challenge generation failed")
|
return nil, fmt.Errorf("security challenge generation failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if client provided challenge response
|
// Check if client provided challenge response
|
||||||
challengeResponse := r.Header.Get("X-Challenge-Response")
|
challengeResponse := r.Header.Get("X-Challenge-Response")
|
||||||
if challengeResponse == "" {
|
if challengeResponse == "" {
|
||||||
@@ -2077,15 +2096,15 @@ func validateBearerTokenWithSession(r *http.Request, secret string) (*BearerToke
|
|||||||
setSecurityHeaders(w, 2, challenge)
|
setSecurityHeaders(w, 2, challenge)
|
||||||
return nil, fmt.Errorf("challenge-response required for network change")
|
return nil, fmt.Errorf("challenge-response required for network change")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate challenge response
|
// Validate challenge response
|
||||||
if !validateSecurityChallenge(session, challengeResponse, secret) {
|
if !validateSecurityChallenge(session, challengeResponse, secret) {
|
||||||
setSecurityHeaders(w, 2, challenge)
|
setSecurityHeaders(w, 2, challenge)
|
||||||
return nil, fmt.Errorf("invalid challenge response")
|
return nil, fmt.Errorf("invalid challenge response")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("✅ Challenge-response validated for session %s", sessionID)
|
log.Infof("✅ Challenge-response validated for session %s", sessionID)
|
||||||
|
|
||||||
case 3:
|
case 3:
|
||||||
// Full re-authentication required
|
// Full re-authentication required
|
||||||
setSecurityHeaders(w, 3, "")
|
setSecurityHeaders(w, 3, "")
|
||||||
@@ -2104,7 +2123,7 @@ func validateBearerTokenWithSession(r *http.Request, secret string) (*BearerToke
|
|||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
EventType: "network_switch",
|
EventType: "network_switch",
|
||||||
})
|
})
|
||||||
log.Infof("🌐 Network switch detected for session %s: %s → %s",
|
log.Infof("🌐 Network switch detected for session %s: %s → %s",
|
||||||
sessionID, session.LastIP, currentIP)
|
sessionID, session.LastIP, currentIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2138,7 +2157,7 @@ func validateBearerTokenWithSession(r *http.Request, secret string) (*BearerToke
|
|||||||
// Token refresh successful
|
// Token refresh successful
|
||||||
session.RefreshCount++
|
session.RefreshCount++
|
||||||
session.LastSeen = time.Now()
|
session.LastSeen = time.Now()
|
||||||
|
|
||||||
// Add refresh event to history
|
// Add refresh event to history
|
||||||
session.NetworkHistory = append(session.NetworkHistory, NetworkEvent{
|
session.NetworkHistory = append(session.NetworkHistory, NetworkEvent{
|
||||||
Timestamp: time.Now(),
|
Timestamp: time.Now(),
|
||||||
@@ -2157,12 +2176,12 @@ func validateBearerTokenWithSession(r *http.Request, secret string) (*BearerToke
|
|||||||
Expiry: time.Now().Add(24 * time.Hour).Unix(),
|
Expiry: time.Now().Add(24 * time.Hour).Unix(),
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("✅ Session recovery successful: %s (refresh #%d)",
|
log.Infof("✅ Session recovery successful: %s (refresh #%d)",
|
||||||
sessionID, session.RefreshCount)
|
sessionID, session.RefreshCount)
|
||||||
return refreshedClaims, nil
|
return refreshedClaims, nil
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
log.Warnf("❌ Session %s exceeded maximum refreshes (%d)",
|
log.Warnf("❌ Session %s exceeded maximum refreshes (%d)",
|
||||||
sessionID, session.MaxRefreshes)
|
sessionID, session.MaxRefreshes)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -2191,8 +2210,8 @@ func refreshSessionToken(session *NetworkResilientSession, secret string, r *htt
|
|||||||
size := extractSizeFromRequest(r)
|
size := extractSizeFromRequest(r)
|
||||||
|
|
||||||
// Use session-based payload format for refresh
|
// Use session-based payload format for refresh
|
||||||
payload := fmt.Sprintf("%s\x00%s\x00%d\x00%d\x00%s\x00session_refresh",
|
payload := fmt.Sprintf("%s\x00%s\x00%d\x00%d\x00%s\x00session_refresh",
|
||||||
session.UserJID,
|
session.UserJID,
|
||||||
filename,
|
filename,
|
||||||
size,
|
size,
|
||||||
expiry,
|
expiry,
|
||||||
@@ -2202,7 +2221,7 @@ func refreshSessionToken(session *NetworkResilientSession, secret string, r *htt
|
|||||||
h.Write([]byte(payload))
|
h.Write([]byte(payload))
|
||||||
token := base64.StdEncoding.EncodeToString(h.Sum(nil))
|
token := base64.StdEncoding.EncodeToString(h.Sum(nil))
|
||||||
|
|
||||||
log.Infof("🆕 Generated refresh token for session %s (refresh #%d)",
|
log.Infof("🆕 Generated refresh token for session %s (refresh #%d)",
|
||||||
session.SessionID, session.RefreshCount+1)
|
session.SessionID, session.RefreshCount+1)
|
||||||
|
|
||||||
return token, nil
|
return token, nil
|
||||||
@@ -2251,7 +2270,7 @@ type BearerTokenClaims struct {
|
|||||||
// ENHANCED FOR 100% WIFI ↔ LTE SWITCHING AND STANDBY RECOVERY RELIABILITY
|
// ENHANCED FOR 100% WIFI ↔ LTE SWITCHING AND STANDBY RECOVERY RELIABILITY
|
||||||
func validateHMAC(r *http.Request, secret string) error {
|
func validateHMAC(r *http.Request, secret string) error {
|
||||||
log.Debugf("🔍 validateHMAC: Validating request to %s with query: %s", r.URL.Path, r.URL.RawQuery)
|
log.Debugf("🔍 validateHMAC: Validating request to %s with query: %s", r.URL.Path, r.URL.RawQuery)
|
||||||
|
|
||||||
// Check for X-Signature header (for POST uploads)
|
// Check for X-Signature header (for POST uploads)
|
||||||
signature := r.Header.Get("X-Signature")
|
signature := r.Header.Get("X-Signature")
|
||||||
if signature != "" {
|
if signature != "" {
|
||||||
@@ -2294,7 +2313,7 @@ func validateHMAC(r *http.Request, secret string) error {
|
|||||||
// ENHANCED HMAC CALCULATION: Try multiple formats for maximum compatibility
|
// ENHANCED HMAC CALCULATION: Try multiple formats for maximum compatibility
|
||||||
var validMAC bool
|
var validMAC bool
|
||||||
var messageFormat string
|
var messageFormat string
|
||||||
|
|
||||||
// Calculate HMAC based on protocol version with enhanced compatibility
|
// Calculate HMAC based on protocol version with enhanced compatibility
|
||||||
mac := hmac.New(sha256.New, []byte(secret))
|
mac := hmac.New(sha256.New, []byte(secret))
|
||||||
|
|
||||||
@@ -2305,7 +2324,7 @@ func validateHMAC(r *http.Request, secret string) error {
|
|||||||
mac.Write([]byte(message1))
|
mac.Write([]byte(message1))
|
||||||
calculatedMAC1 := mac.Sum(nil)
|
calculatedMAC1 := mac.Sum(nil)
|
||||||
calculatedMACHex1 := hex.EncodeToString(calculatedMAC1)
|
calculatedMACHex1 := hex.EncodeToString(calculatedMAC1)
|
||||||
|
|
||||||
// Decode provided MAC
|
// Decode provided MAC
|
||||||
if providedMAC, err := hex.DecodeString(providedMACHex); err == nil {
|
if providedMAC, err := hex.DecodeString(providedMACHex); err == nil {
|
||||||
if hmac.Equal(calculatedMAC1, providedMAC) {
|
if hmac.Equal(calculatedMAC1, providedMAC) {
|
||||||
@@ -2314,14 +2333,14 @@ func validateHMAC(r *http.Request, secret string) error {
|
|||||||
log.Debugf("✅ Legacy v protocol HMAC validated: %s", calculatedMACHex1)
|
log.Debugf("✅ Legacy v protocol HMAC validated: %s", calculatedMACHex1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Format 2: Try without content length for compatibility
|
// Format 2: Try without content length for compatibility
|
||||||
if !validMAC {
|
if !validMAC {
|
||||||
message2 := fileStorePath
|
message2 := fileStorePath
|
||||||
mac.Reset()
|
mac.Reset()
|
||||||
mac.Write([]byte(message2))
|
mac.Write([]byte(message2))
|
||||||
calculatedMAC2 := mac.Sum(nil)
|
calculatedMAC2 := mac.Sum(nil)
|
||||||
|
|
||||||
if providedMAC, err := hex.DecodeString(providedMACHex); err == nil {
|
if providedMAC, err := hex.DecodeString(providedMACHex); err == nil {
|
||||||
if hmac.Equal(calculatedMAC2, providedMAC) {
|
if hmac.Equal(calculatedMAC2, providedMAC) {
|
||||||
validMAC = true
|
validMAC = true
|
||||||
@@ -2333,14 +2352,14 @@ func validateHMAC(r *http.Request, secret string) error {
|
|||||||
} else {
|
} else {
|
||||||
// v2 and token protocols: Enhanced format compatibility
|
// v2 and token protocols: Enhanced format compatibility
|
||||||
contentType := GetContentType(fileStorePath)
|
contentType := GetContentType(fileStorePath)
|
||||||
|
|
||||||
// Format 1: Standard format - fileStorePath + "\x00" + contentLength + "\x00" + contentType
|
// Format 1: Standard format - fileStorePath + "\x00" + contentLength + "\x00" + contentType
|
||||||
message1 := fileStorePath + "\x00" + strconv.FormatInt(r.ContentLength, 10) + "\x00" + contentType
|
message1 := fileStorePath + "\x00" + strconv.FormatInt(r.ContentLength, 10) + "\x00" + contentType
|
||||||
mac.Reset()
|
mac.Reset()
|
||||||
mac.Write([]byte(message1))
|
mac.Write([]byte(message1))
|
||||||
calculatedMAC1 := mac.Sum(nil)
|
calculatedMAC1 := mac.Sum(nil)
|
||||||
calculatedMACHex1 := hex.EncodeToString(calculatedMAC1)
|
calculatedMACHex1 := hex.EncodeToString(calculatedMAC1)
|
||||||
|
|
||||||
if providedMAC, err := hex.DecodeString(providedMACHex); err == nil {
|
if providedMAC, err := hex.DecodeString(providedMACHex); err == nil {
|
||||||
if hmac.Equal(calculatedMAC1, providedMAC) {
|
if hmac.Equal(calculatedMAC1, providedMAC) {
|
||||||
validMAC = true
|
validMAC = true
|
||||||
@@ -2348,14 +2367,14 @@ func validateHMAC(r *http.Request, secret string) error {
|
|||||||
log.Debugf("✅ %s protocol HMAC validated (standard): %s", protocolVersion, calculatedMACHex1)
|
log.Debugf("✅ %s protocol HMAC validated (standard): %s", protocolVersion, calculatedMACHex1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Format 2: Without content type for compatibility
|
// Format 2: Without content type for compatibility
|
||||||
if !validMAC {
|
if !validMAC {
|
||||||
message2 := fileStorePath + "\x00" + strconv.FormatInt(r.ContentLength, 10)
|
message2 := fileStorePath + "\x00" + strconv.FormatInt(r.ContentLength, 10)
|
||||||
mac.Reset()
|
mac.Reset()
|
||||||
mac.Write([]byte(message2))
|
mac.Write([]byte(message2))
|
||||||
calculatedMAC2 := mac.Sum(nil)
|
calculatedMAC2 := mac.Sum(nil)
|
||||||
|
|
||||||
if providedMAC, err := hex.DecodeString(providedMACHex); err == nil {
|
if providedMAC, err := hex.DecodeString(providedMACHex); err == nil {
|
||||||
if hmac.Equal(calculatedMAC2, providedMAC) {
|
if hmac.Equal(calculatedMAC2, providedMAC) {
|
||||||
validMAC = true
|
validMAC = true
|
||||||
@@ -2364,14 +2383,14 @@ func validateHMAC(r *http.Request, secret string) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Format 3: Simple path only for maximum compatibility
|
// Format 3: Simple path only for maximum compatibility
|
||||||
if !validMAC {
|
if !validMAC {
|
||||||
message3 := fileStorePath
|
message3 := fileStorePath
|
||||||
mac.Reset()
|
mac.Reset()
|
||||||
mac.Write([]byte(message3))
|
mac.Write([]byte(message3))
|
||||||
calculatedMAC3 := mac.Sum(nil)
|
calculatedMAC3 := mac.Sum(nil)
|
||||||
|
|
||||||
if providedMAC, err := hex.DecodeString(providedMACHex); err == nil {
|
if providedMAC, err := hex.DecodeString(providedMACHex); err == nil {
|
||||||
if hmac.Equal(calculatedMAC3, providedMAC) {
|
if hmac.Equal(calculatedMAC3, providedMAC) {
|
||||||
validMAC = true
|
validMAC = true
|
||||||
@@ -2387,7 +2406,7 @@ func validateHMAC(r *http.Request, secret string) error {
|
|||||||
return fmt.Errorf("invalid MAC for %s protocol", protocolVersion)
|
return fmt.Errorf("invalid MAC for %s protocol", protocolVersion)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("✅ %s HMAC authentication SUCCESSFUL: format=%s, path=%s",
|
log.Infof("✅ %s HMAC authentication SUCCESSFUL: format=%s, path=%s",
|
||||||
protocolVersion, messageFormat, r.URL.Path)
|
protocolVersion, messageFormat, r.URL.Path)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -2417,11 +2436,11 @@ func validateV3HMAC(r *http.Request, secret string) error {
|
|||||||
|
|
||||||
// ULTRA-FLEXIBLE GRACE PERIODS FOR V3 PROTOCOL NETWORK SWITCHING
|
// ULTRA-FLEXIBLE GRACE PERIODS FOR V3 PROTOCOL NETWORK SWITCHING
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
|
|
||||||
if now > expires {
|
if now > expires {
|
||||||
// Base grace period: 8 hours (significantly increased for WiFi ↔ LTE reliability)
|
// Base grace period: 8 hours (significantly increased for WiFi ↔ LTE reliability)
|
||||||
gracePeriod := int64(28800) // 8 hours base grace period
|
gracePeriod := int64(28800) // 8 hours base grace period
|
||||||
|
|
||||||
// Enhanced mobile XMPP client detection
|
// Enhanced mobile XMPP client detection
|
||||||
userAgent := r.Header.Get("User-Agent")
|
userAgent := r.Header.Get("User-Agent")
|
||||||
isMobileXMPP := strings.Contains(strings.ToLower(userAgent), "gajim") ||
|
isMobileXMPP := strings.Contains(strings.ToLower(userAgent), "gajim") ||
|
||||||
@@ -2432,12 +2451,12 @@ func validateV3HMAC(r *http.Request, secret string) error {
|
|||||||
strings.Contains(strings.ToLower(userAgent), "xmpp") ||
|
strings.Contains(strings.ToLower(userAgent), "xmpp") ||
|
||||||
strings.Contains(strings.ToLower(userAgent), "client") ||
|
strings.Contains(strings.ToLower(userAgent), "client") ||
|
||||||
strings.Contains(strings.ToLower(userAgent), "bot")
|
strings.Contains(strings.ToLower(userAgent), "bot")
|
||||||
|
|
||||||
if isMobileXMPP {
|
if isMobileXMPP {
|
||||||
gracePeriod = int64(43200) // 12 hours for mobile XMPP clients
|
gracePeriod = int64(43200) // 12 hours for mobile XMPP clients
|
||||||
log.Infof("📱 V3: Mobile XMPP client detected (%s), using 12-hour grace period", userAgent)
|
log.Infof("📱 V3: Mobile XMPP client detected (%s), using 12-hour grace period", userAgent)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Network resilience parameters for V3 protocol
|
// Network resilience parameters for V3 protocol
|
||||||
sessionId := query.Get("session_id")
|
sessionId := query.Get("session_id")
|
||||||
networkResilience := query.Get("network_resilience")
|
networkResilience := query.Get("network_resilience")
|
||||||
@@ -2446,19 +2465,19 @@ func validateV3HMAC(r *http.Request, secret string) error {
|
|||||||
gracePeriod = int64(86400) // 24 hours for network resilience scenarios
|
gracePeriod = int64(86400) // 24 hours for network resilience scenarios
|
||||||
log.Infof("🌐 V3: Network resilience mode detected, using 24-hour grace period")
|
log.Infof("🌐 V3: Network resilience mode detected, using 24-hour grace period")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Detect network switching indicators
|
// Detect network switching indicators
|
||||||
clientIP := getClientIP(r)
|
clientIP := getClientIP(r)
|
||||||
xForwardedFor := r.Header.Get("X-Forwarded-For")
|
xForwardedFor := r.Header.Get("X-Forwarded-For")
|
||||||
xRealIP := r.Header.Get("X-Real-IP")
|
xRealIP := r.Header.Get("X-Real-IP")
|
||||||
|
|
||||||
if xForwardedFor != "" || xRealIP != "" {
|
if xForwardedFor != "" || xRealIP != "" {
|
||||||
// Client behind proxy/NAT - likely mobile network switching
|
// Client behind proxy/NAT - likely mobile network switching
|
||||||
gracePeriod = int64(86400) // 24 hours for proxy/NAT scenarios
|
gracePeriod = int64(86400) // 24 hours for proxy/NAT scenarios
|
||||||
log.Infof("🔄 V3: Network switching detected (IP: %s, X-Forwarded-For: %s), using 24-hour grace period",
|
log.Infof("🔄 V3: Network switching detected (IP: %s, X-Forwarded-For: %s), using 24-hour grace period",
|
||||||
clientIP, xForwardedFor)
|
clientIP, xForwardedFor)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Large file uploads get additional grace time
|
// Large file uploads get additional grace time
|
||||||
if contentLengthStr := r.Header.Get("Content-Length"); contentLengthStr != "" {
|
if contentLengthStr := r.Header.Get("Content-Length"); contentLengthStr != "" {
|
||||||
if contentLength, parseErr := strconv.ParseInt(contentLengthStr, 10, 64); parseErr == nil {
|
if contentLength, parseErr := strconv.ParseInt(contentLengthStr, 10, 64); parseErr == nil {
|
||||||
@@ -2466,33 +2485,33 @@ func validateV3HMAC(r *http.Request, secret string) error {
|
|||||||
if contentLength > 10*1024*1024 {
|
if contentLength > 10*1024*1024 {
|
||||||
additionalTime := (contentLength / (10 * 1024 * 1024)) * 3600 // 1 hour per 10MB
|
additionalTime := (contentLength / (10 * 1024 * 1024)) * 3600 // 1 hour per 10MB
|
||||||
gracePeriod += additionalTime
|
gracePeriod += additionalTime
|
||||||
log.Infof("📁 V3: Large file (%d bytes), extending grace period by %d seconds",
|
log.Infof("📁 V3: Large file (%d bytes), extending grace period by %d seconds",
|
||||||
contentLength, additionalTime)
|
contentLength, additionalTime)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Maximum grace period cap: 48 hours
|
// Maximum grace period cap: 48 hours
|
||||||
maxGracePeriod := int64(172800) // 48 hours absolute maximum
|
maxGracePeriod := int64(172800) // 48 hours absolute maximum
|
||||||
if gracePeriod > maxGracePeriod {
|
if gracePeriod > maxGracePeriod {
|
||||||
gracePeriod = maxGracePeriod
|
gracePeriod = maxGracePeriod
|
||||||
log.Infof("⚠️ V3: Grace period capped at 48 hours maximum")
|
log.Infof("⚠️ V3: Grace period capped at 48 hours maximum")
|
||||||
}
|
}
|
||||||
|
|
||||||
// STANDBY RECOVERY: Handle device standby scenarios
|
// STANDBY RECOVERY: Handle device standby scenarios
|
||||||
expiredTime := now - expires
|
expiredTime := now - expires
|
||||||
standbyGraceExtension := int64(86400) // Additional 24 hours for standby
|
standbyGraceExtension := int64(86400) // Additional 24 hours for standby
|
||||||
isLikelyStandbyRecovery := expiredTime > gracePeriod && expiredTime < (gracePeriod + standbyGraceExtension)
|
isLikelyStandbyRecovery := expiredTime > gracePeriod && expiredTime < (gracePeriod+standbyGraceExtension)
|
||||||
|
|
||||||
if expiredTime > gracePeriod && !isLikelyStandbyRecovery {
|
if expiredTime > gracePeriod && !isLikelyStandbyRecovery {
|
||||||
// Ultra-generous final check for mobile scenarios
|
// Ultra-generous final check for mobile scenarios
|
||||||
ultraMaxGrace := int64(259200) // 72 hours ultra-maximum for critical scenarios
|
ultraMaxGrace := int64(259200) // 72 hours ultra-maximum for critical scenarios
|
||||||
if isMobileXMPP && expiredTime < ultraMaxGrace {
|
if isMobileXMPP && expiredTime < ultraMaxGrace {
|
||||||
log.Warnf("⚡ V3 ULTRA-GRACE: Mobile client token expired %d seconds ago, allowing within 72-hour window", expiredTime)
|
log.Warnf("⚡ V3 ULTRA-GRACE: Mobile client token expired %d seconds ago, allowing within 72-hour window", expiredTime)
|
||||||
} else {
|
} else {
|
||||||
log.Warnf("❌ V3 signature expired beyond all grace periods: now=%d, expires=%d, expired_for=%d seconds, grace_period=%d, user_agent=%s",
|
log.Warnf("❌ V3 signature expired beyond all grace periods: now=%d, expires=%d, expired_for=%d seconds, grace_period=%d, user_agent=%s",
|
||||||
now, expires, expiredTime, gracePeriod, userAgent)
|
now, expires, expiredTime, gracePeriod, userAgent)
|
||||||
return fmt.Errorf("signature has expired beyond grace period (expired %d seconds ago, grace period: %d seconds)",
|
return fmt.Errorf("signature has expired beyond grace period (expired %d seconds ago, grace period: %d seconds)",
|
||||||
expiredTime, gracePeriod)
|
expiredTime, gracePeriod)
|
||||||
}
|
}
|
||||||
} else if isLikelyStandbyRecovery {
|
} else if isLikelyStandbyRecovery {
|
||||||
@@ -2507,18 +2526,18 @@ func validateV3HMAC(r *http.Request, secret string) error {
|
|||||||
// ENHANCED MESSAGE CONSTRUCTION: Try multiple formats for compatibility
|
// ENHANCED MESSAGE CONSTRUCTION: Try multiple formats for compatibility
|
||||||
var validSignature bool
|
var validSignature bool
|
||||||
var messageFormat string
|
var messageFormat string
|
||||||
|
|
||||||
// Format 1: Standard v3 format
|
// Format 1: Standard v3 format
|
||||||
message1 := fmt.Sprintf("%s\n%s\n%s", r.Method, expiresStr, r.URL.Path)
|
message1 := fmt.Sprintf("%s\n%s\n%s", r.Method, expiresStr, r.URL.Path)
|
||||||
h1 := hmac.New(sha256.New, []byte(secret))
|
h1 := hmac.New(sha256.New, []byte(secret))
|
||||||
h1.Write([]byte(message1))
|
h1.Write([]byte(message1))
|
||||||
expectedSignature1 := hex.EncodeToString(h1.Sum(nil))
|
expectedSignature1 := hex.EncodeToString(h1.Sum(nil))
|
||||||
|
|
||||||
if hmac.Equal([]byte(signature), []byte(expectedSignature1)) {
|
if hmac.Equal([]byte(signature), []byte(expectedSignature1)) {
|
||||||
validSignature = true
|
validSignature = true
|
||||||
messageFormat = "standard_v3"
|
messageFormat = "standard_v3"
|
||||||
}
|
}
|
||||||
|
|
||||||
// Format 2: Alternative format with query string
|
// Format 2: Alternative format with query string
|
||||||
if !validSignature {
|
if !validSignature {
|
||||||
pathWithQuery := r.URL.Path
|
pathWithQuery := r.URL.Path
|
||||||
@@ -2529,32 +2548,32 @@ func validateV3HMAC(r *http.Request, secret string) error {
|
|||||||
h2 := hmac.New(sha256.New, []byte(secret))
|
h2 := hmac.New(sha256.New, []byte(secret))
|
||||||
h2.Write([]byte(message2))
|
h2.Write([]byte(message2))
|
||||||
expectedSignature2 := hex.EncodeToString(h2.Sum(nil))
|
expectedSignature2 := hex.EncodeToString(h2.Sum(nil))
|
||||||
|
|
||||||
if hmac.Equal([]byte(signature), []byte(expectedSignature2)) {
|
if hmac.Equal([]byte(signature), []byte(expectedSignature2)) {
|
||||||
validSignature = true
|
validSignature = true
|
||||||
messageFormat = "with_query"
|
messageFormat = "with_query"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Format 3: Simplified format (fallback)
|
// Format 3: Simplified format (fallback)
|
||||||
if !validSignature {
|
if !validSignature {
|
||||||
message3 := fmt.Sprintf("%s\n%s", r.Method, r.URL.Path)
|
message3 := fmt.Sprintf("%s\n%s", r.Method, r.URL.Path)
|
||||||
h3 := hmac.New(sha256.New, []byte(secret))
|
h3 := hmac.New(sha256.New, []byte(secret))
|
||||||
h3.Write([]byte(message3))
|
h3.Write([]byte(message3))
|
||||||
expectedSignature3 := hex.EncodeToString(h3.Sum(nil))
|
expectedSignature3 := hex.EncodeToString(h3.Sum(nil))
|
||||||
|
|
||||||
if hmac.Equal([]byte(signature), []byte(expectedSignature3)) {
|
if hmac.Equal([]byte(signature), []byte(expectedSignature3)) {
|
||||||
validSignature = true
|
validSignature = true
|
||||||
messageFormat = "simplified"
|
messageFormat = "simplified"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !validSignature {
|
if !validSignature {
|
||||||
log.Warnf("❌ Invalid V3 HMAC signature (tried all 3 formats)")
|
log.Warnf("❌ Invalid V3 HMAC signature (tried all 3 formats)")
|
||||||
return errors.New("invalid v3 HMAC signature")
|
return errors.New("invalid v3 HMAC signature")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("✅ V3 HMAC authentication SUCCESSFUL: format=%s, method=%s, path=%s",
|
log.Infof("✅ V3 HMAC authentication SUCCESSFUL: format=%s, method=%s, path=%s",
|
||||||
messageFormat, r.Method, r.URL.Path)
|
messageFormat, r.Method, r.URL.Path)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -2563,7 +2582,7 @@ func validateV3HMAC(r *http.Request, secret string) error {
|
|||||||
func copyWithProgressTracking(dst io.Writer, src io.Reader, buf []byte, totalSize int64, clientIP string) (int64, error) {
|
func copyWithProgressTracking(dst io.Writer, src io.Reader, buf []byte, totalSize int64, clientIP string) (int64, error) {
|
||||||
var written int64
|
var written int64
|
||||||
lastLogTime := time.Now()
|
lastLogTime := time.Now()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
n, err := src.Read(buf)
|
n, err := src.Read(buf)
|
||||||
if n > 0 {
|
if n > 0 {
|
||||||
@@ -2572,12 +2591,12 @@ func copyWithProgressTracking(dst io.Writer, src io.Reader, buf []byte, totalSiz
|
|||||||
if werr != nil {
|
if werr != nil {
|
||||||
return written, werr
|
return written, werr
|
||||||
}
|
}
|
||||||
|
|
||||||
// Log progress for large files every 10MB or 30 seconds
|
// Log progress for large files every 10MB or 30 seconds
|
||||||
if totalSize > 50*1024*1024 &&
|
if totalSize > 50*1024*1024 &&
|
||||||
(written%10*1024*1024 == 0 || time.Since(lastLogTime) > 30*time.Second) {
|
(written%10*1024*1024 == 0 || time.Since(lastLogTime) > 30*time.Second) {
|
||||||
progress := float64(written) / float64(totalSize) * 100
|
progress := float64(written) / float64(totalSize) * 100
|
||||||
log.Infof("📥 Download progress: %.1f%% (%s/%s) for IP %s",
|
log.Infof("📥 Download progress: %.1f%% (%s/%s) for IP %s",
|
||||||
progress, formatBytes(written), formatBytes(totalSize), clientIP)
|
progress, formatBytes(written), formatBytes(totalSize), clientIP)
|
||||||
lastLogTime = time.Now()
|
lastLogTime = time.Now()
|
||||||
}
|
}
|
||||||
@@ -2589,7 +2608,7 @@ func copyWithProgressTracking(dst io.Writer, src io.Reader, buf []byte, totalSiz
|
|||||||
return written, err
|
return written, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return written, nil
|
return written, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2606,11 +2625,11 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Generate session ID for multi-upload tracking
|
// Generate session ID for multi-upload tracking
|
||||||
sessionID = generateUploadSessionID("upload", r.Header.Get("User-Agent"), getClientIP(r))
|
sessionID = generateUploadSessionID("upload", r.Header.Get("User-Agent"), getClientIP(r))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set session headers for client continuation
|
// Set session headers for client continuation
|
||||||
w.Header().Set("X-Session-ID", sessionID)
|
w.Header().Set("X-Session-ID", sessionID)
|
||||||
w.Header().Set("X-Upload-Session-Timeout", "3600") // 1 hour
|
w.Header().Set("X-Upload-Session-Timeout", "3600") // 1 hour
|
||||||
|
|
||||||
// Only allow POST method
|
// Only allow POST method
|
||||||
if r.Method != http.MethodPost {
|
if r.Method != http.MethodPost {
|
||||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
@@ -2621,22 +2640,22 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
|
|||||||
// ENHANCED AUTHENTICATION with network switching support
|
// ENHANCED AUTHENTICATION with network switching support
|
||||||
var bearerClaims *BearerTokenClaims
|
var bearerClaims *BearerTokenClaims
|
||||||
authHeader := r.Header.Get("Authorization")
|
authHeader := r.Header.Get("Authorization")
|
||||||
|
|
||||||
if strings.HasPrefix(authHeader, "Bearer ") {
|
if strings.HasPrefix(authHeader, "Bearer ") {
|
||||||
// Bearer token authentication with session recovery for network switching
|
// Bearer token authentication with session recovery for network switching
|
||||||
// Store response writer in context for session headers
|
// Store response writer in context for session headers
|
||||||
ctx := context.WithValue(r.Context(), responseWriterKey, w)
|
ctx := context.WithValue(r.Context(), responseWriterKey, w)
|
||||||
r = r.WithContext(ctx)
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
claims, err := validateBearerTokenWithSession(r, conf.Security.Secret)
|
claims, err := validateBearerTokenWithSession(r, conf.Security.Secret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Enhanced error logging for network switching scenarios
|
// Enhanced error logging for network switching scenarios
|
||||||
clientIP := getClientIP(r)
|
clientIP := getClientIP(r)
|
||||||
userAgent := r.Header.Get("User-Agent")
|
userAgent := r.Header.Get("User-Agent")
|
||||||
sessionID := getSessionIDFromRequest(r)
|
sessionID := getSessionIDFromRequest(r)
|
||||||
log.Warnf("🔴 Authentication failed for IP %s, User-Agent: %s, Session: %s, Error: %v",
|
log.Warnf("🔴 Authentication failed for IP %s, User-Agent: %s, Session: %s, Error: %v",
|
||||||
clientIP, userAgent, sessionID, err)
|
clientIP, userAgent, sessionID, err)
|
||||||
|
|
||||||
// Check if this might be a network switching scenario and provide helpful response
|
// Check if this might be a network switching scenario and provide helpful response
|
||||||
if strings.Contains(err.Error(), "expired") || strings.Contains(err.Error(), "invalid") {
|
if strings.Contains(err.Error(), "expired") || strings.Contains(err.Error(), "invalid") {
|
||||||
w.Header().Set("X-Network-Switch-Detected", "true")
|
w.Header().Set("X-Network-Switch-Detected", "true")
|
||||||
@@ -2646,15 +2665,17 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
|
|||||||
w.Header().Set("X-Session-ID", sessionID)
|
w.Header().Set("X-Session-ID", sessionID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AuditAuthFailure(r, "bearer_token", err.Error())
|
||||||
http.Error(w, fmt.Sprintf("Bearer Token Authentication failed: %v", err), http.StatusUnauthorized)
|
http.Error(w, fmt.Sprintf("Bearer Token Authentication failed: %v", err), http.StatusUnauthorized)
|
||||||
uploadErrorsTotal.Inc()
|
uploadErrorsTotal.Inc()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
AuditAuthSuccess(r, claims.User, "bearer_token")
|
||||||
bearerClaims = claims
|
bearerClaims = claims
|
||||||
log.Infof("✅ Bearer token authentication successful: user=%s, file=%s, IP=%s",
|
log.Infof("✅ Bearer token authentication successful: user=%s, file=%s, IP=%s",
|
||||||
claims.User, claims.Filename, getClientIP(r))
|
claims.User, claims.Filename, getClientIP(r))
|
||||||
|
|
||||||
// Add comprehensive response headers for audit logging and client tracking
|
// Add comprehensive response headers for audit logging and client tracking
|
||||||
w.Header().Set("X-Authenticated-User", claims.User)
|
w.Header().Set("X-Authenticated-User", claims.User)
|
||||||
w.Header().Set("X-Auth-Method", "Bearer-Token")
|
w.Header().Set("X-Auth-Method", "Bearer-Token")
|
||||||
@@ -2665,10 +2686,12 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
|
|||||||
_, err := validateJWTFromRequest(r, conf.Security.JWTSecret)
|
_, err := validateJWTFromRequest(r, conf.Security.JWTSecret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("🔴 JWT Authentication failed for IP %s: %v", getClientIP(r), err)
|
log.Warnf("🔴 JWT Authentication failed for IP %s: %v", getClientIP(r), err)
|
||||||
|
AuditAuthFailure(r, "jwt", err.Error())
|
||||||
http.Error(w, fmt.Sprintf("JWT Authentication failed: %v", err), http.StatusUnauthorized)
|
http.Error(w, fmt.Sprintf("JWT Authentication failed: %v", err), http.StatusUnauthorized)
|
||||||
uploadErrorsTotal.Inc()
|
uploadErrorsTotal.Inc()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
AuditAuthSuccess(r, "", "jwt")
|
||||||
log.Infof("✅ JWT authentication successful for upload request: %s", r.URL.Path)
|
log.Infof("✅ JWT authentication successful for upload request: %s", r.URL.Path)
|
||||||
w.Header().Set("X-Auth-Method", "JWT")
|
w.Header().Set("X-Auth-Method", "JWT")
|
||||||
} else {
|
} else {
|
||||||
@@ -2676,10 +2699,12 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
|
|||||||
err := validateHMAC(r, conf.Security.Secret)
|
err := validateHMAC(r, conf.Security.Secret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("🔴 HMAC Authentication failed for IP %s: %v", getClientIP(r), err)
|
log.Warnf("🔴 HMAC Authentication failed for IP %s: %v", getClientIP(r), err)
|
||||||
|
AuditAuthFailure(r, "hmac", err.Error())
|
||||||
http.Error(w, fmt.Sprintf("HMAC Authentication failed: %v", err), http.StatusUnauthorized)
|
http.Error(w, fmt.Sprintf("HMAC Authentication failed: %v", err), http.StatusUnauthorized)
|
||||||
uploadErrorsTotal.Inc()
|
uploadErrorsTotal.Inc()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
AuditAuthSuccess(r, "", "hmac")
|
||||||
log.Infof("✅ HMAC authentication successful for upload request: %s", r.URL.Path)
|
log.Infof("✅ HMAC authentication successful for upload request: %s", r.URL.Path)
|
||||||
w.Header().Set("X-Auth-Method", "HMAC")
|
w.Header().Set("X-Auth-Method", "HMAC")
|
||||||
}
|
}
|
||||||
@@ -2699,30 +2724,30 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Generate new session ID with enhanced entropy
|
// Generate new session ID with enhanced entropy
|
||||||
sessionID = generateSessionID("", "")
|
sessionID = generateSessionID("", "")
|
||||||
}
|
}
|
||||||
|
|
||||||
clientIP := getClientIP(r)
|
clientIP := getClientIP(r)
|
||||||
|
|
||||||
// Detect potential network switching
|
// Detect potential network switching
|
||||||
xForwardedFor := r.Header.Get("X-Forwarded-For")
|
xForwardedFor := r.Header.Get("X-Forwarded-For")
|
||||||
xRealIP := r.Header.Get("X-Real-IP")
|
xRealIP := r.Header.Get("X-Real-IP")
|
||||||
networkSwitchIndicators := xForwardedFor != "" || xRealIP != ""
|
networkSwitchIndicators := xForwardedFor != "" || xRealIP != ""
|
||||||
|
|
||||||
if networkSwitchIndicators {
|
if networkSwitchIndicators {
|
||||||
log.Infof("🔄 Network switching indicators detected: session=%s, client_ip=%s, x_forwarded_for=%s, x_real_ip=%s",
|
log.Infof("🔄 Network switching indicators detected: session=%s, client_ip=%s, x_forwarded_for=%s, x_real_ip=%s",
|
||||||
sessionID, clientIP, xForwardedFor, xRealIP)
|
sessionID, clientIP, xForwardedFor, xRealIP)
|
||||||
w.Header().Set("X-Network-Switch-Detected", "true")
|
w.Header().Set("X-Network-Switch-Detected", "true")
|
||||||
}
|
}
|
||||||
|
|
||||||
clientSession = clientTracker.TrackClientSession(sessionID, clientIP, r)
|
clientSession = clientTracker.TrackClientSession(sessionID, clientIP, r)
|
||||||
|
|
||||||
// Enhanced session response headers for client coordination
|
// Enhanced session response headers for client coordination
|
||||||
w.Header().Set("X-Upload-Session-ID", sessionID)
|
w.Header().Set("X-Upload-Session-ID", sessionID)
|
||||||
w.Header().Set("X-Session-IP-Count", fmt.Sprintf("%d", len(clientSession.ClientIPs)))
|
w.Header().Set("X-Session-IP-Count", fmt.Sprintf("%d", len(clientSession.ClientIPs)))
|
||||||
w.Header().Set("X-Connection-Type", clientSession.ConnectionType)
|
w.Header().Set("X-Connection-Type", clientSession.ConnectionType)
|
||||||
|
|
||||||
log.Infof("🔗 Client session tracking: %s from IP %s (connection: %s, total_ips: %d)",
|
log.Infof("🔗 Client session tracking: %s from IP %s (connection: %s, total_ips: %d)",
|
||||||
sessionID, clientIP, clientSession.ConnectionType, len(clientSession.ClientIPs))
|
sessionID, clientIP, clientSession.ConnectionType, len(clientSession.ClientIPs))
|
||||||
|
|
||||||
// Add user context for Bearer token authentication
|
// Add user context for Bearer token authentication
|
||||||
if bearerClaims != nil {
|
if bearerClaims != nil {
|
||||||
log.Infof("👤 Session associated with XMPP user: %s", bearerClaims.User)
|
log.Infof("👤 Session associated with XMPP user: %s", bearerClaims.User)
|
||||||
@@ -2749,6 +2774,57 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
defer file.Close()
|
defer file.Close()
|
||||||
|
|
||||||
|
// Get user JID for quota and audit tracking
|
||||||
|
var userJID string
|
||||||
|
if bearerClaims != nil {
|
||||||
|
userJID = bearerClaims.User
|
||||||
|
}
|
||||||
|
r.Header.Set("X-User-JID", userJID)
|
||||||
|
r.Header.Set("X-File-Name", header.Filename)
|
||||||
|
|
||||||
|
// Check quota before upload
|
||||||
|
if qm := GetQuotaManager(); qm != nil && qm.config.Enabled && userJID != "" {
|
||||||
|
canUpload, _ := qm.CanUpload(r.Context(), userJID, header.Size)
|
||||||
|
if !canUpload {
|
||||||
|
used, limit, _ := qm.GetUsage(r.Context(), userJID)
|
||||||
|
AuditQuotaExceeded(r, userJID, used, limit, header.Size)
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("X-Quota-Used", fmt.Sprintf("%d", used))
|
||||||
|
w.Header().Set("X-Quota-Limit", fmt.Sprintf("%d", limit))
|
||||||
|
w.WriteHeader(http.StatusRequestEntityTooLarge)
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"error": "quota_exceeded",
|
||||||
|
"message": "Storage quota exceeded",
|
||||||
|
"used": used,
|
||||||
|
"limit": limit,
|
||||||
|
"requested": header.Size,
|
||||||
|
})
|
||||||
|
uploadErrorsTotal.Inc()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Content type validation using magic bytes
|
||||||
|
var fileReader io.Reader = file
|
||||||
|
declaredContentType := header.Header.Get("Content-Type")
|
||||||
|
detectedContentType := declaredContentType
|
||||||
|
|
||||||
|
if validator := GetContentValidator(); validator != nil && validator.config.CheckMagicBytes {
|
||||||
|
validatedReader, detected, validErr := validator.ValidateContent(file, declaredContentType, header.Size)
|
||||||
|
if validErr != nil {
|
||||||
|
if valErr, ok := validErr.(*ValidationError); ok {
|
||||||
|
AuditValidationFailure(r, userJID, header.Filename, declaredContentType, detected, valErr.Code)
|
||||||
|
WriteValidationError(w, valErr)
|
||||||
|
} else {
|
||||||
|
http.Error(w, validErr.Error(), http.StatusUnsupportedMediaType)
|
||||||
|
}
|
||||||
|
uploadErrorsTotal.Inc()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fileReader = validatedReader
|
||||||
|
detectedContentType = detected
|
||||||
|
}
|
||||||
|
|
||||||
// Validate file size against max_upload_size if configured
|
// Validate file size against max_upload_size if configured
|
||||||
if conf.Server.MaxUploadSize != "" {
|
if conf.Server.MaxUploadSize != "" {
|
||||||
maxSizeBytes, err := parseSize(conf.Server.MaxUploadSize)
|
maxSizeBytes, err := parseSize(conf.Server.MaxUploadSize)
|
||||||
@@ -2759,9 +2835,9 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
if header.Size > maxSizeBytes {
|
if header.Size > maxSizeBytes {
|
||||||
log.Warnf("⚠️ File size %s exceeds maximum allowed size %s (IP: %s)",
|
log.Warnf("⚠️ File size %s exceeds maximum allowed size %s (IP: %s)",
|
||||||
formatBytes(header.Size), conf.Server.MaxUploadSize, getClientIP(r))
|
formatBytes(header.Size), conf.Server.MaxUploadSize, getClientIP(r))
|
||||||
http.Error(w, fmt.Sprintf("File size %s exceeds maximum allowed size %s",
|
http.Error(w, fmt.Sprintf("File size %s exceeds maximum allowed size %s",
|
||||||
formatBytes(header.Size), conf.Server.MaxUploadSize), http.StatusRequestEntityTooLarge)
|
formatBytes(header.Size), conf.Server.MaxUploadSize), http.StatusRequestEntityTooLarge)
|
||||||
uploadErrorsTotal.Inc()
|
uploadErrorsTotal.Inc()
|
||||||
return
|
return
|
||||||
@@ -2815,20 +2891,20 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
|
|||||||
uploadsTotal.Inc()
|
uploadsTotal.Inc()
|
||||||
uploadSizeBytes.Observe(float64(existingFileInfo.Size()))
|
uploadSizeBytes.Observe(float64(existingFileInfo.Size()))
|
||||||
filesDeduplicatedTotal.Inc()
|
filesDeduplicatedTotal.Inc()
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.Header().Set("X-Deduplication-Hit", "true")
|
w.Header().Set("X-Deduplication-Hit", "true")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
response := map[string]interface{}{
|
response := map[string]interface{}{
|
||||||
"success": true,
|
"success": true,
|
||||||
"filename": filename,
|
"filename": filename,
|
||||||
"size": existingFileInfo.Size(),
|
"size": existingFileInfo.Size(),
|
||||||
"message": "File already exists (deduplication hit)",
|
"message": "File already exists (deduplication hit)",
|
||||||
"upload_time": duration.String(),
|
"upload_time": duration.String(),
|
||||||
}
|
}
|
||||||
_ = json.NewEncoder(w).Encode(response)
|
_ = json.NewEncoder(w).Encode(response)
|
||||||
|
|
||||||
log.Infof("💾 Deduplication hit: file %s already exists (%s), returning success immediately (IP: %s)",
|
log.Infof("💾 Deduplication hit: file %s already exists (%s), returning success immediately (IP: %s)",
|
||||||
filename, formatBytes(existingFileInfo.Size()), getClientIP(r))
|
filename, formatBytes(existingFileInfo.Size()), getClientIP(r))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -2855,30 +2931,43 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
|
|||||||
uploadCtx = networkManager.RegisterUpload(networkSessionID)
|
uploadCtx = networkManager.RegisterUpload(networkSessionID)
|
||||||
defer networkManager.UnregisterUpload(networkSessionID)
|
defer networkManager.UnregisterUpload(networkSessionID)
|
||||||
log.Infof("🌐 Registered upload with network resilience: session=%s, IP=%s", networkSessionID, getClientIP(r))
|
log.Infof("🌐 Registered upload with network resilience: session=%s, IP=%s", networkSessionID, getClientIP(r))
|
||||||
|
|
||||||
// Add network resilience headers
|
// Add network resilience headers
|
||||||
w.Header().Set("X-Network-Resilience", "enabled")
|
w.Header().Set("X-Network-Resilience", "enabled")
|
||||||
w.Header().Set("X-Upload-Context-ID", networkSessionID)
|
w.Header().Set("X-Upload-Context-ID", networkSessionID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy file content with network resilience support and enhanced progress tracking
|
// Copy file content with network resilience support and enhanced progress tracking
|
||||||
written, err := copyWithNetworkResilience(dst, file, uploadCtx)
|
// Use fileReader which may be wrapped with content validation
|
||||||
|
written, err := copyWithNetworkResilience(dst, fileReader, uploadCtx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("🔴 Error saving file %s (IP: %s, session: %s): %v", filename, getClientIP(r), sessionID, err)
|
log.Errorf("🔴 Error saving file %s (IP: %s, session: %s): %v", filename, getClientIP(r), sessionID, err)
|
||||||
http.Error(w, fmt.Sprintf("Error saving file: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("Error saving file: %v", err), http.StatusInternalServerError)
|
||||||
uploadErrorsTotal.Inc()
|
uploadErrorsTotal.Inc()
|
||||||
// Clean up partial file
|
// Clean up partial file
|
||||||
os.Remove(absFilename)
|
os.Remove(absFilename)
|
||||||
|
// Audit the failure
|
||||||
|
AuditUploadFailure(r, userJID, header.Filename, header.Size, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Update quota after successful upload
|
||||||
|
if qm := GetQuotaManager(); qm != nil && qm.config.Enabled && userJID != "" {
|
||||||
|
if err := qm.RecordUpload(r.Context(), userJID, absFilename, written); err != nil {
|
||||||
|
log.Warnf("⚠️ Failed to update quota for user %s: %v", userJID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Audit successful upload
|
||||||
|
AuditUploadSuccess(r, userJID, filename, written, detectedContentType)
|
||||||
|
|
||||||
// ✅ CRITICAL FIX: Send immediate success response for large files (>1GB)
|
// ✅ CRITICAL FIX: Send immediate success response for large files (>1GB)
|
||||||
// This prevents client timeouts while server does post-processing
|
// This prevents client timeouts while server does post-processing
|
||||||
isLargeFile := header.Size > 1024*1024*1024 // 1GB threshold
|
isLargeFile := header.Size > 1024*1024*1024 // 1GB threshold
|
||||||
|
|
||||||
if isLargeFile {
|
if isLargeFile {
|
||||||
log.Infof("🚀 Large file detected (%s), sending immediate success response", formatBytes(header.Size))
|
log.Infof("🚀 Large file detected (%s), sending immediate success response", formatBytes(header.Size))
|
||||||
|
|
||||||
// Send immediate success response to client
|
// Send immediate success response to client
|
||||||
duration := time.Since(startTime)
|
duration := time.Since(startTime)
|
||||||
uploadDuration.Observe(duration.Seconds())
|
uploadDuration.Observe(duration.Seconds())
|
||||||
@@ -2893,12 +2982,12 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
|
|||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
response := map[string]interface{}{
|
response := map[string]interface{}{
|
||||||
"success": true,
|
"success": true,
|
||||||
"filename": filename,
|
"filename": filename,
|
||||||
"size": written,
|
"size": written,
|
||||||
"duration": duration.String(),
|
"duration": duration.String(),
|
||||||
"client_ip": getClientIP(r),
|
"client_ip": getClientIP(r),
|
||||||
"timestamp": time.Now().Unix(),
|
"timestamp": time.Now().Unix(),
|
||||||
"post_processing": "background",
|
"post_processing": "background",
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2921,7 +3010,7 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
|
|||||||
fmt.Fprintf(w, `{"success": true, "filename": "%s", "size": %d, "post_processing": "background"}`, filename, written)
|
fmt.Fprintf(w, `{"success": true, "filename": "%s", "size": %d, "post_processing": "background"}`, filename, written)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("✅ Immediate response sent for large file %s (%s) in %s from IP %s",
|
log.Infof("✅ Immediate response sent for large file %s (%s) in %s from IP %s",
|
||||||
filename, formatBytes(written), duration, getClientIP(r))
|
filename, formatBytes(written), duration, getClientIP(r))
|
||||||
|
|
||||||
// Process deduplication asynchronously for large files
|
// Process deduplication asynchronously for large files
|
||||||
@@ -2936,7 +3025,7 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
|
|||||||
log.Infof("✅ Background deduplication completed for %s", filename)
|
log.Infof("✅ Background deduplication completed for %s", filename)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add to scan queue for virus scanning if enabled
|
// Add to scan queue for virus scanning if enabled
|
||||||
if conf.ClamAV.ClamAVEnabled && len(conf.ClamAV.ScanFileExtensions) > 0 {
|
if conf.ClamAV.ClamAVEnabled && len(conf.ClamAV.ScanFileExtensions) > 0 {
|
||||||
ext := strings.ToLower(filepath.Ext(header.Filename))
|
ext := strings.ToLower(filepath.Ext(header.Filename))
|
||||||
@@ -2958,7 +3047,7 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2987,10 +3076,10 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
|
|||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
response := map[string]interface{}{
|
response := map[string]interface{}{
|
||||||
"success": true,
|
"success": true,
|
||||||
"filename": filename,
|
"filename": filename,
|
||||||
"size": written,
|
"size": written,
|
||||||
"duration": duration.String(),
|
"duration": duration.String(),
|
||||||
"client_ip": getClientIP(r),
|
"client_ip": getClientIP(r),
|
||||||
"timestamp": time.Now().Unix(),
|
"timestamp": time.Now().Unix(),
|
||||||
}
|
}
|
||||||
@@ -3014,7 +3103,7 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
|
|||||||
fmt.Fprintf(w, `{"success": true, "filename": "%s", "size": %d}`, filename, written)
|
fmt.Fprintf(w, `{"success": true, "filename": "%s", "size": %d}`, filename, written)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("✅ Successfully uploaded %s (%s) in %s from IP %s (session: %s)",
|
log.Infof("✅ Successfully uploaded %s (%s) in %s from IP %s (session: %s)",
|
||||||
filename, formatBytes(written), duration, getClientIP(r), sessionID)
|
filename, formatBytes(written), duration, getClientIP(r), sessionID)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3030,20 +3119,24 @@ func handleDownload(w http.ResponseWriter, r *http.Request) {
|
|||||||
_, err := validateJWTFromRequest(r, conf.Security.JWTSecret)
|
_, err := validateJWTFromRequest(r, conf.Security.JWTSecret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("🔴 JWT Authentication failed for download from IP %s: %v", getClientIP(r), err)
|
log.Warnf("🔴 JWT Authentication failed for download from IP %s: %v", getClientIP(r), err)
|
||||||
|
AuditAuthFailure(r, "jwt", err.Error())
|
||||||
http.Error(w, fmt.Sprintf("JWT Authentication failed: %v", err), http.StatusUnauthorized)
|
http.Error(w, fmt.Sprintf("JWT Authentication failed: %v", err), http.StatusUnauthorized)
|
||||||
downloadErrorsTotal.Inc()
|
downloadErrorsTotal.Inc()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
AuditAuthSuccess(r, "", "jwt")
|
||||||
log.Infof("✅ JWT authentication successful for download request: %s", r.URL.Path)
|
log.Infof("✅ JWT authentication successful for download request: %s", r.URL.Path)
|
||||||
w.Header().Set("X-Auth-Method", "JWT")
|
w.Header().Set("X-Auth-Method", "JWT")
|
||||||
} else {
|
} else {
|
||||||
err := validateHMAC(r, conf.Security.Secret)
|
err := validateHMAC(r, conf.Security.Secret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("🔴 HMAC Authentication failed for download from IP %s: %v", getClientIP(r), err)
|
log.Warnf("🔴 HMAC Authentication failed for download from IP %s: %v", getClientIP(r), err)
|
||||||
|
AuditAuthFailure(r, "hmac", err.Error())
|
||||||
http.Error(w, fmt.Sprintf("HMAC Authentication failed: %v", err), http.StatusUnauthorized)
|
http.Error(w, fmt.Sprintf("HMAC Authentication failed: %v", err), http.StatusUnauthorized)
|
||||||
downloadErrorsTotal.Inc()
|
downloadErrorsTotal.Inc()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
AuditAuthSuccess(r, "", "hmac")
|
||||||
log.Infof("✅ HMAC authentication successful for download request: %s", r.URL.Path)
|
log.Infof("✅ HMAC authentication successful for download request: %s", r.URL.Path)
|
||||||
w.Header().Set("X-Auth-Method", "HMAC")
|
w.Header().Set("X-Auth-Method", "HMAC")
|
||||||
}
|
}
|
||||||
@@ -3060,13 +3153,13 @@ func handleDownload(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Enhanced file path validation and construction
|
// Enhanced file path validation and construction
|
||||||
var absFilename string
|
var absFilename string
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// Use storage path or ISO mount point
|
// Use storage path or ISO mount point
|
||||||
storagePath := conf.Server.StoragePath
|
storagePath := conf.Server.StoragePath
|
||||||
if conf.ISO.Enabled {
|
if conf.ISO.Enabled {
|
||||||
storagePath = conf.ISO.MountPoint
|
storagePath = conf.ISO.MountPoint
|
||||||
}
|
}
|
||||||
|
|
||||||
absFilename, err = sanitizeFilePath(storagePath, filename)
|
absFilename, err = sanitizeFilePath(storagePath, filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("🔴 Invalid file path requested from IP %s: %s, error: %v", getClientIP(r), filename, err)
|
log.Warnf("🔴 Invalid file path requested from IP %s: %s, error: %v", getClientIP(r), filename, err)
|
||||||
@@ -3079,12 +3172,12 @@ func handleDownload(w http.ResponseWriter, r *http.Request) {
|
|||||||
fileInfo, err := os.Stat(absFilename)
|
fileInfo, err := os.Stat(absFilename)
|
||||||
if os.IsNotExist(err) {
|
if os.IsNotExist(err) {
|
||||||
log.Warnf("🔴 File not found: %s (requested by IP %s)", absFilename, getClientIP(r))
|
log.Warnf("🔴 File not found: %s (requested by IP %s)", absFilename, getClientIP(r))
|
||||||
|
|
||||||
// Enhanced 404 response with network switching hints
|
// Enhanced 404 response with network switching hints
|
||||||
w.Header().Set("X-File-Not-Found", "true")
|
w.Header().Set("X-File-Not-Found", "true")
|
||||||
w.Header().Set("X-Client-IP", getClientIP(r))
|
w.Header().Set("X-Client-IP", getClientIP(r))
|
||||||
w.Header().Set("X-Network-Switch-Support", "enabled")
|
w.Header().Set("X-Network-Switch-Support", "enabled")
|
||||||
|
|
||||||
// Check if this might be a network switching issue
|
// Check if this might be a network switching issue
|
||||||
userAgent := r.Header.Get("User-Agent")
|
userAgent := r.Header.Get("User-Agent")
|
||||||
isMobileXMPP := strings.Contains(strings.ToLower(userAgent), "conversations") ||
|
isMobileXMPP := strings.Contains(strings.ToLower(userAgent), "conversations") ||
|
||||||
@@ -3093,13 +3186,13 @@ func handleDownload(w http.ResponseWriter, r *http.Request) {
|
|||||||
strings.Contains(strings.ToLower(userAgent), "android") ||
|
strings.Contains(strings.ToLower(userAgent), "android") ||
|
||||||
strings.Contains(strings.ToLower(userAgent), "mobile") ||
|
strings.Contains(strings.ToLower(userAgent), "mobile") ||
|
||||||
strings.Contains(strings.ToLower(userAgent), "xmpp")
|
strings.Contains(strings.ToLower(userAgent), "xmpp")
|
||||||
|
|
||||||
if isMobileXMPP {
|
if isMobileXMPP {
|
||||||
w.Header().Set("X-Mobile-Client-Detected", "true")
|
w.Header().Set("X-Mobile-Client-Detected", "true")
|
||||||
w.Header().Set("X-Retry-Suggestion", "30") // Suggest retry after 30 seconds
|
w.Header().Set("X-Retry-Suggestion", "30") // Suggest retry after 30 seconds
|
||||||
log.Infof("📱 Mobile XMPP client file not found - may be network switching issue: %s", userAgent)
|
log.Infof("📱 Mobile XMPP client file not found - may be network switching issue: %s", userAgent)
|
||||||
}
|
}
|
||||||
|
|
||||||
http.Error(w, "File not found", http.StatusNotFound)
|
http.Error(w, "File not found", http.StatusNotFound)
|
||||||
downloadErrorsTotal.Inc()
|
downloadErrorsTotal.Inc()
|
||||||
return
|
return
|
||||||
@@ -3126,13 +3219,13 @@ func handleDownload(w http.ResponseWriter, r *http.Request) {
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
if attempt < maxRetries {
|
if attempt < maxRetries {
|
||||||
log.Warnf("⚠️ Attempt %d/%d: Error opening file %s from IP %s: %v (retrying...)",
|
log.Warnf("⚠️ Attempt %d/%d: Error opening file %s from IP %s: %v (retrying...)",
|
||||||
attempt, maxRetries, absFilename, getClientIP(r), err)
|
attempt, maxRetries, absFilename, getClientIP(r), err)
|
||||||
time.Sleep(time.Duration(attempt) * time.Second) // Progressive backoff
|
time.Sleep(time.Duration(attempt) * time.Second) // Progressive backoff
|
||||||
} else {
|
} else {
|
||||||
log.Errorf("🔴 Failed to open file %s after %d attempts from IP %s: %v",
|
log.Errorf("🔴 Failed to open file %s after %d attempts from IP %s: %v",
|
||||||
absFilename, maxRetries, getClientIP(r), err)
|
absFilename, maxRetries, getClientIP(r), err)
|
||||||
http.Error(w, fmt.Sprintf("Error opening file: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("Error opening file: %v", err), http.StatusInternalServerError)
|
||||||
downloadErrorsTotal.Inc()
|
downloadErrorsTotal.Inc()
|
||||||
@@ -3149,7 +3242,7 @@ func handleDownload(w http.ResponseWriter, r *http.Request) {
|
|||||||
w.Header().Set("X-Network-Switch-Support", "enabled")
|
w.Header().Set("X-Network-Switch-Support", "enabled")
|
||||||
w.Header().Set("X-File-Path", filename)
|
w.Header().Set("X-File-Path", filename)
|
||||||
w.Header().Set("X-Download-Start-Time", fmt.Sprintf("%d", time.Now().Unix()))
|
w.Header().Set("X-Download-Start-Time", fmt.Sprintf("%d", time.Now().Unix()))
|
||||||
|
|
||||||
// Add cache control headers for mobile network optimization
|
// Add cache control headers for mobile network optimization
|
||||||
userAgent := r.Header.Get("User-Agent")
|
userAgent := r.Header.Get("User-Agent")
|
||||||
isMobileXMPP := strings.Contains(strings.ToLower(userAgent), "conversations") ||
|
isMobileXMPP := strings.Contains(strings.ToLower(userAgent), "conversations") ||
|
||||||
@@ -3158,7 +3251,7 @@ func handleDownload(w http.ResponseWriter, r *http.Request) {
|
|||||||
strings.Contains(strings.ToLower(userAgent), "android") ||
|
strings.Contains(strings.ToLower(userAgent), "android") ||
|
||||||
strings.Contains(strings.ToLower(userAgent), "mobile") ||
|
strings.Contains(strings.ToLower(userAgent), "mobile") ||
|
||||||
strings.Contains(strings.ToLower(userAgent), "xmpp")
|
strings.Contains(strings.ToLower(userAgent), "xmpp")
|
||||||
|
|
||||||
if isMobileXMPP {
|
if isMobileXMPP {
|
||||||
w.Header().Set("X-Mobile-Client-Detected", "true")
|
w.Header().Set("X-Mobile-Client-Detected", "true")
|
||||||
w.Header().Set("Cache-Control", "public, max-age=86400") // 24 hours cache for mobile
|
w.Header().Set("Cache-Control", "public, max-age=86400") // 24 hours cache for mobile
|
||||||
@@ -3173,7 +3266,7 @@ func handleDownload(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// Track download progress for large files
|
// Track download progress for large files
|
||||||
if fileInfo.Size() > 10*1024*1024 { // Log progress for files > 10MB
|
if fileInfo.Size() > 10*1024*1024 { // Log progress for files > 10MB
|
||||||
log.Infof("📥 Starting download of %s (%.1f MiB) for IP %s",
|
log.Infof("📥 Starting download of %s (%.1f MiB) for IP %s",
|
||||||
filepath.Base(absFilename), float64(fileInfo.Size())/(1024*1024), getClientIP(r))
|
filepath.Base(absFilename), float64(fileInfo.Size())/(1024*1024), getClientIP(r))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3191,8 +3284,11 @@ func handleDownload(w http.ResponseWriter, r *http.Request) {
|
|||||||
downloadDuration.Observe(duration.Seconds())
|
downloadDuration.Observe(duration.Seconds())
|
||||||
downloadsTotal.Inc()
|
downloadsTotal.Inc()
|
||||||
downloadSizeBytes.Observe(float64(n))
|
downloadSizeBytes.Observe(float64(n))
|
||||||
|
|
||||||
log.Infof("✅ Successfully downloaded %s (%s) in %s for IP %s (session complete)",
|
// Audit successful download
|
||||||
|
AuditDownloadSuccess(r, "", filepath.Base(absFilename), n)
|
||||||
|
|
||||||
|
log.Infof("✅ Successfully downloaded %s (%s) in %s for IP %s (session complete)",
|
||||||
filepath.Base(absFilename), formatBytes(n), duration, getClientIP(r))
|
filepath.Base(absFilename), formatBytes(n), duration, getClientIP(r))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3262,7 +3358,7 @@ func handleV3Upload(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
if r.ContentLength > maxSizeBytes {
|
if r.ContentLength > maxSizeBytes {
|
||||||
http.Error(w, fmt.Sprintf("File size %s exceeds maximum allowed size %s",
|
http.Error(w, fmt.Sprintf("File size %s exceeds maximum allowed size %s",
|
||||||
formatBytes(r.ContentLength), conf.Server.MaxUploadSize), http.StatusRequestEntityTooLarge)
|
formatBytes(r.ContentLength), conf.Server.MaxUploadSize), http.StatusRequestEntityTooLarge)
|
||||||
uploadErrorsTotal.Inc()
|
uploadErrorsTotal.Inc()
|
||||||
return
|
return
|
||||||
@@ -3298,7 +3394,7 @@ func handleV3Upload(w http.ResponseWriter, r *http.Request) {
|
|||||||
uploadsTotal.Inc()
|
uploadsTotal.Inc()
|
||||||
uploadSizeBytes.Observe(float64(existingFileInfo.Size()))
|
uploadSizeBytes.Observe(float64(existingFileInfo.Size()))
|
||||||
filesDeduplicatedTotal.Inc()
|
filesDeduplicatedTotal.Inc()
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
response := map[string]interface{}{
|
response := map[string]interface{}{
|
||||||
@@ -3308,8 +3404,8 @@ func handleV3Upload(w http.ResponseWriter, r *http.Request) {
|
|||||||
"message": "File already exists (deduplication hit)",
|
"message": "File already exists (deduplication hit)",
|
||||||
}
|
}
|
||||||
_ = json.NewEncoder(w).Encode(response)
|
_ = json.NewEncoder(w).Encode(response)
|
||||||
|
|
||||||
log.Infof("Deduplication hit: file %s already exists (%s), returning success immediately",
|
log.Infof("Deduplication hit: file %s already exists (%s), returning success immediately",
|
||||||
filename, formatBytes(existingFileInfo.Size()))
|
filename, formatBytes(existingFileInfo.Size()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -3337,10 +3433,10 @@ func handleV3Upload(w http.ResponseWriter, r *http.Request) {
|
|||||||
// ✅ CRITICAL FIX: Send immediate success response for large files (>1GB)
|
// ✅ CRITICAL FIX: Send immediate success response for large files (>1GB)
|
||||||
// This prevents client timeouts while server does post-processing
|
// This prevents client timeouts while server does post-processing
|
||||||
isLargeFile := written > 1024*1024*1024 // 1GB threshold
|
isLargeFile := written > 1024*1024*1024 // 1GB threshold
|
||||||
|
|
||||||
if isLargeFile {
|
if isLargeFile {
|
||||||
log.Infof("🚀 Large file detected (%s), sending immediate success response (v3)", formatBytes(written))
|
log.Infof("🚀 Large file detected (%s), sending immediate success response (v3)", formatBytes(written))
|
||||||
|
|
||||||
// Send immediate success response to client
|
// Send immediate success response to client
|
||||||
duration := time.Since(startTime)
|
duration := time.Since(startTime)
|
||||||
uploadDuration.Observe(duration.Seconds())
|
uploadDuration.Observe(duration.Seconds())
|
||||||
@@ -3355,11 +3451,11 @@ func handleV3Upload(w http.ResponseWriter, r *http.Request) {
|
|||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
response := map[string]interface{}{
|
response := map[string]interface{}{
|
||||||
"success": true,
|
"success": true,
|
||||||
"filename": filename,
|
"filename": filename,
|
||||||
"size": written,
|
"size": written,
|
||||||
"duration": duration.String(),
|
"duration": duration.String(),
|
||||||
"protocol": "v3",
|
"protocol": "v3",
|
||||||
"post_processing": "background",
|
"post_processing": "background",
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3370,7 +3466,7 @@ func handleV3Upload(w http.ResponseWriter, r *http.Request) {
|
|||||||
fmt.Fprintf(w, `{"success": true, "filename": "%s", "size": %d, "post_processing": "background"}`, filename, written)
|
fmt.Fprintf(w, `{"success": true, "filename": "%s", "size": %d, "post_processing": "background"}`, filename, written)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("✅ Immediate response sent for large file %s (%s) in %s via v3 protocol",
|
log.Infof("✅ Immediate response sent for large file %s (%s) in %s via v3 protocol",
|
||||||
filename, formatBytes(written), duration)
|
filename, formatBytes(written), duration)
|
||||||
|
|
||||||
// Process deduplication asynchronously for large files
|
// Process deduplication asynchronously for large files
|
||||||
@@ -3385,7 +3481,7 @@ func handleV3Upload(w http.ResponseWriter, r *http.Request) {
|
|||||||
log.Infof("✅ Background deduplication completed for %s (v3)", filename)
|
log.Infof("✅ Background deduplication completed for %s (v3)", filename)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add to scan queue for virus scanning if enabled
|
// Add to scan queue for virus scanning if enabled
|
||||||
if conf.ClamAV.ClamAVEnabled && len(conf.ClamAV.ScanFileExtensions) > 0 {
|
if conf.ClamAV.ClamAVEnabled && len(conf.ClamAV.ScanFileExtensions) > 0 {
|
||||||
ext := strings.ToLower(filepath.Ext(originalFilename))
|
ext := strings.ToLower(filepath.Ext(originalFilename))
|
||||||
@@ -3407,7 +3503,7 @@ func handleV3Upload(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3462,7 +3558,7 @@ func handleLegacyUpload(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Generate session ID for XMPP multi-upload tracking
|
// Generate session ID for XMPP multi-upload tracking
|
||||||
sessionID = generateUploadSessionID("legacy", r.Header.Get("User-Agent"), getClientIP(r))
|
sessionID = generateUploadSessionID("legacy", r.Header.Get("User-Agent"), getClientIP(r))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set session headers for XMPP client continuation
|
// Set session headers for XMPP client continuation
|
||||||
w.Header().Set("X-Session-ID", sessionID)
|
w.Header().Set("X-Session-ID", sessionID)
|
||||||
w.Header().Set("X-Upload-Session-Timeout", "3600") // 1 hour
|
w.Header().Set("X-Upload-Session-Timeout", "3600") // 1 hour
|
||||||
@@ -3531,7 +3627,7 @@ func handleLegacyUpload(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
if r.ContentLength > maxSizeBytes {
|
if r.ContentLength > maxSizeBytes {
|
||||||
http.Error(w, fmt.Sprintf("File size %s exceeds maximum allowed size %s",
|
http.Error(w, fmt.Sprintf("File size %s exceeds maximum allowed size %s",
|
||||||
formatBytes(r.ContentLength), conf.Server.MaxUploadSize), http.StatusRequestEntityTooLarge)
|
formatBytes(r.ContentLength), conf.Server.MaxUploadSize), http.StatusRequestEntityTooLarge)
|
||||||
uploadErrorsTotal.Inc()
|
uploadErrorsTotal.Inc()
|
||||||
return
|
return
|
||||||
@@ -3582,9 +3678,9 @@ func handleLegacyUpload(w http.ResponseWriter, r *http.Request) {
|
|||||||
uploadsTotal.Inc()
|
uploadsTotal.Inc()
|
||||||
uploadSizeBytes.Observe(float64(existingFileInfo.Size()))
|
uploadSizeBytes.Observe(float64(existingFileInfo.Size()))
|
||||||
filesDeduplicatedTotal.Inc()
|
filesDeduplicatedTotal.Inc()
|
||||||
|
|
||||||
w.WriteHeader(http.StatusCreated) // 201 Created for legacy compatibility
|
w.WriteHeader(http.StatusCreated) // 201 Created for legacy compatibility
|
||||||
log.Infof("Deduplication hit: file %s already exists (%s), returning success immediately",
|
log.Infof("Deduplication hit: file %s already exists (%s), returning success immediately",
|
||||||
filename, formatBytes(existingFileInfo.Size()))
|
filename, formatBytes(existingFileInfo.Size()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -3617,10 +3713,10 @@ func handleLegacyUpload(w http.ResponseWriter, r *http.Request) {
|
|||||||
// ✅ CRITICAL FIX: Send immediate success response for large files (>1GB)
|
// ✅ CRITICAL FIX: Send immediate success response for large files (>1GB)
|
||||||
// This prevents client timeouts while server does post-processing
|
// This prevents client timeouts while server does post-processing
|
||||||
isLargeFile := written > 1024*1024*1024 // 1GB threshold
|
isLargeFile := written > 1024*1024*1024 // 1GB threshold
|
||||||
|
|
||||||
if isLargeFile {
|
if isLargeFile {
|
||||||
log.Infof("🚀 Large file detected (%s), sending immediate success response (legacy)", formatBytes(written))
|
log.Infof("🚀 Large file detected (%s), sending immediate success response (legacy)", formatBytes(written))
|
||||||
|
|
||||||
// Send immediate success response to client
|
// Send immediate success response to client
|
||||||
duration := time.Since(startTime)
|
duration := time.Since(startTime)
|
||||||
uploadDuration.Observe(duration.Seconds())
|
uploadDuration.Observe(duration.Seconds())
|
||||||
@@ -3634,7 +3730,7 @@ func handleLegacyUpload(w http.ResponseWriter, r *http.Request) {
|
|||||||
w.Header().Set("X-Post-Processing", "background")
|
w.Header().Set("X-Post-Processing", "background")
|
||||||
w.WriteHeader(http.StatusCreated)
|
w.WriteHeader(http.StatusCreated)
|
||||||
|
|
||||||
log.Infof("✅ Immediate response sent for large file %s (%s) in %s via legacy protocol",
|
log.Infof("✅ Immediate response sent for large file %s (%s) in %s via legacy protocol",
|
||||||
filename, formatBytes(written), duration)
|
filename, formatBytes(written), duration)
|
||||||
|
|
||||||
// Process deduplication asynchronously for large files
|
// Process deduplication asynchronously for large files
|
||||||
@@ -3649,7 +3745,7 @@ func handleLegacyUpload(w http.ResponseWriter, r *http.Request) {
|
|||||||
log.Infof("✅ Background deduplication completed for %s (legacy)", filename)
|
log.Infof("✅ Background deduplication completed for %s (legacy)", filename)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add to scan queue for virus scanning if enabled
|
// Add to scan queue for virus scanning if enabled
|
||||||
if conf.ClamAV.ClamAVEnabled && len(conf.ClamAV.ScanFileExtensions) > 0 {
|
if conf.ClamAV.ClamAVEnabled && len(conf.ClamAV.ScanFileExtensions) > 0 {
|
||||||
ext := strings.ToLower(filepath.Ext(fileStorePath))
|
ext := strings.ToLower(filepath.Ext(fileStorePath))
|
||||||
@@ -3671,7 +3767,7 @@ func handleLegacyUpload(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
557
cmd/server/quota.go
Normal file
557
cmd/server/quota.go
Normal file
@@ -0,0 +1,557 @@
|
|||||||
|
// quota.go - Per-user storage quota management
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-redis/redis/v8"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// QuotaConfig holds quota configuration
|
||||||
|
type QuotaConfig struct {
|
||||||
|
Enabled bool `toml:"enabled" mapstructure:"enabled"`
|
||||||
|
Default string `toml:"default" mapstructure:"default"` // Default quota (e.g., "100MB")
|
||||||
|
Tracking string `toml:"tracking" mapstructure:"tracking"` // "redis" | "memory"
|
||||||
|
Custom map[string]string `toml:"custom" mapstructure:"custom"` // Custom quotas per JID
|
||||||
|
}
|
||||||
|
|
||||||
|
// QuotaInfo contains quota information for a user
|
||||||
|
type QuotaInfo struct {
|
||||||
|
JID string `json:"jid"`
|
||||||
|
Used int64 `json:"used"`
|
||||||
|
Limit int64 `json:"limit"`
|
||||||
|
Remaining int64 `json:"remaining"`
|
||||||
|
FileCount int64 `json:"file_count"`
|
||||||
|
IsCustom bool `json:"is_custom"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// QuotaExceededError represents a quota exceeded error
|
||||||
|
type QuotaExceededError struct {
|
||||||
|
JID string `json:"jid"`
|
||||||
|
Used int64 `json:"used"`
|
||||||
|
Limit int64 `json:"limit"`
|
||||||
|
Requested int64 `json:"requested"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *QuotaExceededError) Error() string {
|
||||||
|
return fmt.Sprintf("quota exceeded for %s: used %d, limit %d, requested %d",
|
||||||
|
e.JID, e.Used, e.Limit, e.Requested)
|
||||||
|
}
|
||||||
|
|
||||||
|
// QuotaManager handles per-user storage quotas
|
||||||
|
type QuotaManager struct {
|
||||||
|
config *QuotaConfig
|
||||||
|
redisClient *redis.Client
|
||||||
|
defaultQuota int64
|
||||||
|
customQuotas map[string]int64
|
||||||
|
|
||||||
|
// In-memory fallback when Redis is unavailable
|
||||||
|
memoryUsage map[string]int64
|
||||||
|
memoryFiles map[string]map[string]int64 // jid -> filePath -> size
|
||||||
|
mutex sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
quotaManager *QuotaManager
|
||||||
|
quotaOnce sync.Once
|
||||||
|
)
|
||||||
|
|
||||||
|
// Redis key patterns
|
||||||
|
const (
|
||||||
|
quotaUsedKey = "quota:%s:used" // quota:{jid}:used -> int64
|
||||||
|
quotaFilesKey = "quota:%s:files" // quota:{jid}:files -> HASH {path: size}
|
||||||
|
quotaInfoKey = "quota:%s:info" // quota:{jid}:info -> JSON
|
||||||
|
)
|
||||||
|
|
||||||
|
// InitQuotaManager initializes the quota manager
|
||||||
|
func InitQuotaManager(config *QuotaConfig, redisClient *redis.Client) error {
|
||||||
|
var initErr error
|
||||||
|
quotaOnce.Do(func() {
|
||||||
|
quotaManager = &QuotaManager{
|
||||||
|
config: config,
|
||||||
|
redisClient: redisClient,
|
||||||
|
customQuotas: make(map[string]int64),
|
||||||
|
memoryUsage: make(map[string]int64),
|
||||||
|
memoryFiles: make(map[string]map[string]int64),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse default quota
|
||||||
|
if config.Default != "" {
|
||||||
|
quota, err := parseSize(config.Default)
|
||||||
|
if err != nil {
|
||||||
|
initErr = fmt.Errorf("invalid default quota: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
quotaManager.defaultQuota = quota
|
||||||
|
} else {
|
||||||
|
quotaManager.defaultQuota = 100 * 1024 * 1024 // 100MB default
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse custom quotas
|
||||||
|
for jid, quotaStr := range config.Custom {
|
||||||
|
quota, err := parseSize(quotaStr)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Invalid custom quota for %s: %v", jid, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
quotaManager.customQuotas[strings.ToLower(jid)] = quota
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Quota manager initialized: enabled=%v, default=%s, custom=%d users, tracking=%s",
|
||||||
|
config.Enabled, config.Default, len(config.Custom), config.Tracking)
|
||||||
|
})
|
||||||
|
|
||||||
|
return initErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetQuotaManager returns the singleton quota manager
|
||||||
|
func GetQuotaManager() *QuotaManager {
|
||||||
|
return quotaManager
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLimit returns the quota limit for a user
|
||||||
|
func (q *QuotaManager) GetLimit(jid string) int64 {
|
||||||
|
if jid == "" {
|
||||||
|
return q.defaultQuota
|
||||||
|
}
|
||||||
|
|
||||||
|
jidLower := strings.ToLower(jid)
|
||||||
|
if custom, ok := q.customQuotas[jidLower]; ok {
|
||||||
|
return custom
|
||||||
|
}
|
||||||
|
return q.defaultQuota
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUsage returns the current storage usage for a user
|
||||||
|
func (q *QuotaManager) GetUsage(ctx context.Context, jid string) (used, limit int64, err error) {
|
||||||
|
if !q.config.Enabled {
|
||||||
|
return 0, 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
limit = q.GetLimit(jid)
|
||||||
|
|
||||||
|
// Try Redis first
|
||||||
|
if q.redisClient != nil && q.config.Tracking == "redis" {
|
||||||
|
key := fmt.Sprintf(quotaUsedKey, jid)
|
||||||
|
usedStr, err := q.redisClient.Get(ctx, key).Result()
|
||||||
|
if err == redis.Nil {
|
||||||
|
return 0, limit, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Failed to get quota from Redis, falling back to memory: %v", err)
|
||||||
|
} else {
|
||||||
|
used, _ = strconv.ParseInt(usedStr, 10, 64)
|
||||||
|
return used, limit, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to memory
|
||||||
|
q.mutex.RLock()
|
||||||
|
used = q.memoryUsage[jid]
|
||||||
|
q.mutex.RUnlock()
|
||||||
|
|
||||||
|
return used, limit, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetQuotaInfo returns detailed quota information for a user
|
||||||
|
func (q *QuotaManager) GetQuotaInfo(ctx context.Context, jid string) (*QuotaInfo, error) {
|
||||||
|
used, limit, err := q.GetUsage(ctx, jid)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
fileCount := int64(0)
|
||||||
|
|
||||||
|
// Get file count
|
||||||
|
if q.redisClient != nil && q.config.Tracking == "redis" {
|
||||||
|
key := fmt.Sprintf(quotaFilesKey, jid)
|
||||||
|
count, err := q.redisClient.HLen(ctx, key).Result()
|
||||||
|
if err == nil {
|
||||||
|
fileCount = count
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
q.mutex.RLock()
|
||||||
|
if files, ok := q.memoryFiles[jid]; ok {
|
||||||
|
fileCount = int64(len(files))
|
||||||
|
}
|
||||||
|
q.mutex.RUnlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
_, isCustom := q.customQuotas[strings.ToLower(jid)]
|
||||||
|
|
||||||
|
return &QuotaInfo{
|
||||||
|
JID: jid,
|
||||||
|
Used: used,
|
||||||
|
Limit: limit,
|
||||||
|
Remaining: limit - used,
|
||||||
|
FileCount: fileCount,
|
||||||
|
IsCustom: isCustom,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CanUpload checks if a user can upload a file of the given size
|
||||||
|
func (q *QuotaManager) CanUpload(ctx context.Context, jid string, size int64) (bool, error) {
|
||||||
|
if !q.config.Enabled {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
used, limit, err := q.GetUsage(ctx, jid)
|
||||||
|
if err != nil {
|
||||||
|
// On error, allow upload but log warning
|
||||||
|
log.Warnf("Failed to check quota for %s, allowing upload: %v", jid, err)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return used+size <= limit, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordUpload records a file upload for quota tracking
|
||||||
|
func (q *QuotaManager) RecordUpload(ctx context.Context, jid, filePath string, size int64) error {
|
||||||
|
if !q.config.Enabled || jid == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try Redis first with atomic operation
|
||||||
|
if q.redisClient != nil && q.config.Tracking == "redis" {
|
||||||
|
pipe := q.redisClient.TxPipeline()
|
||||||
|
|
||||||
|
usedKey := fmt.Sprintf(quotaUsedKey, jid)
|
||||||
|
filesKey := fmt.Sprintf(quotaFilesKey, jid)
|
||||||
|
|
||||||
|
pipe.IncrBy(ctx, usedKey, size)
|
||||||
|
pipe.HSet(ctx, filesKey, filePath, size)
|
||||||
|
|
||||||
|
_, err := pipe.Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Failed to record upload in Redis: %v", err)
|
||||||
|
} else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to memory
|
||||||
|
q.mutex.Lock()
|
||||||
|
defer q.mutex.Unlock()
|
||||||
|
|
||||||
|
q.memoryUsage[jid] += size
|
||||||
|
|
||||||
|
if q.memoryFiles[jid] == nil {
|
||||||
|
q.memoryFiles[jid] = make(map[string]int64)
|
||||||
|
}
|
||||||
|
q.memoryFiles[jid][filePath] = size
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordDelete records a file deletion for quota tracking
|
||||||
|
func (q *QuotaManager) RecordDelete(ctx context.Context, jid, filePath string, size int64) error {
|
||||||
|
if !q.config.Enabled || jid == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// If size is 0, try to get it from tracking
|
||||||
|
if size == 0 {
|
||||||
|
size = q.getFileSize(ctx, jid, filePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try Redis first
|
||||||
|
if q.redisClient != nil && q.config.Tracking == "redis" {
|
||||||
|
pipe := q.redisClient.TxPipeline()
|
||||||
|
|
||||||
|
usedKey := fmt.Sprintf(quotaUsedKey, jid)
|
||||||
|
filesKey := fmt.Sprintf(quotaFilesKey, jid)
|
||||||
|
|
||||||
|
pipe.DecrBy(ctx, usedKey, size)
|
||||||
|
pipe.HDel(ctx, filesKey, filePath)
|
||||||
|
|
||||||
|
_, err := pipe.Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Failed to record delete in Redis: %v", err)
|
||||||
|
} else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to memory
|
||||||
|
q.mutex.Lock()
|
||||||
|
defer q.mutex.Unlock()
|
||||||
|
|
||||||
|
q.memoryUsage[jid] -= size
|
||||||
|
if q.memoryUsage[jid] < 0 {
|
||||||
|
q.memoryUsage[jid] = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if q.memoryFiles[jid] != nil {
|
||||||
|
delete(q.memoryFiles[jid], filePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getFileSize retrieves the size of a tracked file
|
||||||
|
func (q *QuotaManager) getFileSize(ctx context.Context, jid, filePath string) int64 {
|
||||||
|
// Try Redis
|
||||||
|
if q.redisClient != nil && q.config.Tracking == "redis" {
|
||||||
|
key := fmt.Sprintf(quotaFilesKey, jid)
|
||||||
|
sizeStr, err := q.redisClient.HGet(ctx, key, filePath).Result()
|
||||||
|
if err == nil {
|
||||||
|
size, _ := strconv.ParseInt(sizeStr, 10, 64)
|
||||||
|
return size
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try memory
|
||||||
|
q.mutex.RLock()
|
||||||
|
defer q.mutex.RUnlock()
|
||||||
|
|
||||||
|
if files, ok := q.memoryFiles[jid]; ok {
|
||||||
|
return files[filePath]
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetCustomQuota sets a custom quota for a user
|
||||||
|
func (q *QuotaManager) SetCustomQuota(jid string, quota int64) {
|
||||||
|
q.mutex.Lock()
|
||||||
|
defer q.mutex.Unlock()
|
||||||
|
q.customQuotas[strings.ToLower(jid)] = quota
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveCustomQuota removes a custom quota for a user
|
||||||
|
func (q *QuotaManager) RemoveCustomQuota(jid string) {
|
||||||
|
q.mutex.Lock()
|
||||||
|
defer q.mutex.Unlock()
|
||||||
|
delete(q.customQuotas, strings.ToLower(jid))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllQuotas returns quota info for all tracked users
|
||||||
|
func (q *QuotaManager) GetAllQuotas(ctx context.Context) ([]QuotaInfo, error) {
|
||||||
|
var quotas []QuotaInfo
|
||||||
|
|
||||||
|
// Get from Redis
|
||||||
|
if q.redisClient != nil && q.config.Tracking == "redis" {
|
||||||
|
// Scan for all quota keys
|
||||||
|
iter := q.redisClient.Scan(ctx, 0, "quota:*:used", 100).Iterator()
|
||||||
|
for iter.Next(ctx) {
|
||||||
|
key := iter.Val()
|
||||||
|
// Extract JID from key
|
||||||
|
parts := strings.Split(key, ":")
|
||||||
|
if len(parts) >= 2 {
|
||||||
|
jid := parts[1]
|
||||||
|
info, err := q.GetQuotaInfo(ctx, jid)
|
||||||
|
if err == nil {
|
||||||
|
quotas = append(quotas, *info)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return quotas, iter.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get from memory
|
||||||
|
q.mutex.RLock()
|
||||||
|
defer q.mutex.RUnlock()
|
||||||
|
|
||||||
|
for jid, used := range q.memoryUsage {
|
||||||
|
limit := q.GetLimit(jid)
|
||||||
|
fileCount := int64(0)
|
||||||
|
if files, ok := q.memoryFiles[jid]; ok {
|
||||||
|
fileCount = int64(len(files))
|
||||||
|
}
|
||||||
|
_, isCustom := q.customQuotas[strings.ToLower(jid)]
|
||||||
|
|
||||||
|
quotas = append(quotas, QuotaInfo{
|
||||||
|
JID: jid,
|
||||||
|
Used: used,
|
||||||
|
Limit: limit,
|
||||||
|
Remaining: limit - used,
|
||||||
|
FileCount: fileCount,
|
||||||
|
IsCustom: isCustom,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return quotas, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reconcile recalculates quota usage from actual file storage
|
||||||
|
func (q *QuotaManager) Reconcile(ctx context.Context, jid string, files map[string]int64) error {
|
||||||
|
if !q.config.Enabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var totalSize int64
|
||||||
|
for _, size := range files {
|
||||||
|
totalSize += size
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update Redis
|
||||||
|
if q.redisClient != nil && q.config.Tracking == "redis" {
|
||||||
|
usedKey := fmt.Sprintf(quotaUsedKey, jid)
|
||||||
|
filesKey := fmt.Sprintf(quotaFilesKey, jid)
|
||||||
|
|
||||||
|
pipe := q.redisClient.TxPipeline()
|
||||||
|
pipe.Set(ctx, usedKey, totalSize, 0)
|
||||||
|
pipe.Del(ctx, filesKey)
|
||||||
|
|
||||||
|
for path, size := range files {
|
||||||
|
pipe.HSet(ctx, filesKey, path, size)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := pipe.Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to reconcile quota in Redis: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update memory
|
||||||
|
q.mutex.Lock()
|
||||||
|
defer q.mutex.Unlock()
|
||||||
|
|
||||||
|
q.memoryUsage[jid] = totalSize
|
||||||
|
q.memoryFiles[jid] = files
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckQuotaMiddleware is a middleware that checks quota before upload
|
||||||
|
func CheckQuotaMiddleware(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
qm := GetQuotaManager()
|
||||||
|
if qm == nil || !qm.config.Enabled {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only check for upload methods
|
||||||
|
if r.Method != http.MethodPut && r.Method != http.MethodPost {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get JID from context/headers
|
||||||
|
jid := r.Header.Get("X-User-JID")
|
||||||
|
if jid == "" {
|
||||||
|
// Try to get from authorization context
|
||||||
|
if claims, ok := r.Context().Value(contextKey("bearerClaims")).(*BearerTokenClaims); ok {
|
||||||
|
jid = claims.User
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if jid == "" {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check quota
|
||||||
|
ctx := r.Context()
|
||||||
|
canUpload, err := qm.CanUpload(ctx, jid, r.ContentLength)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Error checking quota: %v", err)
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !canUpload {
|
||||||
|
used, limit, _ := qm.GetUsage(ctx, jid)
|
||||||
|
|
||||||
|
// Log to audit
|
||||||
|
AuditQuotaExceeded(r, jid, used, limit, r.ContentLength)
|
||||||
|
|
||||||
|
// Return 413 with quota info
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("X-Quota-Used", strconv.FormatInt(used, 10))
|
||||||
|
w.Header().Set("X-Quota-Limit", strconv.FormatInt(limit, 10))
|
||||||
|
w.Header().Set("X-Quota-Remaining", strconv.FormatInt(limit-used, 10))
|
||||||
|
w.WriteHeader(http.StatusRequestEntityTooLarge)
|
||||||
|
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"error": "quota_exceeded",
|
||||||
|
"message": "Storage quota exceeded",
|
||||||
|
"used": used,
|
||||||
|
"limit": limit,
|
||||||
|
"requested": r.ContentLength,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add quota headers
|
||||||
|
used, limit, _ := qm.GetUsage(ctx, jid)
|
||||||
|
w.Header().Set("X-Quota-Used", strconv.FormatInt(used, 10))
|
||||||
|
w.Header().Set("X-Quota-Limit", strconv.FormatInt(limit, 10))
|
||||||
|
w.Header().Set("X-Quota-Remaining", strconv.FormatInt(limit-used, 10))
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateQuotaAfterUpload updates quota after successful upload
|
||||||
|
func UpdateQuotaAfterUpload(ctx context.Context, jid, filePath string, size int64) {
|
||||||
|
qm := GetQuotaManager()
|
||||||
|
if qm == nil || !qm.config.Enabled || jid == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := qm.RecordUpload(ctx, jid, filePath, size); err != nil {
|
||||||
|
log.WithFields(logrus.Fields{
|
||||||
|
"jid": jid,
|
||||||
|
"file": filePath,
|
||||||
|
"size": size,
|
||||||
|
"error": err,
|
||||||
|
}).Warn("Failed to update quota after upload")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateQuotaAfterDelete updates quota after file deletion
|
||||||
|
func UpdateQuotaAfterDelete(ctx context.Context, jid, filePath string, size int64) {
|
||||||
|
qm := GetQuotaManager()
|
||||||
|
if qm == nil || !qm.config.Enabled || jid == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := qm.RecordDelete(ctx, jid, filePath, size); err != nil {
|
||||||
|
log.WithFields(logrus.Fields{
|
||||||
|
"jid": jid,
|
||||||
|
"file": filePath,
|
||||||
|
"size": size,
|
||||||
|
"error": err,
|
||||||
|
}).Warn("Failed to update quota after delete")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultQuotaConfig returns default quota configuration
|
||||||
|
func DefaultQuotaConfig() QuotaConfig {
|
||||||
|
return QuotaConfig{
|
||||||
|
Enabled: false,
|
||||||
|
Default: "100MB",
|
||||||
|
Tracking: "redis",
|
||||||
|
Custom: make(map[string]string),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartQuotaReconciliation starts a background job to reconcile quotas
|
||||||
|
func StartQuotaReconciliation(interval time.Duration) {
|
||||||
|
if quotaManager == nil || !quotaManager.config.Enabled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for range ticker.C {
|
||||||
|
log.Debug("Running quota reconciliation")
|
||||||
|
// This would scan the storage and update quotas
|
||||||
|
// Implementation depends on how files are tracked
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
340
cmd/server/validation.go
Normal file
340
cmd/server/validation.go
Normal file
@@ -0,0 +1,340 @@
|
|||||||
|
// validation.go - Content type validation using magic bytes
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ValidationConfig holds content validation configuration
|
||||||
|
type ValidationConfig struct {
|
||||||
|
CheckMagicBytes bool `toml:"check_magic_bytes" mapstructure:"check_magic_bytes"`
|
||||||
|
AllowedTypes []string `toml:"allowed_types" mapstructure:"allowed_types"`
|
||||||
|
BlockedTypes []string `toml:"blocked_types" mapstructure:"blocked_types"`
|
||||||
|
MaxFileSize string `toml:"max_file_size" mapstructure:"max_file_size"`
|
||||||
|
StrictMode bool `toml:"strict_mode" mapstructure:"strict_mode"` // Reject if type can't be detected
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidationResult contains the result of content validation
|
||||||
|
type ValidationResult struct {
|
||||||
|
Valid bool `json:"valid"`
|
||||||
|
DetectedType string `json:"detected_type"`
|
||||||
|
DeclaredType string `json:"declared_type,omitempty"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
Reason string `json:"reason,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidationError represents a validation failure
|
||||||
|
type ValidationError struct {
|
||||||
|
Code string `json:"error"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
DetectedType string `json:"detected_type"`
|
||||||
|
DeclaredType string `json:"declared_type,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ValidationError) Error() string {
|
||||||
|
return e.Message
|
||||||
|
}
|
||||||
|
|
||||||
|
// ContentValidator handles content type validation
|
||||||
|
type ContentValidator struct {
|
||||||
|
config *ValidationConfig
|
||||||
|
allowedTypes map[string]bool
|
||||||
|
blockedTypes map[string]bool
|
||||||
|
wildcardAllow []string
|
||||||
|
wildcardBlock []string
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
contentValidator *ContentValidator
|
||||||
|
validatorOnce sync.Once
|
||||||
|
)
|
||||||
|
|
||||||
|
// InitContentValidator initializes the content validator
|
||||||
|
func InitContentValidator(config *ValidationConfig) {
|
||||||
|
validatorOnce.Do(func() {
|
||||||
|
contentValidator = &ContentValidator{
|
||||||
|
config: config,
|
||||||
|
allowedTypes: make(map[string]bool),
|
||||||
|
blockedTypes: make(map[string]bool),
|
||||||
|
wildcardAllow: []string{},
|
||||||
|
wildcardBlock: []string{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process allowed types
|
||||||
|
for _, t := range config.AllowedTypes {
|
||||||
|
t = strings.ToLower(strings.TrimSpace(t))
|
||||||
|
if strings.HasSuffix(t, "/*") {
|
||||||
|
contentValidator.wildcardAllow = append(contentValidator.wildcardAllow, strings.TrimSuffix(t, "/*"))
|
||||||
|
} else {
|
||||||
|
contentValidator.allowedTypes[t] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process blocked types
|
||||||
|
for _, t := range config.BlockedTypes {
|
||||||
|
t = strings.ToLower(strings.TrimSpace(t))
|
||||||
|
if strings.HasSuffix(t, "/*") {
|
||||||
|
contentValidator.wildcardBlock = append(contentValidator.wildcardBlock, strings.TrimSuffix(t, "/*"))
|
||||||
|
} else {
|
||||||
|
contentValidator.blockedTypes[t] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Content validator initialized: magic_bytes=%v, allowed=%d types, blocked=%d types",
|
||||||
|
config.CheckMagicBytes, len(config.AllowedTypes), len(config.BlockedTypes))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetContentValidator returns the singleton content validator
|
||||||
|
func GetContentValidator() *ContentValidator {
|
||||||
|
return contentValidator
|
||||||
|
}
|
||||||
|
|
||||||
|
// isTypeAllowed checks if a content type is in the allowed list
|
||||||
|
func (v *ContentValidator) isTypeAllowed(contentType string) bool {
|
||||||
|
contentType = strings.ToLower(contentType)
|
||||||
|
|
||||||
|
// Extract main type (before any parameters like charset)
|
||||||
|
if idx := strings.Index(contentType, ";"); idx != -1 {
|
||||||
|
contentType = strings.TrimSpace(contentType[:idx])
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no allowed types configured, allow all (except blocked)
|
||||||
|
if len(v.allowedTypes) == 0 && len(v.wildcardAllow) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check exact match
|
||||||
|
if v.allowedTypes[contentType] {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check wildcard patterns
|
||||||
|
for _, prefix := range v.wildcardAllow {
|
||||||
|
if strings.HasPrefix(contentType, prefix+"/") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// isTypeBlocked checks if a content type is in the blocked list
|
||||||
|
func (v *ContentValidator) isTypeBlocked(contentType string) bool {
|
||||||
|
contentType = strings.ToLower(contentType)
|
||||||
|
|
||||||
|
// Extract main type (before any parameters)
|
||||||
|
if idx := strings.Index(contentType, ";"); idx != -1 {
|
||||||
|
contentType = strings.TrimSpace(contentType[:idx])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check exact match
|
||||||
|
if v.blockedTypes[contentType] {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check wildcard patterns
|
||||||
|
for _, prefix := range v.wildcardBlock {
|
||||||
|
if strings.HasPrefix(contentType, prefix+"/") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateContent validates the content type of a reader
|
||||||
|
// Returns a new reader that includes the buffered bytes, the detected type, and any error
|
||||||
|
func (v *ContentValidator) ValidateContent(reader io.Reader, declaredType string, size int64) (io.Reader, string, error) {
|
||||||
|
if v == nil || !v.config.CheckMagicBytes {
|
||||||
|
return reader, declaredType, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read first 512 bytes for magic byte detection
|
||||||
|
buf := make([]byte, 512)
|
||||||
|
n, err := io.ReadFull(reader, buf)
|
||||||
|
if err != nil && err != io.ErrUnexpectedEOF && err != io.EOF {
|
||||||
|
return nil, "", fmt.Errorf("failed to read content for validation: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle small files
|
||||||
|
if n == 0 {
|
||||||
|
if v.config.StrictMode {
|
||||||
|
return nil, "", &ValidationError{
|
||||||
|
Code: "empty_content",
|
||||||
|
Message: "Cannot validate empty content",
|
||||||
|
DetectedType: "",
|
||||||
|
DeclaredType: declaredType,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return bytes.NewReader(buf[:n]), declaredType, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Detect content type using magic bytes
|
||||||
|
detectedType := http.DetectContentType(buf[:n])
|
||||||
|
|
||||||
|
// Normalize detected type
|
||||||
|
if idx := strings.Index(detectedType, ";"); idx != -1 {
|
||||||
|
detectedType = strings.TrimSpace(detectedType[:idx])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if type is blocked (highest priority)
|
||||||
|
if v.isTypeBlocked(detectedType) {
|
||||||
|
return nil, detectedType, &ValidationError{
|
||||||
|
Code: "content_type_blocked",
|
||||||
|
Message: fmt.Sprintf("File type %s is blocked", detectedType),
|
||||||
|
DetectedType: detectedType,
|
||||||
|
DeclaredType: declaredType,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if type is allowed
|
||||||
|
if !v.isTypeAllowed(detectedType) {
|
||||||
|
return nil, detectedType, &ValidationError{
|
||||||
|
Code: "content_type_rejected",
|
||||||
|
Message: fmt.Sprintf("File type %s is not allowed", detectedType),
|
||||||
|
DetectedType: detectedType,
|
||||||
|
DeclaredType: declaredType,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new reader that includes the buffered bytes
|
||||||
|
combinedReader := io.MultiReader(bytes.NewReader(buf[:n]), reader)
|
||||||
|
|
||||||
|
return combinedReader, detectedType, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateContentType validates a content type without reading content
|
||||||
|
func (v *ContentValidator) ValidateContentType(contentType string) error {
|
||||||
|
if v == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if v.isTypeBlocked(contentType) {
|
||||||
|
return &ValidationError{
|
||||||
|
Code: "content_type_blocked",
|
||||||
|
Message: fmt.Sprintf("File type %s is blocked", contentType),
|
||||||
|
DetectedType: contentType,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !v.isTypeAllowed(contentType) {
|
||||||
|
return &ValidationError{
|
||||||
|
Code: "content_type_rejected",
|
||||||
|
Message: fmt.Sprintf("File type %s is not allowed", contentType),
|
||||||
|
DetectedType: contentType,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteValidationError writes a validation error response
|
||||||
|
func WriteValidationError(w http.ResponseWriter, err *ValidationError) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusUnsupportedMediaType)
|
||||||
|
_ = json.NewEncoder(w).Encode(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateUploadContent is a helper function for validating upload content
|
||||||
|
func ValidateUploadContent(r *http.Request, reader io.Reader, declaredType string, size int64) (io.Reader, string, error) {
|
||||||
|
validator := GetContentValidator()
|
||||||
|
if validator == nil || !validator.config.CheckMagicBytes {
|
||||||
|
return reader, declaredType, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
newReader, detectedType, err := validator.ValidateContent(reader, declaredType, size)
|
||||||
|
if err != nil {
|
||||||
|
// Log validation failure to audit
|
||||||
|
jid := r.Header.Get("X-User-JID")
|
||||||
|
fileName := r.Header.Get("X-File-Name")
|
||||||
|
if fileName == "" {
|
||||||
|
fileName = "unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
var reason string
|
||||||
|
if validErr, ok := err.(*ValidationError); ok {
|
||||||
|
reason = validErr.Code
|
||||||
|
} else {
|
||||||
|
reason = err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
AuditValidationFailure(r, jid, fileName, declaredType, detectedType, reason)
|
||||||
|
|
||||||
|
return nil, detectedType, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return newReader, detectedType, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultValidationConfig returns default validation configuration
|
||||||
|
func DefaultValidationConfig() ValidationConfig {
|
||||||
|
return ValidationConfig{
|
||||||
|
CheckMagicBytes: false,
|
||||||
|
AllowedTypes: []string{
|
||||||
|
"image/*",
|
||||||
|
"video/*",
|
||||||
|
"audio/*",
|
||||||
|
"application/pdf",
|
||||||
|
"text/plain",
|
||||||
|
"text/html",
|
||||||
|
"application/json",
|
||||||
|
"application/xml",
|
||||||
|
"application/zip",
|
||||||
|
"application/x-gzip",
|
||||||
|
"application/x-tar",
|
||||||
|
"application/x-7z-compressed",
|
||||||
|
"application/vnd.openxmlformats-officedocument.*",
|
||||||
|
"application/vnd.oasis.opendocument.*",
|
||||||
|
},
|
||||||
|
BlockedTypes: []string{
|
||||||
|
"application/x-executable",
|
||||||
|
"application/x-msdos-program",
|
||||||
|
"application/x-msdownload",
|
||||||
|
"application/x-dosexec",
|
||||||
|
"application/x-sh",
|
||||||
|
"application/x-shellscript",
|
||||||
|
},
|
||||||
|
MaxFileSize: "100MB",
|
||||||
|
StrictMode: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extended MIME type detection for better accuracy
|
||||||
|
var customMagicBytes = map[string][]byte{
|
||||||
|
"application/x-executable": {0x7f, 'E', 'L', 'F'}, // ELF
|
||||||
|
"application/x-msdos-program": {0x4d, 0x5a}, // MZ (DOS/Windows)
|
||||||
|
"application/pdf": {0x25, 0x50, 0x44, 0x46}, // %PDF
|
||||||
|
"application/zip": {0x50, 0x4b, 0x03, 0x04}, // PK
|
||||||
|
"image/png": {0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a}, // PNG
|
||||||
|
"image/jpeg": {0xff, 0xd8, 0xff}, // JPEG
|
||||||
|
"image/gif": {0x47, 0x49, 0x46, 0x38}, // GIF8
|
||||||
|
"image/webp": {0x52, 0x49, 0x46, 0x46}, // RIFF (WebP starts with RIFF)
|
||||||
|
"video/mp4": {0x00, 0x00, 0x00}, // MP4 (variable, check ftyp)
|
||||||
|
"audio/mpeg": {0xff, 0xfb}, // MP3
|
||||||
|
"audio/ogg": {0x4f, 0x67, 0x67, 0x53}, // OggS
|
||||||
|
}
|
||||||
|
|
||||||
|
// DetectContentTypeExtended provides extended content type detection
|
||||||
|
func DetectContentTypeExtended(data []byte) string {
|
||||||
|
// First try standard detection
|
||||||
|
detected := http.DetectContentType(data)
|
||||||
|
|
||||||
|
// If generic, try custom detection
|
||||||
|
if detected == "application/octet-stream" {
|
||||||
|
for mimeType, magic := range customMagicBytes {
|
||||||
|
if len(data) >= len(magic) && bytes.Equal(data[:len(magic)], magic) {
|
||||||
|
return mimeType
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return detected
|
||||||
|
}
|
||||||
162
templates/config-enhanced-features.toml
Normal file
162
templates/config-enhanced-features.toml
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
# HMAC File Server 3.3.0 "Nexus Infinitum" Configuration
|
||||||
|
# Enhanced Features Template: Audit Logging, Content Validation, Quotas, Admin API
|
||||||
|
# Generated on: January 2025
|
||||||
|
|
||||||
|
[server]
|
||||||
|
listen_address = "8080"
|
||||||
|
storage_path = "/opt/hmac-file-server/data/uploads"
|
||||||
|
metrics_enabled = true
|
||||||
|
metrics_port = "9090"
|
||||||
|
pid_file = "/opt/hmac-file-server/data/hmac-file-server.pid"
|
||||||
|
max_upload_size = "10GB"
|
||||||
|
deduplication_enabled = true
|
||||||
|
min_free_bytes = "1GB"
|
||||||
|
file_naming = "original"
|
||||||
|
enable_dynamic_workers = true
|
||||||
|
|
||||||
|
[security]
|
||||||
|
secret = "CHANGE-THIS-SECRET-KEY-MINIMUM-32-CHARACTERS"
|
||||||
|
enablejwt = false
|
||||||
|
|
||||||
|
[uploads]
|
||||||
|
allowedextensions = [".txt", ".pdf", ".jpg", ".jpeg", ".png", ".gif", ".webp", ".zip", ".tar", ".gz", ".7z", ".mp4", ".webm", ".ogg", ".mp3", ".wav", ".flac", ".doc", ".docx", ".xls", ".xlsx", ".ppt", ".pptx", ".odt", ".ods", ".odp"]
|
||||||
|
maxfilesize = "100MB"
|
||||||
|
chunkeduploadsenabled = true
|
||||||
|
chunksize = "10MB"
|
||||||
|
networkevents = true
|
||||||
|
|
||||||
|
[downloads]
|
||||||
|
chunkeddownloadsenabled = true
|
||||||
|
chunksize = "10MB"
|
||||||
|
|
||||||
|
[logging]
|
||||||
|
level = "INFO"
|
||||||
|
file = "/opt/hmac-file-server/data/logs/hmac-file-server.log"
|
||||||
|
max_size = 100
|
||||||
|
max_backups = 3
|
||||||
|
max_age = 30
|
||||||
|
compress = true
|
||||||
|
|
||||||
|
[workers]
|
||||||
|
numworkers = 10
|
||||||
|
uploadqueuesize = 1000
|
||||||
|
autoscaling = true
|
||||||
|
|
||||||
|
[timeouts]
|
||||||
|
readtimeout = "30s"
|
||||||
|
writetimeout = "30s"
|
||||||
|
idletimeout = "120s"
|
||||||
|
shutdown = "30s"
|
||||||
|
|
||||||
|
[clamav]
|
||||||
|
enabled = false
|
||||||
|
|
||||||
|
[redis]
|
||||||
|
enabled = true
|
||||||
|
address = "127.0.0.1:6379"
|
||||||
|
db = 0
|
||||||
|
|
||||||
|
# ============================================
|
||||||
|
# NEW ENHANCED FEATURES (v3.3.0)
|
||||||
|
# ============================================
|
||||||
|
|
||||||
|
# Security Audit Logging
|
||||||
|
# Records security-relevant events for compliance and forensics
|
||||||
|
[audit]
|
||||||
|
enabled = true
|
||||||
|
output = "file" # "file" or "stdout"
|
||||||
|
path = "/var/log/hmac-audit.log" # Log file path (when output = "file")
|
||||||
|
format = "json" # "json" or "text"
|
||||||
|
max_size = 100 # Max size in MB before rotation
|
||||||
|
max_age = 30 # Max age in days
|
||||||
|
events = [
|
||||||
|
"upload", # Log all file uploads
|
||||||
|
"download", # Log all file downloads
|
||||||
|
"delete", # Log file deletions
|
||||||
|
"auth_success", # Log successful authentications
|
||||||
|
"auth_failure", # Log failed authentications
|
||||||
|
"rate_limited", # Log rate limiting events
|
||||||
|
"banned", # Log ban events
|
||||||
|
"quota_exceeded", # Log quota exceeded events
|
||||||
|
"validation_failure" # Log content validation failures
|
||||||
|
]
|
||||||
|
|
||||||
|
# Magic Bytes Content Validation
|
||||||
|
# Validates uploaded file content types using magic bytes detection
|
||||||
|
[validation]
|
||||||
|
check_magic_bytes = true # Enable magic bytes validation
|
||||||
|
strict_mode = false # Strict mode rejects mismatched types
|
||||||
|
max_peek_size = 65536 # Bytes to read for detection (64KB)
|
||||||
|
|
||||||
|
# Allowed content types (supports wildcards like "image/*")
|
||||||
|
# If empty, all types are allowed (except blocked)
|
||||||
|
allowed_types = [
|
||||||
|
"image/*", # All image types
|
||||||
|
"video/*", # All video types
|
||||||
|
"audio/*", # All audio types
|
||||||
|
"text/plain", # Plain text
|
||||||
|
"application/pdf", # PDF documents
|
||||||
|
"application/zip", # ZIP archives
|
||||||
|
"application/gzip", # GZIP archives
|
||||||
|
"application/x-tar", # TAR archives
|
||||||
|
"application/x-7z-compressed", # 7-Zip archives
|
||||||
|
"application/vnd.openxmlformats-officedocument.*", # MS Office docs
|
||||||
|
"application/vnd.oasis.opendocument.*" # LibreOffice docs
|
||||||
|
]
|
||||||
|
|
||||||
|
# Blocked content types (takes precedence over allowed)
|
||||||
|
blocked_types = [
|
||||||
|
"application/x-executable", # Executable files
|
||||||
|
"application/x-msdos-program", # DOS executables
|
||||||
|
"application/x-msdownload", # Windows executables
|
||||||
|
"application/x-elf", # ELF binaries
|
||||||
|
"application/x-shellscript", # Shell scripts
|
||||||
|
"application/javascript", # JavaScript files
|
||||||
|
"text/html", # HTML files (potential XSS)
|
||||||
|
"application/x-php" # PHP files
|
||||||
|
]
|
||||||
|
|
||||||
|
# Per-User Storage Quotas
|
||||||
|
# Track and enforce storage limits per XMPP JID
|
||||||
|
[quotas]
|
||||||
|
enabled = true # Enable quota enforcement
|
||||||
|
default = "100MB" # Default quota for all users
|
||||||
|
tracking = "redis" # "redis" or "memory"
|
||||||
|
|
||||||
|
# Custom quotas per user (JID -> quota)
|
||||||
|
[quotas.custom]
|
||||||
|
"admin@example.com" = "10GB" # Admin gets 10GB
|
||||||
|
"premium@example.com" = "1GB" # Premium user gets 1GB
|
||||||
|
"vip@example.com" = "5GB" # VIP user gets 5GB
|
||||||
|
|
||||||
|
# Admin API for Operations and Monitoring
|
||||||
|
# Protected endpoints for system management
|
||||||
|
[admin]
|
||||||
|
enabled = true # Enable admin API
|
||||||
|
path_prefix = "/admin" # URL prefix for admin endpoints
|
||||||
|
|
||||||
|
# Available endpoints (when enabled):
|
||||||
|
# GET /admin/stats - Server statistics and metrics
|
||||||
|
# GET /admin/files - List all uploaded files
|
||||||
|
# GET /admin/files/:id - Get file details
|
||||||
|
# DEL /admin/files/:id - Delete a file
|
||||||
|
# GET /admin/users - List users and quota usage
|
||||||
|
# GET /admin/users/:jid - Get user details and quota
|
||||||
|
# POST /admin/users/:jid/quota - Set user quota
|
||||||
|
# GET /admin/bans - List banned IPs/users
|
||||||
|
# POST /admin/bans - Ban an IP or user
|
||||||
|
# DEL /admin/bans/:id - Unban
|
||||||
|
|
||||||
|
# Admin authentication
|
||||||
|
[admin.auth]
|
||||||
|
type = "bearer" # "bearer" or "basic"
|
||||||
|
token = "${ADMIN_TOKEN}" # Bearer token (from environment variable)
|
||||||
|
# For basic auth:
|
||||||
|
# type = "basic"
|
||||||
|
# username = "admin"
|
||||||
|
# password_hash = "$2a$12$..." # bcrypt hash
|
||||||
|
|
||||||
|
# Rate limiting for admin endpoints
|
||||||
|
[admin.rate_limit]
|
||||||
|
enabled = true
|
||||||
|
requests_per_minute = 60 # Max requests per minute per IP
|
||||||
Reference in New Issue
Block a user