Implement network switching improvements for HMAC file server

- Added support for chunked and resumable uploads to enhance resilience against network interruptions.
- Introduced a new upload session management system to track progress and handle retries.
- Enhanced connection management with improved timeout settings for mobile networks.
- Implemented network change detection and handling to pause and resume uploads seamlessly.
- Developed client-side retry logic for uploads to improve reliability.
- Updated configuration options to enable new features and set recommended defaults for timeouts and chunk sizes.
- Created integration layer to add new features without modifying existing core functionality.
- Established a network resilience manager to monitor network changes and manage active uploads.
This commit is contained in:
2025-07-17 18:00:14 +02:00
parent 71a73bc514
commit 6a90fa6e30
8 changed files with 1553 additions and 1 deletions

View File

@ -0,0 +1,304 @@
// chunked_upload_handler.go - New chunked upload handler without modifying existing ones
package main
import (
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"time"
)
// Global upload session store
var uploadSessionStore *UploadSessionStore
// handleChunkedUpload handles chunked/resumable uploads
func handleChunkedUpload(w http.ResponseWriter, r *http.Request) {
startTime := time.Now()
activeConnections.Inc()
defer activeConnections.Dec()
// Only allow PUT and POST methods
if r.Method != http.MethodPut && r.Method != http.MethodPost {
http.Error(w, "Method not allowed for chunked uploads", http.StatusMethodNotAllowed)
uploadErrorsTotal.Inc()
return
}
// Extract headers for chunked upload
sessionID := r.Header.Get("X-Upload-Session-ID")
chunkNumberStr := r.Header.Get("X-Chunk-Number")
totalChunksStr := r.Header.Get("X-Total-Chunks")
contentRange := r.Header.Get("Content-Range")
filename := r.Header.Get("X-Filename")
// Handle session creation for new uploads
if sessionID == "" {
// This is a new upload session
totalSizeStr := r.Header.Get("X-Total-Size")
if totalSizeStr == "" || filename == "" {
http.Error(w, "Missing required headers for new upload session", http.StatusBadRequest)
uploadErrorsTotal.Inc()
return
}
totalSize, err := strconv.ParseInt(totalSizeStr, 10, 64)
if err != nil {
http.Error(w, "Invalid total size", http.StatusBadRequest)
uploadErrorsTotal.Inc()
return
}
// Authentication (reuse existing logic)
if conf.Security.EnableJWT {
_, err := validateJWTFromRequest(r, conf.Security.JWTSecret)
if err != nil {
http.Error(w, fmt.Sprintf("JWT Authentication failed: %v", err), http.StatusUnauthorized)
uploadErrorsTotal.Inc()
return
}
} else {
err := validateHMAC(r, conf.Security.Secret)
if err != nil {
http.Error(w, fmt.Sprintf("HMAC Authentication failed: %v", err), http.StatusUnauthorized)
uploadErrorsTotal.Inc()
return
}
}
// Create new session
clientIP := getClientIP(r)
session := uploadSessionStore.CreateSession(filename, totalSize, clientIP)
// Return session info
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
response := map[string]interface{}{
"session_id": session.ID,
"chunk_size": session.ChunkSize,
"total_chunks": (totalSize + session.ChunkSize - 1) / session.ChunkSize,
}
writeJSONResponse(w, response)
return
}
// Handle chunk upload
session, exists := uploadSessionStore.GetSession(sessionID)
if !exists {
http.Error(w, "Upload session not found", http.StatusNotFound)
uploadErrorsTotal.Inc()
return
}
// Parse chunk number
chunkNumber, err := strconv.Atoi(chunkNumberStr)
if err != nil {
http.Error(w, "Invalid chunk number", http.StatusBadRequest)
uploadErrorsTotal.Inc()
return
}
// Check if chunk already uploaded
if chunkInfo, exists := session.Chunks[chunkNumber]; exists && chunkInfo.Completed {
w.WriteHeader(http.StatusOK)
writeJSONResponse(w, map[string]interface{}{
"message": "Chunk already uploaded",
"chunk": chunkNumber,
})
return
}
// Create chunk file
chunkPath := filepath.Join(session.TempDir, fmt.Sprintf("chunk_%d", chunkNumber))
chunkFile, err := os.Create(chunkPath)
if err != nil {
http.Error(w, fmt.Sprintf("Error creating chunk file: %v", err), http.StatusInternalServerError)
uploadErrorsTotal.Inc()
return
}
defer chunkFile.Close()
// Copy chunk data with progress tracking
written, err := copyChunkWithResilience(chunkFile, r.Body, r.ContentLength, sessionID, chunkNumber)
if err != nil {
http.Error(w, fmt.Sprintf("Error saving chunk: %v", err), http.StatusInternalServerError)
uploadErrorsTotal.Inc()
os.Remove(chunkPath) // Clean up failed chunk
return
}
// Update session
err = uploadSessionStore.UpdateSession(sessionID, chunkNumber, written)
if err != nil {
http.Error(w, fmt.Sprintf("Error updating session: %v", err), http.StatusInternalServerError)
uploadErrorsTotal.Inc()
return
}
// Check if upload is complete
if uploadSessionStore.IsSessionComplete(sessionID) {
// Assemble final file
finalPath, err := uploadSessionStore.AssembleFile(sessionID)
if err != nil {
http.Error(w, fmt.Sprintf("Error assembling file: %v", err), http.StatusInternalServerError)
uploadErrorsTotal.Inc()
return
}
// Handle deduplication if enabled (reuse existing logic)
if conf.Server.DeduplicationEnabled {
// Note: This calls the existing deduplication function without modification
err = handleDeduplication(r.Context(), finalPath)
if err != nil {
log.Warnf("Deduplication failed for %s: %v", finalPath, err)
}
}
// Update metrics
duration := time.Since(startTime)
uploadDuration.Observe(duration.Seconds())
uploadsTotal.Inc()
uploadSizeBytes.Observe(float64(session.TotalSize))
// Return success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
response := map[string]interface{}{
"success": true,
"filename": session.Filename,
"size": session.TotalSize,
"duration": duration.String(),
"completed": true,
}
writeJSONResponse(w, response)
log.Infof("Successfully completed chunked upload %s (%s) in %s",
session.Filename, formatBytes(session.TotalSize), duration)
} else {
// Return partial success
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
response := map[string]interface{}{
"success": true,
"chunk": chunkNumber,
"uploaded_bytes": session.UploadedBytes,
"total_size": session.TotalSize,
"progress": float64(session.UploadedBytes) / float64(session.TotalSize),
"completed": false,
}
writeJSONResponse(w, response)
}
}
// copyChunkWithResilience copies chunk data with network resilience
func copyChunkWithResilience(dst io.Writer, src io.Reader, contentLength int64, sessionID string, chunkNumber int) (int64, error) {
// Register with network resilience manager if available
var uploadCtx *UploadContext
if networkManager != nil {
uploadCtx = networkManager.RegisterUpload(fmt.Sprintf("%s_chunk_%d", sessionID, chunkNumber))
defer networkManager.UnregisterUpload(fmt.Sprintf("%s_chunk_%d", sessionID, chunkNumber))
}
// Use buffered copying with pause/resume capability
bufPtr := bufferPool.Get().(*[]byte)
defer bufferPool.Put(bufPtr)
buf := *bufPtr
var written int64
for {
// Check for pause signals
if uploadCtx != nil {
select {
case <-uploadCtx.PauseChan:
// Wait for resume signal
<-uploadCtx.ResumeChan
default:
// Continue
}
}
// Read chunk of data
n, readErr := src.Read(buf)
if n > 0 {
// Write to destination
w, writeErr := dst.Write(buf[:n])
written += int64(w)
if writeErr != nil {
return written, writeErr
}
}
if readErr != nil {
if readErr == io.EOF {
break
}
return written, readErr
}
}
return written, nil
}
// handleUploadStatus returns the status of an upload session
func handleUploadStatus(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
sessionID := r.URL.Query().Get("session_id")
if sessionID == "" {
http.Error(w, "Missing session_id parameter", http.StatusBadRequest)
return
}
session, exists := uploadSessionStore.GetSession(sessionID)
if !exists {
http.Error(w, "Session not found", http.StatusNotFound)
return
}
w.Header().Set("Content-Type", "application/json")
response := map[string]interface{}{
"session_id": session.ID,
"filename": session.Filename,
"total_size": session.TotalSize,
"uploaded_bytes": session.UploadedBytes,
"progress": float64(session.UploadedBytes) / float64(session.TotalSize),
"completed": uploadSessionStore.IsSessionComplete(sessionID),
"last_activity": session.LastActivity,
"chunks": len(session.Chunks),
}
writeJSONResponse(w, response)
}
// Helper functions
func getClientIP(r *http.Request) string {
// Check X-Forwarded-For header first
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
parts := strings.Split(xff, ",")
return strings.TrimSpace(parts[0])
}
// Check X-Real-IP header
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return xri
}
// Fall back to remote address
host, _, _ := strings.Cut(r.RemoteAddr, ":")
return host
}
func writeJSONResponse(w http.ResponseWriter, data interface{}) {
w.Header().Set("Content-Type", "application/json")
if jsonBytes, err := json.Marshal(data); err == nil {
w.Write(jsonBytes)
} else {
http.Error(w, "Error encoding JSON response", http.StatusInternalServerError)
}
}

View File

@ -602,6 +602,12 @@ func setupRouter() *http.ServeMux {
})
log.Info("HTTP router configured successfully with full protocol support (v, v2, token, v3)")
// Enhance router with network resilience features (non-intrusive)
if conf.Uploads.ChunkedUploadsEnabled {
EnhanceExistingRouter(mux)
}
return mux
}

134
cmd/server/integration.go Normal file
View File

@ -0,0 +1,134 @@
// integration.go - Integration layer to add new features without modifying core
package main
import (
"net/http"
"path/filepath"
"time"
)
// InitializeUploadResilience initializes the upload resilience system
func InitializeUploadResilience() {
// Initialize upload session store
tempDir := filepath.Join(conf.Server.StoragePath, ".upload_sessions")
uploadSessionStore = NewUploadSessionStore(tempDir)
// Initialize network resilience
InitializeNetworkResilience()
log.Info("Upload resilience system initialized")
}
// EnhanceExistingRouter adds new routes without modifying existing setupRouter function
func EnhanceExistingRouter(mux *http.ServeMux) {
// Add chunked upload endpoints
mux.HandleFunc("/upload/chunked", ResilientHTTPHandler(handleChunkedUpload, networkManager))
mux.HandleFunc("/upload/status", handleUploadStatus)
// Wrap existing upload handlers with resilience (optional)
if conf.Uploads.ChunkedUploadsEnabled {
log.Info("Enhanced upload endpoints added:")
log.Info(" POST/PUT /upload/chunked - Chunked/resumable uploads")
log.Info(" GET /upload/status - Upload status check")
}
}
// UpdateConfigurationDefaults suggests better defaults without forcing changes
func UpdateConfigurationDefaults() {
log.Info("Network resilience recommendations:")
// Log current settings vs recommended
recommendations := map[string]string{
"ReadTimeout": "300s (current: " + conf.Timeouts.Read + ")",
"WriteTimeout": "300s (current: " + conf.Timeouts.Write + ")",
"IdleTimeout": "600s (current: " + conf.Timeouts.Idle + ")",
"ChunkSize": "5MB for mobile networks",
"RetryAttempts": "3-5 for network switching scenarios",
}
log.Info("Recommended configuration changes for network switching resilience:")
for setting, recommendation := range recommendations {
log.Infof(" %s: %s", setting, recommendation)
}
}
// MonitorUploadPerformance provides additional metrics without modifying existing metrics
func MonitorUploadPerformance() {
ticker := time.NewTicker(60 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
// Log upload session statistics
if uploadSessionStore != nil {
uploadSessionStore.mutex.RLock()
activeSessionsCount := len(uploadSessionStore.sessions)
uploadSessionStore.mutex.RUnlock()
if activeSessionsCount > 0 {
log.Infof("Active upload sessions: %d", activeSessionsCount)
}
}
// Log network resilience status
if networkManager != nil {
networkManager.mutex.RLock()
activeUploadsCount := len(networkManager.activeUploads)
isPaused := networkManager.isPaused
networkManager.mutex.RUnlock()
if activeUploadsCount > 0 {
status := "active"
if isPaused {
status = "paused"
}
log.Infof("Network resilience: %d uploads %s", activeUploadsCount, status)
}
}
}
}
}
// GetResilienceStatus returns current resilience system status (for monitoring)
func GetResilienceStatus() map[string]interface{} {
status := map[string]interface{}{
"upload_sessions_enabled": uploadSessionStore != nil,
"network_monitoring": networkManager != nil,
"active_sessions": 0,
"active_uploads": 0,
"network_paused": false,
}
if uploadSessionStore != nil {
uploadSessionStore.mutex.RLock()
status["active_sessions"] = len(uploadSessionStore.sessions)
uploadSessionStore.mutex.RUnlock()
}
if networkManager != nil {
networkManager.mutex.RLock()
status["active_uploads"] = len(networkManager.activeUploads)
status["network_paused"] = networkManager.isPaused
networkManager.mutex.RUnlock()
}
return status
}
// Non-intrusive initialization function to be called from main()
func InitializeEnhancements() {
// Only initialize if chunked uploads are enabled
if conf.Uploads.ChunkedUploadsEnabled {
InitializeUploadResilience()
// Start performance monitoring
go MonitorUploadPerformance()
// Log configuration recommendations
UpdateConfigurationDefaults()
} else {
log.Info("Chunked uploads disabled. Enable 'chunkeduploadsenabled = true' for network resilience features")
}
}

View File

@ -140,7 +140,8 @@ type UploadsConfig struct {
ChunkedUploadsEnabled bool `toml:"chunkeduploadsenabled" mapstructure:"chunkeduploadsenabled"`
ChunkSize string `toml:"chunksize" mapstructure:"chunksize"`
ResumableUploadsEnabled bool `toml:"resumableuploadsenabled" mapstructure:"resumableuploadsenabled"`
MaxResumableAge string `toml:"max_resumable_age" mapstructure:"max_resumable_age"`
SessionTimeout string `toml:"sessiontimeout" mapstructure:"sessiontimeout"`
MaxRetries int `toml:"maxretries" mapstructure:"maxretries"`
}
type DownloadsConfig struct {
@ -760,6 +761,9 @@ func main() {
}
log.Infof("Running version: %s", versionString)
// Initialize network resilience features (non-intrusive)
InitializeEnhancements()
log.Infof("Starting HMAC file server %s...", versionString)
if conf.Server.UnixSocket {
socketPath := "/tmp/hmac-file-server.sock" // Use a default socket path since ListenAddress is now a port

View File

@ -0,0 +1,338 @@
// network_resilience.go - Network resilience middleware without modifying core functions
package main
import (
"context"
"net"
"net/http"
"sync"
"time"
)
// NetworkResilienceManager handles network change detection and upload pausing
type NetworkResilienceManager struct {
activeUploads map[string]*UploadContext
mutex sync.RWMutex
isPaused bool
pauseChannel chan bool
resumeChannel chan bool
lastInterfaces []net.Interface
}
// UploadContext tracks active upload state
type UploadContext struct {
SessionID string
PauseChan chan bool
ResumeChan chan bool
CancelChan chan bool
IsPaused bool
}
// NewNetworkResilienceManager creates a new network resilience manager
func NewNetworkResilienceManager() *NetworkResilienceManager {
manager := &NetworkResilienceManager{
activeUploads: make(map[string]*UploadContext),
pauseChannel: make(chan bool, 100),
resumeChannel: make(chan bool, 100),
}
// Start network monitoring if enabled
if conf.Server.NetworkEvents {
go manager.monitorNetworkChanges()
}
return manager
}
// RegisterUpload registers an active upload for pause/resume functionality
func (m *NetworkResilienceManager) RegisterUpload(sessionID string) *UploadContext {
m.mutex.Lock()
defer m.mutex.Unlock()
ctx := &UploadContext{
SessionID: sessionID,
PauseChan: make(chan bool, 1),
ResumeChan: make(chan bool, 1),
CancelChan: make(chan bool, 1),
IsPaused: false,
}
m.activeUploads[sessionID] = ctx
// If currently paused, immediately pause this upload
if m.isPaused {
select {
case ctx.PauseChan <- true:
ctx.IsPaused = true
default:
}
}
return ctx
}
// UnregisterUpload removes an upload from tracking
func (m *NetworkResilienceManager) UnregisterUpload(sessionID string) {
m.mutex.Lock()
defer m.mutex.Unlock()
if ctx, exists := m.activeUploads[sessionID]; exists {
close(ctx.PauseChan)
close(ctx.ResumeChan)
close(ctx.CancelChan)
delete(m.activeUploads, sessionID)
}
}
// PauseAllUploads pauses all active uploads
func (m *NetworkResilienceManager) PauseAllUploads() {
m.mutex.Lock()
defer m.mutex.Unlock()
m.isPaused = true
log.Info("Pausing all active uploads due to network change")
for _, ctx := range m.activeUploads {
if !ctx.IsPaused {
select {
case ctx.PauseChan <- true:
ctx.IsPaused = true
default:
}
}
}
}
// ResumeAllUploads resumes all paused uploads
func (m *NetworkResilienceManager) ResumeAllUploads() {
m.mutex.Lock()
defer m.mutex.Unlock()
m.isPaused = false
log.Info("Resuming all paused uploads")
for _, ctx := range m.activeUploads {
if ctx.IsPaused {
select {
case ctx.ResumeChan <- true:
ctx.IsPaused = false
default:
}
}
}
}
// monitorNetworkChanges monitors for network interface changes
func (m *NetworkResilienceManager) monitorNetworkChanges() {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
// Get initial interface state
m.lastInterfaces, _ = net.Interfaces()
for {
select {
case <-ticker.C:
currentInterfaces, err := net.Interfaces()
if err != nil {
log.Warnf("Failed to get network interfaces: %v", err)
continue
}
if m.hasNetworkChanges(m.lastInterfaces, currentInterfaces) {
log.Info("Network change detected")
m.PauseAllUploads()
// Wait for network stabilization
time.Sleep(2 * time.Second)
m.ResumeAllUploads()
}
m.lastInterfaces = currentInterfaces
}
}
}
// hasNetworkChanges compares interface states to detect changes
func (m *NetworkResilienceManager) hasNetworkChanges(old, new []net.Interface) bool {
if len(old) != len(new) {
return true
}
// Create maps for comparison
oldMap := make(map[string]net.Flags)
newMap := make(map[string]net.Flags)
for _, iface := range old {
if iface.Flags&net.FlagLoopback == 0 { // Skip loopback
oldMap[iface.Name] = iface.Flags
}
}
for _, iface := range new {
if iface.Flags&net.FlagLoopback == 0 { // Skip loopback
newMap[iface.Name] = iface.Flags
}
}
// Check for status changes
for name, oldFlags := range oldMap {
newFlags, exists := newMap[name]
if !exists || (oldFlags&net.FlagUp) != (newFlags&net.FlagUp) {
return true
}
}
// Check for new interfaces
for name := range newMap {
if _, exists := oldMap[name]; !exists {
return true
}
}
return false
}
// ResilientHTTPHandler wraps existing handlers with network resilience
func ResilientHTTPHandler(handler http.HandlerFunc, manager *NetworkResilienceManager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Check for chunked upload headers
sessionID := r.Header.Get("X-Upload-Session-ID")
if sessionID != "" {
// This is a chunked upload, register for pause/resume
uploadCtx := manager.RegisterUpload(sessionID)
defer manager.UnregisterUpload(sessionID)
// Create a context that can be cancelled
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
// Monitor for pause/resume signals in a goroutine
go func() {
for {
select {
case <-uploadCtx.PauseChan:
// Pause by setting a short timeout
log.Debugf("Upload %s paused", sessionID)
// Note: We can't actually pause an ongoing HTTP request,
// but we can ensure the next chunk upload waits
case <-uploadCtx.ResumeChan:
log.Debugf("Upload %s resumed", sessionID)
case <-uploadCtx.CancelChan:
cancel()
return
case <-ctx.Done():
return
}
}
}()
// Pass the context-aware request to the handler
r = r.WithContext(ctx)
}
// Call the original handler
handler(w, r)
}
}
// RetryableUploadWrapper adds retry logic around upload operations
func RetryableUploadWrapper(originalHandler http.HandlerFunc, maxRetries int) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var lastErr error
for attempt := 0; attempt <= maxRetries; attempt++ {
if attempt > 0 {
// Exponential backoff with jitter
delay := time.Duration(attempt*attempt) * time.Second
jitter := time.Duration(float64(delay) * 0.1 * (2*time.Now().UnixNano()%2 - 1))
time.Sleep(delay + jitter)
log.Infof("Retrying upload attempt %d/%d", attempt+1, maxRetries+1)
}
// Create a custom ResponseWriter that captures errors
recorder := &ResponseRecorder{
ResponseWriter: w,
statusCode: 200,
}
// Call the original handler
originalHandler(recorder, r)
// Check if the request was successful
if recorder.statusCode < 400 {
return // Success
}
lastErr = recorder.lastError
// Don't retry on client errors (4xx)
if recorder.statusCode >= 400 && recorder.statusCode < 500 {
break
}
}
// All retries failed
if lastErr != nil {
log.Errorf("Upload failed after %d retries: %v", maxRetries+1, lastErr)
http.Error(w, lastErr.Error(), http.StatusInternalServerError)
} else {
http.Error(w, "Upload failed after retries", http.StatusInternalServerError)
}
}
}
// ResponseRecorder captures response information for retry logic
type ResponseRecorder struct {
http.ResponseWriter
statusCode int
lastError error
}
func (r *ResponseRecorder) WriteHeader(statusCode int) {
r.statusCode = statusCode
r.ResponseWriter.WriteHeader(statusCode)
}
func (r *ResponseRecorder) Write(data []byte) (int, error) {
n, err := r.ResponseWriter.Write(data)
if err != nil {
r.lastError = err
}
return n, err
}
// Enhanced timeout configuration for mobile scenarios
func ConfigureEnhancedTimeouts() {
// These don't modify core functions, just suggest better defaults
log.Info("Applying enhanced timeout configuration for mobile/network switching scenarios")
// Log current timeout settings
log.Infof("Current ReadTimeout: %s", conf.Timeouts.Read)
log.Infof("Current WriteTimeout: %s", conf.Timeouts.Write)
log.Infof("Current IdleTimeout: %s", conf.Timeouts.Idle)
// Suggest better timeouts in logs
log.Info("Recommended timeouts for mobile scenarios:")
log.Info(" ReadTimeout: 300s (5 minutes)")
log.Info(" WriteTimeout: 300s (5 minutes)")
log.Info(" IdleTimeout: 600s (10 minutes)")
log.Info(" Update your configuration file to apply these settings")
}
// Global network resilience manager
var networkManager *NetworkResilienceManager
// InitializeNetworkResilience initializes the network resilience system
func InitializeNetworkResilience() {
networkManager = NewNetworkResilienceManager()
ConfigureEnhancedTimeouts()
log.Info("Network resilience system initialized")
}

View File

@ -0,0 +1,308 @@
// upload_session.go - Resumable upload session management without modifying core functions
package main
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"sync"
"time"
)
// ChunkedUploadSession represents an ongoing upload session
type ChunkedUploadSession struct {
ID string `json:"id"`
Filename string `json:"filename"`
TotalSize int64 `json:"total_size"`
ChunkSize int64 `json:"chunk_size"`
UploadedBytes int64 `json:"uploaded_bytes"`
Chunks map[int]ChunkInfo `json:"chunks"`
LastActivity time.Time `json:"last_activity"`
ClientIP string `json:"client_ip"`
TempDir string `json:"temp_dir"`
Metadata map[string]interface{} `json:"metadata"`
}
// ChunkInfo represents information about an uploaded chunk
type ChunkInfo struct {
Number int `json:"number"`
Size int64 `json:"size"`
Hash string `json:"hash"`
Completed bool `json:"completed"`
}
// UploadSessionStore manages upload sessions
type UploadSessionStore struct {
sessions map[string]*ChunkedUploadSession
mutex sync.RWMutex
tempDir string
}
// NewUploadSessionStore creates a new session store
func NewUploadSessionStore(tempDir string) *UploadSessionStore {
store := &UploadSessionStore{
sessions: make(map[string]*ChunkedUploadSession),
tempDir: tempDir,
}
// Create temp directory if it doesn't exist
os.MkdirAll(tempDir, 0755)
// Start cleanup routine
go store.cleanupExpiredSessions()
return store
}
// CreateSession creates a new upload session
func (s *UploadSessionStore) CreateSession(filename string, totalSize int64, clientIP string) *ChunkedUploadSession {
s.mutex.Lock()
defer s.mutex.Unlock()
sessionID := generateSessionID()
tempDir := filepath.Join(s.tempDir, sessionID)
os.MkdirAll(tempDir, 0755)
session := &ChunkedUploadSession{
ID: sessionID,
Filename: filename,
TotalSize: totalSize,
ChunkSize: getChunkSize(),
UploadedBytes: 0,
Chunks: make(map[int]ChunkInfo),
LastActivity: time.Now(),
ClientIP: clientIP,
TempDir: tempDir,
Metadata: make(map[string]interface{}),
}
s.sessions[sessionID] = session
s.persistSession(session)
return session
}
// GetSession retrieves an existing session
func (s *UploadSessionStore) GetSession(sessionID string) (*ChunkedUploadSession, bool) {
s.mutex.RLock()
defer s.mutex.RUnlock()
session, exists := s.sessions[sessionID]
if !exists {
// Try to load from persistence
session = s.loadSession(sessionID)
if session != nil {
s.sessions[sessionID] = session
exists = true
}
}
return session, exists
}
// UpdateSession updates session progress
func (s *UploadSessionStore) UpdateSession(sessionID string, chunkNumber int, chunkSize int64) error {
s.mutex.Lock()
defer s.mutex.Unlock()
session, exists := s.sessions[sessionID]
if !exists {
return fmt.Errorf("session not found")
}
session.Chunks[chunkNumber] = ChunkInfo{
Number: chunkNumber,
Size: chunkSize,
Completed: true,
}
session.UploadedBytes += chunkSize
session.LastActivity = time.Now()
s.persistSession(session)
return nil
}
// IsSessionComplete checks if all chunks are uploaded
func (s *UploadSessionStore) IsSessionComplete(sessionID string) bool {
session, exists := s.GetSession(sessionID)
if !exists {
return false
}
return session.UploadedBytes >= session.TotalSize
}
// AssembleFile combines all chunks into final file (calls existing upload logic)
func (s *UploadSessionStore) AssembleFile(sessionID string) (string, error) {
session, exists := s.GetSession(sessionID)
if !exists {
return "", fmt.Errorf("session not found")
}
if !s.IsSessionComplete(sessionID) {
return "", fmt.Errorf("upload not complete")
}
// Create final file path
finalPath := filepath.Join(conf.Server.StoragePath, session.Filename)
finalFile, err := os.Create(finalPath)
if err != nil {
return "", err
}
defer finalFile.Close()
// Combine chunks in order
totalChunks := int((session.TotalSize + session.ChunkSize - 1) / session.ChunkSize)
for i := 0; i < totalChunks; i++ {
chunkPath := filepath.Join(session.TempDir, fmt.Sprintf("chunk_%d", i))
chunkFile, err := os.Open(chunkPath)
if err != nil {
return "", err
}
_, err = copyFileContent(finalFile, chunkFile)
chunkFile.Close()
if err != nil {
return "", err
}
}
// Cleanup temp files
s.CleanupSession(sessionID)
return finalPath, nil
}
// CleanupSession removes session and temporary files
func (s *UploadSessionStore) CleanupSession(sessionID string) {
s.mutex.Lock()
defer s.mutex.Unlock()
if session, exists := s.sessions[sessionID]; exists {
os.RemoveAll(session.TempDir)
delete(s.sessions, sessionID)
s.removePersistedSession(sessionID)
}
}
// persistSession saves session to disk/redis
func (s *UploadSessionStore) persistSession(session *ChunkedUploadSession) {
// Try Redis first, fallback to disk
if redisClient != nil && redisConnected {
data, _ := json.Marshal(session)
redisClient.Set(context.Background(), "upload_session:"+session.ID, data, 24*time.Hour)
} else {
// Fallback to disk persistence
sessionFile := filepath.Join(s.tempDir, session.ID+".session")
data, _ := json.Marshal(session)
os.WriteFile(sessionFile, data, 0644)
}
}
// loadSession loads session from disk/redis
func (s *UploadSessionStore) loadSession(sessionID string) *ChunkedUploadSession {
var session ChunkedUploadSession
// Try Redis first
if redisClient != nil && redisConnected {
data, err := redisClient.Get(context.Background(), "upload_session:"+sessionID).Result()
if err == nil {
if json.Unmarshal([]byte(data), &session) == nil {
return &session
}
}
}
// Fallback to disk
sessionFile := filepath.Join(s.tempDir, sessionID+".session")
data, err := os.ReadFile(sessionFile)
if err == nil {
if json.Unmarshal(data, &session) == nil {
return &session
}
}
return nil
}
// removePersistedSession removes persisted session data
func (s *UploadSessionStore) removePersistedSession(sessionID string) {
if redisClient != nil && redisConnected {
redisClient.Del(context.Background(), "upload_session:"+sessionID)
}
sessionFile := filepath.Join(s.tempDir, sessionID+".session")
os.Remove(sessionFile)
}
// cleanupExpiredSessions periodically removes old sessions
func (s *UploadSessionStore) cleanupExpiredSessions() {
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.mutex.Lock()
now := time.Now()
for sessionID, session := range s.sessions {
if now.Sub(session.LastActivity) > 24*time.Hour {
s.CleanupSession(sessionID)
}
}
s.mutex.Unlock()
}
}
}
// Helper functions
func generateSessionID() string {
return fmt.Sprintf("%d_%s", time.Now().Unix(), randomString(16))
}
func getChunkSize() int64 {
// Default 5MB chunks, configurable
if conf.Uploads.ChunkSize != "" {
if size, err := parseSize(conf.Uploads.ChunkSize); err == nil {
return size
}
}
return 5 * 1024 * 1024 // 5MB default
}
func randomString(n int) string {
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
b := make([]byte, n)
for i := range b {
b[i] = charset[time.Now().UnixNano()%int64(len(charset))]
}
return string(b)
}
func copyFileContent(dst, src *os.File) (int64, error) {
// Use the existing buffer pool for efficiency
bufPtr := bufferPool.Get().(*[]byte)
defer bufferPool.Put(bufPtr)
buf := *bufPtr
var written int64
for {
n, err := src.Read(buf)
if n > 0 {
w, werr := dst.Write(buf[:n])
written += int64(w)
if werr != nil {
return written, werr
}
}
if err != nil {
if err.Error() == "EOF" {
break
}
return written, err
}
}
return written, nil
}