extract proxy session management and add clear proxy sessions when updating a proxy config

Signed-off-by: Ronni Skansing <rskansing@gmail.com>
This commit is contained in:
Ronni Skansing
2025-11-04 21:42:31 +01:00
parent a5b2622317
commit 4ce58fb235
5 changed files with 362 additions and 161 deletions

View File

@@ -71,6 +71,7 @@ func NewServer(
// setup goproxy-based proxy server
proxyServer := proxy.NewProxyHandler(
logger,
services.ProxySessionManager,
repositories.Page,
repositories.CampaignRecipient,
repositories.Campaign,

View File

@@ -9,34 +9,35 @@ import (
// Services is a collection of services
type Services struct {
Asset *service.Asset
Attachment *service.Attachment
File *service.File
Company *service.Company
InstallSetup *service.InstallSetup
Option *service.Option
Page *service.Page
Proxy *service.Proxy
Session *service.Session
User *service.User
Domain *service.Domain
Recipient *service.Recipient
RecipientGroup *service.RecipientGroup
SMTPConfiguration *service.SMTPConfiguration
Email *service.Email
CampaignTemplate *service.CampaignTemplate
Campaign *service.Campaign
Template *service.Template
APISender *service.APISender
AllowDeny *service.AllowDeny
Webhook *service.Webhook
Identifier *service.Identifier
Version *service.Version
SSO *service.SSO
Update *service.Update
Import *service.Import
Backup *service.Backup
IPAllowList *service.IPAllowListService
Asset *service.Asset
Attachment *service.Attachment
File *service.File
Company *service.Company
InstallSetup *service.InstallSetup
Option *service.Option
Page *service.Page
Proxy *service.Proxy
Session *service.Session
User *service.User
Domain *service.Domain
Recipient *service.Recipient
RecipientGroup *service.RecipientGroup
SMTPConfiguration *service.SMTPConfiguration
Email *service.Email
CampaignTemplate *service.CampaignTemplate
Campaign *service.Campaign
Template *service.Template
APISender *service.APISender
AllowDeny *service.AllowDeny
Webhook *service.Webhook
Identifier *service.Identifier
Version *service.Version
SSO *service.SSO
Update *service.Update
Import *service.Import
Backup *service.Backup
IPAllowList *service.IPAllowListService
ProxySessionManager *service.ProxySessionManager
}
// NewServices creates a collection of services
@@ -157,6 +158,7 @@ func NewServices(
FileService: file,
TemplateService: templateService,
}
proxySessionManager := service.NewProxySessionManager(logger)
proxy := &service.Proxy{
Common: common,
ProxyRepository: repositories.Proxy,
@@ -164,6 +166,7 @@ func NewServices(
CampaignRepository: repositories.Campaign,
CampaignTemplateService: campaignTemplate,
DomainService: domain,
ProxySessionManager: proxySessionManager,
}
ipAllowListService := service.NewIPAllowListService(logger, repositories.Proxy)
email := &service.Email{
@@ -247,33 +250,34 @@ func NewServices(
}
return &Services{
Asset: asset,
Attachment: attachment,
Company: companyService,
File: file,
InstallSetup: installSetup,
Option: optionService,
Page: page,
Proxy: proxy,
Session: sessionService,
User: userService,
Domain: domain,
Recipient: recipient,
RecipientGroup: recipientGroup,
SMTPConfiguration: smtpConfiguration,
Email: email,
Template: templateService,
CampaignTemplate: campaignTemplate,
Campaign: campaign,
APISender: apiSender,
AllowDeny: allowDeny,
Webhook: webhook,
Identifier: identifier,
Version: versionService,
SSO: ssoService,
Update: updateService,
Import: importService,
Backup: backupService,
IPAllowList: ipAllowListService,
Asset: asset,
Attachment: attachment,
Company: companyService,
File: file,
InstallSetup: installSetup,
Option: optionService,
Page: page,
Proxy: proxy,
Session: sessionService,
User: userService,
Domain: domain,
Recipient: recipient,
RecipientGroup: recipientGroup,
SMTPConfiguration: smtpConfiguration,
Email: email,
Template: templateService,
CampaignTemplate: campaignTemplate,
Campaign: campaign,
APISender: apiSender,
AllowDeny: allowDeny,
Webhook: webhook,
Identifier: identifier,
Version: versionService,
SSO: ssoService,
Update: updateService,
Import: importService,
Backup: backupService,
IPAllowList: ipAllowListService,
ProxySessionManager: proxySessionManager,
}
}

View File

@@ -19,7 +19,6 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/PuerkitoBio/goquery"
@@ -71,23 +70,6 @@ var (
MATCH_URL_REGEXP_WITHOUT_SCHEME = regexp.MustCompile(`\b(([A-Za-z0-9-]{1,63}\.)?[A-Za-z0-9]+(-[a-z0-9]+)*\.)+(arpa|root|aero|biz|cat|com|coop|edu|gov|info|int|jobs|mil|mobi|museum|name|net|org|pro|tel|travel|bot|inc|game|xyz|cloud|live|today|online|shop|tech|art|site|wiki|ink|vip|lol|club|click|ac|ad|ae|af|ag|ai|al|am|an|ao|aq|ar|as|at|au|aw|ax|az|ba|bb|bd|be|bf|bg|bh|bi|bj|bm|bn|bo|br|bs|bt|bv|bw|by|bz|ca|cc|cd|cf|cg|ch|ci|ck|cl|cm|cn|co|cr|cu|cv|cx|cy|cz|dev|de|dj|dk|dm|do|dz|ec|ee|eg|er|es|et|eu|fi|fj|fk|fm|fo|fr|ga|gb|gd|ge|gf|gg|gh|gi|gl|gm|gn|gp|gq|gr|gs|gt|gu|gw|gy|hk|hm|hn|hr|ht|hu|id|ie|il|im|in|io|iq|ir|is|it|je|jm|jo|jp|ke|kg|kh|ki|km|kn|kr|kw|ky|kz|la|lb|lc|li|lk|lr|ls|lt|lu|lv|ly|ma|mc|md|mg|mh|mk|ml|mm|mn|mo|mp|mq|mr|ms|mt|mu|mv|mw|mx|my|mz|na|nc|ne|nf|ng|ni|nl|no|np|nr|nu|nz|om|pa|pe|pf|pg|ph|pk|pl|pm|pn|pr|ps|pt|pw|py|qa|re|ro|ru|rw|sa|sb|sc|sd|se|sg|sh|si|sj|sk|sl|sm|sn|so|sr|st|su|sv|sy|sz|tc|td|tf|tg|th|tj|tk|tl|tm|tn|to|tp|tr|tt|tv|tw|tz|ua|ug|uk|um|us|uy|uz|va|vc|ve|vg|vi|vn|vu|wf|ws|ye|yt|yu|za|zm|zw)|([0-9]{1,3}\.{3}[0-9]{1,3})\b`)
)
type ProxySession struct {
ID string
CampaignRecipientID *uuid.UUID
CampaignID *uuid.UUID
RecipientID *uuid.UUID
Campaign *model.Campaign
Domain *database.Domain
TargetDomain string
Config sync.Map // map[string]service.ProxyServiceDomainConfig
CreatedAt time.Time
RequiredCaptures sync.Map // map[string]bool
CapturedData sync.Map // map[string]map[string]string
NextPageType atomic.Value // string
IsComplete atomic.Bool
CookieBundleSubmitted atomic.Bool
}
// RequestContext holds all the context data for a proxy request
type RequestContext struct {
SessionID string
@@ -96,7 +78,7 @@ type RequestContext struct {
TargetDomain string
Domain *database.Domain
ProxyConfig *service.ProxyServiceConfigYAML
Session *ProxySession
Session *service.ProxySession
ConfigMap map[string]service.ProxyServiceDomainConfig
CampaignRecipientID *uuid.UUID
ParamName string
@@ -112,9 +94,7 @@ type RequestContext struct {
type ProxyHandler struct {
logger *zap.SugaredLogger
sessions sync.Map // map[string]*ProxySession
campaignRecipientSessions sync.Map // map[string]string (campaignRecipientID -> sessionID)
urlMappings sync.Map // map[string]string (rewritten URL -> original URL)
SessionManager *service.ProxySessionManager
PageRepository *repository.Page
CampaignRecipientRepository *repository.CampaignRecipient
CampaignRepository *repository.Campaign
@@ -131,6 +111,7 @@ type ProxyHandler struct {
func NewProxyHandler(
logger *zap.SugaredLogger,
sessionManager *service.ProxySessionManager,
pageRepo *repository.Page,
campaignRecipientRepo *repository.CampaignRecipient,
campaignRepo *repository.Campaign,
@@ -151,6 +132,7 @@ func NewProxyHandler(
return &ProxyHandler{
logger: logger,
SessionManager: sessionManager,
PageRepository: pageRepo,
CampaignRecipientRepository: campaignRecipientRepo,
CampaignRepository: campaignRepo,
@@ -458,8 +440,7 @@ func (m *ProxyHandler) processRequestWithContext(req *http.Request, reqCtx *Requ
func (m *ProxyHandler) cleanupExistingSession(campaignRecipientID *uuid.UUID, reqURL string) {
if existingSessionID := m.findSessionByCampaignRecipient(campaignRecipientID); existingSessionID != "" {
m.sessions.Delete(existingSessionID)
m.campaignRecipientSessions.Delete(campaignRecipientID.String())
m.SessionManager.DeleteSession(existingSessionID)
}
}
@@ -471,7 +452,7 @@ func (m *ProxyHandler) prepareRequestWithoutSession(req *http.Request, reqCtx *R
req.URL.Scheme = "https"
// create a dummy session for header normalization (no campaign/session data)
dummySession := &ProxySession{
dummySession := &service.ProxySession{
Config: sync.Map{},
}
// populate dummy config for normalization
@@ -603,14 +584,10 @@ func (m *ProxyHandler) resolveSessionContext(req *http.Request, reqCtx *RequestC
}
} else {
// load existing session
sessionVal, exists := m.sessions.Load(reqCtx.SessionID)
session, exists := m.SessionManager.GetSession(reqCtx.SessionID)
if !exists {
return fmt.Errorf("session not found")
}
session, ok := sessionVal.(*ProxySession)
if !ok {
return fmt.Errorf("invalid session type")
}
reqCtx.Session = session
// copy campaign from session to reqCtx for existing sessions
@@ -897,7 +874,7 @@ func (m *ProxyHandler) rewriteResponseHeadersWithContext(resp *http.Response, re
}
}
func (m *ProxyHandler) applyCustomResponseHeaderReplacements(resp *http.Response, session *ProxySession) {
func (m *ProxyHandler) applyCustomResponseHeaderReplacements(resp *http.Response, session *service.ProxySession) {
// get all headers as a string
var buf bytes.Buffer
resp.Header.Write(&buf)
@@ -1220,7 +1197,7 @@ func (m *ProxyHandler) replaceHostWithPhished(hostname string, config map[string
func (m *ProxyHandler) createNewSession(
req *http.Request,
reqCtx *RequestContext,
) (*ProxySession, error) {
) (*service.ProxySession, error) {
// use cached campaign data from request context
campaign := reqCtx.Campaign
recipientID := reqCtx.RecipientID
@@ -1235,7 +1212,7 @@ func (m *ProxyHandler) createNewSession(
sessionConfig := m.buildSessionConfig(reqCtx.TargetDomain, reqCtx.Domain.Name, reqCtx.ProxyConfig)
// create session
session := &ProxySession{
session := &service.ProxySession{
ID: uuid.New().String(),
CampaignRecipientID: campaignRecipientID,
CampaignID: campaignID,
@@ -1251,9 +1228,9 @@ func (m *ProxyHandler) createNewSession(
m.initializeSession(session, sessionConfig)
// store session
m.sessions.Store(session.ID, session)
m.SessionManager.StoreSession(session.ID, session)
if campaignRecipientID != nil {
m.campaignRecipientSessions.Store(campaignRecipientID.String(), session.ID)
m.SessionManager.StoreCampaignRecipientSession(campaignRecipientID.String(), session.ID)
}
return session, nil
@@ -1310,7 +1287,7 @@ func (m *ProxyHandler) buildSessionConfig(targetDomain, phishDomain string, prox
return sessionConfig
}
func (m *ProxyHandler) initializeSession(session *ProxySession, sessionConfig map[string]service.ProxyServiceDomainConfig) {
func (m *ProxyHandler) initializeSession(session *service.ProxySession, sessionConfig map[string]service.ProxyServiceDomainConfig) {
// store configuration in sync.map
for host, config := range sessionConfig {
session.Config.Store(host, config)
@@ -1330,24 +1307,21 @@ func (m *ProxyHandler) findSessionByCampaignRecipient(campaignRecipientID *uuid.
return ""
}
sessionIDVal, exists := m.campaignRecipientSessions.Load(campaignRecipientID.String())
sessionID, exists := m.SessionManager.GetSessionByCampaignRecipient(campaignRecipientID.String())
if !exists {
return ""
}
sessionID := sessionIDVal.(string)
if sessionVal, sessionExists := m.sessions.Load(sessionID); sessionExists {
if _, ok := sessionVal.(*ProxySession); ok {
return sessionID
}
if _, sessionExists := m.SessionManager.GetSession(sessionID); sessionExists {
return sessionID
}
// cleanup orphaned mapping
m.campaignRecipientSessions.Delete(campaignRecipientID.String())
m.SessionManager.DeleteSession(sessionID)
return ""
}
func (m *ProxyHandler) initializeRequiredCaptures(session *ProxySession) {
func (m *ProxyHandler) initializeRequiredCaptures(session *service.ProxySession) {
// only apply capture rules for the current host
if hostConfig, ok := session.Config.Load(session.TargetDomain); ok {
hCfg := hostConfig.(service.ProxyServiceDomainConfig)
@@ -1361,7 +1335,7 @@ func (m *ProxyHandler) initializeRequiredCaptures(session *ProxySession) {
}
}
func (m *ProxyHandler) onRequestBody(req *http.Request, session *ProxySession) {
func (m *ProxyHandler) onRequestBody(req *http.Request, session *service.ProxySession) {
if req.Body == nil {
return
}
@@ -1383,7 +1357,7 @@ func (m *ProxyHandler) onRequestBody(req *http.Request, session *ProxySession) {
m.applyRequestBodyReplacements(req, session)
}
func (m *ProxyHandler) onRequestHeader(req *http.Request, session *ProxySession) {
func (m *ProxyHandler) onRequestHeader(req *http.Request, session *service.ProxySession) {
hostConfig, exists := m.getHostConfig(session, req.Host)
if !exists {
return
@@ -1436,7 +1410,7 @@ func (m *ProxyHandler) onRequestHeader(req *http.Request, session *ProxySession)
}
}
func (m *ProxyHandler) onResponseBody(resp *http.Response, body []byte, session *ProxySession) {
func (m *ProxyHandler) onResponseBody(resp *http.Response, body []byte, session *service.ProxySession) {
originalHost := resp.Request.Host
if originalHost == "" {
originalHost = session.TargetDomain
@@ -1460,7 +1434,7 @@ func (m *ProxyHandler) onResponseBody(resp *http.Response, body []byte, session
}
}
func (m *ProxyHandler) onResponseCookies(resp *http.Response, session *ProxySession) {
func (m *ProxyHandler) onResponseCookies(resp *http.Response, session *service.ProxySession) {
hostConfig, exists := m.getHostConfig(session, resp.Request.Host)
if !exists {
return
@@ -1496,7 +1470,7 @@ func (m *ProxyHandler) onResponseCookies(resp *http.Response, session *ProxySess
m.checkAndSubmitCookieBundleWhenComplete(session, resp.Request)
}
func (m *ProxyHandler) onResponseHeader(resp *http.Response, session *ProxySession) {
func (m *ProxyHandler) onResponseHeader(resp *http.Response, session *service.ProxySession) {
hostConfig, exists := m.getHostConfig(session, resp.Request.Host)
if !exists {
return
@@ -1546,7 +1520,7 @@ func (m *ProxyHandler) matchesPath(capture service.ProxyServiceCaptureRule, req
return capture.PathRe.MatchString(req.URL.Path)
}
func (m *ProxyHandler) handlePathBasedCapture(capture service.ProxyServiceCaptureRule, session *ProxySession, resp *http.Response) {
func (m *ProxyHandler) handlePathBasedCapture(capture service.ProxyServiceCaptureRule, session *service.ProxySession, resp *http.Response) {
// only mark as complete if path AND method match exactly
methodMatches := capture.Method == "" || capture.Method == resp.Request.Method
pathMatches := m.matchesPath(capture, resp.Request)
@@ -1637,7 +1611,7 @@ func (m *ProxyHandler) readRequestBody(req *http.Request) []byte {
return body
}
func (m *ProxyHandler) captureFromText(text string, capture service.ProxyServiceCaptureRule, session *ProxySession, req *http.Request, captureContext string) {
func (m *ProxyHandler) captureFromText(text string, capture service.ProxyServiceCaptureRule, session *service.ProxySession, req *http.Request, captureContext string) {
if capture.Find == "" {
return
}
@@ -1667,7 +1641,7 @@ func (m *ProxyHandler) captureFromText(text string, capture service.ProxyService
m.handleCampaignFlowProgression(session, req)
}
func (m *ProxyHandler) buildCapturedData(matches []string, capture service.ProxyServiceCaptureRule, session *ProxySession, req *http.Request, captureContext string) map[string]string {
func (m *ProxyHandler) buildCapturedData(matches []string, capture service.ProxyServiceCaptureRule, session *service.ProxySession, req *http.Request, captureContext string) map[string]string {
capturedData := make(map[string]string)
// add capture name to the captured data
@@ -1685,7 +1659,7 @@ func (m *ProxyHandler) buildCapturedData(matches []string, capture service.Proxy
return capturedData
}
func (m *ProxyHandler) formatCapturedData(capturedData map[string]string, capture service.ProxyServiceCaptureRule, matches []string, session *ProxySession, req *http.Request, captureContext string) {
func (m *ProxyHandler) formatCapturedData(capturedData map[string]string, capture service.ProxyServiceCaptureRule, matches []string, session *service.ProxySession, req *http.Request, captureContext string) {
captureName := strings.ToLower(capture.Name)
switch {
@@ -1713,7 +1687,7 @@ func (m *ProxyHandler) formatCapturedData(capturedData map[string]string, captur
}
}
func (m *ProxyHandler) checkCaptureCompletion(session *ProxySession, captureName string) {
func (m *ProxyHandler) checkCaptureCompletion(session *service.ProxySession, captureName string) {
if _, exists := session.RequiredCaptures.Load(captureName); exists {
// only mark as complete if we actually have captured data for this capture
if _, hasData := session.CapturedData.Load(captureName); hasData {
@@ -1733,7 +1707,7 @@ func (m *ProxyHandler) checkCaptureCompletion(session *ProxySession, captureName
}
}
func (m *ProxyHandler) checkAndSubmitCookieBundleWhenComplete(session *ProxySession, req *http.Request) {
func (m *ProxyHandler) checkAndSubmitCookieBundleWhenComplete(session *service.ProxySession, req *http.Request) {
if session.CampaignRecipientID == nil || session.CampaignID == nil {
return
}
@@ -1757,7 +1731,7 @@ func (m *ProxyHandler) checkAndSubmitCookieBundleWhenComplete(session *ProxySess
}
}
func (m *ProxyHandler) collectCookieCaptures(session *ProxySession) (map[string]map[string]string, map[string]bool) {
func (m *ProxyHandler) collectCookieCaptures(session *service.ProxySession) (map[string]map[string]string, map[string]bool) {
cookieCaptures := make(map[string]map[string]string)
requiredCookieCaptures := make(map[string]bool)
@@ -1799,7 +1773,7 @@ func (m *ProxyHandler) areAllCookieCapturesComplete(requiredCookieCaptures map[s
return true
}
func (m *ProxyHandler) createCookieBundle(cookieCaptures map[string]map[string]string, session *ProxySession) map[string]interface{} {
func (m *ProxyHandler) createCookieBundle(cookieCaptures map[string]map[string]string, session *service.ProxySession) map[string]interface{} {
bundledData := map[string]interface{}{
"capture_type": "cookie",
"cookie_count": len(cookieCaptures),
@@ -1817,7 +1791,7 @@ func (m *ProxyHandler) createCookieBundle(cookieCaptures map[string]map[string]s
return bundledData
}
func (m *ProxyHandler) applyRequestBodyReplacements(req *http.Request, session *ProxySession) {
func (m *ProxyHandler) applyRequestBodyReplacements(req *http.Request, session *service.ProxySession) {
if req.Body == nil {
return
}
@@ -1839,7 +1813,7 @@ func (m *ProxyHandler) applyRequestBodyReplacements(req *http.Request, session *
req.Body = io.NopCloser(bytes.NewBuffer(body))
}
func (m *ProxyHandler) applyCustomReplacements(body []byte, session *ProxySession) []byte {
func (m *ProxyHandler) applyCustomReplacements(body []byte, session *service.ProxySession) []byte {
// only apply rewrite rules for the current host
if hostConfig, ok := session.Config.Load(session.TargetDomain); ok {
hCfg := hostConfig.(service.ProxyServiceDomainConfig)
@@ -2028,7 +2002,7 @@ func (m *ProxyHandler) applyTargetFilter(selection *goquery.Selection, target st
return selection
}
func (m *ProxyHandler) processCookiesForPhishingDomain(resp *http.Response, ps *ProxySession) {
func (m *ProxyHandler) processCookiesForPhishingDomain(resp *http.Response, ps *service.ProxySession) {
cookies := resp.Cookies()
if len(cookies) == 0 {
return
@@ -2053,7 +2027,7 @@ func (m *ProxyHandler) processCookiesForPhishingDomain(resp *http.Response, ps *
}
}
func (m *ProxyHandler) adjustCookieSettings(ck *http.Cookie, session *ProxySession, resp *http.Response) {
func (m *ProxyHandler) adjustCookieSettings(ck *http.Cookie, session *service.ProxySession, resp *http.Response) {
if ck.Secure {
ck.SameSite = http.SameSiteNoneMode
} else if ck.SameSite == http.SameSiteDefaultMode {
@@ -2129,7 +2103,7 @@ func (m *ProxyHandler) getCampaignRecipientIDFromURLParams(req *http.Request) (*
}
// Header normalization methods
func (m *ProxyHandler) normalizeRequestHeaders(req *http.Request, session *ProxySession) {
func (m *ProxyHandler) normalizeRequestHeaders(req *http.Request, session *service.ProxySession) {
configMap := m.configToMap(&session.Config)
// fix origin header
@@ -2246,7 +2220,7 @@ func (m *ProxyHandler) shouldProcessContent(contentType string) bool {
return false
}
func (m *ProxyHandler) handleImmediateCampaignRedirect(session *ProxySession, resp *http.Response, req *http.Request, captureLocation string) {
func (m *ProxyHandler) handleImmediateCampaignRedirect(session *service.ProxySession, resp *http.Response, req *http.Request, captureLocation string) {
m.handleCampaignFlowProgression(session, req)
nextPageType := session.NextPageType.Load().(string)
@@ -2268,7 +2242,7 @@ func (m *ProxyHandler) handleImmediateCampaignRedirect(session *ProxySession, re
session.NextPageType.Store("")
}
func (m *ProxyHandler) handleCampaignFlowProgression(session *ProxySession, req *http.Request) {
func (m *ProxyHandler) handleCampaignFlowProgression(session *service.ProxySession, req *http.Request) {
if session.CampaignRecipientID == nil || session.CampaignID == nil {
return
}
@@ -2294,7 +2268,7 @@ func (m *ProxyHandler) handleCampaignFlowProgression(session *ProxySession, req
}
}
func (m *ProxyHandler) getCurrentPageType(req *http.Request, template *model.CampaignTemplate, session *ProxySession) string {
func (m *ProxyHandler) getCurrentPageType(req *http.Request, template *model.CampaignTemplate, session *service.ProxySession) string {
if template.StateIdentifier != nil {
stateParamKey := template.StateIdentifier.Name.MustGet()
encryptedParam := req.URL.Query().Get(stateParamKey)
@@ -2350,12 +2324,12 @@ func (m *ProxyHandler) getNextPageType(currentPageType string, template *model.C
}
}
func (m *ProxyHandler) shouldRedirectForCampaignFlow(session *ProxySession, req *http.Request) bool {
func (m *ProxyHandler) shouldRedirectForCampaignFlow(session *service.ProxySession, req *http.Request) bool {
nextPageTypeStr := session.NextPageType.Load().(string)
return nextPageTypeStr != "" && nextPageTypeStr != data.PAGE_TYPE_DONE && session.IsComplete.Load()
}
func (m *ProxyHandler) createCampaignFlowRedirect(session *ProxySession, resp *http.Response) *http.Response {
func (m *ProxyHandler) createCampaignFlowRedirect(session *service.ProxySession, resp *http.Response) *http.Response {
if resp == nil {
return nil
}
@@ -2381,7 +2355,7 @@ func (m *ProxyHandler) createCampaignFlowRedirect(session *ProxySession, resp *h
return redirectResp
}
func (m *ProxyHandler) buildCampaignFlowRedirectURL(session *ProxySession, nextPageType string) string {
func (m *ProxyHandler) buildCampaignFlowRedirectURL(session *service.ProxySession, nextPageType string) string {
if session.CampaignRecipientID == nil || session.Campaign == nil {
return ""
}
@@ -2469,7 +2443,7 @@ func (m *ProxyHandler) buildCampaignFlowRedirectURL(session *ProxySession, nextP
return targetURL
}
func (m *ProxyHandler) createCampaignSubmitEvent(session *ProxySession, capturedData interface{}, req *http.Request) {
func (m *ProxyHandler) createCampaignSubmitEvent(session *service.ProxySession, capturedData interface{}, req *http.Request) {
if session.CampaignID == nil || session.CampaignRecipientID == nil {
return
}
@@ -2746,35 +2720,8 @@ func (m *ProxyHandler) createResponseFromRule(rule service.ProxyServiceResponseR
}
func (m *ProxyHandler) CleanupExpiredSessions() {
now := time.Now()
cleanedCount := 0
m.sessions.Range(func(key, value interface{}) bool {
sessionID, ok := key.(string)
if !ok {
return true
}
session, ok := value.(*ProxySession)
if !ok {
m.sessions.Delete(sessionID)
cleanedCount++
return true
}
sessionAge := now.Sub(session.CreatedAt)
if sessionAge > time.Duration(PROXY_COOKIE_MAX_AGE)*time.Second {
m.sessions.Delete(sessionID)
if session.CampaignRecipientID != nil {
m.campaignRecipientSessions.Delete(session.CampaignRecipientID.String())
}
cleanedCount++
}
return true
})
if cleanedCount > 0 {
m.logger.Debugw("cleaned up expired sessions", "count", cleanedCount)
}
// cleanup expired sessions
m.SessionManager.CleanupExpiredSessions(time.Duration(PROXY_COOKIE_MAX_AGE) * time.Second)
// cleanup expired IP allow listed entries
ipCleanedCount := m.IPAllowListService.ClearExpired()
@@ -2816,7 +2763,7 @@ func (m *ProxyHandler) isValidSessionCookie(cookie string) bool {
if cookie == "" {
return false
}
_, exists := m.sessions.Load(cookie)
_, exists := m.SessionManager.GetSession(cookie)
return exists
}
@@ -3065,7 +3012,7 @@ func (m *ProxyHandler) createStatusResponse(statusCode int) *http.Response {
}
// registerPageVisitEvent registers a page visit event when a new MITM session is created
func (m *ProxyHandler) registerPageVisitEvent(req *http.Request, session *ProxySession) {
func (m *ProxyHandler) registerPageVisitEvent(req *http.Request, session *service.ProxySession) {
if session.CampaignRecipientID == nil || session.CampaignID == nil || session.RecipientID == nil {
return
}
@@ -3372,7 +3319,7 @@ func (m *ProxyHandler) serveDenyPageResponseDirect(req *http.Request, reqCtx *Re
if err == nil && currentDomainName != templateDomainName.String() {
// we're on mitm domain, redirect to campaign template domain
campaignID := campaign.ID.MustGet()
redirectURL := m.buildCampaignFlowRedirectURL(&ProxySession{
redirectURL := m.buildCampaignFlowRedirectURL(&service.ProxySession{
CampaignRecipientID: reqCtx.CampaignRecipientID,
Campaign: campaign,
CampaignID: &campaignID,
@@ -3877,7 +3824,7 @@ func contains(slice []string, item string) bool {
// storeURLMapping stores the mapping between rewritten and original URLs
func (m *ProxyHandler) storeURLMapping(rewrittenURL, originalURL string) {
m.urlMappings.Store(rewrittenURL, originalURL)
m.SessionManager.StoreURLMapping(rewrittenURL, originalURL)
}
// getReverseURLMapping gets the original URL for a rewritten URL
@@ -3887,8 +3834,8 @@ func (m *ProxyHandler) getReverseURLMapping(path, query string) string {
rewrittenURL += "?" + query
}
if originalURL, exists := m.urlMappings.Load(rewrittenURL); exists {
return originalURL.(string)
if originalURL, exists := m.SessionManager.GetURLMapping(rewrittenURL); exists {
return originalURL
}
return ""
}
@@ -3952,7 +3899,7 @@ func (m *ProxyHandler) rewritePathsInContent(content string, rule service.ProxyS
}
// getHostConfig is a helper function to safely load and cast host configuration
func (m *ProxyHandler) getHostConfig(session *ProxySession, host string) (service.ProxyServiceDomainConfig, bool) {
func (m *ProxyHandler) getHostConfig(session *service.ProxySession, host string) (service.ProxyServiceDomainConfig, bool) {
hostConfigInterface, exists := session.Config.Load(host)
if !exists {
return service.ProxyServiceDomainConfig{}, false

View File

@@ -30,6 +30,7 @@ type Proxy struct {
CampaignRepository *repository.Campaign
CampaignTemplateService *CampaignTemplate
DomainService *Domain
ProxySessionManager *ProxySessionManager
}
// ProxyServiceConfig represents the YAML configuration for proxy
@@ -623,6 +624,15 @@ func (m *Proxy) UpdateByID(
return err
}
// clear all sessions for this proxy since config has changed
// this ensures sessions get the new config (capture rules, rewrite rules, etc.)
if m.ProxySessionManager != nil {
m.Logger.Debugw("clearing all sessions for updated proxy",
"proxyID", id.String(),
)
m.ProxySessionManager.ClearSessionsForProxy(id.String())
}
ae.Details["id"] = id.String()
m.AuditLogAuthorized(ae)

View File

@@ -0,0 +1,239 @@
package service
import (
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"github.com/phishingclub/phishingclub/database"
"github.com/phishingclub/phishingclub/model"
"go.uber.org/zap"
)
// ProxySession represents an active MITM proxy session
type ProxySession struct {
ID string
CampaignRecipientID *uuid.UUID
CampaignID *uuid.UUID
RecipientID *uuid.UUID
Campaign *model.Campaign
Domain *database.Domain
TargetDomain string
Config sync.Map // map[string]ProxyServiceDomainConfig
CreatedAt time.Time
RequiredCaptures sync.Map // map[string]bool
CapturedData sync.Map // map[string]map[string]string
NextPageType atomic.Value // string - accessed concurrently by multiple requests
IsComplete atomic.Bool // accessed concurrently when checking capture completion
CookieBundleSubmitted atomic.Bool // accessed concurrently to prevent duplicate submissions
}
// ProxySessionManager manages proxy session lifecycle and storage
type ProxySessionManager struct {
Common
sessions sync.Map // map[sessionID]*ProxySession
campaignRecipientSessions sync.Map // map[campaignRecipientID]sessionID
urlMappings sync.Map // map[rewritten URL]original URL
}
// NewProxySessionManager creates a new proxy session manager
func NewProxySessionManager(logger *zap.SugaredLogger) *ProxySessionManager {
return &ProxySessionManager{
Common: Common{
Logger: logger,
},
}
}
// GetSession retrieves a session by ID
func (m *ProxySessionManager) GetSession(sessionID string) (*ProxySession, bool) {
val, ok := m.sessions.Load(sessionID)
if !ok {
return nil, false
}
session, ok := val.(*ProxySession)
return session, ok
}
// StoreSession stores a session
func (m *ProxySessionManager) StoreSession(sessionID string, session *ProxySession) {
m.sessions.Store(sessionID, session)
}
// DeleteSession deletes a session and its associated campaign recipient mapping
func (m *ProxySessionManager) DeleteSession(sessionID string) {
if val, ok := m.sessions.Load(sessionID); ok {
if session, ok := val.(*ProxySession); ok {
if session.CampaignRecipientID != nil {
m.campaignRecipientSessions.Delete(session.CampaignRecipientID.String())
}
}
m.sessions.Delete(sessionID)
}
}
// GetSessionByCampaignRecipient retrieves a session ID by campaign recipient ID
func (m *ProxySessionManager) GetSessionByCampaignRecipient(campaignRecipientID string) (string, bool) {
val, ok := m.campaignRecipientSessions.Load(campaignRecipientID)
if !ok {
return "", false
}
sessionID, ok := val.(string)
return sessionID, ok
}
// StoreCampaignRecipientSession stores the mapping between campaign recipient and session
func (m *ProxySessionManager) StoreCampaignRecipientSession(campaignRecipientID string, sessionID string) {
m.campaignRecipientSessions.Store(campaignRecipientID, sessionID)
}
// StoreURLMapping stores a URL mapping for rewrite tracking
func (m *ProxySessionManager) StoreURLMapping(rewrittenURL string, originalURL string) {
m.urlMappings.Store(rewrittenURL, originalURL)
}
// GetURLMapping retrieves the original URL for a rewritten URL
func (m *ProxySessionManager) GetURLMapping(rewrittenURL string) (string, bool) {
val, ok := m.urlMappings.Load(rewrittenURL)
if !ok {
return "", false
}
originalURL, ok := val.(string)
return originalURL, ok
}
// RangeSessions iterates over all sessions
func (m *ProxySessionManager) RangeSessions(fn func(sessionID string, session *ProxySession) bool) {
m.sessions.Range(func(key, value interface{}) bool {
sessionID, ok := key.(string)
if !ok {
return true
}
session, ok := value.(*ProxySession)
if !ok {
return true
}
return fn(sessionID, session)
})
}
// ClearSessionsForProxy clears all sessions associated with a proxy configuration
func (m *ProxySessionManager) ClearSessionsForProxy(proxyID string) {
if proxyID == "" {
return
}
clearedCount := 0
m.sessions.Range(func(key, value interface{}) bool {
sessionID, ok := key.(string)
if !ok {
return true
}
session, ok := value.(*ProxySession)
if !ok {
return true
}
// check if this session's domain belongs to the proxy
if session.Domain != nil && session.Domain.ProxyID != nil {
if session.Domain.ProxyID.String() == proxyID {
m.DeleteSession(sessionID)
clearedCount++
m.Logger.Debugw("cleared session for proxy",
"sessionID", sessionID,
"proxyID", proxyID,
"domain", session.Domain.Name,
)
}
}
return true
})
if clearedCount > 0 {
m.Logger.Infow("cleared all sessions for proxy",
"count", clearedCount,
"proxyID", proxyID,
)
}
}
// ClearSessionsForDomains clears all sessions associated with specific phishing domains
func (m *ProxySessionManager) ClearSessionsForDomains(phishingDomains []string) {
if len(phishingDomains) == 0 {
return
}
// create a map for fast lookup
domainMap := make(map[string]bool)
for _, domain := range phishingDomains {
domainMap[domain] = true
}
clearedCount := 0
m.sessions.Range(func(key, value interface{}) bool {
sessionID, ok := key.(string)
if !ok {
return true
}
session, ok := value.(*ProxySession)
if !ok {
return true
}
// check if this session's domain matches any of the affected domains
if session.Domain != nil {
domainName := session.Domain.Name
if domainMap[domainName] {
m.DeleteSession(sessionID)
clearedCount++
m.Logger.Debugw("cleared session for affected domain",
"sessionID", sessionID,
"domain", domainName,
)
}
}
return true
})
if clearedCount > 0 {
m.Logger.Infow("cleared sessions for affected domains",
"count", clearedCount,
"domains", phishingDomains,
)
}
}
// CleanupExpiredSessions removes sessions older than maxAge
func (m *ProxySessionManager) CleanupExpiredSessions(maxAge time.Duration) int {
now := time.Now()
cleanedCount := 0
m.sessions.Range(func(key, value interface{}) bool {
sessionID, ok := key.(string)
if !ok {
return true
}
session, ok := value.(*ProxySession)
if !ok {
m.sessions.Delete(sessionID)
cleanedCount++
return true
}
sessionAge := now.Sub(session.CreatedAt)
if sessionAge > maxAge {
m.DeleteSession(sessionID)
cleanedCount++
}
return true
})
if cleanedCount > 0 {
m.Logger.Debugw("cleaned up expired sessions", "count", cleanedCount)
}
return cleanedCount
}