From 4ce58fb235d2cdaad43e1c3ff8c48def84f46554 Mon Sep 17 00:00:00 2001 From: Ronni Skansing Date: Tue, 4 Nov 2025 21:42:31 +0100 Subject: [PATCH] extract proxy session management and add clear proxy sessions when updating a proxy config Signed-off-by: Ronni Skansing --- backend/app/server.go | 1 + backend/app/services.go | 116 ++++++------ backend/proxy/proxy.go | 157 ++++++---------- backend/service/proxy.go | 10 ++ backend/service/proxySessionManager.go | 239 +++++++++++++++++++++++++ 5 files changed, 362 insertions(+), 161 deletions(-) create mode 100644 backend/service/proxySessionManager.go diff --git a/backend/app/server.go b/backend/app/server.go index 801b59b..b7ef49a 100644 --- a/backend/app/server.go +++ b/backend/app/server.go @@ -71,6 +71,7 @@ func NewServer( // setup goproxy-based proxy server proxyServer := proxy.NewProxyHandler( logger, + services.ProxySessionManager, repositories.Page, repositories.CampaignRecipient, repositories.Campaign, diff --git a/backend/app/services.go b/backend/app/services.go index f9a3d14..373da4e 100644 --- a/backend/app/services.go +++ b/backend/app/services.go @@ -9,34 +9,35 @@ import ( // Services is a collection of services type Services struct { - Asset *service.Asset - Attachment *service.Attachment - File *service.File - Company *service.Company - InstallSetup *service.InstallSetup - Option *service.Option - Page *service.Page - Proxy *service.Proxy - Session *service.Session - User *service.User - Domain *service.Domain - Recipient *service.Recipient - RecipientGroup *service.RecipientGroup - SMTPConfiguration *service.SMTPConfiguration - Email *service.Email - CampaignTemplate *service.CampaignTemplate - Campaign *service.Campaign - Template *service.Template - APISender *service.APISender - AllowDeny *service.AllowDeny - Webhook *service.Webhook - Identifier *service.Identifier - Version *service.Version - SSO *service.SSO - Update *service.Update - Import *service.Import - Backup *service.Backup - IPAllowList *service.IPAllowListService + Asset *service.Asset + Attachment *service.Attachment + File *service.File + Company *service.Company + InstallSetup *service.InstallSetup + Option *service.Option + Page *service.Page + Proxy *service.Proxy + Session *service.Session + User *service.User + Domain *service.Domain + Recipient *service.Recipient + RecipientGroup *service.RecipientGroup + SMTPConfiguration *service.SMTPConfiguration + Email *service.Email + CampaignTemplate *service.CampaignTemplate + Campaign *service.Campaign + Template *service.Template + APISender *service.APISender + AllowDeny *service.AllowDeny + Webhook *service.Webhook + Identifier *service.Identifier + Version *service.Version + SSO *service.SSO + Update *service.Update + Import *service.Import + Backup *service.Backup + IPAllowList *service.IPAllowListService + ProxySessionManager *service.ProxySessionManager } // NewServices creates a collection of services @@ -157,6 +158,7 @@ func NewServices( FileService: file, TemplateService: templateService, } + proxySessionManager := service.NewProxySessionManager(logger) proxy := &service.Proxy{ Common: common, ProxyRepository: repositories.Proxy, @@ -164,6 +166,7 @@ func NewServices( CampaignRepository: repositories.Campaign, CampaignTemplateService: campaignTemplate, DomainService: domain, + ProxySessionManager: proxySessionManager, } ipAllowListService := service.NewIPAllowListService(logger, repositories.Proxy) email := &service.Email{ @@ -247,33 +250,34 @@ func NewServices( } return &Services{ - Asset: asset, - Attachment: attachment, - Company: companyService, - File: file, - InstallSetup: installSetup, - Option: optionService, - Page: page, - Proxy: proxy, - Session: sessionService, - User: userService, - Domain: domain, - Recipient: recipient, - RecipientGroup: recipientGroup, - SMTPConfiguration: smtpConfiguration, - Email: email, - Template: templateService, - CampaignTemplate: campaignTemplate, - Campaign: campaign, - APISender: apiSender, - AllowDeny: allowDeny, - Webhook: webhook, - Identifier: identifier, - Version: versionService, - SSO: ssoService, - Update: updateService, - Import: importService, - Backup: backupService, - IPAllowList: ipAllowListService, + Asset: asset, + Attachment: attachment, + Company: companyService, + File: file, + InstallSetup: installSetup, + Option: optionService, + Page: page, + Proxy: proxy, + Session: sessionService, + User: userService, + Domain: domain, + Recipient: recipient, + RecipientGroup: recipientGroup, + SMTPConfiguration: smtpConfiguration, + Email: email, + Template: templateService, + CampaignTemplate: campaignTemplate, + Campaign: campaign, + APISender: apiSender, + AllowDeny: allowDeny, + Webhook: webhook, + Identifier: identifier, + Version: versionService, + SSO: ssoService, + Update: updateService, + Import: importService, + Backup: backupService, + IPAllowList: ipAllowListService, + ProxySessionManager: proxySessionManager, } } diff --git a/backend/proxy/proxy.go b/backend/proxy/proxy.go index 838b070..2343dac 100644 --- a/backend/proxy/proxy.go +++ b/backend/proxy/proxy.go @@ -19,7 +19,6 @@ import ( "strconv" "strings" "sync" - "sync/atomic" "time" "github.com/PuerkitoBio/goquery" @@ -71,23 +70,6 @@ var ( MATCH_URL_REGEXP_WITHOUT_SCHEME = regexp.MustCompile(`\b(([A-Za-z0-9-]{1,63}\.)?[A-Za-z0-9]+(-[a-z0-9]+)*\.)+(arpa|root|aero|biz|cat|com|coop|edu|gov|info|int|jobs|mil|mobi|museum|name|net|org|pro|tel|travel|bot|inc|game|xyz|cloud|live|today|online|shop|tech|art|site|wiki|ink|vip|lol|club|click|ac|ad|ae|af|ag|ai|al|am|an|ao|aq|ar|as|at|au|aw|ax|az|ba|bb|bd|be|bf|bg|bh|bi|bj|bm|bn|bo|br|bs|bt|bv|bw|by|bz|ca|cc|cd|cf|cg|ch|ci|ck|cl|cm|cn|co|cr|cu|cv|cx|cy|cz|dev|de|dj|dk|dm|do|dz|ec|ee|eg|er|es|et|eu|fi|fj|fk|fm|fo|fr|ga|gb|gd|ge|gf|gg|gh|gi|gl|gm|gn|gp|gq|gr|gs|gt|gu|gw|gy|hk|hm|hn|hr|ht|hu|id|ie|il|im|in|io|iq|ir|is|it|je|jm|jo|jp|ke|kg|kh|ki|km|kn|kr|kw|ky|kz|la|lb|lc|li|lk|lr|ls|lt|lu|lv|ly|ma|mc|md|mg|mh|mk|ml|mm|mn|mo|mp|mq|mr|ms|mt|mu|mv|mw|mx|my|mz|na|nc|ne|nf|ng|ni|nl|no|np|nr|nu|nz|om|pa|pe|pf|pg|ph|pk|pl|pm|pn|pr|ps|pt|pw|py|qa|re|ro|ru|rw|sa|sb|sc|sd|se|sg|sh|si|sj|sk|sl|sm|sn|so|sr|st|su|sv|sy|sz|tc|td|tf|tg|th|tj|tk|tl|tm|tn|to|tp|tr|tt|tv|tw|tz|ua|ug|uk|um|us|uy|uz|va|vc|ve|vg|vi|vn|vu|wf|ws|ye|yt|yu|za|zm|zw)|([0-9]{1,3}\.{3}[0-9]{1,3})\b`) ) -type ProxySession struct { - ID string - CampaignRecipientID *uuid.UUID - CampaignID *uuid.UUID - RecipientID *uuid.UUID - Campaign *model.Campaign - Domain *database.Domain - TargetDomain string - Config sync.Map // map[string]service.ProxyServiceDomainConfig - CreatedAt time.Time - RequiredCaptures sync.Map // map[string]bool - CapturedData sync.Map // map[string]map[string]string - NextPageType atomic.Value // string - IsComplete atomic.Bool - CookieBundleSubmitted atomic.Bool -} - // RequestContext holds all the context data for a proxy request type RequestContext struct { SessionID string @@ -96,7 +78,7 @@ type RequestContext struct { TargetDomain string Domain *database.Domain ProxyConfig *service.ProxyServiceConfigYAML - Session *ProxySession + Session *service.ProxySession ConfigMap map[string]service.ProxyServiceDomainConfig CampaignRecipientID *uuid.UUID ParamName string @@ -112,9 +94,7 @@ type RequestContext struct { type ProxyHandler struct { logger *zap.SugaredLogger - sessions sync.Map // map[string]*ProxySession - campaignRecipientSessions sync.Map // map[string]string (campaignRecipientID -> sessionID) - urlMappings sync.Map // map[string]string (rewritten URL -> original URL) + SessionManager *service.ProxySessionManager PageRepository *repository.Page CampaignRecipientRepository *repository.CampaignRecipient CampaignRepository *repository.Campaign @@ -131,6 +111,7 @@ type ProxyHandler struct { func NewProxyHandler( logger *zap.SugaredLogger, + sessionManager *service.ProxySessionManager, pageRepo *repository.Page, campaignRecipientRepo *repository.CampaignRecipient, campaignRepo *repository.Campaign, @@ -151,6 +132,7 @@ func NewProxyHandler( return &ProxyHandler{ logger: logger, + SessionManager: sessionManager, PageRepository: pageRepo, CampaignRecipientRepository: campaignRecipientRepo, CampaignRepository: campaignRepo, @@ -458,8 +440,7 @@ func (m *ProxyHandler) processRequestWithContext(req *http.Request, reqCtx *Requ func (m *ProxyHandler) cleanupExistingSession(campaignRecipientID *uuid.UUID, reqURL string) { if existingSessionID := m.findSessionByCampaignRecipient(campaignRecipientID); existingSessionID != "" { - m.sessions.Delete(existingSessionID) - m.campaignRecipientSessions.Delete(campaignRecipientID.String()) + m.SessionManager.DeleteSession(existingSessionID) } } @@ -471,7 +452,7 @@ func (m *ProxyHandler) prepareRequestWithoutSession(req *http.Request, reqCtx *R req.URL.Scheme = "https" // create a dummy session for header normalization (no campaign/session data) - dummySession := &ProxySession{ + dummySession := &service.ProxySession{ Config: sync.Map{}, } // populate dummy config for normalization @@ -603,14 +584,10 @@ func (m *ProxyHandler) resolveSessionContext(req *http.Request, reqCtx *RequestC } } else { // load existing session - sessionVal, exists := m.sessions.Load(reqCtx.SessionID) + session, exists := m.SessionManager.GetSession(reqCtx.SessionID) if !exists { return fmt.Errorf("session not found") } - session, ok := sessionVal.(*ProxySession) - if !ok { - return fmt.Errorf("invalid session type") - } reqCtx.Session = session // copy campaign from session to reqCtx for existing sessions @@ -897,7 +874,7 @@ func (m *ProxyHandler) rewriteResponseHeadersWithContext(resp *http.Response, re } } -func (m *ProxyHandler) applyCustomResponseHeaderReplacements(resp *http.Response, session *ProxySession) { +func (m *ProxyHandler) applyCustomResponseHeaderReplacements(resp *http.Response, session *service.ProxySession) { // get all headers as a string var buf bytes.Buffer resp.Header.Write(&buf) @@ -1220,7 +1197,7 @@ func (m *ProxyHandler) replaceHostWithPhished(hostname string, config map[string func (m *ProxyHandler) createNewSession( req *http.Request, reqCtx *RequestContext, -) (*ProxySession, error) { +) (*service.ProxySession, error) { // use cached campaign data from request context campaign := reqCtx.Campaign recipientID := reqCtx.RecipientID @@ -1235,7 +1212,7 @@ func (m *ProxyHandler) createNewSession( sessionConfig := m.buildSessionConfig(reqCtx.TargetDomain, reqCtx.Domain.Name, reqCtx.ProxyConfig) // create session - session := &ProxySession{ + session := &service.ProxySession{ ID: uuid.New().String(), CampaignRecipientID: campaignRecipientID, CampaignID: campaignID, @@ -1251,9 +1228,9 @@ func (m *ProxyHandler) createNewSession( m.initializeSession(session, sessionConfig) // store session - m.sessions.Store(session.ID, session) + m.SessionManager.StoreSession(session.ID, session) if campaignRecipientID != nil { - m.campaignRecipientSessions.Store(campaignRecipientID.String(), session.ID) + m.SessionManager.StoreCampaignRecipientSession(campaignRecipientID.String(), session.ID) } return session, nil @@ -1310,7 +1287,7 @@ func (m *ProxyHandler) buildSessionConfig(targetDomain, phishDomain string, prox return sessionConfig } -func (m *ProxyHandler) initializeSession(session *ProxySession, sessionConfig map[string]service.ProxyServiceDomainConfig) { +func (m *ProxyHandler) initializeSession(session *service.ProxySession, sessionConfig map[string]service.ProxyServiceDomainConfig) { // store configuration in sync.map for host, config := range sessionConfig { session.Config.Store(host, config) @@ -1330,24 +1307,21 @@ func (m *ProxyHandler) findSessionByCampaignRecipient(campaignRecipientID *uuid. return "" } - sessionIDVal, exists := m.campaignRecipientSessions.Load(campaignRecipientID.String()) + sessionID, exists := m.SessionManager.GetSessionByCampaignRecipient(campaignRecipientID.String()) if !exists { return "" } - sessionID := sessionIDVal.(string) - if sessionVal, sessionExists := m.sessions.Load(sessionID); sessionExists { - if _, ok := sessionVal.(*ProxySession); ok { - return sessionID - } + if _, sessionExists := m.SessionManager.GetSession(sessionID); sessionExists { + return sessionID } // cleanup orphaned mapping - m.campaignRecipientSessions.Delete(campaignRecipientID.String()) + m.SessionManager.DeleteSession(sessionID) return "" } -func (m *ProxyHandler) initializeRequiredCaptures(session *ProxySession) { +func (m *ProxyHandler) initializeRequiredCaptures(session *service.ProxySession) { // only apply capture rules for the current host if hostConfig, ok := session.Config.Load(session.TargetDomain); ok { hCfg := hostConfig.(service.ProxyServiceDomainConfig) @@ -1361,7 +1335,7 @@ func (m *ProxyHandler) initializeRequiredCaptures(session *ProxySession) { } } -func (m *ProxyHandler) onRequestBody(req *http.Request, session *ProxySession) { +func (m *ProxyHandler) onRequestBody(req *http.Request, session *service.ProxySession) { if req.Body == nil { return } @@ -1383,7 +1357,7 @@ func (m *ProxyHandler) onRequestBody(req *http.Request, session *ProxySession) { m.applyRequestBodyReplacements(req, session) } -func (m *ProxyHandler) onRequestHeader(req *http.Request, session *ProxySession) { +func (m *ProxyHandler) onRequestHeader(req *http.Request, session *service.ProxySession) { hostConfig, exists := m.getHostConfig(session, req.Host) if !exists { return @@ -1436,7 +1410,7 @@ func (m *ProxyHandler) onRequestHeader(req *http.Request, session *ProxySession) } } -func (m *ProxyHandler) onResponseBody(resp *http.Response, body []byte, session *ProxySession) { +func (m *ProxyHandler) onResponseBody(resp *http.Response, body []byte, session *service.ProxySession) { originalHost := resp.Request.Host if originalHost == "" { originalHost = session.TargetDomain @@ -1460,7 +1434,7 @@ func (m *ProxyHandler) onResponseBody(resp *http.Response, body []byte, session } } -func (m *ProxyHandler) onResponseCookies(resp *http.Response, session *ProxySession) { +func (m *ProxyHandler) onResponseCookies(resp *http.Response, session *service.ProxySession) { hostConfig, exists := m.getHostConfig(session, resp.Request.Host) if !exists { return @@ -1496,7 +1470,7 @@ func (m *ProxyHandler) onResponseCookies(resp *http.Response, session *ProxySess m.checkAndSubmitCookieBundleWhenComplete(session, resp.Request) } -func (m *ProxyHandler) onResponseHeader(resp *http.Response, session *ProxySession) { +func (m *ProxyHandler) onResponseHeader(resp *http.Response, session *service.ProxySession) { hostConfig, exists := m.getHostConfig(session, resp.Request.Host) if !exists { return @@ -1546,7 +1520,7 @@ func (m *ProxyHandler) matchesPath(capture service.ProxyServiceCaptureRule, req return capture.PathRe.MatchString(req.URL.Path) } -func (m *ProxyHandler) handlePathBasedCapture(capture service.ProxyServiceCaptureRule, session *ProxySession, resp *http.Response) { +func (m *ProxyHandler) handlePathBasedCapture(capture service.ProxyServiceCaptureRule, session *service.ProxySession, resp *http.Response) { // only mark as complete if path AND method match exactly methodMatches := capture.Method == "" || capture.Method == resp.Request.Method pathMatches := m.matchesPath(capture, resp.Request) @@ -1637,7 +1611,7 @@ func (m *ProxyHandler) readRequestBody(req *http.Request) []byte { return body } -func (m *ProxyHandler) captureFromText(text string, capture service.ProxyServiceCaptureRule, session *ProxySession, req *http.Request, captureContext string) { +func (m *ProxyHandler) captureFromText(text string, capture service.ProxyServiceCaptureRule, session *service.ProxySession, req *http.Request, captureContext string) { if capture.Find == "" { return } @@ -1667,7 +1641,7 @@ func (m *ProxyHandler) captureFromText(text string, capture service.ProxyService m.handleCampaignFlowProgression(session, req) } -func (m *ProxyHandler) buildCapturedData(matches []string, capture service.ProxyServiceCaptureRule, session *ProxySession, req *http.Request, captureContext string) map[string]string { +func (m *ProxyHandler) buildCapturedData(matches []string, capture service.ProxyServiceCaptureRule, session *service.ProxySession, req *http.Request, captureContext string) map[string]string { capturedData := make(map[string]string) // add capture name to the captured data @@ -1685,7 +1659,7 @@ func (m *ProxyHandler) buildCapturedData(matches []string, capture service.Proxy return capturedData } -func (m *ProxyHandler) formatCapturedData(capturedData map[string]string, capture service.ProxyServiceCaptureRule, matches []string, session *ProxySession, req *http.Request, captureContext string) { +func (m *ProxyHandler) formatCapturedData(capturedData map[string]string, capture service.ProxyServiceCaptureRule, matches []string, session *service.ProxySession, req *http.Request, captureContext string) { captureName := strings.ToLower(capture.Name) switch { @@ -1713,7 +1687,7 @@ func (m *ProxyHandler) formatCapturedData(capturedData map[string]string, captur } } -func (m *ProxyHandler) checkCaptureCompletion(session *ProxySession, captureName string) { +func (m *ProxyHandler) checkCaptureCompletion(session *service.ProxySession, captureName string) { if _, exists := session.RequiredCaptures.Load(captureName); exists { // only mark as complete if we actually have captured data for this capture if _, hasData := session.CapturedData.Load(captureName); hasData { @@ -1733,7 +1707,7 @@ func (m *ProxyHandler) checkCaptureCompletion(session *ProxySession, captureName } } -func (m *ProxyHandler) checkAndSubmitCookieBundleWhenComplete(session *ProxySession, req *http.Request) { +func (m *ProxyHandler) checkAndSubmitCookieBundleWhenComplete(session *service.ProxySession, req *http.Request) { if session.CampaignRecipientID == nil || session.CampaignID == nil { return } @@ -1757,7 +1731,7 @@ func (m *ProxyHandler) checkAndSubmitCookieBundleWhenComplete(session *ProxySess } } -func (m *ProxyHandler) collectCookieCaptures(session *ProxySession) (map[string]map[string]string, map[string]bool) { +func (m *ProxyHandler) collectCookieCaptures(session *service.ProxySession) (map[string]map[string]string, map[string]bool) { cookieCaptures := make(map[string]map[string]string) requiredCookieCaptures := make(map[string]bool) @@ -1799,7 +1773,7 @@ func (m *ProxyHandler) areAllCookieCapturesComplete(requiredCookieCaptures map[s return true } -func (m *ProxyHandler) createCookieBundle(cookieCaptures map[string]map[string]string, session *ProxySession) map[string]interface{} { +func (m *ProxyHandler) createCookieBundle(cookieCaptures map[string]map[string]string, session *service.ProxySession) map[string]interface{} { bundledData := map[string]interface{}{ "capture_type": "cookie", "cookie_count": len(cookieCaptures), @@ -1817,7 +1791,7 @@ func (m *ProxyHandler) createCookieBundle(cookieCaptures map[string]map[string]s return bundledData } -func (m *ProxyHandler) applyRequestBodyReplacements(req *http.Request, session *ProxySession) { +func (m *ProxyHandler) applyRequestBodyReplacements(req *http.Request, session *service.ProxySession) { if req.Body == nil { return } @@ -1839,7 +1813,7 @@ func (m *ProxyHandler) applyRequestBodyReplacements(req *http.Request, session * req.Body = io.NopCloser(bytes.NewBuffer(body)) } -func (m *ProxyHandler) applyCustomReplacements(body []byte, session *ProxySession) []byte { +func (m *ProxyHandler) applyCustomReplacements(body []byte, session *service.ProxySession) []byte { // only apply rewrite rules for the current host if hostConfig, ok := session.Config.Load(session.TargetDomain); ok { hCfg := hostConfig.(service.ProxyServiceDomainConfig) @@ -2028,7 +2002,7 @@ func (m *ProxyHandler) applyTargetFilter(selection *goquery.Selection, target st return selection } -func (m *ProxyHandler) processCookiesForPhishingDomain(resp *http.Response, ps *ProxySession) { +func (m *ProxyHandler) processCookiesForPhishingDomain(resp *http.Response, ps *service.ProxySession) { cookies := resp.Cookies() if len(cookies) == 0 { return @@ -2053,7 +2027,7 @@ func (m *ProxyHandler) processCookiesForPhishingDomain(resp *http.Response, ps * } } -func (m *ProxyHandler) adjustCookieSettings(ck *http.Cookie, session *ProxySession, resp *http.Response) { +func (m *ProxyHandler) adjustCookieSettings(ck *http.Cookie, session *service.ProxySession, resp *http.Response) { if ck.Secure { ck.SameSite = http.SameSiteNoneMode } else if ck.SameSite == http.SameSiteDefaultMode { @@ -2129,7 +2103,7 @@ func (m *ProxyHandler) getCampaignRecipientIDFromURLParams(req *http.Request) (* } // Header normalization methods -func (m *ProxyHandler) normalizeRequestHeaders(req *http.Request, session *ProxySession) { +func (m *ProxyHandler) normalizeRequestHeaders(req *http.Request, session *service.ProxySession) { configMap := m.configToMap(&session.Config) // fix origin header @@ -2246,7 +2220,7 @@ func (m *ProxyHandler) shouldProcessContent(contentType string) bool { return false } -func (m *ProxyHandler) handleImmediateCampaignRedirect(session *ProxySession, resp *http.Response, req *http.Request, captureLocation string) { +func (m *ProxyHandler) handleImmediateCampaignRedirect(session *service.ProxySession, resp *http.Response, req *http.Request, captureLocation string) { m.handleCampaignFlowProgression(session, req) nextPageType := session.NextPageType.Load().(string) @@ -2268,7 +2242,7 @@ func (m *ProxyHandler) handleImmediateCampaignRedirect(session *ProxySession, re session.NextPageType.Store("") } -func (m *ProxyHandler) handleCampaignFlowProgression(session *ProxySession, req *http.Request) { +func (m *ProxyHandler) handleCampaignFlowProgression(session *service.ProxySession, req *http.Request) { if session.CampaignRecipientID == nil || session.CampaignID == nil { return } @@ -2294,7 +2268,7 @@ func (m *ProxyHandler) handleCampaignFlowProgression(session *ProxySession, req } } -func (m *ProxyHandler) getCurrentPageType(req *http.Request, template *model.CampaignTemplate, session *ProxySession) string { +func (m *ProxyHandler) getCurrentPageType(req *http.Request, template *model.CampaignTemplate, session *service.ProxySession) string { if template.StateIdentifier != nil { stateParamKey := template.StateIdentifier.Name.MustGet() encryptedParam := req.URL.Query().Get(stateParamKey) @@ -2350,12 +2324,12 @@ func (m *ProxyHandler) getNextPageType(currentPageType string, template *model.C } } -func (m *ProxyHandler) shouldRedirectForCampaignFlow(session *ProxySession, req *http.Request) bool { +func (m *ProxyHandler) shouldRedirectForCampaignFlow(session *service.ProxySession, req *http.Request) bool { nextPageTypeStr := session.NextPageType.Load().(string) return nextPageTypeStr != "" && nextPageTypeStr != data.PAGE_TYPE_DONE && session.IsComplete.Load() } -func (m *ProxyHandler) createCampaignFlowRedirect(session *ProxySession, resp *http.Response) *http.Response { +func (m *ProxyHandler) createCampaignFlowRedirect(session *service.ProxySession, resp *http.Response) *http.Response { if resp == nil { return nil } @@ -2381,7 +2355,7 @@ func (m *ProxyHandler) createCampaignFlowRedirect(session *ProxySession, resp *h return redirectResp } -func (m *ProxyHandler) buildCampaignFlowRedirectURL(session *ProxySession, nextPageType string) string { +func (m *ProxyHandler) buildCampaignFlowRedirectURL(session *service.ProxySession, nextPageType string) string { if session.CampaignRecipientID == nil || session.Campaign == nil { return "" } @@ -2469,7 +2443,7 @@ func (m *ProxyHandler) buildCampaignFlowRedirectURL(session *ProxySession, nextP return targetURL } -func (m *ProxyHandler) createCampaignSubmitEvent(session *ProxySession, capturedData interface{}, req *http.Request) { +func (m *ProxyHandler) createCampaignSubmitEvent(session *service.ProxySession, capturedData interface{}, req *http.Request) { if session.CampaignID == nil || session.CampaignRecipientID == nil { return } @@ -2746,35 +2720,8 @@ func (m *ProxyHandler) createResponseFromRule(rule service.ProxyServiceResponseR } func (m *ProxyHandler) CleanupExpiredSessions() { - now := time.Now() - cleanedCount := 0 - - m.sessions.Range(func(key, value interface{}) bool { - sessionID, ok := key.(string) - if !ok { - return true - } - session, ok := value.(*ProxySession) - if !ok { - m.sessions.Delete(sessionID) - cleanedCount++ - return true - } - - sessionAge := now.Sub(session.CreatedAt) - if sessionAge > time.Duration(PROXY_COOKIE_MAX_AGE)*time.Second { - m.sessions.Delete(sessionID) - if session.CampaignRecipientID != nil { - m.campaignRecipientSessions.Delete(session.CampaignRecipientID.String()) - } - cleanedCount++ - } - return true - }) - - if cleanedCount > 0 { - m.logger.Debugw("cleaned up expired sessions", "count", cleanedCount) - } + // cleanup expired sessions + m.SessionManager.CleanupExpiredSessions(time.Duration(PROXY_COOKIE_MAX_AGE) * time.Second) // cleanup expired IP allow listed entries ipCleanedCount := m.IPAllowListService.ClearExpired() @@ -2816,7 +2763,7 @@ func (m *ProxyHandler) isValidSessionCookie(cookie string) bool { if cookie == "" { return false } - _, exists := m.sessions.Load(cookie) + _, exists := m.SessionManager.GetSession(cookie) return exists } @@ -3065,7 +3012,7 @@ func (m *ProxyHandler) createStatusResponse(statusCode int) *http.Response { } // registerPageVisitEvent registers a page visit event when a new MITM session is created -func (m *ProxyHandler) registerPageVisitEvent(req *http.Request, session *ProxySession) { +func (m *ProxyHandler) registerPageVisitEvent(req *http.Request, session *service.ProxySession) { if session.CampaignRecipientID == nil || session.CampaignID == nil || session.RecipientID == nil { return } @@ -3372,7 +3319,7 @@ func (m *ProxyHandler) serveDenyPageResponseDirect(req *http.Request, reqCtx *Re if err == nil && currentDomainName != templateDomainName.String() { // we're on mitm domain, redirect to campaign template domain campaignID := campaign.ID.MustGet() - redirectURL := m.buildCampaignFlowRedirectURL(&ProxySession{ + redirectURL := m.buildCampaignFlowRedirectURL(&service.ProxySession{ CampaignRecipientID: reqCtx.CampaignRecipientID, Campaign: campaign, CampaignID: &campaignID, @@ -3877,7 +3824,7 @@ func contains(slice []string, item string) bool { // storeURLMapping stores the mapping between rewritten and original URLs func (m *ProxyHandler) storeURLMapping(rewrittenURL, originalURL string) { - m.urlMappings.Store(rewrittenURL, originalURL) + m.SessionManager.StoreURLMapping(rewrittenURL, originalURL) } // getReverseURLMapping gets the original URL for a rewritten URL @@ -3887,8 +3834,8 @@ func (m *ProxyHandler) getReverseURLMapping(path, query string) string { rewrittenURL += "?" + query } - if originalURL, exists := m.urlMappings.Load(rewrittenURL); exists { - return originalURL.(string) + if originalURL, exists := m.SessionManager.GetURLMapping(rewrittenURL); exists { + return originalURL } return "" } @@ -3952,7 +3899,7 @@ func (m *ProxyHandler) rewritePathsInContent(content string, rule service.ProxyS } // getHostConfig is a helper function to safely load and cast host configuration -func (m *ProxyHandler) getHostConfig(session *ProxySession, host string) (service.ProxyServiceDomainConfig, bool) { +func (m *ProxyHandler) getHostConfig(session *service.ProxySession, host string) (service.ProxyServiceDomainConfig, bool) { hostConfigInterface, exists := session.Config.Load(host) if !exists { return service.ProxyServiceDomainConfig{}, false diff --git a/backend/service/proxy.go b/backend/service/proxy.go index f90e855..00d7d38 100644 --- a/backend/service/proxy.go +++ b/backend/service/proxy.go @@ -30,6 +30,7 @@ type Proxy struct { CampaignRepository *repository.Campaign CampaignTemplateService *CampaignTemplate DomainService *Domain + ProxySessionManager *ProxySessionManager } // ProxyServiceConfig represents the YAML configuration for proxy @@ -623,6 +624,15 @@ func (m *Proxy) UpdateByID( return err } + // clear all sessions for this proxy since config has changed + // this ensures sessions get the new config (capture rules, rewrite rules, etc.) + if m.ProxySessionManager != nil { + m.Logger.Debugw("clearing all sessions for updated proxy", + "proxyID", id.String(), + ) + m.ProxySessionManager.ClearSessionsForProxy(id.String()) + } + ae.Details["id"] = id.String() m.AuditLogAuthorized(ae) diff --git a/backend/service/proxySessionManager.go b/backend/service/proxySessionManager.go new file mode 100644 index 0000000..ed7cc9d --- /dev/null +++ b/backend/service/proxySessionManager.go @@ -0,0 +1,239 @@ +package service + +import ( + "sync" + "sync/atomic" + "time" + + "github.com/google/uuid" + "github.com/phishingclub/phishingclub/database" + "github.com/phishingclub/phishingclub/model" + "go.uber.org/zap" +) + +// ProxySession represents an active MITM proxy session +type ProxySession struct { + ID string + CampaignRecipientID *uuid.UUID + CampaignID *uuid.UUID + RecipientID *uuid.UUID + Campaign *model.Campaign + Domain *database.Domain + TargetDomain string + Config sync.Map // map[string]ProxyServiceDomainConfig + CreatedAt time.Time + RequiredCaptures sync.Map // map[string]bool + CapturedData sync.Map // map[string]map[string]string + NextPageType atomic.Value // string - accessed concurrently by multiple requests + IsComplete atomic.Bool // accessed concurrently when checking capture completion + CookieBundleSubmitted atomic.Bool // accessed concurrently to prevent duplicate submissions +} + +// ProxySessionManager manages proxy session lifecycle and storage +type ProxySessionManager struct { + Common + sessions sync.Map // map[sessionID]*ProxySession + campaignRecipientSessions sync.Map // map[campaignRecipientID]sessionID + urlMappings sync.Map // map[rewritten URL]original URL +} + +// NewProxySessionManager creates a new proxy session manager +func NewProxySessionManager(logger *zap.SugaredLogger) *ProxySessionManager { + return &ProxySessionManager{ + Common: Common{ + Logger: logger, + }, + } +} + +// GetSession retrieves a session by ID +func (m *ProxySessionManager) GetSession(sessionID string) (*ProxySession, bool) { + val, ok := m.sessions.Load(sessionID) + if !ok { + return nil, false + } + session, ok := val.(*ProxySession) + return session, ok +} + +// StoreSession stores a session +func (m *ProxySessionManager) StoreSession(sessionID string, session *ProxySession) { + m.sessions.Store(sessionID, session) +} + +// DeleteSession deletes a session and its associated campaign recipient mapping +func (m *ProxySessionManager) DeleteSession(sessionID string) { + if val, ok := m.sessions.Load(sessionID); ok { + if session, ok := val.(*ProxySession); ok { + if session.CampaignRecipientID != nil { + m.campaignRecipientSessions.Delete(session.CampaignRecipientID.String()) + } + } + m.sessions.Delete(sessionID) + } +} + +// GetSessionByCampaignRecipient retrieves a session ID by campaign recipient ID +func (m *ProxySessionManager) GetSessionByCampaignRecipient(campaignRecipientID string) (string, bool) { + val, ok := m.campaignRecipientSessions.Load(campaignRecipientID) + if !ok { + return "", false + } + sessionID, ok := val.(string) + return sessionID, ok +} + +// StoreCampaignRecipientSession stores the mapping between campaign recipient and session +func (m *ProxySessionManager) StoreCampaignRecipientSession(campaignRecipientID string, sessionID string) { + m.campaignRecipientSessions.Store(campaignRecipientID, sessionID) +} + +// StoreURLMapping stores a URL mapping for rewrite tracking +func (m *ProxySessionManager) StoreURLMapping(rewrittenURL string, originalURL string) { + m.urlMappings.Store(rewrittenURL, originalURL) +} + +// GetURLMapping retrieves the original URL for a rewritten URL +func (m *ProxySessionManager) GetURLMapping(rewrittenURL string) (string, bool) { + val, ok := m.urlMappings.Load(rewrittenURL) + if !ok { + return "", false + } + originalURL, ok := val.(string) + return originalURL, ok +} + +// RangeSessions iterates over all sessions +func (m *ProxySessionManager) RangeSessions(fn func(sessionID string, session *ProxySession) bool) { + m.sessions.Range(func(key, value interface{}) bool { + sessionID, ok := key.(string) + if !ok { + return true + } + session, ok := value.(*ProxySession) + if !ok { + return true + } + return fn(sessionID, session) + }) +} + +// ClearSessionsForProxy clears all sessions associated with a proxy configuration +func (m *ProxySessionManager) ClearSessionsForProxy(proxyID string) { + if proxyID == "" { + return + } + + clearedCount := 0 + + m.sessions.Range(func(key, value interface{}) bool { + sessionID, ok := key.(string) + if !ok { + return true + } + session, ok := value.(*ProxySession) + if !ok { + return true + } + + // check if this session's domain belongs to the proxy + if session.Domain != nil && session.Domain.ProxyID != nil { + if session.Domain.ProxyID.String() == proxyID { + m.DeleteSession(sessionID) + clearedCount++ + m.Logger.Debugw("cleared session for proxy", + "sessionID", sessionID, + "proxyID", proxyID, + "domain", session.Domain.Name, + ) + } + } + return true + }) + + if clearedCount > 0 { + m.Logger.Infow("cleared all sessions for proxy", + "count", clearedCount, + "proxyID", proxyID, + ) + } +} + +// ClearSessionsForDomains clears all sessions associated with specific phishing domains +func (m *ProxySessionManager) ClearSessionsForDomains(phishingDomains []string) { + if len(phishingDomains) == 0 { + return + } + + // create a map for fast lookup + domainMap := make(map[string]bool) + for _, domain := range phishingDomains { + domainMap[domain] = true + } + + clearedCount := 0 + + m.sessions.Range(func(key, value interface{}) bool { + sessionID, ok := key.(string) + if !ok { + return true + } + session, ok := value.(*ProxySession) + if !ok { + return true + } + + // check if this session's domain matches any of the affected domains + if session.Domain != nil { + domainName := session.Domain.Name + if domainMap[domainName] { + m.DeleteSession(sessionID) + clearedCount++ + m.Logger.Debugw("cleared session for affected domain", + "sessionID", sessionID, + "domain", domainName, + ) + } + } + return true + }) + + if clearedCount > 0 { + m.Logger.Infow("cleared sessions for affected domains", + "count", clearedCount, + "domains", phishingDomains, + ) + } +} + +// CleanupExpiredSessions removes sessions older than maxAge +func (m *ProxySessionManager) CleanupExpiredSessions(maxAge time.Duration) int { + now := time.Now() + cleanedCount := 0 + + m.sessions.Range(func(key, value interface{}) bool { + sessionID, ok := key.(string) + if !ok { + return true + } + session, ok := value.(*ProxySession) + if !ok { + m.sessions.Delete(sessionID) + cleanedCount++ + return true + } + + sessionAge := now.Sub(session.CreatedAt) + if sessionAge > maxAge { + m.DeleteSession(sessionID) + cleanedCount++ + } + return true + }) + + if cleanedCount > 0 { + m.Logger.Debugw("cleaned up expired sessions", "count", cleanedCount) + } + + return cleanedCount +}