Skip to content

Commit cc0cd7e

Browse files
authored
add more custom config vars (#9)
Signed-off-by: Christian Troelsen <[email protected]> fix suggestions by coderabbit Signed-off-by: Christian Troelsen <[email protected]> a few more changes add logging some fixes to metadata endpoints add more token logging try to update resource value try another update more updates remove unnnecessary stuff
1 parent 30df967 commit cc0cd7e

File tree

4 files changed

+96
-30
lines changed

4 files changed

+96
-30
lines changed

config.go

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package oauth
22

33
import (
44
"fmt"
5+
"strconv"
6+
"strings"
57

68
"github.com/tuannvm/oauth-mcp-proxy/provider"
79
)
@@ -18,6 +20,7 @@ type Config struct {
1820
Audience string
1921
ClientID string
2022
ClientSecret string
23+
Scopes []string // OIDC scopes
2124

2225
// Server configuration
2326
ServerURL string // Full URL of the MCP server
@@ -31,6 +34,11 @@ type Config struct {
3134
// Implement the Logger interface (Debug, Info, Warn, Error methods) to
3235
// integrate with your application's logging system (e.g., zap, logrus).
3336
Logger Logger
37+
38+
SkipAudienceCheck bool // whether to skip audience validation
39+
// The issuer URL to use for issuer validation.
40+
// This should only be set if the issuer in the token differs from the standard issuer URL.
41+
ValidatorIssuer string
3442
}
3543

3644
// Validate validates the configuration
@@ -119,11 +127,13 @@ func SetupOAuth(cfg *Config) (provider.TokenValidator, error) {
119127
func createValidator(cfg *Config, logger Logger) (provider.TokenValidator, error) {
120128
// Convert root Config to provider.Config
121129
providerCfg := &provider.Config{
122-
Provider: cfg.Provider,
123-
Issuer: cfg.Issuer,
124-
Audience: cfg.Audience,
125-
JWTSecret: cfg.JWTSecret,
126-
Logger: logger,
130+
Provider: cfg.Provider,
131+
Issuer: cfg.Issuer,
132+
Audience: cfg.Audience,
133+
JWTSecret: cfg.JWTSecret,
134+
Logger: logger,
135+
SkipAudienceCheck: cfg.SkipAudienceCheck,
136+
ValidatorIssuer: cfg.ValidatorIssuer,
127137
}
128138

129139
var validator provider.TokenValidator
@@ -217,12 +227,30 @@ func (b *ConfigBuilder) WithJWTSecret(secret []byte) *ConfigBuilder {
217227
return b
218228
}
219229

230+
// WithScopes sets the OIDC scopes
231+
func (b *ConfigBuilder) WithScopes(scopes []string) *ConfigBuilder {
232+
b.config.Scopes = scopes
233+
return b
234+
}
235+
220236
// WithLogger sets the logger
221237
func (b *ConfigBuilder) WithLogger(logger Logger) *ConfigBuilder {
222238
b.config.Logger = logger
223239
return b
224240
}
225241

242+
// WithSkipAudienceCheck sets audience check toggle
243+
func (b *ConfigBuilder) WithSkipAudienceCheck(skipAudienceCheck bool) *ConfigBuilder {
244+
b.config.SkipAudienceCheck = skipAudienceCheck
245+
return b
246+
}
247+
248+
// WithValidatorIssuer sets the validator issuer URL
249+
func (b *ConfigBuilder) WithValidatorIssuer(validatorIssuer string) *ConfigBuilder {
250+
b.config.ValidatorIssuer = validatorIssuer
251+
return b
252+
}
253+
226254
// WithServerURL sets the full server URL directly
227255
func (b *ConfigBuilder) WithServerURL(url string) *ConfigBuilder {
228256
b.config.ServerURL = url
@@ -281,6 +309,12 @@ func FromEnv() (*Config, error) {
281309

282310
jwtSecret := getEnv("JWT_SECRET", "")
283311

312+
scopes := []string{}
313+
scopesEnv := getEnv("OIDC_SCOPES", "")
314+
if scopesEnv != "" {
315+
scopes = strings.Split(scopesEnv, " ")
316+
}
317+
284318
return NewConfigBuilder().
285319
WithMode(getEnv("OAUTH_MODE", "")).
286320
WithProvider(getEnv("OAUTH_PROVIDER", "")).
@@ -289,7 +323,23 @@ func FromEnv() (*Config, error) {
289323
WithAudience(getEnv("OIDC_AUDIENCE", "")).
290324
WithClientID(getEnv("OIDC_CLIENT_ID", "")).
291325
WithClientSecret(getEnv("OIDC_CLIENT_SECRET", "")).
326+
WithScopes(scopes).
327+
WithSkipAudienceCheck(parseBoolEnv("OIDC_SKIP_AUDIENCE_CHECK", false)).
328+
WithValidatorIssuer(getEnv("OIDC_VALIDATOR_ISSUER", "")).
292329
WithServerURL(serverURL).
293330
WithJWTSecret([]byte(jwtSecret)).
294331
Build()
295332
}
333+
334+
// parseBoolEnv parses a boolean environment variable
335+
func parseBoolEnv(key string, defaultVal bool) bool {
336+
val := getEnv(key, "")
337+
if val == "" {
338+
return defaultVal
339+
}
340+
parsed, err := strconv.ParseBool(val)
341+
if err != nil {
342+
return defaultVal
343+
}
344+
return parsed
345+
}

handlers.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ type OAuth2Config struct {
4646
Audience string
4747
ClientID string
4848
ClientSecret string
49+
Scopes []string // OIDC scopes
4950

5051
// Server configuration
5152
MCPHost string
@@ -96,7 +97,7 @@ func NewOAuth2Handler(cfg *OAuth2Config, logger Logger) *OAuth2Handler {
9697
ClientID: cfg.ClientID,
9798
ClientSecret: cfg.ClientSecret,
9899
Endpoint: endpoint,
99-
Scopes: []string{"openid", "profile", "email"},
100+
Scopes: cfg.Scopes,
100101
}
101102

102103
// Log client configuration type for debugging
@@ -177,6 +178,11 @@ func NewOAuth2ConfigFromConfig(cfg *Config, version string) *OAuth2Config {
177178
mcpURL = getEnv("MCP_URL", fmt.Sprintf("%s://%s:%s", scheme, mcpHost, mcpPort))
178179
}
179180

181+
scopes := cfg.Scopes
182+
if len(scopes) == 0 {
183+
scopes = []string{"openid", "profile", "email"}
184+
}
185+
180186
return &OAuth2Config{
181187
Enabled: true,
182188
Mode: cfg.Mode,
@@ -186,6 +192,7 @@ func NewOAuth2ConfigFromConfig(cfg *Config, version string) *OAuth2Config {
186192
Audience: cfg.Audience,
187193
ClientID: cfg.ClientID,
188194
ClientSecret: cfg.ClientSecret,
195+
Scopes: scopes,
189196
MCPHost: mcpHost,
190197
MCPPort: mcpPort,
191198
MCPURL: mcpURL,
@@ -287,7 +294,7 @@ func (h *OAuth2Handler) HandleAuthorize(w http.ResponseWriter, r *http.Request)
287294
clientID := query.Get("client_id")
288295

289296
h.logger.Info("OAuth2: Authorization request - client_id: %s, redirect_uri: %s, code_challenge: %s",
290-
clientID, clientRedirectURI, truncateString(codeChallenge, 10))
297+
clientID, clientRedirectURI, codeChallenge)
291298

292299
// Determine redirect URI strategy based on configuration
293300
var redirectURI string

metadata.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ func (h *OAuth2Handler) HandleProtectedResourceMetadata(w http.ResponseWriter, r
138138
"resource_documentation": fmt.Sprintf("%s/docs", h.config.MCPURL),
139139
"resource_policy_uri": fmt.Sprintf("%s/policy", h.config.MCPURL),
140140
"resource_tos_uri": fmt.Sprintf("%s/tos", h.config.MCPURL),
141-
"scopes_supported": []string{"openid", "profile", "email"},
141+
"scopes_supported": h.config.Scopes,
142142
}
143143

144144
// Encode and send response
@@ -243,7 +243,7 @@ func (h *OAuth2Handler) HandleOIDCDiscovery(w http.ResponseWriter, r *http.Reque
243243
"token_endpoint_auth_methods_supported": []string{"none"},
244244
"code_challenge_methods_supported": []string{"plain", "S256"},
245245
"subject_types_supported": []string{"public"},
246-
"scopes_supported": []string{"openid", "profile", "email"},
246+
"scopes_supported": h.config.Scopes,
247247
}
248248

249249
// Add provider-specific fields
@@ -283,7 +283,7 @@ func (h *OAuth2Handler) GetAuthorizationServerMetadata() map[string]interface{}
283283
"grant_types_supported": []string{"authorization_code"},
284284
"token_endpoint_auth_methods_supported": []string{"none"},
285285
"code_challenge_methods_supported": []string{"plain", "S256"},
286-
"scopes_supported": []string{"openid", "profile", "email"},
286+
"scopes_supported": h.config.Scopes,
287287
}
288288

289289
// Add provider-specific endpoints
@@ -315,7 +315,7 @@ func (h *OAuth2Handler) GetAuthorizationServerMetadata() map[string]interface{}
315315
"grant_types_supported": []string{"authorization_code"},
316316
"token_endpoint_auth_methods_supported": []string{"none"},
317317
"code_challenge_methods_supported": []string{"plain", "S256"},
318-
"scopes_supported": []string{"openid", "profile", "email"},
318+
"scopes_supported": h.config.Scopes,
319319
}
320320
}
321321

provider/provider.go

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,13 @@ type Logger interface {
3030

3131
// Config holds OAuth configuration (subset needed by provider)
3232
type Config struct {
33-
Provider string
34-
Issuer string
35-
Audience string
36-
JWTSecret []byte
37-
Logger Logger
33+
Provider string
34+
Issuer string
35+
Audience string
36+
JWTSecret []byte
37+
Logger Logger
38+
SkipAudienceCheck bool
39+
ValidatorIssuer string
3840
}
3941

4042
// TokenValidator interface for OAuth token validation
@@ -52,10 +54,11 @@ type HMACValidator struct {
5254

5355
// OIDCValidator validates JWT tokens using OIDC/JWKS (Okta, Google, Azure)
5456
type OIDCValidator struct {
55-
verifier *oidc.IDTokenVerifier
56-
provider *oidc.Provider
57-
audience string
58-
logger Logger
57+
verifier *oidc.IDTokenVerifier
58+
provider *oidc.Provider
59+
audience string
60+
TokenValidators []func(claims jwt.MapClaims) error
61+
logger Logger
5962
}
6063

6164
// Initialize sets up the HMAC validator with JWT secret and audience
@@ -90,7 +93,6 @@ func (v *HMACValidator) ValidateToken(ctx context.Context, tokenString string) (
9093
}
9194
return []byte(v.secret), nil
9295
})
93-
9496
if err != nil {
9597
return nil, fmt.Errorf("failed to parse and validate token: %w", err)
9698
}
@@ -190,7 +192,9 @@ func (v *OIDCValidator) Initialize(cfg *Config) error {
190192
MaxIdleConnsPerHost: 10,
191193
},
192194
}
193-
195+
if cfg.ValidatorIssuer != "" {
196+
ctx = oidc.InsecureIssuerURLContext(ctx, cfg.ValidatorIssuer)
197+
}
194198
// Create OIDC provider with custom HTTP client
195199
provider, err := oidc.NewProvider(
196200
oidc.ClientContext(ctx, httpClient),
@@ -204,15 +208,17 @@ func (v *OIDCValidator) Initialize(cfg *Config) error {
204208
verifier := provider.Verifier(&oidc.Config{
205209
ClientID: cfg.Audience, // Note: go-oidc uses ClientID field for audience validation - see https://github.com/coreos/go-oidc/blob/v3/oidc/verify.go#L85
206210
SupportedSigningAlgs: []string{oidc.RS256, oidc.ES256},
207-
SkipClientIDCheck: false, // Always validate if ClientID is provided
208-
SkipExpiryCheck: false, // Verify expiration
209-
SkipIssuerCheck: false, // Verify issuer
211+
SkipClientIDCheck: cfg.SkipAudienceCheck,
212+
SkipExpiryCheck: false,
213+
SkipIssuerCheck: false,
210214
})
211215

212-
v.logger.Info("OAuth: OIDC validator initialized with audience validation: %s", cfg.Audience)
213-
214216
v.provider = provider
215217
v.verifier = verifier
218+
if !cfg.SkipAudienceCheck {
219+
v.logger.Info("OAuth: OIDC validator initialized with audience validation: %s", cfg.Audience)
220+
v.TokenValidators = append(v.TokenValidators, v.validateAudience)
221+
}
216222
return nil
217223
}
218224

@@ -256,9 +262,12 @@ func (v *OIDCValidator) ValidateToken(ctx context.Context, tokenString string) (
256262
return nil, fmt.Errorf("failed to extract raw claims: %w", err)
257263
}
258264

259-
// Validate audience claim for security (explicit check)
260-
if err := v.validateAudience(rawClaims); err != nil {
261-
return nil, fmt.Errorf("audience validation failed: %w", err)
265+
// Run extra validation functions
266+
for i, fn := range v.TokenValidators {
267+
err := fn(rawClaims)
268+
if err != nil {
269+
return nil, fmt.Errorf("validation function %d failed with error: %w", i, err)
270+
}
262271
}
263272

264273
return &User{

0 commit comments

Comments
 (0)