Skip to content

Commit f0da1eb

Browse files
committed
Refactor code
Signed-off-by: dusan <[email protected]>
1 parent a795887 commit f0da1eb

File tree

10 files changed

+662
-397
lines changed

10 files changed

+662
-397
lines changed

users/api/oauth.go

Lines changed: 17 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,16 @@
44
package api
55

66
import (
7-
"context"
87
"encoding/json"
98
"fmt"
109
"net/http"
1110
"strings"
1211

1312
grpcTokenV1 "github.com/absmach/supermq/api/grpc/token/v1"
14-
smqauth "github.com/absmach/supermq/auth"
1513
"github.com/absmach/supermq/pkg/oauth2"
1614
"github.com/absmach/supermq/users"
15+
useroauth "github.com/absmach/supermq/users/oauth"
1716
"github.com/go-chi/chi/v5"
18-
goauth2 "golang.org/x/oauth2"
1917
)
2018

2119
var (
@@ -43,18 +41,18 @@ type authURLResponse struct {
4341
// - GET /oauth/authorize/{provider} - Returns the authorization URL
4442
// - GET /oauth/callback/{provider} - Handles OAuth2 callback and sets cookies
4543
// - POST /oauth/cli/callback/{provider} - Handles CLI OAuth2 callback and returns JSON.
46-
func oauthHandler(r *chi.Mux, svc users.Service, tokenClient grpcTokenV1.TokenServiceClient, deviceStore DeviceCodeStore, providers ...oauth2.Provider) *chi.Mux {
44+
func oauthHandler(r *chi.Mux, svc users.Service, tokenClient grpcTokenV1.TokenServiceClient, oauthSvc useroauth.Service, providers ...oauth2.Provider) *chi.Mux {
4745
for _, provider := range providers {
48-
r.HandleFunc("/oauth/callback/"+provider.Name(), oauth2CallbackHandler(provider, svc, tokenClient, deviceStore))
46+
r.HandleFunc("/oauth/callback/"+provider.Name(), oauth2CallbackHandler(provider, oauthSvc))
4947
r.Get("/oauth/authorize/"+provider.Name(), oauth2AuthorizeHandler(provider))
50-
r.Post("/oauth/cli/callback/"+provider.Name(), oauth2CLICallbackHandler(provider, svc, tokenClient))
48+
r.Post("/oauth/cli/callback/"+provider.Name(), oauth2CLICallbackHandler(provider, oauthSvc))
5149
}
5250

5351
return r
5452
}
5553

5654
// oauth2CallbackHandler is a http.HandlerFunc that handles OAuth2 callbacks.
57-
func oauth2CallbackHandler(oauth oauth2.Provider, svc users.Service, tokenClient grpcTokenV1.TokenServiceClient, deviceStore DeviceCodeStore) http.HandlerFunc {
55+
func oauth2CallbackHandler(oauth oauth2.Provider, oauthSvc useroauth.Service) http.HandlerFunc {
5856
return func(w http.ResponseWriter, r *http.Request) {
5957
if !oauth.IsEnabled() {
6058
redirectWithError(w, r, oauth.ErrorURL(), "oauth provider is disabled")
@@ -64,8 +62,8 @@ func oauth2CallbackHandler(oauth oauth2.Provider, svc users.Service, tokenClient
6462
state := r.FormValue("state")
6563

6664
// Check if this is a device flow callback (state contains device: prefix)
67-
if strings.HasPrefix(state, "device:") {
68-
handleDeviceFlowCallback(w, r, oauth, svc, tokenClient, deviceStore)
65+
if useroauth.IsDeviceFlowState(state) {
66+
handleDeviceFlowCallback(w, r, oauth, oauthSvc)
6967
return
7068
}
7169

@@ -80,13 +78,7 @@ func oauth2CallbackHandler(oauth oauth2.Provider, svc users.Service, tokenClient
8078
return
8179
}
8280

83-
token, err := oauth.Exchange(r.Context(), code)
84-
if err != nil {
85-
redirectWithError(w, r, oauth.ErrorURL(), err.Error())
86-
return
87-
}
88-
89-
jwt, err := processOAuthUser(r.Context(), oauth, token.AccessToken, svc, tokenClient)
81+
jwt, err := oauthSvc.ProcessWebCallback(r.Context(), oauth, code, "")
9082
if err != nil {
9183
redirectWithError(w, r, oauth.ErrorURL(), err.Error())
9284
return
@@ -125,7 +117,7 @@ func oauth2AuthorizeHandler(oauth oauth2.Provider) http.HandlerFunc {
125117
}
126118

127119
// oauth2CLICallbackHandler handles OAuth2 callbacks for CLI and returns JSON tokens.
128-
func oauth2CLICallbackHandler(oauth oauth2.Provider, svc users.Service, tokenClient grpcTokenV1.TokenServiceClient) http.HandlerFunc {
120+
func oauth2CLICallbackHandler(oauth oauth2.Provider, oauthSvc useroauth.Service) http.HandlerFunc {
129121
return func(w http.ResponseWriter, r *http.Request) {
130122
w.Header().Set("Content-Type", "application/json")
131123

@@ -155,13 +147,7 @@ func oauth2CLICallbackHandler(oauth oauth2.Provider, svc users.Service, tokenCli
155147
return
156148
}
157149

158-
token, err := exchangeCode(r.Context(), oauth, req.Code, req.RedirectURL)
159-
if err != nil {
160-
respondWithJSON(w, http.StatusUnauthorized, newErrorResponse(err.Error()))
161-
return
162-
}
163-
164-
jwt, err := processOAuthUser(r.Context(), oauth, token.AccessToken, svc, tokenClient)
150+
jwt, err := oauthSvc.ProcessWebCallback(r.Context(), oauth, req.Code, req.RedirectURL)
165151
if err != nil {
166152
status := http.StatusInternalServerError
167153
if err.Error() == "unauthorized" {
@@ -176,45 +162,6 @@ func oauth2CLICallbackHandler(oauth oauth2.Provider, svc users.Service, tokenCli
176162
}
177163
}
178164

179-
// exchangeCode exchanges an authorization code for an access token.
180-
// If redirectURL is provided, it uses ExchangeWithRedirect, otherwise uses Exchange.
181-
func exchangeCode(ctx context.Context, provider oauth2.Provider, code, redirectURL string) (goauth2.Token, error) {
182-
if redirectURL != "" {
183-
return provider.ExchangeWithRedirect(ctx, code, redirectURL)
184-
}
185-
return provider.Exchange(ctx, code)
186-
}
187-
188-
// processOAuthUser retrieves user info from the OAuth provider, creates or updates the user,
189-
// adds user policies, and issues a JWT token.
190-
func processOAuthUser(ctx context.Context, provider oauth2.Provider, accessToken string, svc users.Service, tokenClient grpcTokenV1.TokenServiceClient) (*grpcTokenV1.Token, error) {
191-
user, err := provider.UserInfo(accessToken)
192-
if err != nil {
193-
return nil, err
194-
}
195-
196-
user.AuthProvider = provider.Name()
197-
if user.AuthProvider == "" {
198-
user.AuthProvider = "oauth"
199-
}
200-
201-
user, err = svc.OAuthCallback(ctx, user)
202-
if err != nil {
203-
return nil, err
204-
}
205-
206-
if err := svc.OAuthAddUserPolicy(ctx, user); err != nil {
207-
return nil, err
208-
}
209-
210-
return tokenClient.Issue(ctx, &grpcTokenV1.IssueReq{
211-
UserId: user.ID,
212-
Type: uint32(smqauth.AccessKey),
213-
UserRole: uint32(smqauth.UserRole),
214-
Verified: !user.VerifiedAt.IsZero(),
215-
})
216-
}
217-
218165
// respondWithJSON writes a JSON response with the given status code and data.
219166
func respondWithJSON(w http.ResponseWriter, status int, data any) {
220167
w.WriteHeader(status)
@@ -248,13 +195,13 @@ func setTokenCookies(w http.ResponseWriter, jwt *grpcTokenV1.Token) {
248195
}
249196

250197
// handleDeviceFlowCallback processes OAuth callback for device authorization flow.
251-
func handleDeviceFlowCallback(w http.ResponseWriter, r *http.Request, oauth oauth2.Provider, svc users.Service, tokenClient grpcTokenV1.TokenServiceClient, deviceStore DeviceCodeStore) {
198+
func handleDeviceFlowCallback(w http.ResponseWriter, r *http.Request, oauth oauth2.Provider, oauthSvc useroauth.Service) {
252199
// Extract user code from state (format: "device:ABCD-EFGH")
253200
state := r.FormValue("state")
254-
userCode := strings.TrimPrefix(state, "device:")
201+
userCode := useroauth.ExtractUserCodeFromState(state)
255202

256-
// Get device code by user code
257-
deviceCode, err := deviceStore.GetByUserCode(userCode)
203+
// Get device code by user code to validate it exists
204+
_, err := oauthSvc.GetDeviceCodeByUserCode(r.Context(), userCode)
258205
if err != nil {
259206
w.Header().Set("Content-Type", "text/html")
260207
fmt.Fprint(w, strings.Replace(errorHTML, "{{ERROR_MESSAGE}}", "The device code is invalid or has expired.", 1))
@@ -269,20 +216,10 @@ func handleDeviceFlowCallback(w http.ResponseWriter, r *http.Request, oauth oaut
269216
return
270217
}
271218

272-
// Exchange OAuth code for token
273-
token, err := oauth.Exchange(r.Context(), code)
274-
if err != nil {
275-
w.Header().Set("Content-Type", "text/html")
276-
fmt.Fprint(w, strings.Replace(errorHTML, "{{ERROR_MESSAGE}}", fmt.Sprintf("Failed to exchange code: %s.", err.Error()), 1))
277-
return
278-
}
279-
280-
// Mark device code as approved with access token
281-
deviceCode.Approved = true
282-
deviceCode.AccessToken = token.AccessToken
283-
if err := deviceStore.Update(deviceCode); err != nil {
219+
// Process the device callback
220+
if err := oauthSvc.ProcessDeviceCallback(r.Context(), oauth, userCode, code); err != nil {
284221
w.Header().Set("Content-Type", "text/html")
285-
fmt.Fprint(w, strings.Replace(errorHTML, "{{ERROR_MESSAGE}}", fmt.Sprintf("Failed to approve device: %s.", err.Error()), 1))
222+
fmt.Fprint(w, strings.Replace(errorHTML, "{{ERROR_MESSAGE}}", fmt.Sprintf("Failed to process callback: %s.", err.Error()), 1))
286223
return
287224
}
288225

0 commit comments

Comments
 (0)