Skip to content

Commit b8a3f30

Browse files
committed
refactor(interceptor): extract shared proxy handler for testability
Extract building the handler chain for better testability and refactor the interceptor/main.go file to make it cleaner. Changes: - Extract BuildProxyHandler for shared use by production and tests - Move TLS configuration logic to dedicated tls_config.go and add tests - Replace main_test.go with proxy_test.go using shared builder - Simplify config parsing to return values instead of pointers - Remove deprecated DeploymentCachePollIntervalMS config option Signed-off-by: Vincent Link <vlink@redhat.com>
1 parent 32d2556 commit b8a3f30

File tree

17 files changed

+1004
-801
lines changed

17 files changed

+1004
-801
lines changed

interceptor/config/serving.go

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,6 @@ type Serving struct {
2424
// ConfigMapCacheRsyncPeriod is the time interval
2525
// for the configmap informer to rsync the local cache.
2626
ConfigMapCacheRsyncPeriod time.Duration `envconfig:"KEDA_HTTP_SCALER_CONFIG_MAP_INFORMER_RSYNC_PERIOD" default:"60m"`
27-
// Deprecated: The interceptor has an internal process that periodically fetches the state
28-
// of deployment that is running the servers it forwards to.
29-
//
30-
// This is the interval (in milliseconds) representing how often to do a fetch
31-
DeploymentCachePollIntervalMS int `envconfig:"KEDA_HTTP_DEPLOYMENT_CACHE_POLLING_INTERVAL_MS" default:"250"`
3227
// The interceptor has an internal process that periodically fetches the state
3328
// of endpoints that is running the servers it forwards to.
3429
//
@@ -55,10 +50,10 @@ type Serving struct {
5550
LogRequests bool `envconfig:"KEDA_HTTP_LOG_REQUESTS" default:"false"`
5651
}
5752

58-
// Parse parses standard configs using envconfig and returns a pointer to the
59-
// newly created config. Returns nil and a non-nil error if parsing failed
60-
func MustParseServing() *Serving {
61-
ret := new(Serving)
62-
envconfig.MustProcess("", ret)
53+
// MustParseServing parses standard configs using envconfig and returns the
54+
// newly created config. It panics if parsing fails.
55+
func MustParseServing() Serving {
56+
var ret Serving
57+
envconfig.MustProcess("", &ret)
6358
return ret
6459
}

interceptor/config/timeouts.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,10 @@ func (t Timeouts) DefaultBackoff() wait.Backoff {
5252
}
5353
}
5454

55-
// MustParseTimeouts parses standard configs using envconfig and returns a pointer to the
55+
// MustParseTimeouts parses standard configs using envconfig and returns the
5656
// newly created config. It panics if parsing fails.
57-
func MustParseTimeouts() *Timeouts {
58-
ret := new(Timeouts)
59-
envconfig.MustProcess("", ret)
57+
func MustParseTimeouts() Timeouts {
58+
var ret Timeouts
59+
envconfig.MustProcess("", &ret)
6060
return ret
6161
}

interceptor/config/tracing.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ type Tracing struct {
1212
Exporter string `envconfig:"OTEL_EXPORTER_OTLP_TRACES_PROTOCOL" default:"console"`
1313
}
1414

15-
// Parse parses standard configs using envconfig and returns a pointer to the
16-
// newly created config. Returns nil and a non-nil error if parsing failed
17-
func MustParseTracing() *Tracing {
18-
ret := new(Tracing)
19-
envconfig.MustProcess("", ret)
15+
// MustParseTracing parses standard configs using envconfig and returns the
16+
// newly created config. It panics if parsing fails.
17+
func MustParseTracing() Tracing {
18+
var ret Tracing
19+
envconfig.MustProcess("", &ret)
2020
return ret
2121
}

interceptor/config/validate.go

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,10 @@ package config
22

33
import (
44
"fmt"
5-
"os"
65
"time"
7-
8-
"github.com/go-logr/logr"
96
)
107

11-
func Validate(srvCfg *Serving, timeoutsCfg Timeouts, lggr logr.Logger) error {
12-
// TODO(jorturfer): delete this for v0.9.0
13-
_, deploymentEnvExist := os.LookupEnv("KEDA_HTTP_DEPLOYMENT_CACHE_POLLING_INTERVAL_MS")
14-
_, endpointsEnvExist := os.LookupEnv("KEDA_HTTP_ENDPOINTS_CACHE_POLLING_INTERVAL_MS")
15-
if deploymentEnvExist && endpointsEnvExist {
16-
return fmt.Errorf(
17-
"%s and %s are mutual exclusive",
18-
"KEDA_HTTP_DEPLOYMENT_CACHE_POLLING_INTERVAL_MS",
19-
"KEDA_HTTP_ENDPOINTS_CACHE_POLLING_INTERVAL_MS",
20-
)
21-
}
22-
if deploymentEnvExist && !endpointsEnvExist {
23-
srvCfg.EndpointsCachePollIntervalMS = srvCfg.DeploymentCachePollIntervalMS
24-
srvCfg.DeploymentCachePollIntervalMS = 0
25-
lggr.Info("WARNING: KEDA_HTTP_DEPLOYMENT_CACHE_POLLING_INTERVAL_MS has been deprecated in favor of KEDA_HTTP_ENDPOINTS_CACHE_POLLING_INTERVAL_MS and wil be removed for v0.9.0")
26-
}
27-
// END TODO
28-
8+
func Validate(srvCfg Serving, timeoutsCfg Timeouts) error {
299
endpointsCachePollInterval := time.Duration(srvCfg.EndpointsCachePollIntervalMS) * time.Millisecond
3010
if timeoutsCfg.WorkloadReplicas < endpointsCachePollInterval {
3111
return fmt.Errorf(

interceptor/handler/upstream.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ var (
2222

2323
type Upstream struct {
2424
roundTripper http.RoundTripper
25-
tracingCfg *config.Tracing
25+
tracingCfg config.Tracing
2626
shouldFailover bool
2727
}
2828

29-
func NewUpstream(roundTripper http.RoundTripper, tracingCfg *config.Tracing, shouldFailover bool) *Upstream {
29+
func NewUpstream(roundTripper http.RoundTripper, tracingCfg config.Tracing, shouldFailover bool) *Upstream {
3030
return &Upstream{
3131
roundTripper: roundTripper,
3232
tracingCfg: tracingCfg,

interceptor/handler/upstream_test.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ func TestForwarderSuccess(t *testing.T) {
260260
timeouts := defaultTimeouts()
261261
dialCtxFunc := retryDialContextFunc(timeouts, timeouts.DefaultBackoff())
262262
rt := newRoundTripper(dialCtxFunc, timeouts.ResponseHeader)
263-
uh := NewUpstream(rt, &config.Tracing{}, false)
263+
uh := NewUpstream(rt, config.Tracing{}, false)
264264
uh.ServeHTTP(res, req)
265265

266266
r.True(
@@ -304,7 +304,7 @@ func TestForwarderHeaderTimeout(t *testing.T) {
304304
r.NoError(err)
305305
req = util.RequestWithStream(req, originURL)
306306
rt := newRoundTripper(dialCtxFunc, timeouts.ResponseHeader)
307-
uh := NewUpstream(rt, &config.Tracing{}, false)
307+
uh := NewUpstream(rt, config.Tracing{}, false)
308308
uh.ServeHTTP(res, req)
309309

310310
forwardedRequests := hdl.IncomingRequests()
@@ -354,7 +354,7 @@ func TestForwarderWaitsForSlowOrigin(t *testing.T) {
354354
r.NoError(err)
355355
req = util.RequestWithStream(req, originURL)
356356
rt := newRoundTripper(dialCtxFunc, timeouts.ResponseHeader)
357-
uh := NewUpstream(rt, &config.Tracing{}, false)
357+
uh := NewUpstream(rt, config.Tracing{}, false)
358358
uh.ServeHTTP(res, req)
359359
// wait for the goroutine above to finish, with a little cusion
360360
ensureSignalBeforeTimeout(originWaitCh, originDelay*2)
@@ -377,7 +377,7 @@ func TestForwarderConnectionRetryAndTimeout(t *testing.T) {
377377
r.NoError(err)
378378
req = util.RequestWithStream(req, noSuchURL)
379379
rt := newRoundTripper(dialCtxFunc, timeouts.ResponseHeader)
380-
uh := NewUpstream(rt, &config.Tracing{}, false)
380+
uh := NewUpstream(rt, config.Tracing{}, false)
381381

382382
start := time.Now()
383383
uh.ServeHTTP(res, req)
@@ -431,7 +431,7 @@ func TestForwardRequestRedirectAndHeaders(t *testing.T) {
431431
r.NoError(err)
432432
req = util.RequestWithStream(req, srvURL)
433433
rt := newRoundTripper(dialCtxFunc, timeouts.ResponseHeader)
434-
uh := NewUpstream(rt, &config.Tracing{}, false)
434+
uh := NewUpstream(rt, config.Tracing{}, false)
435435
uh.ServeHTTP(res, req)
436436
r.Equal(301, res.Code)
437437
r.Equal("abc123.com", res.Header().Get("Location"))
@@ -488,7 +488,7 @@ func TestUpstreamPreservesXForwardedHeaders(t *testing.T) {
488488
}
489489

490490
// Configure the Upstream and send a dummy request
491-
upstream := NewUpstream(http.DefaultTransport, &config.Tracing{}, false)
491+
upstream := NewUpstream(http.DefaultTransport, config.Tracing{}, false)
492492

493493
req := httptest.NewRequest("GET", "/test", nil)
494494
if tt.forwardedFor != "" {
@@ -604,7 +604,7 @@ func serveHTTP(w http.ResponseWriter, r *http.Request) {
604604
timeouts := defaultTimeouts()
605605
dialCtxFunc := retryDialContextFunc(timeouts, timeouts.DefaultBackoff())
606606
rt := newRoundTripper(dialCtxFunc, timeouts.ResponseHeader)
607-
upstream := NewUpstream(rt, &config.Tracing{Enabled: true}, false)
607+
upstream := NewUpstream(rt, config.Tracing{Enabled: true}, false)
608608

609609
upstream.ServeHTTP(w, r)
610610
}

0 commit comments

Comments
 (0)