Skip to content

Commit 32d2556

Browse files
authored
fix(interceptor): preserve upstream X-Forwarded headers (#1434)
1 parent b222c4e commit 32d2556

File tree

3 files changed

+72
-18
lines changed

3 files changed

+72
-18
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ This changelog keeps track of work items that have been completed and are ready
3636

3737
### Fixes
3838

39-
- **General**: TODO ([#TODO](https://github.com/kedacore/http-add-on/issues/TODO))
39+
- **General**: Preserve `X-Forwarded-Proto` and `X-Forwarded-Host` headers from upstream proxies ([#1432](https://github.com/kedacore/http-add-on/issues/1432))
4040

4141
### Deprecations
4242

interceptor/handler/upstream.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,16 @@ func (uh *Upstream) ServeHTTP(w http.ResponseWriter, r *http.Request) {
7272
pr.SetURL(stream)
7373
// Preserve original Host header (SetURL rewrites it by default).
7474
pr.Out.Host = pr.In.Host
75-
// Preserve X-Forwarded-For chain before appending client IP.
75+
76+
// Preserve and extend X-Forwarded-... headers from upstream proxies
7677
pr.Out.Header["X-Forwarded-For"] = pr.In.Header["X-Forwarded-For"]
7778
pr.SetXForwarded()
79+
if host := pr.In.Header.Get("X-Forwarded-Host"); host != "" {
80+
pr.Out.Header.Set("X-Forwarded-Host", host)
81+
}
82+
if proto := pr.In.Header.Get("X-Forwarded-Proto"); proto != "" {
83+
pr.Out.Header.Set("X-Forwarded-Proto", proto)
84+
}
7885
},
7986
BufferPool: bufferPool,
8087
Transport: uh.roundTripper,

interceptor/handler/upstream_test.go

Lines changed: 63 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -440,15 +440,36 @@ func TestForwardRequestRedirectAndHeaders(t *testing.T) {
440440
r.Equal("Hello from srv", res.Body.String())
441441
}
442442

443-
func TestUpstreamSetsXForwardedFor(t *testing.T) {
443+
func TestUpstreamPreservesXForwardedHeaders(t *testing.T) {
444444
tests := map[string]struct {
445-
forwardedIPs []string
445+
forwardedFor string
446+
forwardedHost string
447+
forwardedProto string
448+
forwardedPort string
446449
}{
447-
"appends to existing header chain": {
448-
forwardedIPs: []string{"1.2.3.4", "5.6.7.8"},
450+
"preserves and extends forwarded IPs": {
451+
forwardedFor: "198.51.100.1",
452+
},
453+
"preserves forwarded host": {
454+
forwardedHost: "example.org",
455+
},
456+
"preserves forwarded proto": {
457+
forwardedProto: "http",
458+
},
459+
"preserves forwarded port": {
460+
forwardedPort: "443",
461+
},
462+
"preserves and extends existing headers": {
463+
forwardedFor: "1.2.3.4, 5.6.7.8",
464+
forwardedHost: "keda.sh",
465+
forwardedProto: "https",
466+
forwardedPort: "8443",
449467
},
450468
"sets header when not present": {
451-
forwardedIPs: nil,
469+
forwardedFor: "",
470+
forwardedHost: "",
471+
forwardedProto: "",
472+
forwardedPort: "",
452473
},
453474
}
454475

@@ -470,28 +491,54 @@ func TestUpstreamSetsXForwardedFor(t *testing.T) {
470491
upstream := NewUpstream(http.DefaultTransport, &config.Tracing{}, false)
471492

472493
req := httptest.NewRequest("GET", "/test", nil)
473-
forwardedIPsStr := strings.Join(tt.forwardedIPs, ", ")
474-
if tt.forwardedIPs != nil {
475-
req.Header.Set("X-Forwarded-For", forwardedIPsStr)
494+
if tt.forwardedFor != "" {
495+
req.Header.Set("X-Forwarded-For", tt.forwardedFor)
496+
}
497+
if tt.forwardedHost != "" {
498+
req.Header.Set("X-Forwarded-Host", tt.forwardedHost)
499+
}
500+
if tt.forwardedProto != "" {
501+
req.Header.Set("X-Forwarded-Proto", tt.forwardedProto)
502+
}
503+
if tt.forwardedPort != "" {
504+
req.Header.Set("X-Forwarded-Port", tt.forwardedPort)
476505
}
477506
req = util.RequestWithStream(req, backendURL)
478507

479508
upstream.ServeHTTP(httptest.NewRecorder(), req)
480509

481510
// Verify the test conditions
482511
xff := receivedHeaders.Get("X-Forwarded-For")
483-
if xff == "" {
484-
t.Fatal("X-Forwarded-For should not be empty")
512+
if tt.forwardedFor != "" {
513+
if !strings.HasPrefix(xff, tt.forwardedFor+", ") {
514+
t.Errorf("expected X-Forwarded-For to start with %q, got: %q", tt.forwardedFor+", ", xff)
515+
}
516+
} else if xff == "" {
517+
t.Error("X-Forwarded-For should contain at least the client IP")
518+
}
519+
520+
xfh := receivedHeaders.Get("X-Forwarded-Host")
521+
if tt.forwardedHost != "" {
522+
if tt.forwardedHost != xfh {
523+
t.Errorf("expected forwarded host %q, got %q", tt.forwardedHost, xfh)
524+
}
525+
} else if xfh != req.Host {
526+
t.Errorf("expected default forwarded host %q, got %q", req.Host, xfh)
485527
}
486528

487-
if tt.forwardedIPs != nil && !strings.HasPrefix(xff, forwardedIPsStr) {
488-
t.Errorf("expected X-Forwarded-For to contain %q, got: %q", forwardedIPsStr, xff)
529+
xfproto := receivedHeaders.Get("X-Forwarded-Proto")
530+
if tt.forwardedProto != "" {
531+
if tt.forwardedProto != xfproto {
532+
t.Errorf("expected forwarded proto %q, got %q", tt.forwardedProto, xfproto)
533+
}
534+
} else if xfproto != "http" {
535+
t.Errorf("expected default forwarded proto %q, got %q", "http", xfproto)
489536
}
490537

491-
ips := strings.Split(xff, ", ")
492-
expectedLen := len(tt.forwardedIPs) + 1
493-
if len(ips) != expectedLen {
494-
t.Errorf("expected %d IPs in X-Forwarded-For, got %d: %q", expectedLen, len(ips), xff)
538+
// Ensure that X-Forwarded-Port is preserved even if we don't set a default for it
539+
xfport := receivedHeaders.Get("X-Forwarded-Port")
540+
if xfport != tt.forwardedPort {
541+
t.Errorf("expected forwarded port %q, got %q", tt.forwardedPort, xfport)
495542
}
496543
})
497544
}

0 commit comments

Comments
 (0)